Compare commits

...

103 Commits

Author SHA1 Message Date
Wesley Liddick
0236b97d49 Merge pull request #1134 from yasu-dev221/fix/openai-compat-prompt-cache-key
fix(openai): add fallback prompt_cache_key for compat codex OAuth requests
2026-03-19 22:02:08 +08:00
Wesley Liddick
26f6b1eeff Merge pull request #1142 from StarryKira/fix/failover-exhausted-upstream-status-code
fix: record original upstream status code when failover exhausted (#1128)
2026-03-19 21:56:58 +08:00
Wesley Liddick
dc447ccebe Merge pull request #1153 from hging/main
feat: add ungrouped filter to account
2026-03-19 21:55:28 +08:00
Wesley Liddick
7ec29638f4 Merge pull request #1147 from DaydreamCoding/feat/persisted-page-size
feat(frontend): 分页 pageSize 持久化到 localStorage,刷新后自动恢复
2026-03-19 21:53:54 +08:00
Wesley Liddick
4c9562af20 Merge pull request #1148 from weak-fox/ci/sync-version-file-after-release
ci: sync VERSION file back to default branch after release
2026-03-19 21:46:45 +08:00
Wesley Liddick
71942fd322 Merge pull request #1132 from touwaeriol/pr/virtual-scroll
perf(frontend): add virtual scrolling to DataTable
2026-03-19 21:46:16 +08:00
Wesley Liddick
550b979ac5 Merge pull request #1146 from DaydreamCoding/fix/test-403-error-status
fix(test): 测试连接收到 403 时将账号标记为 error 状态
2026-03-19 21:44:57 +08:00
Wesley Liddick
3878a5a46f Merge pull request #1164 from GuangYiDing/fix/normalize-tool-parameters-schema
fix: Anthropic tool schema 转换时补充缺失的 properties 字段
2026-03-19 21:44:18 +08:00
Rose Ding
e443a6a1ea fix: 移除 staticcheck S1005 警告的多余 blank identifier
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-19 21:14:29 +08:00
Rose Ding
963494ec6f fix: Anthropic tool schema 转 Responses API 时补充缺失的 properties 字段
当 Claude Code 发来的 MCP tool 的 input_schema 为 {"type":"object"} 且缺少
properties 字段时,OpenAI Codex 后端会拒绝并报错:
Invalid schema for function '...': object schema missing properties.

新增 normalizeToolParameters 函数,在 convertAnthropicToolsToResponses 中
对每个 tool 的 InputSchema 做规范化处理后再赋给 Parameters。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-19 21:08:20 +08:00
shaw
525cdb8830 feat: Anthropic 账号被动用量采样,页面默认展示被动数据
从上游 /v1/messages 响应头被动采集 5h/7d utilization 并存储到
Account.Extra,页面加载时直接读取本地数据而非调用外部 Usage API。
用户可点击"查询"按钮主动拉取最新数据,主动查询结果自动回写被动缓存。

后端:
- UpdateSessionWindow 合并采集 5h + 7d headers 为单次 DB 写入
- 新增 GetPassiveUsage 从 Extra 构建 UsageInfo (复用 estimateSetupTokenUsage)
- GetUsage 主动查询后 syncActiveToPassive 回写被动缓存
- passive_usage_ 前缀注册为 scheduler-neutral

前端:
- Anthropic 账号 mount/refresh 默认 source=passive
- 新增"被动采样"标签和"查询"按钮 (带 loading 动画)
2026-03-19 17:42:59 +08:00
shaw
a6764e82f2 修复 OAuth/SetupToken 转发请求体重排并增加调试开关 2026-03-19 16:56:18 +08:00
Hg
8027531d07 feat: add ungrouped filter to account 2026-03-19 15:42:21 +08:00
weak-fox
30706355a4 ci: sync VERSION file back to default branch after release 2026-03-19 12:53:28 +08:00
QTom
dfe99507b8 feat(frontend): 分页 pageSize 持久化到 localStorage,刷新后自动恢复
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-19 12:44:03 +08:00
QTom
c1717c9a6c fix(test): 测试连接收到 403 时将账号标记为 error 状态
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-19 12:36:40 +08:00
haruka
1fd1a58a7a fix: record original upstream status code when failover exhausted (#1128)
When all failover accounts are exhausted, handleFailoverExhausted maps
the upstream status code (e.g. 403) to a client-facing code (e.g. 502)
but did not write the original code to the gin context. This caused ops
error logs to show the mapped code instead of the real upstream code.

Call SetOpsUpstreamError before mapUpstreamError in all failover-
exhausted paths so that ops_error_logger captures the true upstream
status code and message.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-19 11:15:02 +08:00
jimmy-coder
fad07507be fix(openai): inject stable compat prompt_cache_key for codex oauth chat-completions path 2026-03-19 03:24:31 +08:00
erio
a20c211162 perf(frontend): add virtual scrolling to DataTable
Replace direct row rendering with @tanstack/vue-virtual. The table
now only renders visible rows (~20) via padding <tr> placeholders,
eliminating the rendering bottleneck when displaying 100+ rows with
heavy cell components.

Key changes:
- DataTable.vue: integrate useVirtualizer (always-on), virtual row
  template with measureElement for variable row heights, defineExpose
  virtualizer/sortedData for external access, overflow-y/flex CSS
- useSwipeSelect.ts: dual-mode support via optional
  SwipeSelectVirtualContext — data-driven row index lookup and
  selection range when virtualizer is present, original DOM-based
  path preserved for callers that don't pass virtualContext
2026-03-18 23:11:49 +08:00
Wesley Liddick
9f6ab6b817 Merge pull request #1090 from laukkw/main
fix(setup): align install validation and expose backend errors
2026-03-18 16:23:06 +08:00
shaw
bf3d6c0e6e feat: add 529 overload cooldown toggle and duration settings in admin gateway page
Move 529 overload cooldown configuration from config file to admin
settings UI. Adds an enable/disable toggle and configurable cooldown
duration (1-120 min) under /admin/settings gateway tab, stored as
JSON in the settings table.

When disabled, 529 errors are logged but accounts are no longer
paused from scheduling. Falls back to config file value when DB
is unreachable or settingService is nil.
2026-03-18 16:22:19 +08:00
Wesley Liddick
241023f3fc Merge pull request #1097 from Ethan0x0000/pr/upstream-model-tracking
feat(usage): 新增 upstream_model 追踪,支持按模型来源统计与展示
2026-03-18 15:36:00 +08:00
Wesley Liddick
1292c44b41 Merge pull request #1118 from touwaeriol/worktree-fix/anti_mapping
feat: map claude-haiku-4-5 variants to claude-sonnet-4-6
2026-03-18 15:13:19 +08:00
Wesley Liddick
b4fce47049 Merge pull request #1116 from wucm667/fix/inject-site-title-in-html
fix: 直接访问或刷新页面时浏览器标签页显示自定义站点名称
2026-03-18 15:12:07 +08:00
Wesley Liddick
e7780cd8c8 Merge pull request #1117 from alfadb/fix/empty-text-block-retry
fix: 修复空 text block 导致上游 400 错误未被重试捕获的问题
2026-03-18 15:10:46 +08:00
erio
af96c8ea53 feat: map claude-haiku-4-5 variants to claude-sonnet-4-6
Update model mapping target for claude-haiku-4-5 and
claude-haiku-4-5-20251001 from claude-sonnet-4-5 to claude-sonnet-4-6.
Includes migration script, default constants, and test updates.
2026-03-18 15:03:24 +08:00
alfadb
7d26b81075 fix: address review - add missing whitespace patterns and narrow error matching 2026-03-18 14:31:57 +08:00
alfadb
b8ada63ac3 fix: strip empty text blocks in retry filter and fix error pattern matching
Empty text blocks ({"type":"text","text":""}) cause Anthropic upstream to
return 400: "text content blocks must be non-empty". This was not caught
by the existing error detection pattern in isThinkingBlockSignatureError,
nor handled by FilterThinkingBlocksForRetry.

- Add empty text block stripping to FilterThinkingBlocksForRetry
- Fix isThinkingBlockSignatureError to match new Anthropic error format
- Add fast-path byte patterns to avoid unnecessary JSON parsing

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 14:20:00 +08:00
Ethan0x0000
cfaac12af1 Merge upstream/main into pr/upstream-model-tracking
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-03-18 14:16:50 +08:00
wucm667
6028efd26c test: 添加 injectSiteTitle 函数的单元测试 2026-03-18 14:13:52 +08:00
shaw
62a566ef2c fix: 修复 config.yaml 以只读方式挂载时容器启动失败 (#1113)
entrypoint 中 chown -R /app/data 在遇到 :ro 挂载的文件时报错退出,
添加错误容忍处理;同时去掉 compose 文件注释中的 :ro 建议。
2026-03-18 14:11:51 +08:00
wucm667
94419f434c fix: 直接访问或刷新页面时浏览器标签页显示自定义站点名称
后端 HTML 注入时同步替换 <title> 标签为自定义站点名称,
前端 fetchPublicSettings 完成后重新设置 document.title,
解决路由守卫先于设置加载导致标题回退为默认值的时序问题。
2026-03-18 14:02:00 +08:00
Wesley Liddick
21f349c032 Merge pull request #1095 from LvyuanW/lvyuan/dev
fix(admin/accounts): reset edit modal state on reopen
2026-03-18 11:37:07 +08:00
Wesley Liddick
28e36f7925 Merge pull request #1096 from Ethan0x0000/pr/fix-idle-usage-windows
fix(ui): 会话窗口空闲时显示“现在”,避免重置时间缺失
2026-03-18 11:32:50 +08:00
Wesley Liddick
6c02076333 Merge pull request #1106 from geminiwen/feat/subscription-platform-filter
feat: add platform type filter to subscription management
2026-03-18 11:32:35 +08:00
shaw
7414bdf0e3 fix: 修复 hotpath 测试中 metadata.user_id 格式不合法导致 CI 失败
测试数据使用的 session ID "abc-123" 不符合 ParseMetadataUserID
要求的 36 字符 UUID 格式,替换为合法 UUID。
2026-03-18 11:31:32 +08:00
Wesley Liddick
e6326b2929 Merge pull request #1108 from DaydreamCoding/feat/admin-group-capacity-and-usage
feat(admin): 分组管理列表新增用量、账号分类与容量列
2026-03-18 11:12:43 +08:00
Wesley Liddick
17cdcebd04 Merge pull request #1109 from GuangYiDing/feat/subscription-guide
feat(subscriptions): 订阅管理页面添加教程指南弹窗
2026-03-18 11:12:33 +08:00
shaw
a14babdc73 fix: 兼容 Claude Code v2.1.78+ 新 JSON 格式 metadata.user_id
Claude Code v2.1.78 起将 metadata.user_id 从拼接字符串改为 JSON:
旧: user_{hex}_account_{uuid}_session_{uuid}
新: {"device_id":"...","account_uuid":"...","session_id":"..."}

新增集中解析/格式化模块 metadata_userid.go:
- ParseMetadataUserID: 自动识别两种格式,提取 DeviceID/AccountUUID/SessionID
- FormatMetadataUserID: 根据 UA 版本输出对应格式(>= 2.1.78 输出 JSON)
- ExtractCLIVersion: 从 UA 提取版本号,消除与 ClaudeCodeValidator.ExtractVersion 的重复

修改消费者统一使用新模块:
- claude_code_validator: 用 ParseMetadataUserID 替代只匹配旧格式的 userIDPattern
- identity_service: RewriteUserID/WithMasking 增加 fingerprintUA 参数,
  解析用 ParseMetadataUserID,输出用 FormatMetadataUserID(版本感知)
- gateway_service: GenerateSessionHash 用 ParseMetadataUserID 提取 session_id,
  buildOAuthMetadataUserID 用 FormatMetadataUserID 输出版本匹配格式,
  两处 RewriteUserIDWithMasking 调用传入 fp.UserAgent
- account_test_service: generateSessionString 改用 FormatMetadataUserID,
  自动跟随 DefaultHeaders UA 版本

删除三个旧正则: userIDPattern, userIDRegex, sessionIDRegex
统一 hex 匹配为 [a-fA-F0-9],修复旧 userIDRegex 只匹配小写的不一致
2026-03-18 11:08:58 +08:00
Rose Ding
aadc6a763a feat(subscriptions): 订阅管理页面添加教程指南弹窗
在订阅管理页面工具栏添加教程指南按钮(? 图标),点击弹出模态框,
引导管理员完成订阅功能的完整使用流程:

- 步骤一:创建订阅分组(含跳转分组管理链接)
- 步骤二:分配订阅给用户(搜索用户、选择分组、设置有效期)
- 步骤三:管理已有订阅(调整/重置配额/撤销操作说明表格)
- 底部提示:说明下拉列表为空时的解决方案

弹窗样式参照 BackupView 的 R2 Guide 模态框实现,保持 UI 一致性。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 10:49:41 +08:00
Rose Ding
f16af8bf88 feat(i18n): 添加订阅管理教程指南英文翻译
在 en.ts 中为订阅管理页面新增 guide 相关翻译词条,
与中文翻译保持结构一致,支持中英文切换。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 10:49:32 +08:00
Rose Ding
5ceaef4500 feat(i18n): 添加订阅管理教程指南中文翻译
在 zh.ts 中为订阅管理页面新增 guide 相关翻译词条,包括:
- 教程弹窗标题与副标题
- 三步操作引导文案(创建分组、分配订阅、管理订阅)
- 操作说明表格(调整/重置配额/撤销)
- 底部提示信息

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 10:49:13 +08:00
Gemini Wen
1ac7219a92 fix: add missing platform parameter to List calls in integration tests
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 10:35:03 +08:00
QTom
d4cc9871c4 feat(admin): 分组管理新增容量列(并发/会话/RPM 实时聚合)
复用 GroupCapacityService,在 admin 分组列表中添加容量列,
显示每个分组的实时并发/会话/RPM 使用量和上限。

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 10:06:35 +08:00
QTom
961c30e7c0 feat(admin): 分组管理列表新增用量列与账号数分类
分组管理列表增强:

1. 今日/累计用量列:
   - 新增独立端点 GET /admin/groups/usage-summary
   - 一次查询返回所有分组的今日费用和累计费用(actual_cost)
   - 前端异步加载后合并显示在分组列表中

2. 账号数区分可用/限流/总量:
   - 将账号数列从单一总量改为 badge 内多行展示
   - 可用: active + schedulable 的账号数(绿色)
   - 限流: rate_limit/overload/temp_unschedulable 的账号数(橙色,无限流时隐藏)
   - 总量: 全部关联账号数

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 10:06:35 +08:00
Gemini Wen
13e85b3147 fix: update remaining test stubs for List interface signature
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 09:35:08 +08:00
Gemini Wen
50a3c7fa0b feat: add platform type filter to subscription management page
Add a platform filter dropdown to the admin subscriptions view, allowing
filtering subscriptions by platform (Anthropic, OpenAI, Gemini, etc.)
through the group association.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 09:23:19 +08:00
Ethan0x0000
bd9d2671d7 chore(deps): go mod tidy to remove stale indirect dependencies 2026-03-17 20:46:12 +08:00
Ethan0x0000
62b40636e0 feat(frontend): display upstream model in usage table and distribution charts
Show upstream model mapping (requested -> upstream) in UsageTable with arrow notation. Add requested/upstream/mapping source toggle to ModelDistributionChart with lazy loading — only fetches data when user switches tab, with per-source cache invalidation on filter changes. Include upstream_model column in Excel export and i18n for en/zh.
2026-03-17 19:26:48 +08:00
Ethan0x0000
eeff451bc5 test(backend): add tests for upstream model tracking and model source filtering
Cover IsValidModelSource/NormalizeModelSource, resolveModelDimensionExpression SQL expressions, invalid model_source 400 responses on both GetModelStats and GetUserBreakdown, upstream_model in scan/insert SQL mock expectations, and updated passthrough/billing test signatures.
2026-03-17 19:26:30 +08:00
Ethan0x0000
56fcb20f94 feat(api): expose model_source filter in dashboard endpoints
Add model_source query parameter to GetModelStats and GetUserBreakdown handlers with explicit IsValidModelSource validation. Include model_source in cache key to prevent cross-source cache hits. Expose upstream_model in usage log DTO with omitempty semantics.
2026-03-17 19:26:11 +08:00
Ethan0x0000
7134266acf feat(dashboard): add model source dimension to stats queries
Support querying model statistics by 'requested', 'upstream', or 'mapping' dimension. Add resolveModelDimensionExpression for safe SQL expression generation, IsValidModelSource whitelist validator, and NormalizeModelSource fallback. Repository persists and scans upstream_model in all insert/select paths.
2026-03-17 19:25:52 +08:00
Ethan0x0000
2e4ac88ad9 feat(service): record upstream model across all gateway paths
Propagate UpstreamModel through ForwardResult and OpenAIForwardResult in Anthropic direct, API-key passthrough, Bedrock, and OpenAI gateway flows. Extract optionalNonEqualStringPtr and optionalTrimmedStringPtr into usage_log_helpers.go. Store upstream_model only when it differs from the requested model.

Also introduces anthropicPassthroughForwardInput struct to reduce parameter count.
2026-03-17 19:25:35 +08:00
Ethan0x0000
51547fa216 feat(db): add upstream_model column to usage_logs
Add nullable VARCHAR(100) column to record the actual model sent to upstream providers when model mapping is applied. NULL means no mapping — the requested model was used as-is.

Includes migration, concurrent index for aggregation queries, Ent schema regeneration, and migration README correction (forward-only runner, not goose).
2026-03-17 19:25:17 +08:00
Ethan0x0000
2005fc97a8 fix(ui): show 'now' for idle OpenAI usage windows
Use utilization-based idle detection instead of local request counts so newly imported OAuth accounts keep countdowns when usage is non-zero.
2026-03-17 19:23:35 +08:00
Wang Lvyuan
0772d9250e fix(admin/accounts): reset edit modal state on reopen 2026-03-17 18:44:10 +08:00
laukkw
aa6047c460 fix(setup): align install validation and expose backend errors
Make setup password requirements consistent with backend rules and show API-provided error messages so install failures are actionable. Trim admin email before validation to avoid false invalid-email rejections from surrounding whitespace.
2026-03-17 15:38:18 +08:00
Wesley Liddick
045cba78b4 Merge pull request #1083 from StarryKira/fix/claude-code-version-pattern-validation
fix(settings): remove pattern attribute blocking Claude Code version save fix issue #1081
2026-03-17 14:49:34 +08:00
Wesley Liddick
8989d0d4b6 Merge pull request #1085 from protondrift/main
feat: 个人资料弹窗 GitHub 链接仅对管理员可见
2026-03-17 14:49:07 +08:00
Wesley Liddick
c521117b99 Merge pull request #1074 from StarryKira/fix/session-window-reset-from-header
fix(usage): use real reset header for 5h session window countdown fix issue #1064 #1065
2026-03-17 14:48:16 +08:00
Eric
e0f52a8ab8 feat: 个人资料弹窗 GitHub 链接仅对管理员可见
目前作者已有商业站信息,面向管理可提供赞助渠道,面向普通用户请考虑提供信息隐藏措施
2026-03-17 12:51:34 +08:00
haruka
6c23fadf7e fix(settings): remove pattern attribute blocking Claude Code version save
The `pattern="\d+\.\d+\.\d+"` on the min_claude_code_version input caused
the browser's native HTML5 form validation to silently block form submission
when the value was invalid or when the hidden gateway tab was active. This
resulted in no network request being sent when clicking Save on any tab.

Backend already validates semver format and returns a proper 400 error,
so the frontend pattern attribute is redundant.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 11:33:57 +08:00
haruka
869952d113 fix(review): address Copilot PR feedback
- Add compile-time interface assertion for sessionWindowMockRepo
- Fix flaky fallback test by capturing time.Now() before calling UpdateSessionWindow
- Replace stale hardcoded timestamps with dynamic future values
- Add millisecond detection and bounds validation for reset header timestamp
- Use pause/resume pattern for interval in UsageProgressBar to avoid idle timers on large lists
- Fix gofmt comment alignment

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 10:19:20 +08:00
Wesley Liddick
07ab051ee4 Merge pull request #1078 from luxiang0412/main
fix(proxy): encode special chars in proxy credentials
2026-03-17 09:33:02 +08:00
Wesley Liddick
f2d98fc0c7 Merge pull request #1077 from Clov614/main
fix(auto setup): 修复初始测试连接硬编码问题导致使用自定义数据库测试失败无法执行 auto setup流程
2026-03-17 09:32:52 +08:00
Wesley Liddick
2b41cec840 Merge pull request #1076 from touwaeriol/pr/antigravity-test-connection-unify
refactor(antigravity): unify TestConnection with dispatch retry loop
2026-03-17 09:26:06 +08:00
Wesley Liddick
6cf77040e7 Merge pull request #1075 from touwaeriol/feat/dashboard-user-breakdown
feat(dashboard): add per-user drill-down for distribution charts
2026-03-17 09:25:43 +08:00
Wesley Liddick
20b70bc5fd Merge pull request #1070 from StarryKira/fix/oauth-system-role-to-instructions
fix(oauth): extract system-role input items into instructions field fix issue #1066
2026-03-17 09:24:17 +08:00
Wesley Liddick
4905e7193a Merge pull request #1069 from Ethan0x0000/pr/codex-usage-single-source
fix(openai): use /usage as single source and zero expired codex windows
2026-03-17 09:09:16 +08:00
Wesley Liddick
9c1f4b8e72 Merge pull request #1068 from Ethan0x0000/pr/frontend-last24h
feat(frontend): set last 24h as default range in Usage and Dashboard
2026-03-17 09:06:52 +08:00
Wesley Liddick
9857c17631 Merge pull request #1067 from DaydreamCoding/feat/async-backup
feat(backup): 备份/恢复异步化,解决 504 超时
2026-03-17 08:59:36 +08:00
luxiang
7e34bb946f fix(proxy): encode special chars in proxy credentials 2026-03-17 08:40:08 +08:00
clover614
47b748851b fix(auto setup): 修复初始测试连接硬编码问题导致使用自定义数据库测试失败无法执行 auto setup流程 2026-03-17 06:34:20 +08:00
erio
a6f99cf534 refactor(antigravity): unify TestConnection with dispatch retry loop
TestConnection now reuses antigravityRetryLoop instead of a standalone
HTTP loop, gaining credits overages, smart retry, and 429/503 backoff
for free. AccountSwitchError is caught and surfaced as a friendly
message. Also populates RateLimitedModel in TempUnscheduled switch error.

Test fixes:
- Use RATE_LIMIT_EXCEEDED in 503 short-delay test to avoid 60x1s timeout
- Clamp waitDuration=0 instead of 999s to avoid 15s max-wait timeout
- Enhance mockSmartRetryUpstream with repeatLast and body caching
2026-03-17 01:47:08 +08:00
erio
a120a6bc32 fix(ui): remove redundant sub-table header in user breakdown
The expanded user breakdown rows already align with the parent table
columns (Requests, Token, Actual, Standard), so the repeated sub-header
wastes vertical space. Remove the <thead> from UserBreakdownSubTable.
2026-03-17 00:49:43 +08:00
erio
d557d1a190 fix(ui): restore original max-h-48 height for distribution tables 2026-03-17 00:47:45 +08:00
erio
e0286e5085 test(dashboard): add unit tests for user-breakdown API
Handler tests (9 cases): group_id/model/endpoint filters, default
endpoint_type, custom limit, limit clamping, response format,
empty result, no-filter pass-through.

Repository test: resolveEndpointColumn mapping for inbound/upstream/path.
2026-03-17 00:47:33 +08:00
erio
4b41e898a4 feat(dashboard): add per-user drill-down for group, model, and endpoint distributions
Click on a group name, model name, or endpoint name in the distribution
tables to expand and show per-user usage breakdown (requests, tokens,
actual cost, standard cost).

Backend: new GET /admin/dashboard/user-breakdown API with group_id,
model, endpoint, endpoint_type filters.
Frontend: clickable rows with expand/collapse sub-table in all three
distribution charts.
2026-03-17 00:47:20 +08:00
Elysia
668e164793 fix(usage): use real reset header for session window instead of prediction
The 5h window reset time displayed for Setup Token accounts was inaccurate
because UpdateSessionWindow predicted the window end as "current hour + 5h"
instead of reading the actual `anthropic-ratelimit-unified-5h-reset` response
header. This caused the countdown to differ from the official Claude page.

Backend: parse the reset header (Unix timestamp) and use it as the real
window end, falling back to the hour-truncated prediction only when the
header is absent. Also correct stale predictions when a subsequent request
provides the real reset time.

Frontend: add a reactive 60s timer so the reset countdown in
UsageProgressBar ticks down in real-time instead of freezing at the
initial value.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 00:13:45 +08:00
Elysia
fa2e6188d0 fix(oauth): extract system-role input items into instructions field
OAuth upstreams (ChatGPT) reject requests containing role:"system" in
the input array with HTTP 400 "System messages are not allowed". Extract
such items before forwarding and merge their content into the top-level
instructions field, prepending to any existing value.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-16 21:20:46 +08:00
Ethan0x0000
7fde9ebbc2 fix: zero expired codex windows in backend, use /usage API as single frontend data source 2026-03-16 21:14:52 +08:00
Ethan0x0000
aef7c3b9bb feat: set last 24 hours as default date range in DashboardView 2026-03-16 21:14:26 +08:00
Ethan0x0000
a0b76bd608 feat: implement last 24 hours date range preset and update filters in UsageView 2026-03-16 21:14:26 +08:00
QTom
c1fab7f8d8 feat(backup): 备份/恢复异步化,解决 504 超时
POST /backups 和 POST /backups/:id/restore 改为异步:立即返回 HTTP 202,
后台 goroutine 独立执行 pg_dump → gzip → S3 上传,前端每 2s 轮询状态。

后端:
- 新增 StartBackup/StartRestore 方法,后台 goroutine 不依赖 HTTP 连接
- Graceful shutdown 等待活跃操作完成,启动时清理孤立 running 记录
- BackupRecord 新增 progress/restore_status 字段支持进度和恢复状态追踪

前端:
- 创建备份/恢复后轮询 GET /backups/:id 直到完成或失败
- 标签页切换暂停/恢复轮询,组件卸载清理定时器
- 正确处理 409(备份进行中)和轮询超时

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-16 20:22:10 +08:00
Wesley Liddick
f42c8f2abe Merge pull request #1062 from kunish/fix/antigravity-stream-keepalive
fix(antigravity): add stream keepalive to prevent connection drops
2026-03-16 19:57:13 +08:00
shaw
aa5846b282 fix(docker): resolve /app/data permission denied on volume mounts
Docker named volumes and host bind-mounts may be owned by root,
causing "open data/model_pricing.sha256: permission denied" when
the container runs as the non-root sub2api user.

Add an entrypoint script that fixes /app/data ownership before
dropping to sub2api via su-exec. Replace USER directive with the
entrypoint approach across all three Dockerfiles and update both
GoReleaser configs to include the script in Docker build contexts.
2026-03-16 19:52:14 +08:00
Wesley Liddick
594a0ade38 Merge pull request #1063 from touwaeriol/fix/usage-label-semantic
fix(i18n): correct usage label from "Total" to "Last 30d"
2026-03-16 19:04:36 +08:00
erio
d45cc23171 fix(i18n): correct usage label from "Total" to "Last 30d"
The usage stats query defaults to a 30-day rolling window, but the
UI label said "Total"/"累计" implying lifetime aggregation. Rename
to "Last 30d"/"近30天" so the label matches the actual query semantics.

Closes #1060
2026-03-16 18:25:41 +08:00
kunish
d795734352 fix(antigravity): add stream keepalive to prevent connection drops
Antigravity streaming handlers were missing the keepalive mechanism
that exists in the standard gateway, causing proxy/CDN idle timeouts
to break connections during long thinking phases (e.g. claude-opus-4-6).
This resulted in truncated responses with missing tool calls.

Add StreamKeepaliveInterval support to all three Antigravity streaming
paths: Claude SSE, Gemini SSE, and upstream passthrough.
2026-03-16 17:37:15 +08:00
Wesley Liddick
4da9fdd1d5 Merge pull request #1058 from Ethan0x0000/main
fix(admin/accounts): make usage window refresh deterministic and restore missing stats
2026-03-16 17:06:13 +08:00
Wesley Liddick
6b218caa21 Merge pull request #1053 from touwaeriol/chore/antigravity-ua-1.20.5
chore(antigravity): bump default User-Agent version to 1.20.5
2026-03-16 16:57:22 +08:00
shaw
5c138007d0 chore: update docs 2026-03-16 16:56:42 +08:00
Ethan0x0000
1acfc46f46 fix: always show usage stats for OpenAI OAuth and hide zero-value badges
- Simplify OpenAI rendering: always fetch /usage, prefer fetched data over
  codex snapshot (snapshot serves as loading placeholder only)
- Remove dead code: preferFetchedOpenAIUsage, isOpenAICodexSnapshotStale,
  and unreachable template branch
- Add today-stats support for key accounts (req/tokens/A/U badges)
- Use formatCompactNumber for consistent number formatting
- Add A/U badge titles for clarity
- Filter zero-value window stats in UsageProgressBar to avoid empty badges
- Update tests to match new fetched-data-first behavior
2026-03-16 16:23:13 +08:00
Ethan0x0000
fbffb08aae feat: add today-stats and manual refresh token propagation to usage cells
- Pass todayStats/todayStatsLoading to AccountUsageCell for key accounts
- Propagate usageManualRefreshToken to force usage reload on explicit refresh
- Refresh today stats when toggling usage/today_stats columns visible
2026-03-16 16:23:00 +08:00
Ethan0x0000
8640a62319 refactor: extract formatCompactNumber util and add last_used_at to refresh key
- Add formatCompactNumber() for consistent large-number formatting (K/M/B)
- Include last_used_at in OpenAI usage refresh key for better change detection
- Add .gitattributes eol=lf rules for frontend source files
2026-03-16 16:22:51 +08:00
Ethan0x0000
fa782e70a4 fix: always attach OpenAI 5h/7d window stats regardless of zero values
Removes hasMeaningfulWindowStats guard so the /usage endpoint consistently
returns WindowStats for both time windows. The frontend now controls
zero-value display filtering at the component level.
2026-03-16 16:22:42 +08:00
Ethan0x0000
afd72abc6e fix: allow empty extra payload to clear account quota limits
UpdateAccount previously required len(input.Extra) > 0, causing explicit
empty payloads (extra:{}) to be silently skipped. Change condition to
input.Extra != nil so clearing quota keys actually persists.
2026-03-16 16:22:31 +08:00
erio
71f72e167e chore(antigravity): bump default User-Agent version to 1.20.5 2026-03-16 15:47:32 +08:00
Wesley Liddick
6595c7601e Merge pull request #1050 from touwaeriol/fix/rate-limit-redis-window-reset
fix(billing): add window expiration check to Redis rate limit Lua script
2026-03-16 14:17:41 +08:00
erio
67c0506290 fix(billing): add window expiration check to Redis rate limit Lua script
The updateRateLimitUsageScript Lua script previously performed
unconditional HINCRBYFLOAT on all usage counters without checking
whether the rate limit window had expired. This caused usage to
accumulate across window boundaries in Redis while the DB correctly
reset on expiration, leading to incorrect 429 rate limiting that
could persist for up to 24 hours.

The Lua script now checks each window timestamp before incrementing:
- If the window has expired, usage is reset to the current cost and
  the window timestamp is updated (matching DB-side semantics)
- If the window is still valid, usage is accumulated normally

This also resolves the async race condition where stale HINCRBYFLOAT
tasks from the worker queue could pollute a freshly rebuilt cache
after invalidation, since the script now self-corrects expired windows.

Closes #1049
2026-03-16 13:39:50 +08:00
Wesley Liddick
6447be4534 Merge pull request #1047 from DaydreamCoding/fix/codex-stream-isolation
fix(gateway): 防止 OpenAI Codex 跨用户串流 + WS 连接池条件式 MarkBroken
2026-03-16 11:00:07 +08:00
QTom
3741617ebd fix(gateway): WS 连接池条件式 MarkBroken 防止跨请求串流
正常终端事件(response.completed 等)退出后连接归还复用,
仅异常路径(读写错误、error 事件、客户端断连)MarkBroken 销毁。

Generate 模式:
- 引入 cleanExit 标记,仅在 isTerminalEvent break 时设置 true
- defer 中根据 cleanExit 决定是否 MarkBroken
- 所有异常路径已在各自分支中提前调用 MarkBroken

Ingress 模式:
- 引入 lastTurnClean 标记,sendAndRelay 正常完成时设为 true
- releaseSessionLease 根据 lastTurnClean 决定是否 MarkBroken
- 错误路径重置 lastTurnClean = false
- 客户端断连后 drain 仍保守 MarkBroken(L2916)
2026-03-16 10:50:02 +08:00
QTom
ab4e8b2cf0 fix(gateway): 防止 OpenAI Codex 跨用户串流
根因:多个用户共享同一 OAuth 账号时,conversation_id/session_id 头
未做用户隔离,导致上游 chatgpt.com 将不同用户的请求关联到同一会话。

HTTP SSE 修复:
- 新增 isolateOpenAISessionID(apiKeyID, raw),将 API Key ID 混入
  session 标识符(xxhash),确保不同 Key 的用户产生不同上游会话
- buildUpstreamRequest: OAuth 分支先 Del 客户端透传的 session 头,
  再用隔离值覆盖
- buildUpstreamRequestOpenAIPassthrough: 透传路径同样隔离
- ForwardAsAnthropic: Anthropic Messages 兼容路径同步修复
- buildOpenAIWSHeaders: WS 路径的 OAuth session 头同步隔离
2026-03-16 10:28:51 +08:00
197 changed files with 8613 additions and 1881 deletions

7
.gitattributes vendored
View File

@@ -4,6 +4,13 @@ backend/migrations/*.sql text eol=lf
# Go 源代码文件
*.go text eol=lf
# 前端 源代码文件
*.ts text eol=lf
*.tsx text eol=lf
*.js text eol=lf
*.jsx text eol=lf
*.vue text eol=lf
# Shell 脚本
*.sh text eol=lf

View File

@@ -271,3 +271,36 @@ jobs:
parse_mode: "Markdown",
disable_web_page_preview: true
}')"
sync-version-file:
needs: [release]
if: ${{ needs.release.result == 'success' }}
runs-on: ubuntu-latest
steps:
- name: Checkout default branch
uses: actions/checkout@v6
with:
ref: ${{ github.event.repository.default_branch }}
- name: Sync VERSION file to released tag
run: |
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
VERSION=${{ github.event.inputs.tag }}
VERSION=${VERSION#v}
else
VERSION=${GITHUB_REF#refs/tags/v}
fi
CURRENT_VERSION=$(tr -d '\r\n' < backend/cmd/server/VERSION || true)
if [ "$CURRENT_VERSION" = "$VERSION" ]; then
echo "VERSION file already matches $VERSION"
exit 0
fi
echo "$VERSION" > backend/cmd/server/VERSION
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git add backend/cmd/server/VERSION
git commit -m "chore: sync VERSION to ${VERSION} [skip ci]"
git push origin HEAD:${{ github.event.repository.default_branch }}

View File

@@ -47,6 +47,8 @@ dockers:
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest"
dockerfile: Dockerfile.goreleaser
use: buildx
extra_files:
- deploy/docker-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.version={{ .Version }}"

View File

@@ -63,6 +63,8 @@ dockers:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
dockerfile: Dockerfile.goreleaser
use: buildx
extra_files:
- deploy/docker-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.version={{ .Version }}"
@@ -76,6 +78,8 @@ dockers:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
dockerfile: Dockerfile.goreleaser
use: buildx
extra_files:
- deploy/docker-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.version={{ .Version }}"
@@ -89,6 +93,8 @@ dockers:
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
dockerfile: Dockerfile.goreleaser
use: buildx
extra_files:
- deploy/docker-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.version={{ .Version }}"
@@ -102,6 +108,8 @@ dockers:
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
dockerfile: Dockerfile.goreleaser
use: buildx
extra_files:
- deploy/docker-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.version={{ .Version }}"

View File

@@ -92,6 +92,7 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
RUN apk add --no-cache \
ca-certificates \
tzdata \
su-exec \
libpq \
zstd-libs \
lz4-libs \
@@ -120,8 +121,9 @@ COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/
# Create data directory
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
# Switch to non-root user
USER sub2api
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
RUN chmod +x /app/docker-entrypoint.sh
# Expose port (can be overridden by SERVER_PORT env var)
EXPOSE 8080
@@ -130,5 +132,6 @@ EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
# Run the application
ENTRYPOINT ["/app/sub2api"]
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
ENTRYPOINT ["/app/docker-entrypoint.sh"]
CMD ["/app/sub2api"]

View File

@@ -21,6 +21,7 @@ RUN apk add --no-cache \
ca-certificates \
tzdata \
curl \
su-exec \
libpq \
zstd-libs \
lz4-libs \
@@ -47,11 +48,15 @@ COPY sub2api /app/sub2api
# Create data directory
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
USER sub2api
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
RUN chmod +x /app/docker-entrypoint.sh
EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
ENTRYPOINT ["/app/sub2api"]
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
ENTRYPOINT ["/app/docker-entrypoint.sh"]
CMD ["/app/sub2api"]

View File

@@ -8,27 +8,31 @@
[![Redis](https://img.shields.io/badge/Redis-7+-DC382D.svg)](https://redis.io/)
[![Docker](https://img.shields.io/badge/Docker-Ready-2496ED.svg)](https://www.docker.com/)
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
**AI API Gateway Platform for Subscription Quota Distribution**
English | [中文](README_CN.md)
</div>
> **Sub2API officially uses only the domains `sub2api.org` and `pincc.ai`. Other websites using the Sub2API name may be third-party deployments or services and are not affiliated with this project. Please verify and exercise your own judgment.**
---
## Demo
Try Sub2API online: **https://demo.sub2api.org/**
Try Sub2API online: **[https://demo.sub2api.org/](https://demo.sub2api.org/)**
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
| Email | Password |
|-------|----------|
| admin@sub2api.com | admin123 |
| admin@sub2api.org | admin123 |
## Overview
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions. Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
## Features
@@ -41,6 +45,15 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
- **Admin Dashboard** - Web interface for monitoring and management
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
## Don't Want to Self-Host?
<table>
<tr>
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> is the official relay service built on Sub2API, offering stable access to Claude Code, Codex, Gemini and other popular models — ready to use, no deployment or maintenance required.</td>
</tr>
</table>
## Ecosystem
Community projects that extend or integrate with Sub2API:
@@ -61,10 +74,15 @@ Community projects that extend or integrate with Sub2API:
---
## Documentation
## Nginx Reverse Proxy Note
- Dependency Security: `docs/dependency-security.md`
- Admin Payment Integration API: `docs/ADMIN_PAYMENT_INTEGRATION_API.md`
When using Nginx as a reverse proxy for Sub2API (or CRS) with Codex CLI, add the following to the `http` block in your Nginx configuration:
```nginx
underscores_in_headers on;
```
Nginx drops headers containing underscores by default (e.g. `session_id`), which breaks sticky session routing in multi-account setups.
---

View File

@@ -8,27 +8,30 @@
[![Redis](https://img.shields.io/badge/Redis-7+-DC382D.svg)](https://redis.io/)
[![Docker](https://img.shields.io/badge/Docker-Ready-2496ED.svg)](https://www.docker.com/)
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
**AI API 网关平台 - 订阅配额分发管理**
[English](README.md) | 中文
</div>
> **Sub2API 官方仅使用 `sub2api.org` 与 `pincc.ai` 两个域名。其他使用 Sub2API 名义的网站可能为第三方部署或服务,与本项目无关,请自行甄别。**
---
## 在线体验
体验地址:**https://v2.pincc.ai/**
体验地址:**[https://demo.sub2api.org/](https://demo.sub2api.org/)**
演示账号(共享演示环境;自建部署不会自动创建该账号):
| 邮箱 | 密码 |
|------|------|
| admin@sub2api.com | admin123 |
| admin@sub2api.org | admin123 |
## 项目概述
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
## 核心功能
@@ -41,6 +44,15 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
- **管理后台** - Web 界面进行监控和管理
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
## 不想自建?试试官方中转
<table>
<tr>
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> 是基于 Sub2API 搭建的官方中转服务,提供 Claude Code、Codex、Gemini 等主流模型的稳定中转,开箱即用,免去自建部署与运维烦恼。</td>
</tr>
</table>
## 生态项目
围绕 Sub2API 的社区扩展与集成项目:
@@ -61,17 +73,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
---
## 文档
## Nginx 反向代理注意事项
- 依赖安全:`docs/dependency-security.md`
通过 Nginx 反向代理 Sub2API或 CRS 服务)并搭配 Codex CLI 使用时,需要在 Nginx 配置的 `http` 块中添加:
```nginx
underscores_in_headers on;
```
Nginx 默认会丢弃名称中含下划线的请求头(如 `session_id`),这会导致多账号环境下的粘性会话功能失效。
---
## OpenAI Responses 兼容注意事项
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id``tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
## 部署方式
### 方式一:脚本安装(推荐)

Binary file not shown.

After

Width:  |  Height:  |  Size: 171 KiB

View File

@@ -110,7 +110,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
@@ -143,6 +142,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
rpmCache := repository.NewRPMCache(redisClient)
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
dataManagementService := service.NewDataManagementService()

View File

@@ -716,6 +716,7 @@ var (
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "request_id", Type: field.TypeString, Size: 64},
{Name: "model", Type: field.TypeString, Size: 100},
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
@@ -755,31 +756,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[30]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[31]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[31]},
Columns: []*schema.Column{UsageLogsColumns[32]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[33]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -788,32 +789,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31]},
Columns: []*schema.Column{UsageLogsColumns[32]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[29]},
},
{
Name: "usagelog_account_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[30]},
},
{
Name: "usagelog_group_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[31]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[33]},
},
{
Name: "usagelog_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[28]},
},
{
Name: "usagelog_model",
@@ -828,17 +829,17 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[28]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[28]},
},
{
Name: "usagelog_group_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[28]},
},
},
}

View File

@@ -18239,6 +18239,7 @@ type UsageLogMutation struct {
id *int64
request_id *string
model *string
upstream_model *string
input_tokens *int
addinput_tokens *int
output_tokens *int
@@ -18576,6 +18577,55 @@ func (m *UsageLogMutation) ResetModel() {
m.model = nil
}
// SetUpstreamModel sets the "upstream_model" field.
func (m *UsageLogMutation) SetUpstreamModel(s string) {
m.upstream_model = &s
}
// UpstreamModel returns the value of the "upstream_model" field in the mutation.
func (m *UsageLogMutation) UpstreamModel() (r string, exists bool) {
v := m.upstream_model
if v == nil {
return
}
return *v, true
}
// OldUpstreamModel returns the old "upstream_model" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldUpstreamModel(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldUpstreamModel is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldUpstreamModel requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldUpstreamModel: %w", err)
}
return oldValue.UpstreamModel, nil
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (m *UsageLogMutation) ClearUpstreamModel() {
m.upstream_model = nil
m.clearedFields[usagelog.FieldUpstreamModel] = struct{}{}
}
// UpstreamModelCleared returns if the "upstream_model" field was cleared in this mutation.
func (m *UsageLogMutation) UpstreamModelCleared() bool {
_, ok := m.clearedFields[usagelog.FieldUpstreamModel]
return ok
}
// ResetUpstreamModel resets all changes to the "upstream_model" field.
func (m *UsageLogMutation) ResetUpstreamModel() {
m.upstream_model = nil
delete(m.clearedFields, usagelog.FieldUpstreamModel)
}
// SetGroupID sets the "group_id" field.
func (m *UsageLogMutation) SetGroupID(i int64) {
m.group = &i
@@ -20197,7 +20247,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 32)
fields := make([]string, 0, 33)
if m.user != nil {
fields = append(fields, usagelog.FieldUserID)
}
@@ -20213,6 +20263,9 @@ func (m *UsageLogMutation) Fields() []string {
if m.model != nil {
fields = append(fields, usagelog.FieldModel)
}
if m.upstream_model != nil {
fields = append(fields, usagelog.FieldUpstreamModel)
}
if m.group != nil {
fields = append(fields, usagelog.FieldGroupID)
}
@@ -20312,6 +20365,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.RequestID()
case usagelog.FieldModel:
return m.Model()
case usagelog.FieldUpstreamModel:
return m.UpstreamModel()
case usagelog.FieldGroupID:
return m.GroupID()
case usagelog.FieldSubscriptionID:
@@ -20385,6 +20440,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldRequestID(ctx)
case usagelog.FieldModel:
return m.OldModel(ctx)
case usagelog.FieldUpstreamModel:
return m.OldUpstreamModel(ctx)
case usagelog.FieldGroupID:
return m.OldGroupID(ctx)
case usagelog.FieldSubscriptionID:
@@ -20483,6 +20540,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetModel(v)
return nil
case usagelog.FieldUpstreamModel:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetUpstreamModel(v)
return nil
case usagelog.FieldGroupID:
v, ok := value.(int64)
if !ok {
@@ -20921,6 +20985,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
// mutation.
func (m *UsageLogMutation) ClearedFields() []string {
var fields []string
if m.FieldCleared(usagelog.FieldUpstreamModel) {
fields = append(fields, usagelog.FieldUpstreamModel)
}
if m.FieldCleared(usagelog.FieldGroupID) {
fields = append(fields, usagelog.FieldGroupID)
}
@@ -20962,6 +21029,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool {
// error if the field is not defined in the schema.
func (m *UsageLogMutation) ClearField(name string) error {
switch name {
case usagelog.FieldUpstreamModel:
m.ClearUpstreamModel()
return nil
case usagelog.FieldGroupID:
m.ClearGroupID()
return nil
@@ -21012,6 +21082,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldModel:
m.ResetModel()
return nil
case usagelog.FieldUpstreamModel:
m.ResetUpstreamModel()
return nil
case usagelog.FieldGroupID:
m.ResetGroupID()
return nil

View File

@@ -821,92 +821,96 @@ func init() {
return nil
}
}()
// usagelogDescUpstreamModel is the schema descriptor for upstream_model field.
usagelogDescUpstreamModel := usagelogFields[5].Descriptor()
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
usagelogDescInputTokens := usagelogFields[7].Descriptor()
usagelogDescInputTokens := usagelogFields[8].Descriptor()
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
usagelogDescOutputTokens := usagelogFields[8].Descriptor()
usagelogDescOutputTokens := usagelogFields[9].Descriptor()
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor()
usagelogDescCacheCreationTokens := usagelogFields[10].Descriptor()
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
usagelogDescCacheReadTokens := usagelogFields[10].Descriptor()
usagelogDescCacheReadTokens := usagelogFields[11].Descriptor()
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor()
usagelogDescCacheCreation5mTokens := usagelogFields[12].Descriptor()
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor()
usagelogDescCacheCreation1hTokens := usagelogFields[13].Descriptor()
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
// usagelogDescInputCost is the schema descriptor for input_cost field.
usagelogDescInputCost := usagelogFields[13].Descriptor()
usagelogDescInputCost := usagelogFields[14].Descriptor()
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
// usagelogDescOutputCost is the schema descriptor for output_cost field.
usagelogDescOutputCost := usagelogFields[14].Descriptor()
usagelogDescOutputCost := usagelogFields[15].Descriptor()
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
usagelogDescCacheCreationCost := usagelogFields[15].Descriptor()
usagelogDescCacheCreationCost := usagelogFields[16].Descriptor()
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
usagelogDescCacheReadCost := usagelogFields[16].Descriptor()
usagelogDescCacheReadCost := usagelogFields[17].Descriptor()
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
// usagelogDescTotalCost is the schema descriptor for total_cost field.
usagelogDescTotalCost := usagelogFields[17].Descriptor()
usagelogDescTotalCost := usagelogFields[18].Descriptor()
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
// usagelogDescActualCost is the schema descriptor for actual_cost field.
usagelogDescActualCost := usagelogFields[18].Descriptor()
usagelogDescActualCost := usagelogFields[19].Descriptor()
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
usagelogDescRateMultiplier := usagelogFields[19].Descriptor()
usagelogDescRateMultiplier := usagelogFields[20].Descriptor()
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
// usagelogDescBillingType is the schema descriptor for billing_type field.
usagelogDescBillingType := usagelogFields[21].Descriptor()
usagelogDescBillingType := usagelogFields[22].Descriptor()
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
// usagelogDescStream is the schema descriptor for stream field.
usagelogDescStream := usagelogFields[22].Descriptor()
usagelogDescStream := usagelogFields[23].Descriptor()
// usagelog.DefaultStream holds the default value on creation for the stream field.
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
// usagelogDescUserAgent is the schema descriptor for user_agent field.
usagelogDescUserAgent := usagelogFields[25].Descriptor()
usagelogDescUserAgent := usagelogFields[26].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field.
usagelogDescIPAddress := usagelogFields[26].Descriptor()
usagelogDescIPAddress := usagelogFields[27].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field.
usagelogDescImageCount := usagelogFields[27].Descriptor()
usagelogDescImageCount := usagelogFields[28].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field.
usagelogDescImageSize := usagelogFields[28].Descriptor()
usagelogDescImageSize := usagelogFields[29].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescMediaType is the schema descriptor for media_type field.
usagelogDescMediaType := usagelogFields[29].Descriptor()
usagelogDescMediaType := usagelogFields[30].Descriptor()
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor()
usagelogDescCacheTTLOverridden := usagelogFields[31].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[31].Descriptor()
usagelogDescCreatedAt := usagelogFields[32].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()

View File

@@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field {
field.String("model").
MaxLen(100).
NotEmpty(),
// UpstreamModel stores the actual upstream model name when model mapping
// is applied. NULL means no mapping — the requested model was used as-is.
field.String("upstream_model").
MaxLen(100).
Optional().
Nillable(),
field.Int64("group_id").
Optional().
Nillable(),

View File

@@ -32,6 +32,8 @@ type UsageLog struct {
RequestID string `json:"request_id,omitempty"`
// Model holds the value of the "model" field.
Model string `json:"model,omitempty"`
// UpstreamModel holds the value of the "upstream_model" field.
UpstreamModel *string `json:"upstream_model,omitempty"`
// GroupID holds the value of the "group_id" field.
GroupID *int64 `json:"group_id,omitempty"`
// SubscriptionID holds the value of the "subscription_id" field.
@@ -175,7 +177,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime)
@@ -230,6 +232,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Model = value.String
}
case usagelog.FieldUpstreamModel:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field upstream_model", values[i])
} else if value.Valid {
_m.UpstreamModel = new(string)
*_m.UpstreamModel = value.String
}
case usagelog.FieldGroupID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field group_id", values[i])
@@ -477,6 +486,11 @@ func (_m *UsageLog) String() string {
builder.WriteString("model=")
builder.WriteString(_m.Model)
builder.WriteString(", ")
if v := _m.UpstreamModel; v != nil {
builder.WriteString("upstream_model=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.GroupID; v != nil {
builder.WriteString("group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))

View File

@@ -24,6 +24,8 @@ const (
FieldRequestID = "request_id"
// FieldModel holds the string denoting the model field in the database.
FieldModel = "model"
// FieldUpstreamModel holds the string denoting the upstream_model field in the database.
FieldUpstreamModel = "upstream_model"
// FieldGroupID holds the string denoting the group_id field in the database.
FieldGroupID = "group_id"
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
@@ -135,6 +137,7 @@ var Columns = []string{
FieldAccountID,
FieldRequestID,
FieldModel,
FieldUpstreamModel,
FieldGroupID,
FieldSubscriptionID,
FieldInputTokens,
@@ -179,6 +182,8 @@ var (
RequestIDValidator func(string) error
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
ModelValidator func(string) error
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
UpstreamModelValidator func(string) error
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
DefaultInputTokens int
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
@@ -258,6 +263,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldModel, opts...).ToFunc()
}
// ByUpstreamModel orders the results by the upstream_model field.
func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()
}
// ByGroupID orders the results by the group_id field.
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldGroupID, opts...).ToFunc()

View File

@@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldModel, v))
}
// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ.
func UpstreamModel(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
}
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
func GroupID(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
@@ -405,6 +410,81 @@ func ModelContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v))
}
// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field.
func UpstreamModelEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
}
// UpstreamModelNEQ applies the NEQ predicate on the "upstream_model" field.
func UpstreamModelNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldUpstreamModel, v))
}
// UpstreamModelIn applies the In predicate on the "upstream_model" field.
func UpstreamModelIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldUpstreamModel, vs...))
}
// UpstreamModelNotIn applies the NotIn predicate on the "upstream_model" field.
func UpstreamModelNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldUpstreamModel, vs...))
}
// UpstreamModelGT applies the GT predicate on the "upstream_model" field.
func UpstreamModelGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldUpstreamModel, v))
}
// UpstreamModelGTE applies the GTE predicate on the "upstream_model" field.
func UpstreamModelGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldUpstreamModel, v))
}
// UpstreamModelLT applies the LT predicate on the "upstream_model" field.
func UpstreamModelLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldUpstreamModel, v))
}
// UpstreamModelLTE applies the LTE predicate on the "upstream_model" field.
func UpstreamModelLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldUpstreamModel, v))
}
// UpstreamModelContains applies the Contains predicate on the "upstream_model" field.
func UpstreamModelContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldUpstreamModel, v))
}
// UpstreamModelHasPrefix applies the HasPrefix predicate on the "upstream_model" field.
func UpstreamModelHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldUpstreamModel, v))
}
// UpstreamModelHasSuffix applies the HasSuffix predicate on the "upstream_model" field.
func UpstreamModelHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldUpstreamModel, v))
}
// UpstreamModelIsNil applies the IsNil predicate on the "upstream_model" field.
func UpstreamModelIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldUpstreamModel))
}
// UpstreamModelNotNil applies the NotNil predicate on the "upstream_model" field.
func UpstreamModelNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldUpstreamModel))
}
// UpstreamModelEqualFold applies the EqualFold predicate on the "upstream_model" field.
func UpstreamModelEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldUpstreamModel, v))
}
// UpstreamModelContainsFold applies the ContainsFold predicate on the "upstream_model" field.
func UpstreamModelContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v))
}
// GroupIDEQ applies the EQ predicate on the "group_id" field.
func GroupIDEQ(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))

View File

@@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate {
return _c
}
// SetUpstreamModel sets the "upstream_model" field.
func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate {
_c.mutation.SetUpstreamModel(v)
return _c
}
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate {
if v != nil {
_c.SetUpstreamModel(*v)
}
return _c
}
// SetGroupID sets the "group_id" field.
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
_c.mutation.SetGroupID(v)
@@ -596,6 +610,11 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
}
}
if v, ok := _c.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
if _, ok := _c.mutation.InputTokens(); !ok {
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
}
@@ -714,6 +733,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
_node.Model = value
}
if value, ok := _c.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
_node.UpstreamModel = &value
}
if value, ok := _c.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
_node.InputTokens = value
@@ -1011,6 +1034,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert {
return u
}
// SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert {
u.Set(usagelog.FieldUpstreamModel, v)
return u
}
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateUpstreamModel() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldUpstreamModel)
return u
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert {
u.SetNull(usagelog.FieldUpstreamModel)
return u
}
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
u.Set(usagelog.FieldGroupID, v)
@@ -1600,6 +1641,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne {
})
}
// SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetUpstreamModel(v)
})
}
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateUpstreamModel() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateUpstreamModel()
})
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearUpstreamModel()
})
}
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@@ -2434,6 +2496,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk {
})
}
// SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetUpstreamModel(v)
})
}
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateUpstreamModel() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateUpstreamModel()
})
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearUpstreamModel()
})
}
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {

View File

@@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate {
return _u
}
// SetUpstreamModel sets the "upstream_model" field.
func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate {
_u.mutation.SetUpstreamModel(v)
return _u
}
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableUpstreamModel(v *string) *UsageLogUpdate {
if v != nil {
_u.SetUpstreamModel(*v)
}
return _u
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate {
_u.mutation.ClearUpstreamModel()
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
_u.mutation.SetGroupID(v)
@@ -745,6 +765,11 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
}
}
if v, ok := _u.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@@ -795,6 +820,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Model(); ok {
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
}
if value, ok := _u.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
}
if _u.mutation.UpstreamModelCleared() {
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
}
if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
}
@@ -1177,6 +1208,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne {
return _u
}
// SetUpstreamModel sets the "upstream_model" field.
func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne {
_u.mutation.SetUpstreamModel(v)
return _u
}
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableUpstreamModel(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetUpstreamModel(*v)
}
return _u
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne {
_u.mutation.ClearUpstreamModel()
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
_u.mutation.SetGroupID(v)
@@ -1833,6 +1884,11 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
}
}
if v, ok := _u.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@@ -1900,6 +1956,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if value, ok := _u.mutation.Model(); ok {
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
}
if value, ok := _u.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
}
if _u.mutation.UpstreamModelCleared() {
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
}
if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
}

View File

@@ -22,8 +22,6 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
@@ -60,8 +58,6 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWA
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
@@ -98,10 +94,6 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
@@ -238,8 +230,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -273,8 +263,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -326,8 +314,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=

View File

@@ -82,8 +82,8 @@ var DefaultAntigravityModelMapping = map[string]string{
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
// Claude Haiku → Sonnet无 Haiku 支持)
"claude-haiku-4-5": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
"claude-haiku-4-5": "claude-sonnet-4-6",
"claude-haiku-4-5-20251001": "claude-sonnet-4-6",
// Gemini 2.5 白名单
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-image": "gemini-2.5-flash-image",

View File

@@ -165,6 +165,8 @@ type AccountWithConcurrency struct {
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
}
const accountListGroupUngroupedQueryValue = "ungrouped"
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
item := AccountWithConcurrency{
Account: dto.AccountFromService(account),
@@ -226,7 +228,20 @@ func (h *AccountHandler) List(c *gin.Context) {
var groupID int64
if groupIDStr := c.Query("group"); groupIDStr != "" {
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
if groupIDStr == accountListGroupUngroupedQueryValue {
groupID = service.AccountListGroupUngrouped
} else {
parsedGroupID, parseErr := strconv.ParseInt(groupIDStr, 10, 64)
if parseErr != nil {
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter"))
return
}
if parsedGroupID < 0 {
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter"))
return
}
groupID = parsedGroupID
}
}
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
@@ -1496,7 +1511,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
}
// GetUsage handles getting account usage information
// GET /api/v1/admin/accounts/:id/usage
// GET /api/v1/admin/accounts/:id/usage?source=passive|active
func (h *AccountHandler) GetUsage(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
@@ -1504,7 +1519,14 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
return
}
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
source := c.DefaultQuery("source", "active")
var usage *service.UsageInfo
if source == "passive" {
usage, err = h.accountUsageService.GetPassiveUsage(c.Request.Context(), accountID)
} else {
usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID)
}
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -17,7 +17,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
adminSvc := newStubAdminService()
userHandler := NewUserHandler(adminSvc, nil)
groupHandler := NewGroupHandler(adminSvc)
groupHandler := NewGroupHandler(adminSvc, nil, nil)
proxyHandler := NewProxyHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc, nil)

View File

@@ -98,12 +98,12 @@ func (h *BackupHandler) CreateBackup(c *gin.Context) {
expireDays = *req.ExpireDays
}
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
record, err := h.backupService.StartBackup(c.Request.Context(), "manual", expireDays)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, record)
response.Accepted(c, record)
}
func (h *BackupHandler) ListBackups(c *gin.Context) {
@@ -196,9 +196,10 @@ func (h *BackupHandler) RestoreBackup(c *gin.Context) {
return
}
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
record, err := h.backupService.StartRestore(c.Request.Context(), backupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"restored": true})
response.Accepted(c, record)
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -272,6 +273,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
modelSource := usagestats.ModelSourceRequested
var requestType *int16
var stream *bool
var billingType *int8
@@ -296,6 +298,13 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
groupID = id
}
}
if rawModelSource := strings.TrimSpace(c.Query("model_source")); rawModelSource != "" {
if !usagestats.IsValidModelSource(rawModelSource) {
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
return
}
modelSource = rawModelSource
}
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
@@ -322,7 +331,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
}
}
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, modelSource, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
@@ -604,3 +613,47 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
// GetUserBreakdown handles getting per-user usage breakdown within a dimension.
// GET /api/v1/admin/dashboard/user-breakdown
// Query params: start_date, end_date, group_id, model, endpoint, endpoint_type, limit
func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
dim := usagestats.UserBreakdownDimension{}
if v := c.Query("group_id"); v != "" {
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
dim.GroupID = id
}
}
dim.Model = c.Query("model")
rawModelSource := strings.TrimSpace(c.DefaultQuery("model_source", usagestats.ModelSourceRequested))
if !usagestats.IsValidModelSource(rawModelSource) {
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
return
}
dim.ModelType = rawModelSource
dim.Endpoint = c.Query("endpoint")
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
limit := 50
if v := c.Query("limit"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
limit = n
}
}
stats, err := h.dashboardService.GetUserBreakdownStats(
c.Request.Context(), startTime, endTime, dim, limit,
)
if err != nil {
response.Error(c, 500, "Failed to get user breakdown stats")
return
}
response.Success(c, gin.H{
"users": stats,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
})
}

View File

@@ -149,6 +149,28 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsInvalidModelSource(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=invalid", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsValidModelSource(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=upstream", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
repo := &dashboardUsageRepoCapture{

View File

@@ -0,0 +1,229 @@
package admin
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// --- mock repo ---
type userBreakdownRepoCapture struct {
service.UsageLogRepository
capturedDim usagestats.UserBreakdownDimension
capturedLimit int
result []usagestats.UserBreakdownItem
}
func (r *userBreakdownRepoCapture) GetUserBreakdownStats(
_ context.Context, _, _ time.Time,
dim usagestats.UserBreakdownDimension, limit int,
) ([]usagestats.UserBreakdownItem, error) {
r.capturedDim = dim
r.capturedLimit = limit
if r.result != nil {
return r.result, nil
}
return []usagestats.UserBreakdownItem{}, nil
}
func newUserBreakdownRouter(repo *userBreakdownRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
svc := service.NewDashboardService(repo, nil, nil, nil)
h := NewDashboardHandler(svc, nil)
router := gin.New()
router.GET("/admin/dashboard/user-breakdown", h.GetUserBreakdown)
return router
}
// --- tests ---
func TestGetUserBreakdown_GroupIDFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=42", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, int64(42), repo.capturedDim.GroupID)
require.Empty(t, repo.capturedDim.Model)
require.Empty(t, repo.capturedDim.Endpoint)
require.Equal(t, 50, repo.capturedLimit) // default limit
}
func TestGetUserBreakdown_ModelFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model)
require.Equal(t, usagestats.ModelSourceRequested, repo.capturedDim.ModelType)
require.Equal(t, int64(0), repo.capturedDim.GroupID)
}
func TestGetUserBreakdown_ModelSourceFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6&model_source=upstream", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, usagestats.ModelSourceUpstream, repo.capturedDim.ModelType)
}
func TestGetUserBreakdown_InvalidModelSource(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model_source=foobar", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code)
}
func TestGetUserBreakdown_EndpointFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/v1/messages&endpoint_type=upstream", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, "/v1/messages", repo.capturedDim.Endpoint)
require.Equal(t, "upstream", repo.capturedDim.EndpointType)
}
func TestGetUserBreakdown_DefaultEndpointType(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/chat", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, "inbound", repo.capturedDim.EndpointType)
}
func TestGetUserBreakdown_CustomLimit(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=100", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, 100, repo.capturedLimit)
}
func TestGetUserBreakdown_LimitClamped(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
// limit > 200 should fall back to default 50
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=999", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, 50, repo.capturedLimit)
}
func TestGetUserBreakdown_ResponseFormat(t *testing.T) {
repo := &userBreakdownRepoCapture{
result: []usagestats.UserBreakdownItem{
{UserID: 1, Email: "alice@test.com", Requests: 100, TotalTokens: 50000, Cost: 1.5, ActualCost: 1.2},
{UserID: 2, Email: "bob@test.com", Requests: 50, TotalTokens: 25000, Cost: 0.8, ActualCost: 0.6},
},
}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=1", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
var resp struct {
Code int `json:"code"`
Data struct {
Users []usagestats.UserBreakdownItem `json:"users"`
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
} `json:"data"`
}
err := json.Unmarshal(w.Body.Bytes(), &resp)
require.NoError(t, err)
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Users, 2)
require.Equal(t, int64(1), resp.Data.Users[0].UserID)
require.Equal(t, "alice@test.com", resp.Data.Users[0].Email)
require.Equal(t, int64(100), resp.Data.Users[0].Requests)
require.InDelta(t, 1.2, resp.Data.Users[0].ActualCost, 0.001)
require.Equal(t, "2026-03-01", resp.Data.StartDate)
require.Equal(t, "2026-03-16", resp.Data.EndDate)
}
func TestGetUserBreakdown_EmptyResult(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=999", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
var resp struct {
Data struct {
Users []usagestats.UserBreakdownItem `json:"users"`
} `json:"data"`
}
err := json.Unmarshal(w.Body.Bytes(), &resp)
require.NoError(t, err)
require.Empty(t, resp.Data.Users)
}
func TestGetUserBreakdown_NoFilters(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, int64(0), repo.capturedDim.GroupID)
require.Empty(t, repo.capturedDim.Model)
require.Empty(t, repo.capturedDim.Endpoint)
}

View File

@@ -38,6 +38,7 @@ type dashboardModelGroupCacheKey struct {
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
ModelSource string `json:"model_source,omitempty"`
RequestType *int16 `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
@@ -111,6 +112,7 @@ func (h *DashboardHandler) getModelStatsCached(
ctx context.Context,
startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64,
modelSource string,
requestType *int16,
stream *bool,
billingType *int8,
@@ -122,12 +124,13 @@ func (h *DashboardHandler) getModelStatsCached(
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
ModelSource: usagestats.NormalizeModelSource(modelSource),
RequestType: requestType,
Stream: stream,
BillingType: billingType,
})
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
return h.dashboardService.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, modelSource)
})
if err != nil {
return nil, hit, err

View File

@@ -200,6 +200,7 @@ func (h *DashboardHandler) buildSnapshotV2Response(
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
usagestats.ModelSourceRequested,
filters.RequestType,
filters.Stream,
filters.BillingType,

View File

@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -16,7 +17,9 @@ import (
// GroupHandler handles admin group management
type GroupHandler struct {
adminService service.AdminService
adminService service.AdminService
dashboardService *service.DashboardService
groupCapacityService *service.GroupCapacityService
}
type optionalLimitField struct {
@@ -69,9 +72,11 @@ func (f optionalLimitField) ToServiceInput() *float64 {
}
// NewGroupHandler creates a new admin group handler
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
func NewGroupHandler(adminService service.AdminService, dashboardService *service.DashboardService, groupCapacityService *service.GroupCapacityService) *GroupHandler {
return &GroupHandler{
adminService: adminService,
adminService: adminService,
dashboardService: dashboardService,
groupCapacityService: groupCapacityService,
}
}
@@ -363,6 +368,33 @@ func (h *GroupHandler) GetStats(c *gin.Context) {
_ = groupID // TODO: implement actual stats
}
// GetUsageSummary returns today's and cumulative cost for all groups.
// GET /api/v1/admin/groups/usage-summary?timezone=Asia/Shanghai
func (h *GroupHandler) GetUsageSummary(c *gin.Context) {
userTZ := c.Query("timezone")
now := timezone.NowInUserLocation(userTZ)
todayStart := timezone.StartOfDayInUserLocation(now, userTZ)
results, err := h.dashboardService.GetGroupUsageSummary(c.Request.Context(), todayStart)
if err != nil {
response.Error(c, 500, "Failed to get group usage summary")
return
}
response.Success(c, results)
}
// GetCapacitySummary returns aggregated capacity (concurrency/sessions/RPM) for all active groups.
// GET /api/v1/admin/groups/capacity-summary
func (h *GroupHandler) GetCapacitySummary(c *gin.Context) {
results, err := h.groupCapacityService.GetAllGroupCapacity(c.Request.Context())
if err != nil {
response.Error(c, 500, "Failed to get group capacity summary")
return
}
response.Success(c, results)
}
// GetGroupAPIKeys handles getting API keys in a group
// GET /api/v1/admin/groups/:id/api-keys
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {

View File

@@ -977,6 +977,58 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
response.Success(c, gin.H{"message": "Admin API key deleted"})
}
// GetOverloadCooldownSettings 获取529过载冷却配置
// GET /api/v1/admin/settings/overload-cooldown
func (h *SettingHandler) GetOverloadCooldownSettings(c *gin.Context) {
settings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.OverloadCooldownSettings{
Enabled: settings.Enabled,
CooldownMinutes: settings.CooldownMinutes,
})
}
// UpdateOverloadCooldownSettingsRequest 更新529过载冷却配置请求
type UpdateOverloadCooldownSettingsRequest struct {
Enabled bool `json:"enabled"`
CooldownMinutes int `json:"cooldown_minutes"`
}
// UpdateOverloadCooldownSettings 更新529过载冷却配置
// PUT /api/v1/admin/settings/overload-cooldown
func (h *SettingHandler) UpdateOverloadCooldownSettings(c *gin.Context) {
var req UpdateOverloadCooldownSettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
settings := &service.OverloadCooldownSettings{
Enabled: req.Enabled,
CooldownMinutes: req.CooldownMinutes,
}
if err := h.settingService.SetOverloadCooldownSettings(c.Request.Context(), settings); err != nil {
response.BadRequest(c, err.Error())
return
}
updatedSettings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.OverloadCooldownSettings{
Enabled: updatedSettings.Enabled,
CooldownMinutes: updatedSettings.CooldownMinutes,
})
}
// GetStreamTimeoutSettings 获取流超时处理配置
// GET /api/v1/admin/settings/stream-timeout
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {

View File

@@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
}
}
status := c.Query("status")
platform := c.Query("platform")
// Parse sorting parameters
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, platform, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -135,14 +135,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
return nil
}
out := &AdminGroup{
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
SortOrder: g.SortOrder,
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
ActiveAccountCount: g.ActiveAccountCount,
RateLimitedAccountCount: g.RateLimitedAccountCount,
SortOrder: g.SortOrder,
}
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
@@ -521,6 +523,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
AccountID: l.AccountID,
RequestID: l.RequestID,
Model: l.Model,
UpstreamModel: l.UpstreamModel,
ServiceTier: l.ServiceTier,
ReasoningEffort: l.ReasoningEffort,
InboundEndpoint: l.InboundEndpoint,

View File

@@ -157,6 +157,12 @@ type ListSoraS3ProfilesResponse struct {
Items []SoraS3Profile `json:"items"`
}
// OverloadCooldownSettings 529过载冷却配置 DTO
type OverloadCooldownSettings struct {
Enabled bool `json:"enabled"`
CooldownMinutes int `json:"cooldown_minutes"`
}
// StreamTimeoutSettings 流超时处理配置 DTO
type StreamTimeoutSettings struct {
Enabled bool `json:"enabled"`

View File

@@ -122,9 +122,11 @@ type AdminGroup struct {
DefaultMappedModel string `json:"default_mapped_model"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
SupportedModelScopes []string `json:"supported_model_scopes"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
ActiveAccountCount int64 `json:"active_account_count,omitempty"`
RateLimitedAccountCount int64 `json:"rate_limited_account_count,omitempty"`
// 分组排序
SortOrder int `json:"sort_order"`
@@ -332,6 +334,9 @@ type UsageLog struct {
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Omitted when no mapping was applied (requested model was used as-is).
UpstreamModel *string `json:"upstream_model,omitempty"`
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string `json:"service_tier,omitempty"`
// ReasoningEffort is the request's reasoning effort level.

View File

@@ -1219,6 +1219,10 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
}
}
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
// 使用默认的错误映射
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
@@ -1227,6 +1231,7 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode)
service.SetOpsUpstreamError(c, statusCode, errMsg, "")
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}

View File

@@ -76,7 +76,7 @@ func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service
return nil, nil
}
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) { return 0, 0, nil }
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil
}

View File

@@ -136,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte {
return []byte(`{
"model":"claude-3-5-sonnet-20241022",
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}
}`)
}
@@ -190,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
System: []any{
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
},
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
}
// body 非法 JSON如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
@@ -209,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
"system": []any{
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
},
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
})
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)

View File

@@ -593,6 +593,10 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE
}
}
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
// 使用默认的错误映射
status, message := mapGeminiUpstreamError(statusCode)
googleError(c, status, message)

View File

@@ -1435,6 +1435,10 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
}
}
// 记录原始上游状态码,以便 ops 错误日志捕获真实的上游错误
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
// 使用默认的错误映射
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
@@ -1443,6 +1447,7 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode)
service.SetOpsUpstreamError(c, statusCode, errMsg, "")
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}

View File

@@ -484,6 +484,9 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
}
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}

View File

@@ -273,8 +273,8 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, nil
}
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
@@ -345,6 +345,12 @@ func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Conte
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, nil
}

View File

@@ -49,8 +49,8 @@ const (
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
)
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4
var defaultUserAgentVersion = "1.20.4"
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
var defaultUserAgentVersion = "1.20.5"
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
var defaultClientSecret = "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"

View File

@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
if RedirectURI != "http://localhost:8085/callback" {
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
}
if GetUserAgent() != "antigravity/1.20.4 windows/amd64" {
if GetUserAgent() != "antigravity/1.20.5 windows/amd64" {
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
}
if SessionTTL != 30*time.Minute {

View File

@@ -275,21 +275,6 @@ func filterOpenCodePrompt(text string) string {
return ""
}
// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表
var systemBlockFilterPrefixes = []string{
"x-anthropic-billing-header",
}
// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串
func filterSystemBlockByPrefix(text string) string {
for _, prefix := range systemBlockFilterPrefixes {
if strings.HasPrefix(text, prefix) {
return ""
}
}
return text
}
// buildSystemInstruction 构建 systemInstruction与 Antigravity-Manager 保持一致)
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
var parts []GeminiPart
@@ -306,8 +291,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if strings.Contains(sysStr, "You are Antigravity") {
userHasAntigravityIdentity = true
}
// 过滤 OpenCode 默认提示词和黑名单前缀
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr))
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(sysStr)
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}
@@ -321,8 +306,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if strings.Contains(block.Text, "You are Antigravity") {
userHasAntigravityIdentity = true
}
// 过滤 OpenCode 默认提示词和黑名单前缀
filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text))
// 过滤 OpenCode 默认提示词
filtered := filterOpenCodePrompt(block.Text)
if filtered != "" {
userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
}

View File

@@ -2,7 +2,10 @@ package antigravity
import (
"encoding/json"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
@@ -349,3 +352,51 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
})
}
}
func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t *testing.T) {
tests := []struct {
name string
system json.RawMessage
}{
{
name: "system array",
system: json.RawMessage(`[{"type":"text","text":"x-anthropic-billing-header keep"}]`),
},
{
name: "system string",
system: json.RawMessage(`"x-anthropic-billing-header keep"`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
claudeReq := &ClaudeRequest{
Model: "claude-3-5-sonnet-latest",
System: tt.system,
Messages: []ClaudeMessage{
{
Role: "user",
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
},
},
}
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "gemini-2.5-flash", DefaultTransformOptions())
require.NoError(t, err)
var req V1InternalRequest
require.NoError(t, json.Unmarshal(body, &req))
require.NotNil(t, req.Request.SystemInstruction)
found := false
for _, part := range req.Request.SystemInstruction.Parts {
if strings.Contains(part.Text, "x-anthropic-billing-header keep") {
found = true
break
}
}
require.True(t, found, "转换后的 systemInstruction 应保留 x-anthropic-billing-header 内容")
})
}
}

View File

@@ -1008,3 +1008,114 @@ func TestAnthropicToResponses_ImageEmptyMediaType(t *testing.T) {
// Should default to image/png when media_type is empty.
assert.Equal(t, "data:image/png;base64,iVBOR", parts[0].ImageURL)
}
// ---------------------------------------------------------------------------
// normalizeToolParameters tests
// ---------------------------------------------------------------------------
func TestNormalizeToolParameters(t *testing.T) {
tests := []struct {
name string
input json.RawMessage
expected string
}{
{
name: "nil input",
input: nil,
expected: `{"type":"object","properties":{}}`,
},
{
name: "empty input",
input: json.RawMessage(``),
expected: `{"type":"object","properties":{}}`,
},
{
name: "null input",
input: json.RawMessage(`null`),
expected: `{"type":"object","properties":{}}`,
},
{
name: "object without properties",
input: json.RawMessage(`{"type":"object"}`),
expected: `{"type":"object","properties":{}}`,
},
{
name: "object with properties",
input: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`),
expected: `{"type":"object","properties":{"city":{"type":"string"}}}`,
},
{
name: "non-object type",
input: json.RawMessage(`{"type":"string"}`),
expected: `{"type":"string"}`,
},
{
name: "object with additional fields preserved",
input: json.RawMessage(`{"type":"object","required":["name"]}`),
expected: `{"type":"object","required":["name"],"properties":{}}`,
},
{
name: "invalid JSON passthrough",
input: json.RawMessage(`not json`),
expected: `not json`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := normalizeToolParameters(tt.input)
if tt.name == "invalid JSON passthrough" {
assert.Equal(t, tt.expected, string(result))
} else {
assert.JSONEq(t, tt.expected, string(result))
}
})
}
}
func TestAnthropicToResponses_ToolWithoutProperties(t *testing.T) {
req := &AnthropicRequest{
Model: "gpt-5.2",
MaxTokens: 1024,
Messages: []AnthropicMessage{
{Role: "user", Content: json.RawMessage(`"Hello"`)},
},
Tools: []AnthropicTool{
{Name: "mcp__pencil__get_style_guide_tags", Description: "Get style tags", InputSchema: json.RawMessage(`{"type":"object"}`)},
},
}
resp, err := AnthropicToResponses(req)
require.NoError(t, err)
require.Len(t, resp.Tools, 1)
assert.Equal(t, "function", resp.Tools[0].Type)
assert.Equal(t, "mcp__pencil__get_style_guide_tags", resp.Tools[0].Name)
// Parameters must have "properties" field after normalization.
var params map[string]json.RawMessage
require.NoError(t, json.Unmarshal(resp.Tools[0].Parameters, &params))
assert.Contains(t, params, "properties")
}
func TestAnthropicToResponses_ToolWithNilSchema(t *testing.T) {
req := &AnthropicRequest{
Model: "gpt-5.2",
MaxTokens: 1024,
Messages: []AnthropicMessage{
{Role: "user", Content: json.RawMessage(`"Hello"`)},
},
Tools: []AnthropicTool{
{Name: "simple_tool", Description: "A tool"},
},
}
resp, err := AnthropicToResponses(req)
require.NoError(t, err)
require.Len(t, resp.Tools, 1)
var params map[string]json.RawMessage
require.NoError(t, json.Unmarshal(resp.Tools[0].Parameters, &params))
assert.JSONEq(t, `"object"`, string(params["type"]))
assert.JSONEq(t, `{}`, string(params["properties"]))
}

View File

@@ -409,8 +409,41 @@ func convertAnthropicToolsToResponses(tools []AnthropicTool) []ResponsesTool {
Type: "function",
Name: t.Name,
Description: t.Description,
Parameters: t.InputSchema,
Parameters: normalizeToolParameters(t.InputSchema),
})
}
return out
}
// normalizeToolParameters ensures the tool parameter schema is valid for
// OpenAI's Responses API, which requires "properties" on object schemas.
//
// - nil/empty → {"type":"object","properties":{}}
// - type=object without properties → adds "properties": {}
// - otherwise → returned unchanged
func normalizeToolParameters(schema json.RawMessage) json.RawMessage {
if len(schema) == 0 || string(schema) == "null" {
return json.RawMessage(`{"type":"object","properties":{}}`)
}
var m map[string]json.RawMessage
if err := json.Unmarshal(schema, &m); err != nil {
return schema
}
typ := m["type"]
if string(typ) != `"object"` {
return schema
}
if _, ok := m["properties"]; ok {
return schema
}
m["properties"] = json.RawMessage(`{}`)
out, err := json.Marshal(m)
if err != nil {
return schema
}
return out
}

View File

@@ -47,6 +47,15 @@ func Created(c *gin.Context, data any) {
})
}
// Accepted 返回异步接受响应 (HTTP 202)
func Accepted(c *gin.Context, data any) {
c.JSON(http.StatusAccepted, Response{
Code: 0,
Message: "accepted",
Data: data,
})
}
// Error 返回错误响应
func Error(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, Response{

View File

@@ -3,6 +3,28 @@ package usagestats
import "time"
const (
ModelSourceRequested = "requested"
ModelSourceUpstream = "upstream"
ModelSourceMapping = "mapping"
)
func IsValidModelSource(source string) bool {
switch source {
case ModelSourceRequested, ModelSourceUpstream, ModelSourceMapping:
return true
default:
return false
}
}
func NormalizeModelSource(source string) string {
if IsValidModelSource(source) {
return source
}
return ModelSourceRequested
}
// DashboardStats 仪表盘统计
type DashboardStats struct {
// 用户统计
@@ -90,6 +112,13 @@ type EndpointStat struct {
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// GroupUsageSummary represents today's and cumulative cost for a single group.
type GroupUsageSummary struct {
GroupID int64 `json:"group_id"`
TodayCost float64 `json:"today_cost"`
TotalCost float64 `json:"total_cost"`
}
// GroupStat represents usage statistics for a single group
type GroupStat struct {
GroupID int64 `json:"group_id"`
@@ -129,6 +158,25 @@ type UserSpendingRankingResponse struct {
TotalTokens int64 `json:"total_tokens"`
}
// UserBreakdownItem represents per-user usage breakdown within a dimension (group, model, endpoint).
type UserBreakdownItem struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Requests int64 `json:"requests"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserBreakdownDimension specifies the dimension to filter for user breakdown.
type UserBreakdownDimension struct {
GroupID int64 // filter by group_id (>0 to enable)
Model string // filter by model name (non-empty to enable)
ModelType string // "requested", "upstream", or "mapping"
Endpoint string // filter by endpoint value (non-empty to enable)
EndpointType string // "inbound", "upstream", or "path"
}
// APIKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint struct {
Date string `json:"date"`

View File

@@ -0,0 +1,47 @@
package usagestats
import "testing"
func TestIsValidModelSource(t *testing.T) {
tests := []struct {
name string
source string
want bool
}{
{name: "requested", source: ModelSourceRequested, want: true},
{name: "upstream", source: ModelSourceUpstream, want: true},
{name: "mapping", source: ModelSourceMapping, want: true},
{name: "invalid", source: "foobar", want: false},
{name: "empty", source: "", want: false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := IsValidModelSource(tc.source); got != tc.want {
t.Fatalf("IsValidModelSource(%q)=%v want %v", tc.source, got, tc.want)
}
})
}
}
func TestNormalizeModelSource(t *testing.T) {
tests := []struct {
name string
source string
want string
}{
{name: "requested", source: ModelSourceRequested, want: ModelSourceRequested},
{name: "upstream", source: ModelSourceUpstream, want: ModelSourceUpstream},
{name: "mapping", source: ModelSourceMapping, want: ModelSourceMapping},
{name: "invalid falls back", source: "foobar", want: ModelSourceRequested},
{name: "empty falls back", source: "", want: ModelSourceRequested},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := NormalizeModelSource(tc.source); got != tc.want {
t.Fatalf("NormalizeModelSource(%q)=%q want %q", tc.source, got, tc.want)
}
})
}
}

View File

@@ -56,6 +56,7 @@ var schedulerNeutralExtraKeyPrefixes = []string{
"codex_secondary_",
"codex_5h_",
"codex_7d_",
"passive_usage_",
}
var schedulerNeutralExtraKeys = map[string]struct{}{
@@ -473,7 +474,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
if search != "" {
q = q.Where(dbaccount.NameContainsFold(search))
}
if groupID > 0 {
if groupID == service.AccountListGroupUngrouped {
q = q.Where(dbaccount.Not(dbaccount.HasAccountGroups()))
} else if groupID > 0 {
q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
}

View File

@@ -214,6 +214,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
accType string
status string
search string
groupID int64
wantCount int
validate func(accounts []service.Account)
}{
@@ -265,6 +266,21 @@ func (s *AccountRepoSuite) TestListWithFilters() {
s.Require().Contains(accounts[0].Name, "alpha")
},
},
{
name: "filter_by_ungrouped",
setup: func(client *dbent.Client) {
group := mustCreateGroup(s.T(), client, &service.Group{Name: "g-ungrouped"})
grouped := mustCreateAccount(s.T(), client, &service.Account{Name: "grouped-account"})
mustCreateAccount(s.T(), client, &service.Account{Name: "ungrouped-account"})
mustBindAccountToGroup(s.T(), client, grouped.ID, group.ID, 1)
},
groupID: service.AccountListGroupUngrouped,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal("ungrouped-account", accounts[0].Name)
s.Require().Empty(accounts[0].GroupIDs)
},
},
}
for _, tt := range tests {
@@ -277,7 +293,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
tt.setup(client)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID)
s.Require().NoError(err)
s.Require().Len(accounts, tt.wantCount)
if tt.validate != nil {

View File

@@ -57,6 +57,7 @@ func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
// 读取全部内容以获取大小S3 PutObject 需要知道内容长度)
// 注意:阿里云 OSS 不兼容 s3manager 分片上传的签名方式,因此使用 PutObject
data, err := io.ReadAll(body)
if err != nil {
return 0, fmt.Errorf("read body: %w", err)

View File

@@ -20,6 +20,11 @@ const (
billingCacheTTL = 5 * time.Minute
billingCacheJitter = 30 * time.Second
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
// Rate limit window durations — must match service.RateLimitWindow* constants.
rateLimitWindow5h = 5 * time.Hour
rateLimitWindow1d = 24 * time.Hour
rateLimitWindow7d = 7 * 24 * time.Hour
)
// jitteredTTL 返回带随机抖动的 TTL防止缓存雪崩
@@ -90,17 +95,40 @@ var (
return 1
`)
// updateRateLimitUsageScript atomically increments all three rate limit usage counters.
// Returns 0 if the key doesn't exist (cache miss), 1 on success.
// updateRateLimitUsageScript atomically increments all three rate limit usage counters
// with window expiration checking. If a window has expired, its usage is reset to cost
// (instead of accumulated) and the window timestamp is updated, matching the DB-side
// IncrementRateLimitUsage semantics.
//
// ARGV: [1]=cost, [2]=ttl_seconds, [3]=now_unix, [4]=window_5h_seconds, [5]=window_1d_seconds, [6]=window_7d_seconds
updateRateLimitUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost)
local now = tonumber(ARGV[3])
local win5h = tonumber(ARGV[4])
local win1d = tonumber(ARGV[5])
local win7d = tonumber(ARGV[6])
-- Helper: check if window is expired and update usage + window accordingly
-- Returns nothing, modifies the hash in-place.
local function update_window(usage_field, window_field, window_duration)
local w = tonumber(redis.call('HGET', KEYS[1], window_field) or 0)
if w == 0 or (now - w) >= window_duration then
-- Window expired or never started: reset usage to cost, start new window
redis.call('HSET', KEYS[1], usage_field, tostring(cost))
redis.call('HSET', KEYS[1], window_field, tostring(now))
else
-- Window still valid: accumulate
redis.call('HINCRBYFLOAT', KEYS[1], usage_field, cost)
end
end
update_window('usage_5h', 'window_5h', win5h)
update_window('usage_1d', 'window_1d', win1d)
update_window('usage_7d', 'window_7d', win7d)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
@@ -280,7 +308,15 @@ func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
key := billingRateLimitKey(keyID)
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result()
now := time.Now().Unix()
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key},
cost,
int(rateLimitCacheTTL.Seconds()),
now,
int(rateLimitWindow5h.Seconds()),
int(rateLimitWindow1d.Seconds()),
int(rateLimitWindow7d.Seconds()),
).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
return err

View File

@@ -88,8 +88,9 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
if err != nil {
return nil, err
}
count, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count
total, active, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = total
out.ActiveAccountCount = active
return out, nil
}
@@ -256,7 +257,10 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil {
for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID]
c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
}
}
@@ -283,7 +287,10 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil {
for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID]
c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
}
}
@@ -310,7 +317,10 @@ func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform str
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil {
for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID]
c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
}
}
@@ -369,12 +379,20 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int
return result, nil
}
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
return 0, err
}
return count, nil
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) {
var rateLimited int64
err = scanSingleRow(ctx, r.sql,
`SELECT COUNT(*),
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true),
COUNT(*) FILTER (WHERE a.status = 'active' AND (
a.rate_limit_reset_at > NOW() OR
a.overload_until > NOW() OR
a.temp_unschedulable_until > NOW()
))
FROM account_groups ag JOIN accounts a ON a.id = ag.account_id
WHERE ag.group_id = $1`,
[]any{groupID}, &total, &active, &rateLimited)
return
}
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
@@ -500,15 +518,32 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return affectedUserIDs, nil
}
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) {
counts = make(map[int64]int64, len(groupIDs))
type groupAccountCounts struct {
Total int64
Active int64
RateLimited int64
}
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) {
counts = make(map[int64]groupAccountCounts, len(groupIDs))
if len(groupIDs) == 0 {
return counts, nil
}
rows, err := r.sql.QueryContext(
ctx,
"SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id",
`SELECT ag.group_id,
COUNT(*) AS total,
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active,
COUNT(*) FILTER (WHERE a.status = 'active' AND (
a.rate_limit_reset_at > NOW() OR
a.overload_until > NOW() OR
a.temp_unschedulable_until > NOW()
)) AS rate_limited
FROM account_groups ag
JOIN accounts a ON a.id = ag.account_id
WHERE ag.group_id = ANY($1)
GROUP BY ag.group_id`,
pq.Array(groupIDs),
)
if err != nil {
@@ -523,11 +558,11 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
for rows.Next() {
var groupID int64
var count int64
if err = rows.Scan(&groupID, &count); err != nil {
var c groupAccountCounts
if err = rows.Scan(&groupID, &c.Total, &c.Active, &c.RateLimited); err != nil {
return nil, err
}
counts[groupID] = count
counts[groupID] = c
}
if err = rows.Err(); err != nil {
return nil, err

View File

@@ -603,7 +603,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
s.Require().NoError(err)
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(2), count)
}
@@ -619,7 +619,7 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
}
s.Require().NoError(s.repo.Create(s.ctx, group))
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Zero(count)
}
@@ -651,7 +651,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
count, err := s.repo.GetAccountCount(s.ctx, g.ID)
count, _, err := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(0), count, "expected 0 account groups")
}
@@ -692,7 +692,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
s.Require().NoError(err)
s.Require().Equal(int64(3), affected)
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
count, _, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count)
}

View File

@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
var usageLogInsertArgTypes = [...]string{
"bigint",
@@ -36,6 +36,7 @@ var usageLogInsertArgTypes = [...]string{
"bigint",
"text",
"text",
"text",
"bigint",
"bigint",
"integer",
@@ -277,6 +278,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -311,12 +313,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
$1, $2, $3, $4, $5, $6,
$7, $8,
$9, $10, $11, $12,
$13, $14,
$15, $16, $17, $18, $19, $20,
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -707,6 +709,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -742,7 +745,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
created_at
) AS (VALUES `)
args := make([]any, 0, len(keys)*38)
args := make([]any, 0, len(keys)*39)
argPos := 1
for idx, key := range keys {
if idx > 0 {
@@ -776,6 +779,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -816,6 +820,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -896,6 +901,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -931,7 +937,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*38)
args := make([]any, 0, len(preparedList)*39)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
@@ -962,6 +968,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -1002,6 +1009,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -1050,6 +1058,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -1084,12 +1093,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
$1, $2, $3, $4, $5, $6,
$7, $8,
$9, $10, $11, $12,
$13, $14,
$15, $16, $17, $18, $19, $20,
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
@@ -1121,6 +1130,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint)
upstreamModel := nullString(log.UpstreamModel)
var requestIDArg any
if requestID != "" {
@@ -1138,6 +1148,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.AccountID,
requestIDArg,
log.Model,
upstreamModel,
groupID,
subscriptionID,
log.InputTokens,
@@ -2864,15 +2875,26 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
// GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, usagestats.ModelSourceRequested)
}
// GetModelStatsWithFiltersBySource returns model statistics with optional filters and model source dimension.
// source: requested | upstream | mapping.
func (r *usageLogRepository) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) {
return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, source)
}
func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
modelExpr := resolveModelDimensionExpression(source)
query := fmt.Sprintf(`
SELECT
model,
%s as model,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
@@ -2883,7 +2905,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`, actualCostExpr)
`, modelExpr, actualCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
@@ -2907,7 +2929,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY model ORDER BY total_tokens DESC"
query += fmt.Sprintf(" GROUP BY %s ORDER BY total_tokens DESC", modelExpr)
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
@@ -3000,6 +3022,132 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
return results, nil
}
// GetUserBreakdownStats returns per-user usage breakdown within a specific dimension.
func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) (results []usagestats.UserBreakdownItem, err error) {
query := `
SELECT
COALESCE(ul.user_id, 0) as user_id,
COALESCE(u.email, '') as email,
COUNT(*) as requests,
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(ul.total_cost), 0) as cost,
COALESCE(SUM(ul.actual_cost), 0) as actual_cost
FROM usage_logs ul
LEFT JOIN users u ON u.id = ul.user_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
`
args := []any{startTime, endTime}
if dim.GroupID > 0 {
query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1)
args = append(args, dim.GroupID)
}
if dim.Model != "" {
query += fmt.Sprintf(" AND %s = $%d", resolveModelDimensionExpression(dim.ModelType), len(args)+1)
args = append(args, dim.Model)
}
if dim.Endpoint != "" {
col := resolveEndpointColumn(dim.EndpointType)
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
args = append(args, dim.Endpoint)
}
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
if limit > 0 {
query += fmt.Sprintf(" LIMIT %d", limit)
}
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results = make([]usagestats.UserBreakdownItem, 0)
for rows.Next() {
var row usagestats.UserBreakdownItem
if err := rows.Scan(
&row.UserID,
&row.Email,
&row.Requests,
&row.TotalTokens,
&row.Cost,
&row.ActualCost,
); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
// GetAllGroupUsageSummary returns today's and cumulative actual_cost for every group.
// todayStart is the start-of-day in the caller's timezone (UTC-based).
// TODO(perf): This query scans ALL usage_logs rows for total_cost aggregation.
// When usage_logs exceeds ~1M rows, consider adding a short-lived cache (30s)
// or a materialized view / pre-aggregation table for cumulative costs.
func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
query := `
SELECT
g.id AS group_id,
COALESCE(SUM(ul.actual_cost), 0) AS total_cost,
COALESCE(SUM(CASE WHEN ul.created_at >= $1 THEN ul.actual_cost ELSE 0 END), 0) AS today_cost
FROM groups g
LEFT JOIN usage_logs ul ON ul.group_id = g.id
GROUP BY g.id
`
rows, err := r.sql.QueryContext(ctx, query, todayStart)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var results []usagestats.GroupUsageSummary
for rows.Next() {
var row usagestats.GroupUsageSummary
if err := rows.Scan(&row.GroupID, &row.TotalCost, &row.TodayCost); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
// resolveModelDimensionExpression maps model source type to a safe SQL expression.
func resolveModelDimensionExpression(modelType string) string {
switch usagestats.NormalizeModelSource(modelType) {
case usagestats.ModelSourceUpstream:
return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"
case usagestats.ModelSourceMapping:
return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"
default:
return "model"
}
}
// resolveEndpointColumn maps endpoint type to the corresponding DB column name.
func resolveEndpointColumn(endpointType string) string {
switch endpointType {
case "upstream":
return "ul.upstream_endpoint"
case "path":
return "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"
default:
return "ul.inbound_endpoint"
}
}
// GetGlobalStats gets usage statistics for all users within a time range
func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
query := `
@@ -3740,6 +3888,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
accountID int64
requestID sql.NullString
model string
upstreamModel sql.NullString
groupID sql.NullInt64
subscriptionID sql.NullInt64
inputTokens int
@@ -3782,6 +3931,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&accountID,
&requestID,
&model,
&upstreamModel,
&groupID,
&subscriptionID,
&inputTokens,
@@ -3894,6 +4044,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if upstreamEndpoint.Valid {
log.UpstreamEndpoint = &upstreamEndpoint.String
}
if upstreamModel.Valid {
log.UpstreamModel = &upstreamModel.String
}
return log, nil
}

View File

@@ -0,0 +1,50 @@
//go:build unit
package repository
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/stretchr/testify/require"
)
func TestResolveEndpointColumn(t *testing.T) {
tests := []struct {
endpointType string
want string
}{
{"inbound", "ul.inbound_endpoint"},
{"upstream", "ul.upstream_endpoint"},
{"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"},
{"", "ul.inbound_endpoint"}, // default
{"unknown", "ul.inbound_endpoint"}, // fallback
}
for _, tc := range tests {
t.Run(tc.endpointType, func(t *testing.T) {
got := resolveEndpointColumn(tc.endpointType)
require.Equal(t, tc.want, got)
})
}
}
func TestResolveModelDimensionExpression(t *testing.T) {
tests := []struct {
modelType string
want string
}{
{usagestats.ModelSourceRequested, "model"},
{usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"},
{usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"},
{"", "model"},
{"invalid", "model"},
}
for _, tc := range tests {
t.Run(tc.modelType, func(t *testing.T) {
got := resolveModelDimensionExpression(tc.modelType)
require.Equal(t, tc.want, got)
})
}
}

View File

@@ -44,6 +44,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.AccountID,
log.RequestID,
log.Model,
sqlmock.AnyArg(), // upstream_model
sqlmock.AnyArg(), // group_id
sqlmock.AnyArg(), // subscription_id
log.InputTokens,
@@ -116,6 +117,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.Model,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.InputTokens,
log.OutputTokens,
log.CacheCreationTokens,
@@ -353,6 +355,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(30), // account_id
sql.NullString{Valid: true, String: "req-1"},
"gpt-5", // model
sql.NullString{}, // upstream_model
sql.NullInt64{}, // group_id
sql.NullInt64{}, // subscription_id
1, // input_tokens
@@ -404,6 +407,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(31),
sql.NullString{Valid: true, String: "req-2"},
"gpt-5",
sql.NullString{},
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
@@ -445,6 +449,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(32),
sql.NullString{Valid: true, String: "req-3"},
"gpt-5.4",
sql.NullString{},
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,

View File

@@ -5,6 +5,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -190,7 +191,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
}
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query()
if userID != nil {
@@ -199,6 +200,9 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if groupID != nil {
q = q.Where(usersubscription.GroupIDEQ(*groupID))
}
if platform != "" {
q = q.Where(usersubscription.HasGroupWith(group.PlatformEQ(platform)))
}
// Status filtering with real-time expiration check
now := time.Now()

View File

@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group := s.mustCreateGroup("g-list")
s.mustCreateSubscription(user.ID, group.ID, nil)
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "", "")
s.Require().NoError(err, "List")
s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total)
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID)
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID)
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)

View File

@@ -924,8 +924,8 @@ func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error
return false, errors.New("not implemented")
}
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, errors.New("not implemented")
}
func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
@@ -1289,7 +1289,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
@@ -1637,6 +1637,10 @@ func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, errors.New("not implemented")
}
@@ -1782,6 +1786,9 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
return nil, errors.New("not implemented")
}
type stubSettingRepo struct {
all map[string]string

View File

@@ -135,7 +135,7 @@ func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, user
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {

View File

@@ -646,7 +646,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return nil, nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}

View File

@@ -198,6 +198,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
dashboard.GET("/user-breakdown", h.Admin.Dashboard.GetUserBreakdown)
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
}
}
@@ -226,6 +227,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/usage-summary", h.Admin.Group.GetUsageSummary)
groups.GET("/capacity-summary", h.Admin.Group.GetCapacitySummary)
groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
@@ -399,6 +402,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
// 529过载冷却配置
adminSettings.GET("/overload-cooldown", h.Admin.Setting.GetOverloadCooldownSettings)
adminSettings.PUT("/overload-cooldown", h.Admin.Setting.UpdateOverloadCooldownSettings)
// 流超时处理配置
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)

View File

@@ -14,6 +14,8 @@ var (
ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil")
)
const AccountListGroupUngrouped int64 = -1
type AccountRepository interface {
Create(ctx context.Context, account *Account) error
GetByID(ctx context.Context, id int64) (*Account, error)

View File

@@ -113,15 +113,18 @@ func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error)
return normalized, nil
}
// generateSessionString generates a Claude Code style session string
// generateSessionString generates a Claude Code style session string.
// The output format is determined by the UA version in claude.DefaultHeaders,
// ensuring consistency between the user_id format and the UA sent to upstream.
func generateSessionString() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
hex64 := hex.EncodeToString(bytes)
hex64 := hex.EncodeToString(b)
sessionUUID := uuid.New().String()
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
uaVersion := ExtractCLIVersion(claude.DefaultHeaders["User-Agent"])
return FormatMetadataUserID(hex64, "", sessionUUID, uaVersion), nil
}
// createTestPayload creates a Claude Code style test request payload
@@ -305,7 +308,14 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
errMsg := fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))
// 403 表示账号被上游封禁,标记为 error 状态
if resp.StatusCode == http.StatusForbidden {
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
}
return s.sendErrorAndEnd(c, errMsg)
}
// Process SSE stream

View File

@@ -48,6 +48,8 @@ type UsageLogRepository interface {
GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error)
GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
@@ -175,6 +177,7 @@ type AICredit struct {
// UsageInfo 账号使用量信息
type UsageInfo struct {
Source string `json:"source,omitempty"` // "passive" or "active"
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
@@ -391,6 +394,9 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
// 4. 添加窗口统计有独立缓存1 分钟)
s.addWindowStats(ctx, account, usage)
// 5. 将主动查询结果同步到被动缓存,下次 passive 加载即为最新值
s.syncActiveToPassive(ctx, account.ID, usage)
s.tryClearRecoverableAccountError(ctx, account)
return usage, nil
}
@@ -407,6 +413,81 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
}
// GetPassiveUsage 从 Account.Extra 中的被动采样数据构建 UsageInfo不调用外部 API。
// 仅适用于 Anthropic OAuth / SetupToken 账号。
func (s *AccountUsageService) GetPassiveUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get account failed: %w", err)
}
if !account.IsAnthropicOAuthOrSetupToken() {
return nil, fmt.Errorf("passive usage only supported for Anthropic OAuth/SetupToken accounts")
}
// 复用 estimateSetupTokenUsage 构建 5h 窗口OAuth 和 SetupToken 逻辑一致)
info := s.estimateSetupTokenUsage(account)
info.Source = "passive"
// 设置采样时间
if raw, ok := account.Extra["passive_usage_sampled_at"]; ok {
if str, ok := raw.(string); ok {
if t, err := time.Parse(time.RFC3339, str); err == nil {
info.UpdatedAt = &t
}
}
}
// 构建 7d 窗口(从被动采样数据)
util7d := parseExtraFloat64(account.Extra["passive_usage_7d_utilization"])
reset7dRaw := parseExtraFloat64(account.Extra["passive_usage_7d_reset"])
if util7d > 0 || reset7dRaw > 0 {
var resetAt *time.Time
var remaining int
if reset7dRaw > 0 {
t := time.Unix(int64(reset7dRaw), 0)
resetAt = &t
remaining = int(time.Until(t).Seconds())
if remaining < 0 {
remaining = 0
}
}
info.SevenDay = &UsageProgress{
Utilization: util7d * 100,
ResetsAt: resetAt,
RemainingSeconds: remaining,
}
}
// 添加窗口统计
s.addWindowStats(ctx, account, info)
return info, nil
}
// syncActiveToPassive 将主动查询的最新数据回写到 Extra 被动缓存,
// 这样下次被动加载时能看到最新值。
func (s *AccountUsageService) syncActiveToPassive(ctx context.Context, accountID int64, usage *UsageInfo) {
extraUpdates := make(map[string]any, 4)
if usage.FiveHour != nil {
extraUpdates["session_window_utilization"] = usage.FiveHour.Utilization / 100
}
if usage.SevenDay != nil {
extraUpdates["passive_usage_7d_utilization"] = usage.SevenDay.Utilization / 100
if usage.SevenDay.ResetsAt != nil {
extraUpdates["passive_usage_7d_reset"] = usage.SevenDay.ResetsAt.Unix()
}
}
if len(extraUpdates) > 0 {
extraUpdates["passive_usage_sampled_at"] = time.Now().UTC().Format(time.RFC3339)
if err := s.accountRepo.UpdateExtra(ctx, accountID, extraUpdates); err != nil {
slog.Warn("sync_active_to_passive_failed", "account_id", accountID, "error", err)
}
}
}
func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
now := time.Now()
usage := &UsageInfo{UpdatedAt: &now}
@@ -446,23 +527,17 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
}
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-5*time.Hour)); err == nil {
windowStats := windowStatsFromAccountStats(stats)
if hasMeaningfulWindowStats(windowStats) {
if usage.FiveHour == nil {
usage.FiveHour = &UsageProgress{Utilization: 0}
}
usage.FiveHour.WindowStats = windowStats
if usage.FiveHour == nil {
usage.FiveHour = &UsageProgress{Utilization: 0}
}
usage.FiveHour.WindowStats = windowStatsFromAccountStats(stats)
}
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-7*24*time.Hour)); err == nil {
windowStats := windowStatsFromAccountStats(stats)
if hasMeaningfulWindowStats(windowStats) {
if usage.SevenDay == nil {
usage.SevenDay = &UsageProgress{Utilization: 0}
}
usage.SevenDay.WindowStats = windowStats
if usage.SevenDay == nil {
usage.SevenDay = &UsageProgress{Utilization: 0}
}
usage.SevenDay.WindowStats = windowStatsFromAccountStats(stats)
}
return usage, nil
@@ -992,13 +1067,6 @@ func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
}
}
func hasMeaningfulWindowStats(stats *WindowStats) bool {
if stats == nil {
return false
}
return stats.Requests > 0 || stats.Tokens > 0 || stats.Cost > 0 || stats.StandardCost > 0 || stats.UserCost > 0
}
func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now time.Time) *UsageProgress {
if len(extra) == 0 {
return nil
@@ -1055,6 +1123,11 @@ func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now t
}
}
// 窗口已过期resetAt 在 now 之前)→ 额度已重置,归零
if progress.ResetsAt != nil && !now.Before(*progress.ResetsAt) {
progress.Utilization = 0
}
return progress
}

View File

@@ -148,3 +148,54 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
t.Fatal("waiting for codex probe rate limit persistence timed out")
}
}
func TestBuildCodexUsageProgressFromExtra_ZerosExpiredWindow(t *testing.T) {
t.Parallel()
now := time.Date(2026, 3, 16, 12, 0, 0, 0, time.UTC)
t.Run("expired 5h window zeroes utilization", func(t *testing.T) {
extra := map[string]any{
"codex_5h_used_percent": 42.0,
"codex_5h_reset_at": "2026-03-16T10:00:00Z", // 2h ago
}
progress := buildCodexUsageProgressFromExtra(extra, "5h", now)
if progress == nil {
t.Fatal("expected non-nil progress")
}
if progress.Utilization != 0 {
t.Fatalf("expected Utilization=0 for expired window, got %v", progress.Utilization)
}
if progress.RemainingSeconds != 0 {
t.Fatalf("expected RemainingSeconds=0, got %v", progress.RemainingSeconds)
}
})
t.Run("active 5h window keeps utilization", func(t *testing.T) {
resetAt := now.Add(2 * time.Hour).Format(time.RFC3339)
extra := map[string]any{
"codex_5h_used_percent": 42.0,
"codex_5h_reset_at": resetAt,
}
progress := buildCodexUsageProgressFromExtra(extra, "5h", now)
if progress == nil {
t.Fatal("expected non-nil progress")
}
if progress.Utilization != 42.0 {
t.Fatalf("expected Utilization=42, got %v", progress.Utilization)
}
})
t.Run("expired 7d window zeroes utilization", func(t *testing.T) {
extra := map[string]any{
"codex_7d_used_percent": 88.0,
"codex_7d_reset_at": "2026-03-15T00:00:00Z", // yesterday
}
progress := buildCodexUsageProgressFromExtra(extra, "7d", now)
if progress == nil {
t.Fatal("expected non-nil progress")
}
if progress.Utilization != 0 {
t.Fatalf("expected Utilization=0 for expired 7d window, got %v", progress.Utilization)
}
})
}

View File

@@ -1530,7 +1530,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if len(input.Credentials) > 0 {
account.Credentials = input.Credentials
}
if len(input.Extra) > 0 {
// Extra 使用 map需要区分“未提供(nil)”与“显式清空({})”。
// 关闭配额限制时前端会删除 quota_* 键并提交 extra:{},此时也必须落库。
if input.Extra != nil {
// 保留配额用量字段,防止编辑账号时意外重置
for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} {
if v, ok := account.Extra[key]; ok {

View File

@@ -194,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {

View File

@@ -160,7 +160,7 @@ func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, er
panic("unexpected ExistsByName call")
}
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}

View File

@@ -100,7 +100,7 @@ func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool,
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) {
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}
@@ -383,7 +383,7 @@ func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}
@@ -458,7 +458,7 @@ func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context,
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}

View File

@@ -121,3 +121,35 @@ func TestUpdateAccount_EnableOveragesClearsModelRateLimitsBeforePersist(t *testi
_, exists := repo.account.Extra[modelRateLimitsKey]
require.False(t, exists, "开启 overages 时应在持久化前清掉旧模型限流")
}
func TestUpdateAccount_EmptyExtraPayloadCanClearQuotaLimits(t *testing.T) {
accountID := int64(103)
repo := &updateAccountOveragesRepoStub{
account: &Account{
ID: accountID,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Status: StatusActive,
Extra: map[string]any{
"quota_limit": 100.0,
"quota_daily_limit": 10.0,
"quota_weekly_limit": 40.0,
},
},
}
svc := &adminServiceImpl{accountRepo: repo}
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
// 显式空对象:语义是“清空 extra 中的可配置键”(例如关闭配额限制)
Extra: map[string]any{},
})
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, 1, repo.updateCalls)
require.NotNil(t, repo.account.Extra)
require.NotContains(t, repo.account.Extra, "quota_limit")
require.NotContains(t, repo.account.Extra, "quota_daily_limit")
require.NotContains(t, repo.account.Extra, "quota_weekly_limit")
require.Len(t, repo.account.Extra, 0)
}

View File

@@ -930,7 +930,7 @@ func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParam
case ErrorPolicyTempUnscheduled:
slog.Info("temp_unschedulable_matched",
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, RateLimitedModel: p.requestedModel, IsStickySession: p.isStickySession}
}
return false, statusCode, nil
}
@@ -1001,8 +1001,9 @@ type TestConnectionResult struct {
MappedModel string // 实际使用的模型
}
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
// TestConnection 测试 Antigravity 账号连接
// 复用 antigravityRetryLoop 的完整重试 / credits overages / 智能重试逻辑,
// 与真实调度行为一致。差异:不做账号切换(测试指定账号)、不记录 ops 错误。
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
// 获取 token
@@ -1026,10 +1027,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
// 构建请求体
var requestBody []byte
if strings.HasPrefix(modelID, "gemini-") {
// Gemini 模型:直接使用 Gemini 格式
requestBody, err = s.buildGeminiTestRequest(projectID, mappedModel)
} else {
// Claude 模型:使用协议转换
requestBody, err = s.buildClaudeTestRequest(projectID, mappedModel)
}
if err != nil {
@@ -1042,64 +1041,63 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
proxyURL = account.Proxy.URL()
}
baseURL := resolveAntigravityForwardBaseURL()
if baseURL == "" {
return nil, errors.New("no antigravity forward base url configured")
}
availableURLs := []string{baseURL}
var lastErr error
for urlIdx, baseURL := range availableURLs {
// 构建 HTTP 请求(总是使用流式 endpoint与官方客户端一致
req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody)
if err != nil {
lastErr = err
continue
}
// 调试日志Test 请求信息
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
lastErr = fmt.Errorf("请求失败: %w", err)
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
return nil, lastErr
}
// 读取响应
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
// 检查是否需要 URL 降级
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
}
// 解析流式响应,提取文本
text := extractTextFromSSEResponse(respBody)
// 标记成功的 URL下次优先使用
antigravity.DefaultURLAvailability.MarkSuccess(baseURL)
return &TestConnectionResult{
Text: text,
MappedModel: mappedModel,
}, nil
// 复用 antigravityRetryLoop完整的重试 / credits overages / 智能重试
prefix := fmt.Sprintf("[antigravity-Test] account=%d(%s)", account.ID, account.Name)
p := antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: "streamGenerateContent",
body: requestBody,
c: nil, // 无 gin.Context → 跳过 ops 追踪
httpUpstream: s.httpUpstream,
settingService: s.settingService,
accountRepo: s.accountRepo,
requestedModel: modelID,
handleError: testConnectionHandleError,
}
return nil, lastErr
result, err := s.antigravityRetryLoop(p)
if err != nil {
// AccountSwitchError → 测试时不切换账号,返回友好提示
var switchErr *AntigravityAccountSwitchError
if errors.As(err, &switchErr) {
return nil, fmt.Errorf("该账号模型 %s 当前限流中,请稍后重试", switchErr.RateLimitedModel)
}
return nil, err
}
if result == nil || result.resp == nil {
return nil, errors.New("upstream returned empty response")
}
defer func() { _ = result.resp.Body.Close() }()
respBody, err := io.ReadAll(io.LimitReader(result.resp.Body, 2<<20))
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if result.resp.StatusCode >= 400 {
return nil, fmt.Errorf("API 返回 %d: %s", result.resp.StatusCode, string(respBody))
}
text := extractTextFromSSEResponse(respBody)
return &TestConnectionResult{Text: text, MappedModel: mappedModel}, nil
}
// testConnectionHandleError 是 TestConnection 使用的轻量 handleError 回调。
// 仅记录日志,不做 ops 错误追踪或粘性会话清除。
func testConnectionHandleError(
_ context.Context, prefix string, account *Account,
statusCode int, _ http.Header, body []byte,
requestedModel string, _ int64, _ string, _ bool,
) *handleModelRateLimitResult {
logger.LegacyPrintf("service.antigravity_gateway",
"%s test_handle_error status=%d model=%s account=%d body=%s",
prefix, statusCode, requestedModel, account.ID, truncateForLog(body, 200))
return nil
}
// buildGeminiTestRequest 构建 Gemini 格式测试请求
@@ -3079,6 +3077,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
intervalCh = intervalTicker.C
}
// 下游 keepalive防止代理/Cloudflare Tunnel 因连接空闲而断开
keepaliveInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini")
// 仅发送一次错误事件,避免多次写入导致协议混乱
@@ -3111,6 +3125,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
return nil, ev.err
}
lastDataAt = time.Now()
line := ev.line
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
@@ -3170,6 +3186,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if cw.Disconnected() {
continue
}
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// SSE ping/keepalive保持连接活跃防止 Cloudflare Tunnel 等代理断开
if !cw.Fprintf(":\n\n") {
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity gemini), continuing to drain upstream for billing")
continue
}
}
}
}
@@ -3895,6 +3924,22 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
intervalCh = intervalTicker.C
}
// 下游 keepalive防止代理/Cloudflare Tunnel 因连接空闲而断开
keepaliveInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude")
// 仅发送一次错误事件,避免多次写入导致协议混乱
@@ -3947,6 +3992,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
return nil, fmt.Errorf("stream read error: %w", ev.err)
}
lastDataAt = time.Now()
// 处理 SSE 行,转换为 Claude 格式
claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n"))
if len(claudeEvents) > 0 {
@@ -3969,6 +4016,20 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if cw.Disconnected() {
continue
}
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// SSE ping 事件Anthropic 原生格式,客户端会正确处理,
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
if !cw.Fprintf("event: ping\ndata: {\"type\": \"ping\"}\n\n") {
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity claude), continuing to drain upstream for billing")
continue
}
}
}
}
@@ -4299,6 +4360,22 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
intervalCh = intervalTicker.C
}
// 下游 keepalive防止代理/Cloudflare Tunnel 因连接空闲而断开
keepaliveInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream")
@@ -4316,6 +4393,8 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
}
lastDataAt = time.Now()
line := ev.line
// 记录首 token 时间
@@ -4341,6 +4420,20 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
}
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity upstream)")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
case <-keepaliveCh:
if cw.Disconnected() {
continue
}
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// SSE ping 事件Anthropic 原生格式,客户端会正确处理,
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
if !cw.Fprintf("event: ping\ndata: {\"type\": \"ping\"}\n\n") {
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity upstream), continuing to drain upstream for billing")
continue
}
}
}
}

View File

@@ -57,16 +57,16 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "claude-opus-4-6-thinking",
},
{
name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-5",
name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-6",
requestedModel: "claude-haiku-4-5",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-sonnet-4-6",
},
{
name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-6",
requestedModel: "claude-haiku-4-5-20251001",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-sonnet-4-6",
},
{
name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5",

View File

@@ -260,14 +260,15 @@ func TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches(t *test
// TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit
// 503 + retryDelay < 7s + SingleAccountRetry → 智能重试耗尽后直接返回 503不设限流
// 使用 RATE_LIMIT_EXCEEDED走 1 次智能重试),避免 MODEL_CAPACITY_EXHAUSTED 的 60 次重试导致测试超时
func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testing.T) {
// 智能重试也返回 503
failRespBody := `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
@@ -278,8 +279,9 @@ func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testi
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
responses: []*http.Response{failResp},
errors: []error{nil},
repeatLast: true,
}
repo := &stubAntigravityAccountRepo{}
@@ -294,9 +296,9 @@ func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testi
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
@@ -569,8 +571,9 @@ func TestHandleSingleAccountRetryInPlace_WaitDurationClamped(t *testing.T) {
svc := &AntigravityGatewayService{}
// 等待时间过大应被 clamp 到 antigravitySingleAccountSmartRetryMaxWait
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 999*time.Second, "gemini-3-pro")
// waitDuration=0 会被 clamp 到 antigravitySmartRetryMinWait=1s。
// 首次重试即成功200总耗时 ~1s。
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 0, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp)

View File

@@ -32,11 +32,13 @@ func (c *stubSmartRetryCache) DeleteSessionAccountID(_ context.Context, groupID
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
type mockSmartRetryUpstream struct {
responses []*http.Response
errors []error
callIdx int
calls []string
requestBodies [][]byte
responses []*http.Response
responseBodies [][]byte // 缓存的 response body 字节(用于 repeatLast 重建)
errors []error
callIdx int
calls []string
requestBodies [][]byte
repeatLast bool // 超出范围时重复最后一个响应
}
func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
@@ -50,10 +52,45 @@ func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountI
m.requestBodies = append(m.requestBodies, nil)
}
m.callIdx++
if idx < len(m.responses) {
return m.responses[idx], m.errors[idx]
// 确定使用哪个索引
respIdx := idx
if respIdx >= len(m.responses) {
if !m.repeatLast || len(m.responses) == 0 {
return nil, nil
}
respIdx = len(m.responses) - 1
}
return nil, nil
resp := m.responses[respIdx]
respErr := m.errors[respIdx]
if resp == nil {
return nil, respErr
}
// 首次调用时缓存 body 字节
if respIdx >= len(m.responseBodies) {
for len(m.responseBodies) <= respIdx {
m.responseBodies = append(m.responseBodies, nil)
}
}
if m.responseBodies[respIdx] == nil && resp.Body != nil {
bodyBytes, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
m.responseBodies[respIdx] = bodyBytes
}
// 用缓存的 body 字节重建新的 reader
var body io.ReadCloser
if m.responseBodies[respIdx] != nil {
body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx]))
}
return &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: body,
}, respErr
}
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {

View File

@@ -4,11 +4,13 @@ import (
"compress/gzip"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
@@ -84,17 +86,21 @@ type BackupScheduleConfig struct {
// BackupRecord 备份记录
type BackupRecord struct {
ID string `json:"id"`
Status string `json:"status"` // pending, running, completed, failed
BackupType string `json:"backup_type"` // postgres
FileName string `json:"file_name"`
S3Key string `json:"s3_key"`
SizeBytes int64 `json:"size_bytes"`
TriggeredBy string `json:"triggered_by"` // manual, scheduled
ErrorMsg string `json:"error_message,omitempty"`
StartedAt string `json:"started_at"`
FinishedAt string `json:"finished_at,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
ID string `json:"id"`
Status string `json:"status"` // pending, running, completed, failed
BackupType string `json:"backup_type"` // postgres
FileName string `json:"file_name"`
S3Key string `json:"s3_key"`
SizeBytes int64 `json:"size_bytes"`
TriggeredBy string `json:"triggered_by"` // manual, scheduled
ErrorMsg string `json:"error_message,omitempty"`
StartedAt string `json:"started_at"`
FinishedAt string `json:"finished_at,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
Progress string `json:"progress,omitempty"` // "dumping", "uploading", ""
RestoreStatus string `json:"restore_status,omitempty"` // "", "running", "completed", "failed"
RestoreError string `json:"restore_error,omitempty"`
RestoredAt string `json:"restored_at,omitempty"`
}
// BackupService 数据库备份恢复服务
@@ -105,17 +111,24 @@ type BackupService struct {
storeFactory BackupObjectStoreFactory
dumper DBDumper
mu sync.Mutex
store BackupObjectStore
s3Cfg *BackupS3Config
opMu sync.Mutex // 保护 backingUp/restoring 标志
backingUp bool
restoring bool
storeMu sync.Mutex // 保护 store/s3Cfg 缓存
store BackupObjectStore
s3Cfg *BackupS3Config
recordsMu sync.Mutex // 保护 records 的 load/save 操作
cronMu sync.Mutex
cronSched *cron.Cron
cronEntryID cron.EntryID
wg sync.WaitGroup // 追踪活跃的备份/恢复 goroutine
shuttingDown atomic.Bool // 阻止新备份启动
bgCtx context.Context // 所有后台操作的 parent context
bgCancel context.CancelFunc // 取消所有活跃后台操作
}
func NewBackupService(
@@ -125,20 +138,26 @@ func NewBackupService(
storeFactory BackupObjectStoreFactory,
dumper DBDumper,
) *BackupService {
bgCtx, bgCancel := context.WithCancel(context.Background())
return &BackupService{
settingRepo: settingRepo,
dbCfg: &cfg.Database,
encryptor: encryptor,
storeFactory: storeFactory,
dumper: dumper,
bgCtx: bgCtx,
bgCancel: bgCancel,
}
}
// Start 启动定时备份调度器
// Start 启动定时备份调度器并清理孤立记录
func (s *BackupService) Start() {
s.cronSched = cron.New()
s.cronSched.Start()
// 清理重启后孤立的 running 记录
s.recoverStaleRecords()
// 加载已有的定时配置
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
@@ -154,13 +173,65 @@ func (s *BackupService) Start() {
}
}
// Stop 停止定时备份
// recoverStaleRecords 启动时将孤立的 running 记录标记为 failed
func (s *BackupService) recoverStaleRecords() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
records, err := s.loadRecords(ctx)
if err != nil {
return
}
for i := range records {
if records[i].Status == "running" {
records[i].Status = "failed"
records[i].ErrorMsg = "interrupted by server restart"
records[i].Progress = ""
records[i].FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(ctx, &records[i])
logger.LegacyPrintf("service.backup", "[Backup] recovered stale running record: %s", records[i].ID)
}
if records[i].RestoreStatus == "running" {
records[i].RestoreStatus = "failed"
records[i].RestoreError = "interrupted by server restart"
_ = s.saveRecord(ctx, &records[i])
logger.LegacyPrintf("service.backup", "[Backup] recovered stale restoring record: %s", records[i].ID)
}
}
}
// Stop 停止定时备份并等待活跃操作完成
func (s *BackupService) Stop() {
s.shuttingDown.Store(true)
s.cronMu.Lock()
defer s.cronMu.Unlock()
if s.cronSched != nil {
s.cronSched.Stop()
}
s.cronMu.Unlock()
// 等待活跃备份/恢复完成(最多 5 分钟)
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
logger.LegacyPrintf("service.backup", "[Backup] all active operations finished")
case <-time.After(5 * time.Minute):
logger.LegacyPrintf("service.backup", "[Backup] shutdown timeout after 5min, cancelling active operations")
if s.bgCancel != nil {
s.bgCancel() // 取消所有后台操作
}
// 给 goroutine 时间响应取消并完成清理
select {
case <-done:
logger.LegacyPrintf("service.backup", "[Backup] active operations cancelled and cleaned up")
case <-time.After(10 * time.Second):
logger.LegacyPrintf("service.backup", "[Backup] goroutine cleanup timed out")
}
}
}
// ─── S3 配置管理 ───
@@ -203,10 +274,10 @@ func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config)
}
// 清除缓存的 S3 客户端
s.mu.Lock()
s.storeMu.Lock()
s.store = nil
s.s3Cfg = nil
s.mu.Unlock()
s.storeMu.Unlock()
cfg.SecretAccessKey = ""
return &cfg, nil
@@ -314,7 +385,10 @@ func (s *BackupService) removeCronSchedule() {
}
func (s *BackupService) runScheduledBackup() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
s.wg.Add(1)
defer s.wg.Done()
ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute)
defer cancel()
// 读取定时备份配置中的过期天数
@@ -327,7 +401,11 @@ func (s *BackupService) runScheduledBackup() {
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
if err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
if errors.Is(err, ErrBackupInProgress) {
logger.LegacyPrintf("service.backup", "[Backup] 定时备份跳过: 已有备份正在进行中")
} else {
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
}
return
}
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
@@ -346,17 +424,21 @@ func (s *BackupService) runScheduledBackup() {
// CreateBackup 创建全量数据库备份并上传到 S3流式处理
// expireDays: 备份过期天数0=永不过期默认14天
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
s.mu.Lock()
if s.shuttingDown.Load() {
return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down")
}
s.opMu.Lock()
if s.backingUp {
s.mu.Unlock()
s.opMu.Unlock()
return nil, ErrBackupInProgress
}
s.backingUp = true
s.mu.Unlock()
s.opMu.Unlock()
defer func() {
s.mu.Lock()
s.opMu.Lock()
s.backingUp = false
s.mu.Unlock()
s.opMu.Unlock()
}()
s3Cfg, err := s.loadS3Config(ctx)
@@ -405,36 +487,47 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
pr, pw := io.Pipe()
var gzipErr error
gzipDone := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
pw.CloseWithError(fmt.Errorf("gzip goroutine panic: %v", r)) //nolint:errcheck
gzipDone <- fmt.Errorf("gzip goroutine panic: %v", r)
}
}()
gzWriter := gzip.NewWriter(pw)
_, gzipErr = io.Copy(gzWriter, dumpReader)
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
gzipErr = closeErr
var gzErr error
_, gzErr = io.Copy(gzWriter, dumpReader)
if closeErr := gzWriter.Close(); closeErr != nil && gzErr == nil {
gzErr = closeErr
}
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
gzipErr = closeErr
if closeErr := dumpReader.Close(); closeErr != nil && gzErr == nil {
gzErr = closeErr
}
if gzipErr != nil {
_ = pw.CloseWithError(gzipErr)
if gzErr != nil {
_ = pw.CloseWithError(gzErr)
} else {
_ = pw.Close()
}
gzipDone <- gzErr
}()
contentType := "application/gzip"
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
if err != nil {
_ = pr.CloseWithError(err) // 确保 gzip goroutine 不会悬挂
gzErr := <-gzipDone // 安全等待 gzip goroutine 完成
record.Status = "failed"
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
if gzipErr != nil {
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
if gzErr != nil {
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzErr)
}
record.ErrorMsg = errMsg
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(ctx, record)
return record, fmt.Errorf("backup upload: %w", err)
}
<-gzipDone // 确保 gzip goroutine 已退出
record.SizeBytes = sizeBytes
record.Status = "completed"
@@ -446,19 +539,187 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex
return record, nil
}
// StartBackup 异步创建备份,立即返回 running 状态的记录
func (s *BackupService) StartBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
if s.shuttingDown.Load() {
return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down")
}
s.opMu.Lock()
if s.backingUp {
s.opMu.Unlock()
return nil, ErrBackupInProgress
}
s.backingUp = true
s.opMu.Unlock()
// 初始化阶段出错时自动重置标志
launched := false
defer func() {
if !launched {
s.opMu.Lock()
s.backingUp = false
s.opMu.Unlock()
}
}()
// 在返回前加载 S3 配置和创建 store避免 goroutine 中配置被修改
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return nil, err
}
if s3Cfg == nil || !s3Cfg.IsConfigured() {
return nil, ErrBackupS3NotConfigured
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return nil, fmt.Errorf("init object store: %w", err)
}
now := time.Now()
backupID := uuid.New().String()[:8]
fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405"))
s3Key := s.buildS3Key(s3Cfg, fileName)
var expiresAt string
if expireDays > 0 {
expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339)
}
record := &BackupRecord{
ID: backupID,
Status: "running",
BackupType: "postgres",
FileName: fileName,
S3Key: s3Key,
TriggeredBy: triggeredBy,
StartedAt: now.Format(time.RFC3339),
ExpiresAt: expiresAt,
Progress: "pending",
}
if err := s.saveRecord(ctx, record); err != nil {
return nil, fmt.Errorf("save initial record: %w", err)
}
launched = true
// 在启动 goroutine 前完成拷贝,避免数据竞争
result := *record
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() {
s.opMu.Lock()
s.backingUp = false
s.opMu.Unlock()
}()
defer func() {
if r := recover(); r != nil {
logger.LegacyPrintf("service.backup", "[Backup] panic recovered: %v", r)
record.Status = "failed"
record.ErrorMsg = fmt.Sprintf("internal panic: %v", r)
record.Progress = ""
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(context.Background(), record)
}
}()
s.executeBackup(record, objectStore)
}()
return &result, nil
}
// executeBackup 后台执行备份(独立于 HTTP context
func (s *BackupService) executeBackup(record *BackupRecord, objectStore BackupObjectStore) {
ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute)
defer cancel()
// 阶段1: pg_dump
record.Progress = "dumping"
_ = s.saveRecord(ctx, record)
dumpReader, err := s.dumper.Dump(ctx)
if err != nil {
record.Status = "failed"
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
record.Progress = ""
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(context.Background(), record)
return
}
// 阶段2: gzip + upload
record.Progress = "uploading"
_ = s.saveRecord(ctx, record)
pr, pw := io.Pipe()
gzipDone := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
pw.CloseWithError(fmt.Errorf("gzip goroutine panic: %v", r)) //nolint:errcheck
gzipDone <- fmt.Errorf("gzip goroutine panic: %v", r)
}
}()
gzWriter := gzip.NewWriter(pw)
var gzErr error
_, gzErr = io.Copy(gzWriter, dumpReader)
if closeErr := gzWriter.Close(); closeErr != nil && gzErr == nil {
gzErr = closeErr
}
if closeErr := dumpReader.Close(); closeErr != nil && gzErr == nil {
gzErr = closeErr
}
if gzErr != nil {
_ = pw.CloseWithError(gzErr)
} else {
_ = pw.Close()
}
gzipDone <- gzErr
}()
contentType := "application/gzip"
sizeBytes, err := objectStore.Upload(ctx, record.S3Key, pr, contentType)
if err != nil {
_ = pr.CloseWithError(err) // 确保 gzip goroutine 不会悬挂
gzErr := <-gzipDone // 安全等待 gzip goroutine 完成
record.Status = "failed"
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
if gzErr != nil {
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzErr)
}
record.ErrorMsg = errMsg
record.Progress = ""
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(context.Background(), record)
return
}
<-gzipDone // 确保 gzip goroutine 已退出
record.SizeBytes = sizeBytes
record.Status = "completed"
record.Progress = ""
record.FinishedAt = time.Now().Format(time.RFC3339)
if err := s.saveRecord(context.Background(), record); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err)
}
}
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
s.mu.Lock()
s.opMu.Lock()
if s.restoring {
s.mu.Unlock()
s.opMu.Unlock()
return ErrRestoreInProgress
}
s.restoring = true
s.mu.Unlock()
s.opMu.Unlock()
defer func() {
s.mu.Lock()
s.opMu.Lock()
s.restoring = false
s.mu.Unlock()
s.opMu.Unlock()
}()
record, err := s.GetBackupRecord(ctx, backupID)
@@ -500,6 +761,112 @@ func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) erro
return nil
}
// StartRestore 异步恢复备份,立即返回
func (s *BackupService) StartRestore(ctx context.Context, backupID string) (*BackupRecord, error) {
if s.shuttingDown.Load() {
return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down")
}
s.opMu.Lock()
if s.restoring {
s.opMu.Unlock()
return nil, ErrRestoreInProgress
}
s.restoring = true
s.opMu.Unlock()
// 初始化阶段出错时自动重置标志
launched := false
defer func() {
if !launched {
s.opMu.Lock()
s.restoring = false
s.opMu.Unlock()
}
}()
record, err := s.GetBackupRecord(ctx, backupID)
if err != nil {
return nil, err
}
if record.Status != "completed" {
return nil, infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup")
}
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return nil, err
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return nil, fmt.Errorf("init object store: %w", err)
}
record.RestoreStatus = "running"
_ = s.saveRecord(ctx, record)
launched = true
result := *record
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() {
s.opMu.Lock()
s.restoring = false
s.opMu.Unlock()
}()
defer func() {
if r := recover(); r != nil {
logger.LegacyPrintf("service.backup", "[Backup] restore panic recovered: %v", r)
record.RestoreStatus = "failed"
record.RestoreError = fmt.Sprintf("internal panic: %v", r)
_ = s.saveRecord(context.Background(), record)
}
}()
s.executeRestore(record, objectStore)
}()
return &result, nil
}
// executeRestore 后台执行恢复
func (s *BackupService) executeRestore(record *BackupRecord, objectStore BackupObjectStore) {
ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute)
defer cancel()
body, err := objectStore.Download(ctx, record.S3Key)
if err != nil {
record.RestoreStatus = "failed"
record.RestoreError = fmt.Sprintf("S3 download failed: %v", err)
_ = s.saveRecord(context.Background(), record)
return
}
defer func() { _ = body.Close() }()
gzReader, err := gzip.NewReader(body)
if err != nil {
record.RestoreStatus = "failed"
record.RestoreError = fmt.Sprintf("gzip reader: %v", err)
_ = s.saveRecord(context.Background(), record)
return
}
defer func() { _ = gzReader.Close() }()
if err := s.dumper.Restore(ctx, gzReader); err != nil {
record.RestoreStatus = "failed"
record.RestoreError = fmt.Sprintf("pg restore: %v", err)
_ = s.saveRecord(context.Background(), record)
return
}
record.RestoreStatus = "completed"
record.RestoredAt = time.Now().Format(time.RFC3339)
if err := s.saveRecord(context.Background(), record); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 保存恢复记录失败: %v", err)
}
}
// ─── 备份记录管理 ───
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
@@ -614,8 +981,8 @@ func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, erro
}
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.storeMu.Lock()
defer s.storeMu.Unlock()
if s.store != nil && s.s3Cfg != nil {
return s.store, nil

View File

@@ -134,6 +134,30 @@ func (m *mockDumper) Restore(_ context.Context, data io.Reader) error {
return nil
}
// blockingDumper 可控延迟的 dumper用于测试异步行为
type blockingDumper struct {
blockCh chan struct{}
data []byte
restErr error
}
func (d *blockingDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
select {
case <-d.blockCh:
case <-ctx.Done():
return nil, ctx.Err()
}
return io.NopCloser(bytes.NewReader(d.data)), nil
}
func (d *blockingDumper) Restore(_ context.Context, data io.Reader) error {
if d.restErr != nil {
return d.restErr
}
_, _ = io.ReadAll(data)
return nil
}
type mockObjectStore struct {
objects map[string][]byte
mu sync.Mutex
@@ -179,7 +203,7 @@ func (m *mockObjectStore) HeadBucket(_ context.Context) error {
return nil
}
func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService {
func newTestBackupService(repo *mockSettingRepo, dumper DBDumper, store *mockObjectStore) *BackupService {
cfg := &config.Config{
Database: config.DatabaseConfig{
Host: "localhost",
@@ -361,9 +385,9 @@ func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) {
svc := newTestBackupService(repo, dumper, store)
// 手动设置 backingUp 标志
svc.mu.Lock()
svc.opMu.Lock()
svc.backingUp = true
svc.mu.Unlock()
svc.opMu.Unlock()
_, err := svc.CreateBackup(context.Background(), "manual", 14)
require.ErrorIs(t, err, ErrBackupInProgress)
@@ -526,3 +550,154 @@ func TestBackupService_LoadS3Config_Corrupted(t *testing.T) {
require.Error(t, err)
require.Nil(t, cfg)
}
// ─── Async Backup Tests ───
func TestStartBackup_ReturnsImmediately(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.StartBackup(context.Background(), "manual", 14)
require.NoError(t, err)
require.Equal(t, "running", record.Status)
require.NotEmpty(t, record.ID)
// 释放 dumper 让后台完成
close(dumper.blockCh)
svc.wg.Wait()
// 验证最终状态
final, err := svc.GetBackupRecord(context.Background(), record.ID)
require.NoError(t, err)
require.Equal(t, "completed", final.Status)
require.Greater(t, final.SizeBytes, int64(0))
}
func TestStartBackup_ConcurrentBlocked(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
// 第一次启动
_, err := svc.StartBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// 第二次应被阻塞
_, err = svc.StartBackup(context.Background(), "manual", 14)
require.ErrorIs(t, err, ErrBackupInProgress)
close(dumper.blockCh)
svc.wg.Wait()
}
func TestStartBackup_ShuttingDown(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
svc := newTestBackupService(repo, &mockDumper{dumpData: []byte("data")}, newMockObjectStore())
svc.shuttingDown.Store(true)
_, err := svc.StartBackup(context.Background(), "manual", 14)
require.Error(t, err)
require.Contains(t, err.Error(), "shutting down")
}
func TestRecoverStaleRecords(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
// 模拟一条孤立的 running 记录
_ = svc.saveRecord(context.Background(), &BackupRecord{
ID: "stale-1",
Status: "running",
StartedAt: time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
})
// 模拟一条孤立的恢复中记录
_ = svc.saveRecord(context.Background(), &BackupRecord{
ID: "stale-2",
Status: "completed",
RestoreStatus: "running",
StartedAt: time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
})
svc.recoverStaleRecords()
r1, _ := svc.GetBackupRecord(context.Background(), "stale-1")
require.Equal(t, "failed", r1.Status)
require.Contains(t, r1.ErrorMsg, "server restart")
r2, _ := svc.GetBackupRecord(context.Background(), "stale-2")
require.Equal(t, "failed", r2.RestoreStatus)
require.Contains(t, r2.RestoreError, "server restart")
}
func TestGracefulShutdown(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
_, err := svc.StartBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// Stop 应该等待备份完成
done := make(chan struct{})
go func() {
svc.Stop()
close(done)
}()
// 短暂等待确认 Stop 还在等待
select {
case <-done:
t.Fatal("Stop returned before backup finished")
case <-time.After(100 * time.Millisecond):
// 预期Stop 还在等待
}
// 释放备份
close(dumper.blockCh)
// 现在 Stop 应该完成
select {
case <-done:
// 预期
case <-time.After(5 * time.Second):
t.Fatal("Stop did not return after backup finished")
}
}
func TestStartRestore_Async(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
dumper := &mockDumper{dumpData: []byte(dumpContent)}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
// 先创建一个备份(同步方式)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// 异步恢复
restored, err := svc.StartRestore(context.Background(), record.ID)
require.NoError(t, err)
require.Equal(t, "running", restored.RestoreStatus)
svc.wg.Wait()
// 验证最终状态
final, err := svc.GetBackupRecord(context.Background(), record.ID)
require.NoError(t, err)
require.Equal(t, "completed", final.RestoreStatus)
}

View File

@@ -21,9 +21,6 @@ var (
// 带捕获组的版本提取正则
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
systemPromptThreshold = 0.5
)
@@ -124,7 +121,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
return false
}
if !userIDPattern.MatchString(userID) {
if ParseMetadataUserID(userID) == nil {
return false
}
@@ -278,11 +275,7 @@ func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context
// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
if len(matches) >= 2 {
return matches[1]
}
return ""
return ExtractCLIVersion(ua)
}
// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中

View File

@@ -140,6 +140,27 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
func (s *DashboardService) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, modelSource string) ([]usagestats.ModelStat, error) {
normalizedSource := usagestats.NormalizeModelSource(modelSource)
if normalizedSource == usagestats.ModelSourceRequested {
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
}
type modelStatsBySourceRepo interface {
GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) ([]usagestats.ModelStat, error)
}
if sourceRepo, ok := s.usageRepo.(modelStatsBySourceRepo); ok {
stats, err := sourceRepo.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, normalizedSource)
if err != nil {
return nil, fmt.Errorf("get model stats with filters by source: %w", err)
}
return stats, nil
}
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
}
func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
@@ -148,6 +169,15 @@ func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
// GetGroupUsageSummary returns today's and cumulative cost for all groups.
func (s *DashboardService) GetGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
results, err := s.usageRepo.GetAllGroupUsageSummary(ctx, todayStart)
if err != nil {
return nil, fmt.Errorf("get group usage summary: %w", err)
}
return results, nil
}
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
data, err := s.cache.GetDashboardStats(ctx)
if err != nil {
@@ -335,6 +365,14 @@ func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime
return ranking, nil
}
func (s *DashboardService) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
stats, err := s.usageRepo.GetUserBreakdownStats(ctx, startTime, endTime, dim, limit)
if err != nil {
return nil, fmt.Errorf("get user breakdown stats: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
if err != nil {

View File

@@ -170,6 +170,13 @@ const (
// SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings.
SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config"
// =========================
// Overload Cooldown (529)
// =========================
// SettingKeyOverloadCooldownSettings stores JSON config for 529 overload cooldown handling.
SettingKeyOverloadCooldownSettings = "overload_cooldown_settings"
// =========================
// Stream Timeout Handling
// =========================

View File

@@ -688,6 +688,83 @@ func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *t
require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta")
}
func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
body string
}{
{
name: "system array",
body: `{"model":"claude-3-5-sonnet-latest","system":[{"type":"text","text":"x-anthropic-billing-header keep"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`,
},
{
name: "system string",
body: `{"model":"claude-3-5-sonnet-latest","system":"x-anthropic-billing-header keep","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
parsed, err := ParseGatewayRequest([]byte(tt.body), PlatformAnthropic)
require.NoError(t, err)
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"x-request-id": []string{"rid-oauth-preserve"},
},
Body: io.NopCloser(strings.NewReader(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":12,"output_tokens":7}}`)),
},
}
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &GatewayService{
cfg: cfg,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
deferredService: &DeferredService{},
}
account := &Account{
ID: 301,
Name: "anthropic-oauth-preserve",
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
},
Status: StatusActive,
Schedulable: true,
}
result, err := svc.Forward(context.Background(), c, account, parsed)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("authorization"))
require.Contains(t, upstream.lastReq.Header.Get("anthropic-beta"), claude.BetaOAuth)
system := gjson.GetBytes(upstream.lastBody, "system")
require.True(t, system.Exists())
require.Contains(t, system.Raw, "x-anthropic-billing-header keep")
})
}
}
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -788,7 +865,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuc
rateLimitService: &RateLimitService{},
}
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now())
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now())
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 12, result.Usage.InputTokens)
@@ -815,7 +892,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenTyp
}
svc := &GatewayService{}
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now())
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now())
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "requires apikey token")
@@ -840,7 +917,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequest
}
account := newAnthropicAPIKeyAccountForTest()
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now())
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", "x", false, time.Now())
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "upstream request failed")
@@ -873,7 +950,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBo
httpUpstream: upstream,
}
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now())
result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", "x", false, time.Now())
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "empty response")

View File

@@ -0,0 +1,72 @@
package service
import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/stretchr/testify/require"
)
func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) {
t.Helper()
last := -1
for _, token := range tokens {
pos := strings.Index(body, token)
require.NotEqualf(t, -1, pos, "missing token %s in body %s", token, body)
require.Greaterf(t, pos, last, "token %s should appear after previous tokens in body %s", token, body)
last = pos
}
}
func TestReplaceModelInBody_PreservesTopLevelFieldOrder(t *testing.T) {
svc := &GatewayService{}
body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","messages":[],"omega":2}`)
result := svc.replaceModelInBody(body, "claude-3-5-sonnet-20241022")
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"messages"`, `"omega"`)
require.Contains(t, resultStr, `"model":"claude-3-5-sonnet-20241022"`)
}
func TestNormalizeClaudeOAuthRequestBody_PreservesTopLevelFieldOrder(t *testing.T) {
body := []byte(`{"alpha":1,"model":"claude-3-5-sonnet-latest","temperature":0.2,"system":"You are OpenCode, the best coding agent on the planet.","messages":[],"tool_choice":{"type":"auto"},"omega":2}`)
result, modelID := normalizeClaudeOAuthRequestBody(body, "claude-3-5-sonnet-latest", claudeOAuthNormalizeOptions{
injectMetadata: true,
metadataUserID: "user-1",
})
resultStr := string(result)
require.Equal(t, claude.NormalizeModelID("claude-3-5-sonnet-latest"), modelID)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`)
require.NotContains(t, resultStr, `"temperature"`)
require.NotContains(t, resultStr, `"tool_choice"`)
require.Contains(t, resultStr, `"system":"`+claudeCodeSystemPrompt+`"`)
require.Contains(t, resultStr, `"tools":[]`)
require.Contains(t, resultStr, `"metadata":{"user_id":"user-1"}`)
}
func TestInjectClaudeCodePrompt_PreservesFieldOrder(t *testing.T) {
body := []byte(`{"alpha":1,"system":[{"id":"block-1","type":"text","text":"Custom"}],"messages":[],"omega":2}`)
result := injectClaudeCodePrompt(body, []any{
map[string]any{"id": "block-1", "type": "text", "text": "Custom"},
})
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
require.Contains(t, resultStr, `{"id":"block-1","type":"text","text":"`+claudeCodeSystemPrompt+`\n\nCustom"}`)
}
func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) {
body := []byte(`{"alpha":1,"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"s2","cache_control":{"type":"ephemeral"}}],"messages":[{"role":"user","content":[{"type":"text","text":"m1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m2","cache_control":{"type":"ephemeral"}},{"type":"text","text":"m3","cache_control":{"type":"ephemeral"}}]}],"omega":2}`)
result := enforceCacheControlLimit(body)
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`))
}

View File

@@ -0,0 +1,34 @@
package service
import "testing"
func TestDebugGatewayBodyLoggingEnabled(t *testing.T) {
t.Run("default disabled", func(t *testing.T) {
t.Setenv(debugGatewayBodyEnv, "")
if debugGatewayBodyLoggingEnabled() {
t.Fatalf("expected debug gateway body logging to be disabled by default")
}
})
t.Run("enabled with true-like values", func(t *testing.T) {
for _, value := range []string{"1", "true", "TRUE", "yes", "on"} {
t.Run(value, func(t *testing.T) {
t.Setenv(debugGatewayBodyEnv, value)
if !debugGatewayBodyLoggingEnabled() {
t.Fatalf("expected debug gateway body logging to be enabled for %q", value)
}
})
}
})
t.Run("disabled with other values", func(t *testing.T) {
for _, value := range []string{"0", "false", "off", "debug"} {
t.Run(value, func(t *testing.T) {
t.Setenv(debugGatewayBodyEnv, value)
if debugGatewayBodyLoggingEnabled() {
t.Fatalf("expected debug gateway body logging to be disabled for %q", value)
}
})
}
})
}

View File

@@ -278,8 +278,8 @@ func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, plat
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, nil
}
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil

View File

@@ -28,6 +28,12 @@ var (
patternEmptyContentSpaced = []byte(`"content": []`)
patternEmptyContentSp1 = []byte(`"content" : []`)
patternEmptyContentSp2 = []byte(`"content" :[]`)
// Fast-path patterns for empty text blocks: {"type":"text","text":""}
patternEmptyText = []byte(`"text":""`)
patternEmptyTextSpaced = []byte(`"text": ""`)
patternEmptyTextSp1 = []byte(`"text" : ""`)
patternEmptyTextSp2 = []byte(`"text" :""`)
)
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
@@ -233,15 +239,22 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
bytes.Contains(body, patternThinkingField) ||
bytes.Contains(body, patternThinkingFieldSpaced)
// Also check for empty content arrays that need fixing.
// Also check for empty content arrays and empty text blocks that need fixing.
// Note: This is a heuristic check; the actual empty content handling is done below.
hasEmptyContent := bytes.Contains(body, patternEmptyContent) ||
bytes.Contains(body, patternEmptyContentSpaced) ||
bytes.Contains(body, patternEmptyContentSp1) ||
bytes.Contains(body, patternEmptyContentSp2)
// Check for empty text blocks: {"type":"text","text":""}
// These cause upstream 400: "text content blocks must be non-empty"
hasEmptyTextBlock := bytes.Contains(body, patternEmptyText) ||
bytes.Contains(body, patternEmptyTextSpaced) ||
bytes.Contains(body, patternEmptyTextSp1) ||
bytes.Contains(body, patternEmptyTextSp2)
// Fast path: nothing to process
if !hasThinkingContent && !hasEmptyContent {
if !hasThinkingContent && !hasEmptyContent && !hasEmptyTextBlock {
return body
}
@@ -260,7 +273,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
bytes.Contains(body, patternTypeRedactedThinking) ||
bytes.Contains(body, patternTypeRedactedSpaced) ||
bytes.Contains(body, patternThinkingFieldSpaced)
if !hasEmptyContent && !containsThinkingBlocks {
if !hasEmptyContent && !hasEmptyTextBlock && !containsThinkingBlocks {
if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() {
if out, err := sjson.DeleteBytes(body, "thinking"); err == nil {
out = removeThinkingDependentContextStrategies(out)
@@ -320,6 +333,16 @@ func FilterThinkingBlocksForRetry(body []byte) []byte {
blockType, _ := blockMap["type"].(string)
// Strip empty text blocks: {"type":"text","text":""}
// Upstream rejects these with 400: "text content blocks must be non-empty"
if blockType == "text" {
if txt, _ := blockMap["text"].(string); txt == "" {
modifiedThisMsg = true
ensureNewContent(bi)
continue
}
}
// Convert thinking blocks to text (preserve content) and drop redacted_thinking.
switch blockType {
case "thinking":

View File

@@ -404,6 +404,51 @@ func TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder(t *testing.T)
require.NotEmpty(t, content0["text"])
}
func TestFilterThinkingBlocksForRetry_StripsEmptyTextBlocks(t *testing.T) {
// Empty text blocks cause upstream 400: "text content blocks must be non-empty"
input := []byte(`{
"messages":[
{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":""}]},
{"role":"assistant","content":[{"type":"text","text":""}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
var req map[string]any
require.NoError(t, json.Unmarshal(out, &req))
msgs, ok := req["messages"].([]any)
require.True(t, ok)
// First message: empty text block stripped, "hello" preserved
msg0 := msgs[0].(map[string]any)
content0 := msg0["content"].([]any)
require.Len(t, content0, 1)
require.Equal(t, "hello", content0[0].(map[string]any)["text"])
// Second message: only had empty text block → gets placeholder
msg1 := msgs[1].(map[string]any)
content1 := msg1["content"].([]any)
require.Len(t, content1, 1)
block1 := content1[0].(map[string]any)
require.Equal(t, "text", block1["type"])
require.NotEmpty(t, block1["text"])
}
func TestFilterThinkingBlocksForRetry_PreservesNonEmptyTextBlocks(t *testing.T) {
// Non-empty text blocks should pass through unchanged
input := []byte(`{
"messages":[
{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]}
]
}`)
out := FilterThinkingBlocksForRetry(input)
// Fast path: no thinking content, no empty content, no empty text blocks → unchanged
require.Equal(t, input, out)
}
func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
input := []byte(`{
"thinking":{"type":"enabled","budget_tokens":1024},

File diff suppressed because it is too large Load Diff

View File

@@ -230,8 +230,8 @@ func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platf
func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, nil
}
func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil

View File

@@ -24,7 +24,7 @@ func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
@@ -196,7 +196,7 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
@@ -212,6 +212,22 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) {
"metadata session_id should take priority over SessionContext")
}
func TestGenerateSessionHash_MetadataJSON_HasHighestPriority(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
MetadataUserID: `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`,
System: "You are a helpful assistant.",
HasSystem: true,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
hash := svc.GenerateSessionHash(parsed)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", hash, "JSON format metadata session_id should have highest priority")
}
func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) {
svc := &GatewayService{}

View File

@@ -64,8 +64,10 @@ type Group struct {
CreatedAt time.Time
UpdatedAt time.Time
AccountGroups []AccountGroup
AccountCount int64
AccountGroups []AccountGroup
AccountCount int64
ActiveAccountCount int64
RateLimitedAccountCount int64
}
func (g *Group) IsActive() bool {

View File

@@ -0,0 +1,131 @@
package service
import (
"context"
"time"
)
// GroupCapacitySummary holds aggregated capacity for a single group.
type GroupCapacitySummary struct {
GroupID int64 `json:"group_id"`
ConcurrencyUsed int `json:"concurrency_used"`
ConcurrencyMax int `json:"concurrency_max"`
SessionsUsed int `json:"sessions_used"`
SessionsMax int `json:"sessions_max"`
RPMUsed int `json:"rpm_used"`
RPMMax int `json:"rpm_max"`
}
// GroupCapacityService aggregates per-group capacity from runtime data.
type GroupCapacityService struct {
accountRepo AccountRepository
groupRepo GroupRepository
concurrencyService *ConcurrencyService
sessionLimitCache SessionLimitCache
rpmCache RPMCache
}
// NewGroupCapacityService creates a new GroupCapacityService.
func NewGroupCapacityService(
accountRepo AccountRepository,
groupRepo GroupRepository,
concurrencyService *ConcurrencyService,
sessionLimitCache SessionLimitCache,
rpmCache RPMCache,
) *GroupCapacityService {
return &GroupCapacityService{
accountRepo: accountRepo,
groupRepo: groupRepo,
concurrencyService: concurrencyService,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
}
}
// GetAllGroupCapacity returns capacity summary for all active groups.
func (s *GroupCapacityService) GetAllGroupCapacity(ctx context.Context) ([]GroupCapacitySummary, error) {
groups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, err
}
results := make([]GroupCapacitySummary, 0, len(groups))
for i := range groups {
cap, err := s.getGroupCapacity(ctx, groups[i].ID)
if err != nil {
// Skip groups with errors, return partial results
continue
}
cap.GroupID = groups[i].ID
results = append(results, cap)
}
return results, nil
}
func (s *GroupCapacityService) getGroupCapacity(ctx context.Context, groupID int64) (GroupCapacitySummary, error) {
accounts, err := s.accountRepo.ListSchedulableByGroupID(ctx, groupID)
if err != nil {
return GroupCapacitySummary{}, err
}
if len(accounts) == 0 {
return GroupCapacitySummary{}, nil
}
// Collect account IDs and config values
accountIDs := make([]int64, 0, len(accounts))
sessionTimeouts := make(map[int64]time.Duration)
var concurrencyMax, sessionsMax, rpmMax int
for i := range accounts {
acc := &accounts[i]
accountIDs = append(accountIDs, acc.ID)
concurrencyMax += acc.Concurrency
if ms := acc.GetMaxSessions(); ms > 0 {
sessionsMax += ms
timeout := time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
if timeout <= 0 {
timeout = 5 * time.Minute
}
sessionTimeouts[acc.ID] = timeout
}
if rpm := acc.GetBaseRPM(); rpm > 0 {
rpmMax += rpm
}
}
// Batch query runtime data from Redis
concurrencyMap, _ := s.concurrencyService.GetAccountConcurrencyBatch(ctx, accountIDs)
var sessionsMap map[int64]int
if sessionsMax > 0 && s.sessionLimitCache != nil {
sessionsMap, _ = s.sessionLimitCache.GetActiveSessionCountBatch(ctx, accountIDs, sessionTimeouts)
}
var rpmMap map[int64]int
if rpmMax > 0 && s.rpmCache != nil {
rpmMap, _ = s.rpmCache.GetRPMBatch(ctx, accountIDs)
}
// Aggregate
var concurrencyUsed, sessionsUsed, rpmUsed int
for _, id := range accountIDs {
concurrencyUsed += concurrencyMap[id]
if sessionsMap != nil {
sessionsUsed += sessionsMap[id]
}
if rpmMap != nil {
rpmUsed += rpmMap[id]
}
}
return GroupCapacitySummary{
ConcurrencyUsed: concurrencyUsed,
ConcurrencyMax: concurrencyMax,
SessionsUsed: sessionsUsed,
SessionsMax: sessionsMax,
RPMUsed: rpmUsed,
RPMMax: rpmMax,
}, nil
}

View File

@@ -27,7 +27,7 @@ type GroupRepository interface {
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID去重
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
@@ -202,7 +202,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any,
}
// 获取账号数量
accountCount, err := s.groupRepo.GetAccountCount(ctx, id)
accountCount, _, err := s.groupRepo.GetAccountCount(ctx, id)
if err != nil {
return nil, fmt.Errorf("get account count: %w", err)
}

View File

@@ -5,7 +5,6 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"log/slog"
"net/http"
@@ -15,14 +14,12 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// 预编译正则表达式(避免每次调用重新编译)
var (
// 匹配 user_id 格式:
// 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID)
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID)
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`)
// 匹配 User-Agent 版本号: xxx/x.y.z
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
)
@@ -209,67 +206,57 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
}
// RewriteUserID 重写body中的metadata.user_id
// 输入格式user_{clientId}_account__session_{sessionUUID}
// 输出格式user_{cachedClientID}_account_{accountUUID}_session_{newHash}
// 支持旧拼接格式和新 JSON 格式的 user_id 解析,
// 根据 fingerprintUA 版本选择输出格式。
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) {
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
return body, nil
}
// 使用 RawMessage 保留其他字段的原始字节
var reqMap map[string]json.RawMessage
if err := json.Unmarshal(body, &reqMap); err != nil {
metadata := gjson.GetBytes(body, "metadata")
if !metadata.Exists() || metadata.Type == gjson.Null {
return body, nil
}
if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") {
return body, nil
}
// 解析 metadata 字段
metadataRaw, ok := reqMap["metadata"]
if !ok {
userIDResult := metadata.Get("user_id")
if !userIDResult.Exists() || userIDResult.Type != gjson.String {
return body, nil
}
userID := userIDResult.String()
if userID == "" {
return body, nil
}
var metadata map[string]any
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
// 解析 user_id兼容旧拼接格式和新 JSON 格式)
parsed := ParseMetadataUserID(userID)
if parsed == nil {
return body, nil
}
userID, ok := metadata["user_id"].(string)
if !ok || userID == "" {
return body, nil
}
// 匹配格式:
// 旧格式: user_{64位hex}_account__session_{uuid}
// 新格式: user_{64位hex}_account_{uuid}_session_{uuid}
matches := userIDRegex.FindStringSubmatch(userID)
if matches == nil {
return body, nil
}
// matches[1] = account UUID (可能为空), matches[2] = session UUID
sessionTail := matches[2] // 原始session UUID
sessionTail := parsed.SessionID // 原始session UUID
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
newSessionHash := generateUUIDFromSeed(seed)
// 构建新的user_id
// 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash}
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash)
// 根据客户端版本选择输出格式
version := ExtractCLIVersion(fingerprintUA)
newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version)
if newUserID == userID {
return body, nil
}
metadata["user_id"] = newUserID
// 只重新序列化 metadata 字段
newMetadataRaw, err := json.Marshal(metadata)
newBody, err := sjson.SetBytes(body, "metadata.user_id", newUserID)
if err != nil {
return body, nil
}
reqMap["metadata"] = newMetadataRaw
return json.Marshal(reqMap)
return newBody, nil
}
// RewriteUserIDWithMasking 重写body中的metadata.user_id支持会话ID伪装
@@ -278,9 +265,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
//
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
// 避免重新序列化导致 thinking 块等内容被修改。
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) {
// 先执行常规的 RewriteUserID 逻辑
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID, fingerprintUA)
if err != nil {
return newBody, err
}
@@ -290,32 +277,26 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
return newBody, nil
}
// 使用 RawMessage 保留其他字段的原始字节
var reqMap map[string]json.RawMessage
if err := json.Unmarshal(newBody, &reqMap); err != nil {
metadata := gjson.GetBytes(newBody, "metadata")
if !metadata.Exists() || metadata.Type == gjson.Null {
return newBody, nil
}
if !strings.HasPrefix(strings.TrimSpace(metadata.Raw), "{") {
return newBody, nil
}
// 解析 metadata 字段
metadataRaw, ok := reqMap["metadata"]
if !ok {
userIDResult := metadata.Get("user_id")
if !userIDResult.Exists() || userIDResult.Type != gjson.String {
return newBody, nil
}
userID := userIDResult.String()
if userID == "" {
return newBody, nil
}
var metadata map[string]any
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
return newBody, nil
}
userID, ok := metadata["user_id"].(string)
if !ok || userID == "" {
return newBody, nil
}
// 查找 _session_ 的位置,替换其后的内容
const sessionMarker = "_session_"
idx := strings.LastIndex(userID, sessionMarker)
if idx == -1 {
// 解析已重写的 user_id
uidParsed := ParseMetadataUserID(userID)
if uidParsed == nil {
return newBody, nil
}
@@ -337,8 +318,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err)
}
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID
// 用 FormatMetadataUserID 重建(保持与 RewriteUserID 相同的格式)
version := ExtractCLIVersion(fingerprintUA)
newUserID := FormatMetadataUserID(uidParsed.DeviceID, uidParsed.AccountUUID, maskedSessionID, version)
slog.Debug("session_id_masking_applied",
"account_id", account.ID,
@@ -346,16 +328,15 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
"after", newUserID,
)
metadata["user_id"] = newUserID
// 只重新序列化 metadata 字段
newMetadataRaw, marshalErr := json.Marshal(metadata)
if marshalErr != nil {
if newUserID == userID {
return newBody, nil
}
reqMap["metadata"] = newMetadataRaw
return json.Marshal(reqMap)
maskedBody, setErr := sjson.SetBytes(newBody, "metadata.user_id", newUserID)
if setErr != nil {
return newBody, nil
}
return maskedBody, nil
}
// generateRandomUUID 生成随机 UUID v4 格式字符串

View File

@@ -0,0 +1,82 @@
package service
import (
"context"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
type identityCacheStub struct {
maskedSessionID string
}
func (s *identityCacheStub) GetFingerprint(_ context.Context, _ int64) (*Fingerprint, error) {
return nil, nil
}
func (s *identityCacheStub) SetFingerprint(_ context.Context, _ int64, _ *Fingerprint) error {
return nil
}
func (s *identityCacheStub) GetMaskedSessionID(_ context.Context, _ int64) (string, error) {
return s.maskedSessionID, nil
}
func (s *identityCacheStub) SetMaskedSessionID(_ context.Context, _ int64, sessionID string) error {
s.maskedSessionID = sessionID
return nil
}
func TestIdentityService_RewriteUserID_PreservesTopLevelFieldOrder(t *testing.T) {
cache := &identityCacheStub{}
svc := NewIdentityService(cache)
originalUserID := FormatMetadataUserID(
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
"",
"7578cf37-aaca-46e4-a45c-71285d9dbb83",
"2.1.78",
)
body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`)
result, err := svc.RewriteUserID(body, 123, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)")
require.NoError(t, err)
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`)
require.NotContains(t, resultStr, originalUserID)
require.Contains(t, resultStr, `"metadata":{"user_id":"`)
}
func TestIdentityService_RewriteUserIDWithMasking_PreservesTopLevelFieldOrder(t *testing.T) {
cache := &identityCacheStub{maskedSessionID: "11111111-2222-4333-8444-555555555555"}
svc := NewIdentityService(cache)
originalUserID := FormatMetadataUserID(
"d61f76d0730d2b920763648949bad5c79742155c27037fc77ac3f9805cb90169",
"",
"7578cf37-aaca-46e4-a45c-71285d9dbb83",
"2.1.78",
)
body := []byte(`{"alpha":1,"messages":[],"metadata":{"user_id":` + strconvQuote(originalUserID) + `},"max_tokens":64000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"},"stream":true}`)
account := &Account{
ID: 123,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"session_id_masking_enabled": true,
},
}
result, err := svc.RewriteUserIDWithMasking(context.Background(), body, account, "acc-uuid", "client-xyz", "claude-cli/2.1.78 (external, cli)")
require.NoError(t, err)
resultStr := string(result)
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"messages"`, `"metadata"`, `"max_tokens"`, `"thinking"`, `"output_config"`, `"stream"`)
require.Contains(t, resultStr, cache.maskedSessionID)
require.True(t, strings.Contains(resultStr, `"metadata":{"user_id":"`))
}
func strconvQuote(v string) string {
return `"` + strings.ReplaceAll(strings.ReplaceAll(v, `\`, `\\`), `"`, `\"`) + `"`
}

View File

@@ -0,0 +1,104 @@
package service
import (
"encoding/json"
"regexp"
"strings"
)
// NewMetadataFormatMinVersion is the minimum Claude Code version that uses
// JSON-formatted metadata.user_id instead of the legacy concatenated string.
const NewMetadataFormatMinVersion = "2.1.78"
// ParsedUserID represents the components extracted from a metadata.user_id value.
type ParsedUserID struct {
DeviceID string // 64-char hex (or arbitrary client id)
AccountUUID string // may be empty
SessionID string // UUID
IsNewFormat bool // true if the original was JSON format
}
// legacyUserIDRegex matches the legacy user_id format:
//
// user_{64hex}_account_{optional_uuid}_session_{uuid}
var legacyUserIDRegex = regexp.MustCompile(`^user_([a-fA-F0-9]{64})_account_([a-fA-F0-9-]*)_session_([a-fA-F0-9-]{36})$`)
// jsonUserID is the JSON structure for the new metadata.user_id format.
type jsonUserID struct {
DeviceID string `json:"device_id"`
AccountUUID string `json:"account_uuid"`
SessionID string `json:"session_id"`
}
// ParseMetadataUserID parses a metadata.user_id string in either format.
// Returns nil if the input cannot be parsed.
func ParseMetadataUserID(raw string) *ParsedUserID {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
// Try JSON format first (starts with '{')
if raw[0] == '{' {
var j jsonUserID
if err := json.Unmarshal([]byte(raw), &j); err != nil {
return nil
}
if j.DeviceID == "" || j.SessionID == "" {
return nil
}
return &ParsedUserID{
DeviceID: j.DeviceID,
AccountUUID: j.AccountUUID,
SessionID: j.SessionID,
IsNewFormat: true,
}
}
// Try legacy format
matches := legacyUserIDRegex.FindStringSubmatch(raw)
if matches == nil {
return nil
}
return &ParsedUserID{
DeviceID: matches[1],
AccountUUID: matches[2],
SessionID: matches[3],
IsNewFormat: false,
}
}
// FormatMetadataUserID builds a metadata.user_id string in the format
// appropriate for the given CLI version. Components are the rewritten values
// (not necessarily the originals).
func FormatMetadataUserID(deviceID, accountUUID, sessionID, uaVersion string) string {
if IsNewMetadataFormatVersion(uaVersion) {
b, _ := json.Marshal(jsonUserID{
DeviceID: deviceID,
AccountUUID: accountUUID,
SessionID: sessionID,
})
return string(b)
}
// Legacy format
return "user_" + deviceID + "_account_" + accountUUID + "_session_" + sessionID
}
// IsNewMetadataFormatVersion returns true if the given CLI version uses the
// new JSON metadata.user_id format (>= 2.1.78).
func IsNewMetadataFormatVersion(version string) bool {
if version == "" {
return false
}
return CompareVersions(version, NewMetadataFormatMinVersion) >= 0
}
// ExtractCLIVersion extracts the Claude Code version from a User-Agent string.
// Returns "" if the UA doesn't match the expected pattern.
func ExtractCLIVersion(ua string) string {
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
if len(matches) >= 2 {
return matches[1]
}
return ""
}

View File

@@ -0,0 +1,183 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// ============ ParseMetadataUserID Tests ============
func TestParseMetadataUserID_LegacyFormat_WithoutAccountUUID(t *testing.T) {
raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_LegacyFormat_WithAccountUUID(t *testing.T) {
raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID)
require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID)
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_JSONFormat_WithoutAccountUUID(t *testing.T) {
raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_JSONFormat_WithAccountUUID(t *testing.T) {
raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`
parsed := ParseMetadataUserID(raw)
require.NotNil(t, parsed)
require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID)
require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID)
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseMetadataUserID_InvalidInputs(t *testing.T) {
tests := []struct {
name string
raw string
}{
{"empty string", ""},
{"whitespace only", " "},
{"random text", "not-a-valid-user-id"},
{"partial legacy format", "session_123e4567-e89b-12d3-a456-426614174000"},
{"invalid JSON", `{"device_id":}`},
{"JSON missing device_id", `{"account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`},
{"JSON missing session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":""}`},
{"JSON empty device_id", `{"device_id":"","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`},
{"JSON empty session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":""}`},
{"legacy format short hex", "user_a1b2c3d4_account__session_123e4567-e89b-12d3-a456-426614174000"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Nil(t, ParseMetadataUserID(tt.raw), "should return nil for: %s", tt.raw)
})
}
}
func TestParseMetadataUserID_HexCaseInsensitive(t *testing.T) {
// Legacy format should accept both upper and lower case hex
rawUpper := "user_A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2_account__session_123e4567-e89b-12d3-a456-426614174000"
parsed := ParseMetadataUserID(rawUpper)
require.NotNil(t, parsed, "legacy format should accept uppercase hex")
require.Equal(t, "A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2", parsed.DeviceID)
}
// ============ FormatMetadataUserID Tests ============
func TestFormatMetadataUserID_LegacyVersion(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.77")
require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account_acc-uuid_session_sess-uuid", result)
}
func TestFormatMetadataUserID_NewVersion(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.78")
require.Equal(t, `{"device_id":"deadbeef00112233445566778899aabbccddeeff0011223344556677","account_uuid":"acc-uuid","session_id":"sess-uuid"}`, result)
}
func TestFormatMetadataUserID_EmptyVersion_Legacy(t *testing.T) {
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "")
require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account__session_sess-uuid", result)
}
func TestFormatMetadataUserID_EmptyAccountUUID(t *testing.T) {
// Legacy format with empty account UUID → double underscore
result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.22")
require.Contains(t, result, "_account__session_")
// New format with empty account UUID → empty string in JSON
result = FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.78")
require.Contains(t, result, `"account_uuid":""`)
}
// ============ IsNewMetadataFormatVersion Tests ============
func TestIsNewMetadataFormatVersion(t *testing.T) {
tests := []struct {
version string
want bool
}{
{"", false},
{"2.1.77", false},
{"2.1.78", true},
{"2.1.79", true},
{"2.2.0", true},
{"3.0.0", true},
{"2.0.100", false},
{"1.9.99", false},
}
for _, tt := range tests {
t.Run(tt.version, func(t *testing.T) {
require.Equal(t, tt.want, IsNewMetadataFormatVersion(tt.version))
})
}
}
// ============ Round-trip Tests ============
func TestParseFormat_RoundTrip_Legacy(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
accountUUID := "550e8400-e29b-41d4-a716-446655440000"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.22")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, accountUUID, parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
require.False(t, parsed.IsNewFormat)
}
func TestParseFormat_RoundTrip_JSON(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
accountUUID := "550e8400-e29b-41d4-a716-446655440000"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.78")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, accountUUID, parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
require.True(t, parsed.IsNewFormat)
}
func TestParseFormat_RoundTrip_EmptyAccountUUID(t *testing.T) {
deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
sessionID := "123e4567-e89b-12d3-a456-426614174000"
// Legacy round-trip with empty account UUID
formatted := FormatMetadataUserID(deviceID, "", sessionID, "2.1.22")
parsed := ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
// JSON round-trip with empty account UUID
formatted = FormatMetadataUserID(deviceID, "", sessionID, "2.1.78")
parsed = ParseMetadataUserID(formatted)
require.NotNil(t, parsed)
require.Equal(t, deviceID, parsed.DeviceID)
require.Equal(t, "", parsed.AccountUUID)
require.Equal(t, sessionID, parsed.SessionID)
}

Some files were not shown because too many files have changed in this diff Show More