Compare commits

..

130 Commits

Author SHA1 Message Date
Wesley Liddick
9f6ab6b817 Merge pull request #1090 from laukkw/main
fix(setup): align install validation and expose backend errors
2026-03-18 16:23:06 +08:00
shaw
bf3d6c0e6e feat: add 529 overload cooldown toggle and duration settings in admin gateway page
Move 529 overload cooldown configuration from config file to admin
settings UI. Adds an enable/disable toggle and configurable cooldown
duration (1-120 min) under /admin/settings gateway tab, stored as
JSON in the settings table.

When disabled, 529 errors are logged but accounts are no longer
paused from scheduling. Falls back to config file value when DB
is unreachable or settingService is nil.
2026-03-18 16:22:19 +08:00
Wesley Liddick
241023f3fc Merge pull request #1097 from Ethan0x0000/pr/upstream-model-tracking
feat(usage): 新增 upstream_model 追踪,支持按模型来源统计与展示
2026-03-18 15:36:00 +08:00
Wesley Liddick
1292c44b41 Merge pull request #1118 from touwaeriol/worktree-fix/anti_mapping
feat: map claude-haiku-4-5 variants to claude-sonnet-4-6
2026-03-18 15:13:19 +08:00
Wesley Liddick
b4fce47049 Merge pull request #1116 from wucm667/fix/inject-site-title-in-html
fix: 直接访问或刷新页面时浏览器标签页显示自定义站点名称
2026-03-18 15:12:07 +08:00
Wesley Liddick
e7780cd8c8 Merge pull request #1117 from alfadb/fix/empty-text-block-retry
fix: 修复空 text block 导致上游 400 错误未被重试捕获的问题
2026-03-18 15:10:46 +08:00
erio
af96c8ea53 feat: map claude-haiku-4-5 variants to claude-sonnet-4-6
Update model mapping target for claude-haiku-4-5 and
claude-haiku-4-5-20251001 from claude-sonnet-4-5 to claude-sonnet-4-6.
Includes migration script, default constants, and test updates.
2026-03-18 15:03:24 +08:00
alfadb
7d26b81075 fix: address review - add missing whitespace patterns and narrow error matching 2026-03-18 14:31:57 +08:00
alfadb
b8ada63ac3 fix: strip empty text blocks in retry filter and fix error pattern matching
Empty text blocks ({"type":"text","text":""}) cause Anthropic upstream to
return 400: "text content blocks must be non-empty". This was not caught
by the existing error detection pattern in isThinkingBlockSignatureError,
nor handled by FilterThinkingBlocksForRetry.

- Add empty text block stripping to FilterThinkingBlocksForRetry
- Fix isThinkingBlockSignatureError to match new Anthropic error format
- Add fast-path byte patterns to avoid unnecessary JSON parsing

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 14:20:00 +08:00
Ethan0x0000
cfaac12af1 Merge upstream/main into pr/upstream-model-tracking
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-03-18 14:16:50 +08:00
wucm667
6028efd26c test: 添加 injectSiteTitle 函数的单元测试 2026-03-18 14:13:52 +08:00
shaw
62a566ef2c fix: 修复 config.yaml 以只读方式挂载时容器启动失败 (#1113)
entrypoint 中 chown -R /app/data 在遇到 :ro 挂载的文件时报错退出,
添加错误容忍处理;同时去掉 compose 文件注释中的 :ro 建议。
2026-03-18 14:11:51 +08:00
wucm667
94419f434c fix: 直接访问或刷新页面时浏览器标签页显示自定义站点名称
后端 HTML 注入时同步替换 <title> 标签为自定义站点名称,
前端 fetchPublicSettings 完成后重新设置 document.title,
解决路由守卫先于设置加载导致标题回退为默认值的时序问题。
2026-03-18 14:02:00 +08:00
Wesley Liddick
21f349c032 Merge pull request #1095 from LvyuanW/lvyuan/dev
fix(admin/accounts): reset edit modal state on reopen
2026-03-18 11:37:07 +08:00
Wesley Liddick
28e36f7925 Merge pull request #1096 from Ethan0x0000/pr/fix-idle-usage-windows
fix(ui): 会话窗口空闲时显示“现在”,避免重置时间缺失
2026-03-18 11:32:50 +08:00
Wesley Liddick
6c02076333 Merge pull request #1106 from geminiwen/feat/subscription-platform-filter
feat: add platform type filter to subscription management
2026-03-18 11:32:35 +08:00
shaw
7414bdf0e3 fix: 修复 hotpath 测试中 metadata.user_id 格式不合法导致 CI 失败
测试数据使用的 session ID "abc-123" 不符合 ParseMetadataUserID
要求的 36 字符 UUID 格式,替换为合法 UUID。
2026-03-18 11:31:32 +08:00
Wesley Liddick
e6326b2929 Merge pull request #1108 from DaydreamCoding/feat/admin-group-capacity-and-usage
feat(admin): 分组管理列表新增用量、账号分类与容量列
2026-03-18 11:12:43 +08:00
Wesley Liddick
17cdcebd04 Merge pull request #1109 from GuangYiDing/feat/subscription-guide
feat(subscriptions): 订阅管理页面添加教程指南弹窗
2026-03-18 11:12:33 +08:00
shaw
a14babdc73 fix: 兼容 Claude Code v2.1.78+ 新 JSON 格式 metadata.user_id
Claude Code v2.1.78 起将 metadata.user_id 从拼接字符串改为 JSON:
旧: user_{hex}_account_{uuid}_session_{uuid}
新: {"device_id":"...","account_uuid":"...","session_id":"..."}

新增集中解析/格式化模块 metadata_userid.go:
- ParseMetadataUserID: 自动识别两种格式,提取 DeviceID/AccountUUID/SessionID
- FormatMetadataUserID: 根据 UA 版本输出对应格式(>= 2.1.78 输出 JSON)
- ExtractCLIVersion: 从 UA 提取版本号,消除与 ClaudeCodeValidator.ExtractVersion 的重复

修改消费者统一使用新模块:
- claude_code_validator: 用 ParseMetadataUserID 替代只匹配旧格式的 userIDPattern
- identity_service: RewriteUserID/WithMasking 增加 fingerprintUA 参数,
  解析用 ParseMetadataUserID,输出用 FormatMetadataUserID(版本感知)
- gateway_service: GenerateSessionHash 用 ParseMetadataUserID 提取 session_id,
  buildOAuthMetadataUserID 用 FormatMetadataUserID 输出版本匹配格式,
  两处 RewriteUserIDWithMasking 调用传入 fp.UserAgent
- account_test_service: generateSessionString 改用 FormatMetadataUserID,
  自动跟随 DefaultHeaders UA 版本

删除三个旧正则: userIDPattern, userIDRegex, sessionIDRegex
统一 hex 匹配为 [a-fA-F0-9],修复旧 userIDRegex 只匹配小写的不一致
2026-03-18 11:08:58 +08:00
Rose Ding
aadc6a763a feat(subscriptions): 订阅管理页面添加教程指南弹窗
在订阅管理页面工具栏添加教程指南按钮(? 图标),点击弹出模态框,
引导管理员完成订阅功能的完整使用流程:

- 步骤一:创建订阅分组(含跳转分组管理链接)
- 步骤二:分配订阅给用户(搜索用户、选择分组、设置有效期)
- 步骤三:管理已有订阅(调整/重置配额/撤销操作说明表格)
- 底部提示:说明下拉列表为空时的解决方案

弹窗样式参照 BackupView 的 R2 Guide 模态框实现,保持 UI 一致性。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 10:49:41 +08:00
Rose Ding
f16af8bf88 feat(i18n): 添加订阅管理教程指南英文翻译
在 en.ts 中为订阅管理页面新增 guide 相关翻译词条,
与中文翻译保持结构一致,支持中英文切换。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 10:49:32 +08:00
Rose Ding
5ceaef4500 feat(i18n): 添加订阅管理教程指南中文翻译
在 zh.ts 中为订阅管理页面新增 guide 相关翻译词条,包括:
- 教程弹窗标题与副标题
- 三步操作引导文案(创建分组、分配订阅、管理订阅)
- 操作说明表格(调整/重置配额/撤销)
- 底部提示信息

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-18 10:49:13 +08:00
Gemini Wen
1ac7219a92 fix: add missing platform parameter to List calls in integration tests
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 10:35:03 +08:00
QTom
d4cc9871c4 feat(admin): 分组管理新增容量列(并发/会话/RPM 实时聚合)
复用 GroupCapacityService,在 admin 分组列表中添加容量列,
显示每个分组的实时并发/会话/RPM 使用量和上限。

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 10:06:35 +08:00
QTom
961c30e7c0 feat(admin): 分组管理列表新增用量列与账号数分类
分组管理列表增强:

1. 今日/累计用量列:
   - 新增独立端点 GET /admin/groups/usage-summary
   - 一次查询返回所有分组的今日费用和累计费用(actual_cost)
   - 前端异步加载后合并显示在分组列表中

2. 账号数区分可用/限流/总量:
   - 将账号数列从单一总量改为 badge 内多行展示
   - 可用: active + schedulable 的账号数(绿色)
   - 限流: rate_limit/overload/temp_unschedulable 的账号数(橙色,无限流时隐藏)
   - 总量: 全部关联账号数

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 10:06:35 +08:00
Gemini Wen
13e85b3147 fix: update remaining test stubs for List interface signature
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 09:35:08 +08:00
Gemini Wen
50a3c7fa0b feat: add platform type filter to subscription management page
Add a platform filter dropdown to the admin subscriptions view, allowing
filtering subscriptions by platform (Anthropic, OpenAI, Gemini, etc.)
through the group association.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 09:23:19 +08:00
Ethan0x0000
bd9d2671d7 chore(deps): go mod tidy to remove stale indirect dependencies 2026-03-17 20:46:12 +08:00
Ethan0x0000
62b40636e0 feat(frontend): display upstream model in usage table and distribution charts
Show upstream model mapping (requested -> upstream) in UsageTable with arrow notation. Add requested/upstream/mapping source toggle to ModelDistributionChart with lazy loading — only fetches data when user switches tab, with per-source cache invalidation on filter changes. Include upstream_model column in Excel export and i18n for en/zh.
2026-03-17 19:26:48 +08:00
Ethan0x0000
eeff451bc5 test(backend): add tests for upstream model tracking and model source filtering
Cover IsValidModelSource/NormalizeModelSource, resolveModelDimensionExpression SQL expressions, invalid model_source 400 responses on both GetModelStats and GetUserBreakdown, upstream_model in scan/insert SQL mock expectations, and updated passthrough/billing test signatures.
2026-03-17 19:26:30 +08:00
Ethan0x0000
56fcb20f94 feat(api): expose model_source filter in dashboard endpoints
Add model_source query parameter to GetModelStats and GetUserBreakdown handlers with explicit IsValidModelSource validation. Include model_source in cache key to prevent cross-source cache hits. Expose upstream_model in usage log DTO with omitempty semantics.
2026-03-17 19:26:11 +08:00
Ethan0x0000
7134266acf feat(dashboard): add model source dimension to stats queries
Support querying model statistics by 'requested', 'upstream', or 'mapping' dimension. Add resolveModelDimensionExpression for safe SQL expression generation, IsValidModelSource whitelist validator, and NormalizeModelSource fallback. Repository persists and scans upstream_model in all insert/select paths.
2026-03-17 19:25:52 +08:00
Ethan0x0000
2e4ac88ad9 feat(service): record upstream model across all gateway paths
Propagate UpstreamModel through ForwardResult and OpenAIForwardResult in Anthropic direct, API-key passthrough, Bedrock, and OpenAI gateway flows. Extract optionalNonEqualStringPtr and optionalTrimmedStringPtr into usage_log_helpers.go. Store upstream_model only when it differs from the requested model.

Also introduces anthropicPassthroughForwardInput struct to reduce parameter count.
2026-03-17 19:25:35 +08:00
Ethan0x0000
51547fa216 feat(db): add upstream_model column to usage_logs
Add nullable VARCHAR(100) column to record the actual model sent to upstream providers when model mapping is applied. NULL means no mapping — the requested model was used as-is.

Includes migration, concurrent index for aggregation queries, Ent schema regeneration, and migration README correction (forward-only runner, not goose).
2026-03-17 19:25:17 +08:00
Ethan0x0000
2005fc97a8 fix(ui): show 'now' for idle OpenAI usage windows
Use utilization-based idle detection instead of local request counts so newly imported OAuth accounts keep countdowns when usage is non-zero.
2026-03-17 19:23:35 +08:00
Wang Lvyuan
0772d9250e fix(admin/accounts): reset edit modal state on reopen 2026-03-17 18:44:10 +08:00
laukkw
aa6047c460 fix(setup): align install validation and expose backend errors
Make setup password requirements consistent with backend rules and show API-provided error messages so install failures are actionable. Trim admin email before validation to avoid false invalid-email rejections from surrounding whitespace.
2026-03-17 15:38:18 +08:00
Wesley Liddick
045cba78b4 Merge pull request #1083 from StarryKira/fix/claude-code-version-pattern-validation
fix(settings): remove pattern attribute blocking Claude Code version save fix issue #1081
2026-03-17 14:49:34 +08:00
Wesley Liddick
8989d0d4b6 Merge pull request #1085 from protondrift/main
feat: 个人资料弹窗 GitHub 链接仅对管理员可见
2026-03-17 14:49:07 +08:00
Wesley Liddick
c521117b99 Merge pull request #1074 from StarryKira/fix/session-window-reset-from-header
fix(usage): use real reset header for 5h session window countdown fix issue #1064 #1065
2026-03-17 14:48:16 +08:00
Eric
e0f52a8ab8 feat: 个人资料弹窗 GitHub 链接仅对管理员可见
目前作者已有商业站信息,面向管理可提供赞助渠道,面向普通用户请考虑提供信息隐藏措施
2026-03-17 12:51:34 +08:00
haruka
6c23fadf7e fix(settings): remove pattern attribute blocking Claude Code version save
The `pattern="\d+\.\d+\.\d+"` on the min_claude_code_version input caused
the browser's native HTML5 form validation to silently block form submission
when the value was invalid or when the hidden gateway tab was active. This
resulted in no network request being sent when clicking Save on any tab.

Backend already validates semver format and returns a proper 400 error,
so the frontend pattern attribute is redundant.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 11:33:57 +08:00
haruka
869952d113 fix(review): address Copilot PR feedback
- Add compile-time interface assertion for sessionWindowMockRepo
- Fix flaky fallback test by capturing time.Now() before calling UpdateSessionWindow
- Replace stale hardcoded timestamps with dynamic future values
- Add millisecond detection and bounds validation for reset header timestamp
- Use pause/resume pattern for interval in UsageProgressBar to avoid idle timers on large lists
- Fix gofmt comment alignment

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 10:19:20 +08:00
Wesley Liddick
07ab051ee4 Merge pull request #1078 from luxiang0412/main
fix(proxy): encode special chars in proxy credentials
2026-03-17 09:33:02 +08:00
Wesley Liddick
f2d98fc0c7 Merge pull request #1077 from Clov614/main
fix(auto setup): 修复初始测试连接硬编码问题导致使用自定义数据库测试失败无法执行 auto setup流程
2026-03-17 09:32:52 +08:00
Wesley Liddick
2b41cec840 Merge pull request #1076 from touwaeriol/pr/antigravity-test-connection-unify
refactor(antigravity): unify TestConnection with dispatch retry loop
2026-03-17 09:26:06 +08:00
Wesley Liddick
6cf77040e7 Merge pull request #1075 from touwaeriol/feat/dashboard-user-breakdown
feat(dashboard): add per-user drill-down for distribution charts
2026-03-17 09:25:43 +08:00
Wesley Liddick
20b70bc5fd Merge pull request #1070 from StarryKira/fix/oauth-system-role-to-instructions
fix(oauth): extract system-role input items into instructions field fix issue #1066
2026-03-17 09:24:17 +08:00
Wesley Liddick
4905e7193a Merge pull request #1069 from Ethan0x0000/pr/codex-usage-single-source
fix(openai): use /usage as single source and zero expired codex windows
2026-03-17 09:09:16 +08:00
Wesley Liddick
9c1f4b8e72 Merge pull request #1068 from Ethan0x0000/pr/frontend-last24h
feat(frontend): set last 24h as default range in Usage and Dashboard
2026-03-17 09:06:52 +08:00
Wesley Liddick
9857c17631 Merge pull request #1067 from DaydreamCoding/feat/async-backup
feat(backup): 备份/恢复异步化,解决 504 超时
2026-03-17 08:59:36 +08:00
luxiang
7e34bb946f fix(proxy): encode special chars in proxy credentials 2026-03-17 08:40:08 +08:00
clover614
47b748851b fix(auto setup): 修复初始测试连接硬编码问题导致使用自定义数据库测试失败无法执行 auto setup流程 2026-03-17 06:34:20 +08:00
erio
a6f99cf534 refactor(antigravity): unify TestConnection with dispatch retry loop
TestConnection now reuses antigravityRetryLoop instead of a standalone
HTTP loop, gaining credits overages, smart retry, and 429/503 backoff
for free. AccountSwitchError is caught and surfaced as a friendly
message. Also populates RateLimitedModel in TempUnscheduled switch error.

Test fixes:
- Use RATE_LIMIT_EXCEEDED in 503 short-delay test to avoid 60x1s timeout
- Clamp waitDuration=0 instead of 999s to avoid 15s max-wait timeout
- Enhance mockSmartRetryUpstream with repeatLast and body caching
2026-03-17 01:47:08 +08:00
erio
a120a6bc32 fix(ui): remove redundant sub-table header in user breakdown
The expanded user breakdown rows already align with the parent table
columns (Requests, Token, Actual, Standard), so the repeated sub-header
wastes vertical space. Remove the <thead> from UserBreakdownSubTable.
2026-03-17 00:49:43 +08:00
erio
d557d1a190 fix(ui): restore original max-h-48 height for distribution tables 2026-03-17 00:47:45 +08:00
erio
e0286e5085 test(dashboard): add unit tests for user-breakdown API
Handler tests (9 cases): group_id/model/endpoint filters, default
endpoint_type, custom limit, limit clamping, response format,
empty result, no-filter pass-through.

Repository test: resolveEndpointColumn mapping for inbound/upstream/path.
2026-03-17 00:47:33 +08:00
erio
4b41e898a4 feat(dashboard): add per-user drill-down for group, model, and endpoint distributions
Click on a group name, model name, or endpoint name in the distribution
tables to expand and show per-user usage breakdown (requests, tokens,
actual cost, standard cost).

Backend: new GET /admin/dashboard/user-breakdown API with group_id,
model, endpoint, endpoint_type filters.
Frontend: clickable rows with expand/collapse sub-table in all three
distribution charts.
2026-03-17 00:47:20 +08:00
Elysia
668e164793 fix(usage): use real reset header for session window instead of prediction
The 5h window reset time displayed for Setup Token accounts was inaccurate
because UpdateSessionWindow predicted the window end as "current hour + 5h"
instead of reading the actual `anthropic-ratelimit-unified-5h-reset` response
header. This caused the countdown to differ from the official Claude page.

Backend: parse the reset header (Unix timestamp) and use it as the real
window end, falling back to the hour-truncated prediction only when the
header is absent. Also correct stale predictions when a subsequent request
provides the real reset time.

Frontend: add a reactive 60s timer so the reset countdown in
UsageProgressBar ticks down in real-time instead of freezing at the
initial value.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 00:13:45 +08:00
Elysia
fa2e6188d0 fix(oauth): extract system-role input items into instructions field
OAuth upstreams (ChatGPT) reject requests containing role:"system" in
the input array with HTTP 400 "System messages are not allowed". Extract
such items before forwarding and merge their content into the top-level
instructions field, prepending to any existing value.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-16 21:20:46 +08:00
Ethan0x0000
7fde9ebbc2 fix: zero expired codex windows in backend, use /usage API as single frontend data source 2026-03-16 21:14:52 +08:00
Ethan0x0000
aef7c3b9bb feat: set last 24 hours as default date range in DashboardView 2026-03-16 21:14:26 +08:00
Ethan0x0000
a0b76bd608 feat: implement last 24 hours date range preset and update filters in UsageView 2026-03-16 21:14:26 +08:00
QTom
c1fab7f8d8 feat(backup): 备份/恢复异步化,解决 504 超时
POST /backups 和 POST /backups/:id/restore 改为异步:立即返回 HTTP 202,
后台 goroutine 独立执行 pg_dump → gzip → S3 上传,前端每 2s 轮询状态。

后端:
- 新增 StartBackup/StartRestore 方法,后台 goroutine 不依赖 HTTP 连接
- Graceful shutdown 等待活跃操作完成,启动时清理孤立 running 记录
- BackupRecord 新增 progress/restore_status 字段支持进度和恢复状态追踪

前端:
- 创建备份/恢复后轮询 GET /backups/:id 直到完成或失败
- 标签页切换暂停/恢复轮询,组件卸载清理定时器
- 正确处理 409(备份进行中)和轮询超时

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-16 20:22:10 +08:00
Wesley Liddick
f42c8f2abe Merge pull request #1062 from kunish/fix/antigravity-stream-keepalive
fix(antigravity): add stream keepalive to prevent connection drops
2026-03-16 19:57:13 +08:00
shaw
aa5846b282 fix(docker): resolve /app/data permission denied on volume mounts
Docker named volumes and host bind-mounts may be owned by root,
causing "open data/model_pricing.sha256: permission denied" when
the container runs as the non-root sub2api user.

Add an entrypoint script that fixes /app/data ownership before
dropping to sub2api via su-exec. Replace USER directive with the
entrypoint approach across all three Dockerfiles and update both
GoReleaser configs to include the script in Docker build contexts.
2026-03-16 19:52:14 +08:00
Wesley Liddick
594a0ade38 Merge pull request #1063 from touwaeriol/fix/usage-label-semantic
fix(i18n): correct usage label from "Total" to "Last 30d"
2026-03-16 19:04:36 +08:00
erio
d45cc23171 fix(i18n): correct usage label from "Total" to "Last 30d"
The usage stats query defaults to a 30-day rolling window, but the
UI label said "Total"/"累计" implying lifetime aggregation. Rename
to "Last 30d"/"近30天" so the label matches the actual query semantics.

Closes #1060
2026-03-16 18:25:41 +08:00
kunish
d795734352 fix(antigravity): add stream keepalive to prevent connection drops
Antigravity streaming handlers were missing the keepalive mechanism
that exists in the standard gateway, causing proxy/CDN idle timeouts
to break connections during long thinking phases (e.g. claude-opus-4-6).
This resulted in truncated responses with missing tool calls.

Add StreamKeepaliveInterval support to all three Antigravity streaming
paths: Claude SSE, Gemini SSE, and upstream passthrough.
2026-03-16 17:37:15 +08:00
Wesley Liddick
4da9fdd1d5 Merge pull request #1058 from Ethan0x0000/main
fix(admin/accounts): make usage window refresh deterministic and restore missing stats
2026-03-16 17:06:13 +08:00
Wesley Liddick
6b218caa21 Merge pull request #1053 from touwaeriol/chore/antigravity-ua-1.20.5
chore(antigravity): bump default User-Agent version to 1.20.5
2026-03-16 16:57:22 +08:00
shaw
5c138007d0 chore: update docs 2026-03-16 16:56:42 +08:00
Ethan0x0000
1acfc46f46 fix: always show usage stats for OpenAI OAuth and hide zero-value badges
- Simplify OpenAI rendering: always fetch /usage, prefer fetched data over
  codex snapshot (snapshot serves as loading placeholder only)
- Remove dead code: preferFetchedOpenAIUsage, isOpenAICodexSnapshotStale,
  and unreachable template branch
- Add today-stats support for key accounts (req/tokens/A/U badges)
- Use formatCompactNumber for consistent number formatting
- Add A/U badge titles for clarity
- Filter zero-value window stats in UsageProgressBar to avoid empty badges
- Update tests to match new fetched-data-first behavior
2026-03-16 16:23:13 +08:00
Ethan0x0000
fbffb08aae feat: add today-stats and manual refresh token propagation to usage cells
- Pass todayStats/todayStatsLoading to AccountUsageCell for key accounts
- Propagate usageManualRefreshToken to force usage reload on explicit refresh
- Refresh today stats when toggling usage/today_stats columns visible
2026-03-16 16:23:00 +08:00
Ethan0x0000
8640a62319 refactor: extract formatCompactNumber util and add last_used_at to refresh key
- Add formatCompactNumber() for consistent large-number formatting (K/M/B)
- Include last_used_at in OpenAI usage refresh key for better change detection
- Add .gitattributes eol=lf rules for frontend source files
2026-03-16 16:22:51 +08:00
Ethan0x0000
fa782e70a4 fix: always attach OpenAI 5h/7d window stats regardless of zero values
Removes hasMeaningfulWindowStats guard so the /usage endpoint consistently
returns WindowStats for both time windows. The frontend now controls
zero-value display filtering at the component level.
2026-03-16 16:22:42 +08:00
Ethan0x0000
afd72abc6e fix: allow empty extra payload to clear account quota limits
UpdateAccount previously required len(input.Extra) > 0, causing explicit
empty payloads (extra:{}) to be silently skipped. Change condition to
input.Extra != nil so clearing quota keys actually persists.
2026-03-16 16:22:31 +08:00
erio
71f72e167e chore(antigravity): bump default User-Agent version to 1.20.5 2026-03-16 15:47:32 +08:00
Wesley Liddick
6595c7601e Merge pull request #1050 from touwaeriol/fix/rate-limit-redis-window-reset
fix(billing): add window expiration check to Redis rate limit Lua script
2026-03-16 14:17:41 +08:00
erio
67c0506290 fix(billing): add window expiration check to Redis rate limit Lua script
The updateRateLimitUsageScript Lua script previously performed
unconditional HINCRBYFLOAT on all usage counters without checking
whether the rate limit window had expired. This caused usage to
accumulate across window boundaries in Redis while the DB correctly
reset on expiration, leading to incorrect 429 rate limiting that
could persist for up to 24 hours.

The Lua script now checks each window timestamp before incrementing:
- If the window has expired, usage is reset to the current cost and
  the window timestamp is updated (matching DB-side semantics)
- If the window is still valid, usage is accumulated normally

This also resolves the async race condition where stale HINCRBYFLOAT
tasks from the worker queue could pollute a freshly rebuilt cache
after invalidation, since the script now self-corrects expired windows.

Closes #1049
2026-03-16 13:39:50 +08:00
Wesley Liddick
6447be4534 Merge pull request #1047 from DaydreamCoding/fix/codex-stream-isolation
fix(gateway): 防止 OpenAI Codex 跨用户串流 + WS 连接池条件式 MarkBroken
2026-03-16 11:00:07 +08:00
QTom
3741617ebd fix(gateway): WS 连接池条件式 MarkBroken 防止跨请求串流
正常终端事件(response.completed 等)退出后连接归还复用,
仅异常路径(读写错误、error 事件、客户端断连)MarkBroken 销毁。

Generate 模式:
- 引入 cleanExit 标记,仅在 isTerminalEvent break 时设置 true
- defer 中根据 cleanExit 决定是否 MarkBroken
- 所有异常路径已在各自分支中提前调用 MarkBroken

Ingress 模式:
- 引入 lastTurnClean 标记,sendAndRelay 正常完成时设为 true
- releaseSessionLease 根据 lastTurnClean 决定是否 MarkBroken
- 错误路径重置 lastTurnClean = false
- 客户端断连后 drain 仍保守 MarkBroken(L2916)
2026-03-16 10:50:02 +08:00
QTom
ab4e8b2cf0 fix(gateway): 防止 OpenAI Codex 跨用户串流
根因:多个用户共享同一 OAuth 账号时,conversation_id/session_id 头
未做用户隔离,导致上游 chatgpt.com 将不同用户的请求关联到同一会话。

HTTP SSE 修复:
- 新增 isolateOpenAISessionID(apiKeyID, raw),将 API Key ID 混入
  session 标识符(xxhash),确保不同 Key 的用户产生不同上游会话
- buildUpstreamRequest: OAuth 分支先 Del 客户端透传的 session 头,
  再用隔离值覆盖
- buildUpstreamRequestOpenAIPassthrough: 透传路径同样隔离
- ForwardAsAnthropic: Anthropic Messages 兼容路径同步修复
- buildOpenAIWSHeaders: WS 路径的 OAuth session 头同步隔离
2026-03-16 10:28:51 +08:00
Wesley Liddick
474165d7aa Merge pull request #1043 from touwaeriol/pr/antigravity-credits-overages
feat: Antigravity AI Credits overages handling & balance display
2026-03-16 09:22:19 +08:00
Wesley Liddick
94e067a2e2 Merge pull request #1040 from 0xObjc/codex/fix-user-spending-ranking-others
fix(admin): polish spending ranking and usage defaults
2026-03-16 09:19:46 +08:00
Wesley Liddick
4293c89166 Merge pull request #1036 from Ethan0x0000/feat/usage-endpoint-distribution
fix: record endpoint info for all API surfaces & unify normalization via middleware
2026-03-16 09:17:32 +08:00
Wesley Liddick
ec82c37da5 Merge pull request #1042 from touwaeriol/feat/unified-oauth-refresh-api
feat: unified OAuth token refresh API with distributed locking
2026-03-16 09:00:42 +08:00
erio
552a4b998a fix: resolve golangci-lint issues (gofmt, errcheck)
- Fix gofmt alignment in admin_service.go and trailing newline in
  antigravity_credits_overages.go
- Suppress errcheck for fmt.Sscanf in client.go GetMinimumAmount
2026-03-16 05:15:27 +08:00
erio
0d2061b268 fix: remove ClaudeMax references not yet in upstream/main
Remove SimulateClaudeMaxEnabled field and related logic from
admin_service.go, and remove applyClaudeMaxCacheBillingPolicyToUsage,
applyClaudeMaxNonStreamingRewrite, setupClaudeMaxStreamingHook calls
from antigravity_gateway_service.go. These symbols are not yet
available in upstream/main.
2026-03-16 05:01:42 +08:00
erio
8a260defc2 refactor: replace sync.Map credits state with AICredits rate limit key
Replace process-memory sync.Map + per-model runtime state with a single
"AICredits" key in model_rate_limits, making credits exhaustion fully
isomorphic with model-level rate limiting.

Scheduler: rate-limited accounts with overages enabled + credits available
are now scheduled instead of excluded.

Forwarding: when model is rate-limited + credits available, inject credits
proactively without waiting for a 429 round trip.

Storage: credits exhaustion stored as model_rate_limits["AICredits"] with
5h duration, reusing SetModelRateLimit/isRateLimitActiveForKey.

Frontend: show credits_active (yellow ) when model rate-limited but
credits available, credits_exhausted (red) when AICredits key active.

Tests: add unit tests for shouldMarkCreditsExhausted, injectEnabledCreditTypes,
clearCreditsExhausted, and update existing overages tests.
2026-03-16 04:58:58 +08:00
SilentFlower
e14c87597a feat: simplify AI Credits display logic and enhance UI presentation 2026-03-16 04:58:46 +08:00
SilentFlower
f3f19d35aa feat: enhance Antigravity account overages handling and improve UI credit display 2026-03-16 04:58:35 +08:00
SilentFlower
ced90e1d84 feat: add AI Credits balance handling and update model status indicators 2026-03-16 04:58:23 +08:00
SilentFlower
17e4033340 feat: implement resolveCreditsOveragesModelKey function to stabilize model key resolution for credit overages 2026-03-16 04:58:12 +08:00
erio
044d3a013d fix: suppress SA4006 unused value warning in Path A branch 2026-03-16 01:38:06 +08:00
erio
1fc9dd7b68 feat: unified OAuth token refresh API with distributed locking
Introduce OAuthRefreshAPI as the single entry point for all OAuth token
refresh operations, eliminating the race condition where background
refresh and inline refresh could simultaneously use the same
refresh_token (fixes #1035).

Key changes:
- Add OAuthRefreshExecutor interface extending TokenRefresher with CacheKey
- Add OAuthRefreshAPI.RefreshIfNeeded with lock → DB re-read → double-check flow
- Add ProviderRefreshPolicy / BackgroundRefreshPolicy strategy types
- Simplify all 4 TokenProviders to delegate to OAuthRefreshAPI
- Rewrite TokenRefreshService.refreshWithRetry to use unified API path
- Add MergeCredentials and BuildClaudeAccountCredentials helpers
- Add 40 unit tests covering all new and modified code paths
2026-03-16 01:31:54 +08:00
Peter
8147866c09 fix(admin): polish spending ranking and usage defaults 2026-03-16 00:17:47 +08:00
Ethan0x0000
7bd1972f94 refactor: migrate all handlers to shared endpoint normalization middleware
- Apply InboundEndpointMiddleware to all gateway route groups
- Replace normalizedOpenAIInboundEndpoint/normalizedOpenAIUpstreamEndpoint and normalizedGatewayInboundEndpoint/normalizedGatewayUpstreamEndpoint with GetInboundEndpoint/GetUpstreamEndpoint
- Remove 4 old constants and 4 old normalization functions (-70 lines)
- Migrate existing endpoint normalization test to new API

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-03-15 22:13:42 +08:00
Ethan0x0000
2c9dcfe27b refactor: add unified endpoint normalization infrastructure
Introduce endpoint.go with shared constants, NormalizeInboundEndpoint, DeriveUpstreamEndpoint, InboundEndpointMiddleware, and context helpers. This replaces the two separate normalization implementations (OpenAI and Gateway) with a single source of truth. Includes comprehensive test coverage.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-03-15 22:13:31 +08:00
Ethan0x0000
1b79b0f3ff feat: add InboundEndpoint/UpstreamEndpoint fields to non-OpenAI usage records
Extend RecordUsageInput and RecordUsageLongContextInput structs with InboundEndpoint and UpstreamEndpoint so that Claude, Gemini, and Sora handlers can record endpoint info alongside OpenAI handlers.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-03-15 22:13:22 +08:00
Ethan0x0000
c637e6cf31 fix: use half-open date ranges for DST-safe usage queries
Replace t.Add(24*time.Hour - time.Nanosecond) with t.AddDate(0, 0, 1) and use SQL < instead of <= for end-of-day boundaries. This avoids edge-case misses around DST transitions.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-03-15 22:13:12 +08:00
Wesley Liddick
d3a9f5bb88 Merge pull request #1027 from touwaeriol/feat/ignore-insufficient-balance-errors
feat(ops): add ignore insufficient balance errors toggle and extract error constants
2026-03-15 19:10:18 +08:00
Wesley Liddick
7eb0415a8a Merge pull request #1028 from IanShaw027/fix/open-issues-cleanup
fix: 修复多个issues - Gemini schema 兼容性、批量编辑白名单、Docker 工具支持和限额字段处理Fix/open issues cleanup
2026-03-15 19:09:49 +08:00
erio
bdbc8fa08f fix(ops): align constant declarations for gofmt compliance 2026-03-15 18:55:14 +08:00
erio
63f3af0f94 fix(ops): match "insufficient account balance" in error filter
The upstream Gemini API returns "Insufficient account balance" which
doesn't contain the substring "insufficient balance". Add explicit
match for the full phrase to ensure the filter works correctly.
2026-03-15 18:45:48 +08:00
IanShaw027
686f890fbf style: 修复 gofmt 格式问题 2026-03-15 18:42:32 +08:00
shaw
220fbe6544 fix: 恢复 UsageProgressBar 中被意外移除的窗口统计数据展示
commit 0debe0a8 在修复 OpenAI WS 用量窗口刷新问题时,意外删除了
UsageProgressBar 中的 window stats 渲染逻辑和格式化函数。

恢复进度条上方的统计行(requests, tokens, account cost, user cost)
及对应的 4 个格式化 computed 属性。
2026-03-15 18:29:23 +08:00
shaw
ae44a94325 fix: 重置密码功能新增UI配置发送邮件域名 2026-03-15 17:52:29 +08:00
IanShaw
3718d6dcd4 Merge branch 'Wei-Shaw:main' into fix/open-issues-cleanup 2026-03-15 17:49:20 +08:00
IanShaw027
90b3838173 fix: 移除 Gemini 不支持的 patternProperties 字段 #795 2026-03-15 17:46:58 +08:00
IanShaw027
19d3ecc76f fix: 修复批量编辑账号时模型白名单显示与实际不一致的问题 #982
修复批量编辑账号时,UI 显示的是 plain 模型名(如 GPT-5),但实际落库的是 dated 模型名的问题。

核心改动:
1. 批量编辑白名单不再使用 BulkEditAccountModal.vue 中手写的过期模型列表
   - 移除了 allModels 和 presetMappings 的硬编码列表(共 200+ 行)
   - 直接复用 ModelWhitelistSelector.vue 组件

2. ModelWhitelistSelector 组件支持多平台联合过滤
   - 新增 platforms 属性支持传入多个平台
   - 添加 normalizedPlatforms 计算属性统一处理单平台和多平台场景
   - availableOptions 根据选中的多个平台动态联合过滤模型列表
   - fillRelated 功能支持一次性填充多个平台的相关模型

3. 模型映射预设改为动态生成
   - filteredPresets 改用 getPresetMappingsByPlatform 从统一模型源按平台动态生成
   - 不再依赖弹窗中的手写预设列表

现在的行为:
- UI 显示什么模型,勾选什么模型,传给后端的就是什么模型
- 彻底解决了批量编辑链路上"显示与实际不一致"的问题
- 模型列表和映射预设始终与系统定义保持同步
2026-03-15 17:46:58 +08:00
IanShaw027
6fba4ebb13 fix: 在 Dockerfile.goreleaser 中添加 pg_dump 和 psql 工具 #1002
为了支持容器内的数据库备份和恢复功能,在运行时镜像中添加 PostgreSQL 客户端工具。

变更内容:
- 使用多阶段构建从 postgres:18-alpine 镜像复制 pg_dump 和 psql 二进制文件
- 添加必要的依赖库(libpq, zstd-libs, lz4-libs, krb5-libs, libldap, libedit)
- 升级基础镜像到 alpine:3.21
- 复制 libpq.so.5 共享库以确保工具正常运行

这样可以在运行时容器中直接执行数据库备份和恢复操作,无需访问 Docker socket。
2026-03-15 17:46:58 +08:00
IanShaw027
c31974c913 fix: 兼容部分限额字段为空的情况 #1021
修复在填写限额时,如果不填写完整的三个限额额度(日限额、周限额、月限额)就会报错的问题。

变更内容:
- 后端:添加 optionalLimitField 类型处理空值和空字符串,兼容部分限额字段为空的情况
- 前端:添加 normalizeOptionalLimit 函数规范化限额输入,将空值、空字符串和无效数字统一处理为 null
2026-03-15 17:46:58 +08:00
erio
6177fa5dd8 fix(i18n): correct insufficient balance error hint text
Remove misleading "upstream" wording - the error is about client API key
user balance, not upstream account balance.
2026-03-15 17:41:51 +08:00
erio
cfe72159d0 feat(ops): add ignore insufficient balance errors toggle and extract error constants
- Add 5th error filter switch IgnoreInsufficientBalanceErrors to suppress
  upstream insufficient balance / insufficient_quota errors from ops log
- Extract hardcoded error strings into package-level constants for
  shouldSkipOpsErrorLog, normalizeOpsErrorType, classifyOpsPhase, and
  classifyOpsIsBusinessLimited
- Define ErrNoAvailableAccounts sentinel error and replace all
  errors.New("no available accounts") call sites
- Update tests to use require.ErrorIs with the sentinel error
2026-03-15 17:26:18 +08:00
Wesley Liddick
8321e4a647 Merge pull request #1023 from YanzheL/fix/claude-output-effort-logging
fix: extract and log Claude output_config.effort in usage records
2026-03-15 16:45:37 +08:00
Wesley Liddick
3084330d0c Merge pull request #1019 from Ethan0x0000/feat/usage-endpoint-distribution
feat: add endpoint metadata and usage endpoint distribution insights
2026-03-15 16:42:03 +08:00
Wesley Liddick
b566649e79 Merge pull request #1025 from touwaeriol/fix/rate-limit-nil-window-reset
fix(billing): treat nil rate limit window as expired to prevent usage accumulation
2026-03-15 16:33:14 +08:00
Wesley Liddick
10a6180e4a Merge pull request #1026 from touwaeriol/fix/group-quota-clear
fix(billing): allow clearing group quota limits and treat 0 as zero-limit
2026-03-15 16:33:00 +08:00
Wesley Liddick
cbe9e78977 Merge pull request #1007 from StarryKira/fix/streaming-failover-corruption
fix(gateway): 防止流式 failover 拼接腐化导致客户端收到双 message_start fix issue #991
2026-03-15 16:29:31 +08:00
Wesley Liddick
74145b1f39 Merge pull request #1017 from SsageParuders/fix/bedrock-account-quota
fix: Bedrock 账户配额限制不生效
2026-03-15 16:28:42 +08:00
Elysia
359e56751b 增加测试 2026-03-15 16:21:49 +08:00
erio
5899784aa4 fix(billing): allow clearing group quota limits and treat 0 as zero-limit
Previously, v-model.number produced "" when input was cleared, causing
JSON decode errors on the backend. Also, normalizeLimit treated 0 as
"unlimited" which prevented setting a zero quota. Now "" is converted
to null (unlimited) in frontend, and 0 is preserved as a valid limit.

Closes Wei-Shaw/sub2api#1021
2026-03-15 16:15:15 +08:00
erio
9e8959c56d fix(billing): treat nil rate limit window as expired to prevent usage accumulation
When Redis cache is populated from DB with a NULL window_1d_start, the
Lua increment script only updates usage counters without setting window
timestamps. IsWindowExpired(nil) previously returned false, so the
accumulated usage was never reset across time windows, effectively
turning usage_1d into a lifetime counter. Once this exceeded
rate_limit_1d the key was incorrectly blocked with "日限额已用完".

Fixes Wei-Shaw/sub2api#1022
2026-03-15 14:04:13 +08:00
YanzheL
1bff2292a6 fix: extract and log Claude output_config.effort in usage records
Claude's output_config.effort parameter (low/medium/high/max) was not
being extracted from requests or logged in the reasoning_effort column
of usage logs. Only the OpenAI path populated this field.

Changes:
- Extract output_config.effort in ParseGatewayRequest
- Add ReasoningEffort field to ForwardResult
- Populate reasoning_effort in both RecordUsage and RecordUsageWithLongContext
- Guard against overwriting service-set effort values in handler
- Update stale comments that described reasoning_effort as OpenAI-only
- Add unit tests for extraction, normalization, and persistence
2026-03-15 12:55:37 +08:00
Ethan0x0000
cf9247754e test: fix usage repo stubs for unit builds 2026-03-15 12:51:34 +08:00
Ethan0x0000
eefab15958 feat: 完善使用记录端点可观测性与分布统计
将入站、上游与路径三类端点分布统一到使用记录页的一致化卡片交互中,并补齐端点元数据与统计链路,提升排障与流量分析效率。
2026-03-15 11:26:42 +08:00
Elysia
0e23732631 fix(gateway): 防止流式 failover 拼接腐化导致客户端收到双 message_start
当上游在 SSE 流中途返回 event:error 时,handleStreamingResponse 已将
部分 SSE 事件写入客户端,但原先的 failover 逻辑仍会切换到下一个账号
并写入完整流,导致客户端收到两个 message_start 进而产生 400 错误。

修复方案:在每次 Forward 调用前记录 c.Writer.Size(),若 Forward 返回
UpstreamFailoverError 后 writer 字节数增加,说明 SSE 内容已不可撤销地
发送给客户端,此时直接调用 handleFailoverExhausted 发送 SSE error 事件
终止流,而非继续 failover。

Ping-only 场景不受影响:slot 等待期的 ping 字节在 Forward 前后相等,
正常 failover 流程照常进行。

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-14 22:49:23 +08:00
SsageParuders
37c044fb4b fix: Bedrock 账户配额限制不生效,配额计数器始终为 $0.00
applyUsageBillingEffects() 中配额更新条件仅检查了 AccountTypeAPIKey,
遗漏了 AccountTypeBedrock,导致 Bedrock 账户的配额计数器永远不递增。
扩展条件以同时支持 APIKey 和 Bedrock 类型。

同时在前端账户筛选下拉框中添加 AWS Bedrock 选项。
2026-03-14 22:47:44 +08:00
219 changed files with 12075 additions and 2357 deletions

7
.gitattributes vendored
View File

@@ -4,6 +4,13 @@ backend/migrations/*.sql text eol=lf
# Go 源代码文件
*.go text eol=lf
# 前端 源代码文件
*.ts text eol=lf
*.tsx text eol=lf
*.js text eol=lf
*.jsx text eol=lf
*.vue text eol=lf
# Shell 脚本
*.sh text eol=lf

View File

@@ -47,6 +47,8 @@ dockers:
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest"
dockerfile: Dockerfile.goreleaser
use: buildx
extra_files:
- deploy/docker-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.version={{ .Version }}"

View File

@@ -63,6 +63,8 @@ dockers:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
dockerfile: Dockerfile.goreleaser
use: buildx
extra_files:
- deploy/docker-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.version={{ .Version }}"
@@ -76,6 +78,8 @@ dockers:
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
dockerfile: Dockerfile.goreleaser
use: buildx
extra_files:
- deploy/docker-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.version={{ .Version }}"
@@ -89,6 +93,8 @@ dockers:
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
dockerfile: Dockerfile.goreleaser
use: buildx
extra_files:
- deploy/docker-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.version={{ .Version }}"
@@ -102,6 +108,8 @@ dockers:
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64"
dockerfile: Dockerfile.goreleaser
use: buildx
extra_files:
- deploy/docker-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.version={{ .Version }}"

View File

@@ -92,6 +92,7 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
RUN apk add --no-cache \
ca-certificates \
tzdata \
su-exec \
libpq \
zstd-libs \
lz4-libs \
@@ -120,8 +121,9 @@ COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/
# Create data directory
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
# Switch to non-root user
USER sub2api
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
RUN chmod +x /app/docker-entrypoint.sh
# Expose port (can be overridden by SERVER_PORT env var)
EXPOSE 8080
@@ -130,5 +132,6 @@ EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
# Run the application
ENTRYPOINT ["/app/sub2api"]
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
ENTRYPOINT ["/app/docker-entrypoint.sh"]
CMD ["/app/sub2api"]

View File

@@ -5,7 +5,12 @@
# It only packages the pre-built binary, no compilation needed.
# =============================================================================
FROM alpine:3.19
ARG ALPINE_IMAGE=alpine:3.21
ARG POSTGRES_IMAGE=postgres:18-alpine
FROM ${POSTGRES_IMAGE} AS pg-client
FROM ${ALPINE_IMAGE}
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
LABEL description="Sub2API - AI API Gateway Platform"
@@ -16,8 +21,21 @@ RUN apk add --no-cache \
ca-certificates \
tzdata \
curl \
su-exec \
libpq \
zstd-libs \
lz4-libs \
krb5-libs \
libldap \
libedit \
&& rm -rf /var/cache/apk/*
# Copy pg_dump and psql from a version-matched PostgreSQL image so backup and
# restore work in the runtime container without requiring Docker socket access.
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
# Create non-root user
RUN addgroup -g 1000 sub2api && \
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
@@ -30,11 +48,15 @@ COPY sub2api /app/sub2api
# Create data directory
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
USER sub2api
# Copy entrypoint script (fixes volume permissions then drops to sub2api)
COPY deploy/docker-entrypoint.sh /app/docker-entrypoint.sh
RUN chmod +x /app/docker-entrypoint.sh
EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
ENTRYPOINT ["/app/sub2api"]
# Run the application (entrypoint fixes /app/data ownership then execs as sub2api)
ENTRYPOINT ["/app/docker-entrypoint.sh"]
CMD ["/app/sub2api"]

View File

@@ -8,27 +8,31 @@
[![Redis](https://img.shields.io/badge/Redis-7+-DC382D.svg)](https://redis.io/)
[![Docker](https://img.shields.io/badge/Docker-Ready-2496ED.svg)](https://www.docker.com/)
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
**AI API Gateway Platform for Subscription Quota Distribution**
English | [中文](README_CN.md)
</div>
> **Sub2API officially uses only the domains `sub2api.org` and `pincc.ai`. Other websites using the Sub2API name may be third-party deployments or services and are not affiliated with this project. Please verify and exercise your own judgment.**
---
## Demo
Try Sub2API online: **https://demo.sub2api.org/**
Try Sub2API online: **[https://demo.sub2api.org/](https://demo.sub2api.org/)**
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
| Email | Password |
|-------|----------|
| admin@sub2api.com | admin123 |
| admin@sub2api.org | admin123 |
## Overview
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions (like Claude Code $200/month). Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
Sub2API is an AI API gateway platform designed to distribute and manage API quotas from AI product subscriptions. Users can access upstream AI services through platform-generated API Keys, while the platform handles authentication, billing, load balancing, and request forwarding.
## Features
@@ -41,6 +45,15 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
- **Admin Dashboard** - Web interface for monitoring and management
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
## Don't Want to Self-Host?
<table>
<tr>
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> is the official relay service built on Sub2API, offering stable access to Claude Code, Codex, Gemini and other popular models — ready to use, no deployment or maintenance required.</td>
</tr>
</table>
## Ecosystem
Community projects that extend or integrate with Sub2API:
@@ -61,10 +74,15 @@ Community projects that extend or integrate with Sub2API:
---
## Documentation
## Nginx Reverse Proxy Note
- Dependency Security: `docs/dependency-security.md`
- Admin Payment Integration API: `docs/ADMIN_PAYMENT_INTEGRATION_API.md`
When using Nginx as a reverse proxy for Sub2API (or CRS) with Codex CLI, add the following to the `http` block in your Nginx configuration:
```nginx
underscores_in_headers on;
```
Nginx drops headers containing underscores by default (e.g. `session_id`), which breaks sticky session routing in multi-account setups.
---

View File

@@ -8,27 +8,30 @@
[![Redis](https://img.shields.io/badge/Redis-7+-DC382D.svg)](https://redis.io/)
[![Docker](https://img.shields.io/badge/Docker-Ready-2496ED.svg)](https://www.docker.com/)
<a href="https://trendshift.io/repositories/21823" target="_blank"><img src="https://trendshift.io/api/badge/repositories/21823" alt="Wei-Shaw%2Fsub2api | Trendshift" width="250" height="55"/></a>
**AI API 网关平台 - 订阅配额分发管理**
[English](README.md) | 中文
</div>
> **Sub2API 官方仅使用 `sub2api.org` 与 `pincc.ai` 两个域名。其他使用 Sub2API 名义的网站可能为第三方部署或服务,与本项目无关,请自行甄别。**
---
## 在线体验
体验地址:**https://v2.pincc.ai/**
体验地址:**[https://demo.sub2api.org/](https://demo.sub2api.org/)**
演示账号(共享演示环境;自建部署不会自动创建该账号):
| 邮箱 | 密码 |
|------|------|
| admin@sub2api.com | admin123 |
| admin@sub2api.org | admin123 |
## 项目概述
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(如 Claude Code $200/月)的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 API 配额。用户通过平台生成的 API Key 调用上游 AI 服务,平台负责鉴权、计费、负载均衡和请求转发。
## 核心功能
@@ -41,6 +44,15 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
- **管理后台** - Web 界面进行监控和管理
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
## 不想自建?试试官方中转
<table>
<tr>
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> 是基于 Sub2API 搭建的官方中转服务,提供 Claude Code、Codex、Gemini 等主流模型的稳定中转,开箱即用,免去自建部署与运维烦恼。</td>
</tr>
</table>
## 生态项目
围绕 Sub2API 的社区扩展与集成项目:
@@ -61,17 +73,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
---
## 文档
## Nginx 反向代理注意事项
- 依赖安全:`docs/dependency-security.md`
通过 Nginx 反向代理 Sub2API或 CRS 服务)并搭配 Codex CLI 使用时,需要在 Nginx 配置的 `http` 块中添加:
```nginx
underscores_in_headers on;
```
Nginx 默认会丢弃名称中含下划线的请求头(如 `session_id`),这会导致多账号环境下的粘性会话功能失效。
---
## OpenAI Responses 兼容注意事项
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id``tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
## 部署方式
### 方式一:脚本安装(推荐)

Binary file not shown.

After

Width:  |  Height:  |  Size: 171 KiB

View File

@@ -110,7 +110,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
@@ -124,6 +123,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
httpUpstream := repository.NewHTTPUpstream(configConfig)
@@ -132,16 +132,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
usageCache := service.NewUsageCache()
identityCache := repository.NewIdentityCache(redisClient)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
rpmCache := repository.NewRPMCache(redisClient)
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
dataManagementService := service.NewDataManagementService()
@@ -166,10 +168,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
@@ -232,7 +234,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)

View File

@@ -716,6 +716,7 @@ var (
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "request_id", Type: field.TypeString, Size: 64},
{Name: "model", Type: field.TypeString, Size: 100},
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
@@ -755,31 +756,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[30]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[31]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[31]},
Columns: []*schema.Column{UsageLogsColumns[32]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[33]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -788,32 +789,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31]},
Columns: []*schema.Column{UsageLogsColumns[32]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[29]},
},
{
Name: "usagelog_account_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[29]},
Columns: []*schema.Column{UsageLogsColumns[30]},
},
{
Name: "usagelog_group_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30]},
Columns: []*schema.Column{UsageLogsColumns[31]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[33]},
},
{
Name: "usagelog_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[28]},
},
{
Name: "usagelog_model",
@@ -828,17 +829,17 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[28]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[28]},
},
{
Name: "usagelog_group_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[28]},
},
},
}

View File

@@ -18239,6 +18239,7 @@ type UsageLogMutation struct {
id *int64
request_id *string
model *string
upstream_model *string
input_tokens *int
addinput_tokens *int
output_tokens *int
@@ -18576,6 +18577,55 @@ func (m *UsageLogMutation) ResetModel() {
m.model = nil
}
// SetUpstreamModel sets the "upstream_model" field.
func (m *UsageLogMutation) SetUpstreamModel(s string) {
m.upstream_model = &s
}
// UpstreamModel returns the value of the "upstream_model" field in the mutation.
func (m *UsageLogMutation) UpstreamModel() (r string, exists bool) {
v := m.upstream_model
if v == nil {
return
}
return *v, true
}
// OldUpstreamModel returns the old "upstream_model" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldUpstreamModel(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldUpstreamModel is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldUpstreamModel requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldUpstreamModel: %w", err)
}
return oldValue.UpstreamModel, nil
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (m *UsageLogMutation) ClearUpstreamModel() {
m.upstream_model = nil
m.clearedFields[usagelog.FieldUpstreamModel] = struct{}{}
}
// UpstreamModelCleared returns if the "upstream_model" field was cleared in this mutation.
func (m *UsageLogMutation) UpstreamModelCleared() bool {
_, ok := m.clearedFields[usagelog.FieldUpstreamModel]
return ok
}
// ResetUpstreamModel resets all changes to the "upstream_model" field.
func (m *UsageLogMutation) ResetUpstreamModel() {
m.upstream_model = nil
delete(m.clearedFields, usagelog.FieldUpstreamModel)
}
// SetGroupID sets the "group_id" field.
func (m *UsageLogMutation) SetGroupID(i int64) {
m.group = &i
@@ -20197,7 +20247,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 32)
fields := make([]string, 0, 33)
if m.user != nil {
fields = append(fields, usagelog.FieldUserID)
}
@@ -20213,6 +20263,9 @@ func (m *UsageLogMutation) Fields() []string {
if m.model != nil {
fields = append(fields, usagelog.FieldModel)
}
if m.upstream_model != nil {
fields = append(fields, usagelog.FieldUpstreamModel)
}
if m.group != nil {
fields = append(fields, usagelog.FieldGroupID)
}
@@ -20312,6 +20365,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.RequestID()
case usagelog.FieldModel:
return m.Model()
case usagelog.FieldUpstreamModel:
return m.UpstreamModel()
case usagelog.FieldGroupID:
return m.GroupID()
case usagelog.FieldSubscriptionID:
@@ -20385,6 +20440,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldRequestID(ctx)
case usagelog.FieldModel:
return m.OldModel(ctx)
case usagelog.FieldUpstreamModel:
return m.OldUpstreamModel(ctx)
case usagelog.FieldGroupID:
return m.OldGroupID(ctx)
case usagelog.FieldSubscriptionID:
@@ -20483,6 +20540,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetModel(v)
return nil
case usagelog.FieldUpstreamModel:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetUpstreamModel(v)
return nil
case usagelog.FieldGroupID:
v, ok := value.(int64)
if !ok {
@@ -20921,6 +20985,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
// mutation.
func (m *UsageLogMutation) ClearedFields() []string {
var fields []string
if m.FieldCleared(usagelog.FieldUpstreamModel) {
fields = append(fields, usagelog.FieldUpstreamModel)
}
if m.FieldCleared(usagelog.FieldGroupID) {
fields = append(fields, usagelog.FieldGroupID)
}
@@ -20962,6 +21029,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool {
// error if the field is not defined in the schema.
func (m *UsageLogMutation) ClearField(name string) error {
switch name {
case usagelog.FieldUpstreamModel:
m.ClearUpstreamModel()
return nil
case usagelog.FieldGroupID:
m.ClearGroupID()
return nil
@@ -21012,6 +21082,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldModel:
m.ResetModel()
return nil
case usagelog.FieldUpstreamModel:
m.ResetUpstreamModel()
return nil
case usagelog.FieldGroupID:
m.ResetGroupID()
return nil

View File

@@ -821,92 +821,96 @@ func init() {
return nil
}
}()
// usagelogDescUpstreamModel is the schema descriptor for upstream_model field.
usagelogDescUpstreamModel := usagelogFields[5].Descriptor()
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
usagelogDescInputTokens := usagelogFields[7].Descriptor()
usagelogDescInputTokens := usagelogFields[8].Descriptor()
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
usagelogDescOutputTokens := usagelogFields[8].Descriptor()
usagelogDescOutputTokens := usagelogFields[9].Descriptor()
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor()
usagelogDescCacheCreationTokens := usagelogFields[10].Descriptor()
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
usagelogDescCacheReadTokens := usagelogFields[10].Descriptor()
usagelogDescCacheReadTokens := usagelogFields[11].Descriptor()
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor()
usagelogDescCacheCreation5mTokens := usagelogFields[12].Descriptor()
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor()
usagelogDescCacheCreation1hTokens := usagelogFields[13].Descriptor()
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
// usagelogDescInputCost is the schema descriptor for input_cost field.
usagelogDescInputCost := usagelogFields[13].Descriptor()
usagelogDescInputCost := usagelogFields[14].Descriptor()
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
// usagelogDescOutputCost is the schema descriptor for output_cost field.
usagelogDescOutputCost := usagelogFields[14].Descriptor()
usagelogDescOutputCost := usagelogFields[15].Descriptor()
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
usagelogDescCacheCreationCost := usagelogFields[15].Descriptor()
usagelogDescCacheCreationCost := usagelogFields[16].Descriptor()
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
usagelogDescCacheReadCost := usagelogFields[16].Descriptor()
usagelogDescCacheReadCost := usagelogFields[17].Descriptor()
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
// usagelogDescTotalCost is the schema descriptor for total_cost field.
usagelogDescTotalCost := usagelogFields[17].Descriptor()
usagelogDescTotalCost := usagelogFields[18].Descriptor()
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
// usagelogDescActualCost is the schema descriptor for actual_cost field.
usagelogDescActualCost := usagelogFields[18].Descriptor()
usagelogDescActualCost := usagelogFields[19].Descriptor()
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
usagelogDescRateMultiplier := usagelogFields[19].Descriptor()
usagelogDescRateMultiplier := usagelogFields[20].Descriptor()
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
// usagelogDescBillingType is the schema descriptor for billing_type field.
usagelogDescBillingType := usagelogFields[21].Descriptor()
usagelogDescBillingType := usagelogFields[22].Descriptor()
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
// usagelogDescStream is the schema descriptor for stream field.
usagelogDescStream := usagelogFields[22].Descriptor()
usagelogDescStream := usagelogFields[23].Descriptor()
// usagelog.DefaultStream holds the default value on creation for the stream field.
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
// usagelogDescUserAgent is the schema descriptor for user_agent field.
usagelogDescUserAgent := usagelogFields[25].Descriptor()
usagelogDescUserAgent := usagelogFields[26].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field.
usagelogDescIPAddress := usagelogFields[26].Descriptor()
usagelogDescIPAddress := usagelogFields[27].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field.
usagelogDescImageCount := usagelogFields[27].Descriptor()
usagelogDescImageCount := usagelogFields[28].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field.
usagelogDescImageSize := usagelogFields[28].Descriptor()
usagelogDescImageSize := usagelogFields[29].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescMediaType is the schema descriptor for media_type field.
usagelogDescMediaType := usagelogFields[29].Descriptor()
usagelogDescMediaType := usagelogFields[30].Descriptor()
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor()
usagelogDescCacheTTLOverridden := usagelogFields[31].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[31].Descriptor()
usagelogDescCreatedAt := usagelogFields[32].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()

View File

@@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field {
field.String("model").
MaxLen(100).
NotEmpty(),
// UpstreamModel stores the actual upstream model name when model mapping
// is applied. NULL means no mapping — the requested model was used as-is.
field.String("upstream_model").
MaxLen(100).
Optional().
Nillable(),
field.Int64("group_id").
Optional().
Nillable(),

View File

@@ -32,6 +32,8 @@ type UsageLog struct {
RequestID string `json:"request_id,omitempty"`
// Model holds the value of the "model" field.
Model string `json:"model,omitempty"`
// UpstreamModel holds the value of the "upstream_model" field.
UpstreamModel *string `json:"upstream_model,omitempty"`
// GroupID holds the value of the "group_id" field.
GroupID *int64 `json:"group_id,omitempty"`
// SubscriptionID holds the value of the "subscription_id" field.
@@ -175,7 +177,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime)
@@ -230,6 +232,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Model = value.String
}
case usagelog.FieldUpstreamModel:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field upstream_model", values[i])
} else if value.Valid {
_m.UpstreamModel = new(string)
*_m.UpstreamModel = value.String
}
case usagelog.FieldGroupID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field group_id", values[i])
@@ -477,6 +486,11 @@ func (_m *UsageLog) String() string {
builder.WriteString("model=")
builder.WriteString(_m.Model)
builder.WriteString(", ")
if v := _m.UpstreamModel; v != nil {
builder.WriteString("upstream_model=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.GroupID; v != nil {
builder.WriteString("group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))

View File

@@ -24,6 +24,8 @@ const (
FieldRequestID = "request_id"
// FieldModel holds the string denoting the model field in the database.
FieldModel = "model"
// FieldUpstreamModel holds the string denoting the upstream_model field in the database.
FieldUpstreamModel = "upstream_model"
// FieldGroupID holds the string denoting the group_id field in the database.
FieldGroupID = "group_id"
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
@@ -135,6 +137,7 @@ var Columns = []string{
FieldAccountID,
FieldRequestID,
FieldModel,
FieldUpstreamModel,
FieldGroupID,
FieldSubscriptionID,
FieldInputTokens,
@@ -179,6 +182,8 @@ var (
RequestIDValidator func(string) error
// ModelValidator is a validator for the "model" field. It is called by the builders before save.
ModelValidator func(string) error
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
UpstreamModelValidator func(string) error
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
DefaultInputTokens int
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
@@ -258,6 +263,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldModel, opts...).ToFunc()
}
// ByUpstreamModel orders the results by the upstream_model field.
func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()
}
// ByGroupID orders the results by the group_id field.
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldGroupID, opts...).ToFunc()

View File

@@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldModel, v))
}
// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ.
func UpstreamModel(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
}
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
func GroupID(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
@@ -405,6 +410,81 @@ func ModelContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v))
}
// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field.
func UpstreamModelEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
}
// UpstreamModelNEQ applies the NEQ predicate on the "upstream_model" field.
func UpstreamModelNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldUpstreamModel, v))
}
// UpstreamModelIn applies the In predicate on the "upstream_model" field.
func UpstreamModelIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldUpstreamModel, vs...))
}
// UpstreamModelNotIn applies the NotIn predicate on the "upstream_model" field.
func UpstreamModelNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldUpstreamModel, vs...))
}
// UpstreamModelGT applies the GT predicate on the "upstream_model" field.
func UpstreamModelGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldUpstreamModel, v))
}
// UpstreamModelGTE applies the GTE predicate on the "upstream_model" field.
func UpstreamModelGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldUpstreamModel, v))
}
// UpstreamModelLT applies the LT predicate on the "upstream_model" field.
func UpstreamModelLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldUpstreamModel, v))
}
// UpstreamModelLTE applies the LTE predicate on the "upstream_model" field.
func UpstreamModelLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldUpstreamModel, v))
}
// UpstreamModelContains applies the Contains predicate on the "upstream_model" field.
func UpstreamModelContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldUpstreamModel, v))
}
// UpstreamModelHasPrefix applies the HasPrefix predicate on the "upstream_model" field.
func UpstreamModelHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldUpstreamModel, v))
}
// UpstreamModelHasSuffix applies the HasSuffix predicate on the "upstream_model" field.
func UpstreamModelHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldUpstreamModel, v))
}
// UpstreamModelIsNil applies the IsNil predicate on the "upstream_model" field.
func UpstreamModelIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldUpstreamModel))
}
// UpstreamModelNotNil applies the NotNil predicate on the "upstream_model" field.
func UpstreamModelNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldUpstreamModel))
}
// UpstreamModelEqualFold applies the EqualFold predicate on the "upstream_model" field.
func UpstreamModelEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldUpstreamModel, v))
}
// UpstreamModelContainsFold applies the ContainsFold predicate on the "upstream_model" field.
func UpstreamModelContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v))
}
// GroupIDEQ applies the EQ predicate on the "group_id" field.
func GroupIDEQ(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))

View File

@@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate {
return _c
}
// SetUpstreamModel sets the "upstream_model" field.
func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate {
_c.mutation.SetUpstreamModel(v)
return _c
}
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate {
if v != nil {
_c.SetUpstreamModel(*v)
}
return _c
}
// SetGroupID sets the "group_id" field.
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
_c.mutation.SetGroupID(v)
@@ -596,6 +610,11 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
}
}
if v, ok := _c.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
if _, ok := _c.mutation.InputTokens(); !ok {
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
}
@@ -714,6 +733,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
_node.Model = value
}
if value, ok := _c.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
_node.UpstreamModel = &value
}
if value, ok := _c.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
_node.InputTokens = value
@@ -1011,6 +1034,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert {
return u
}
// SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert {
u.Set(usagelog.FieldUpstreamModel, v)
return u
}
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateUpstreamModel() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldUpstreamModel)
return u
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert {
u.SetNull(usagelog.FieldUpstreamModel)
return u
}
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
u.Set(usagelog.FieldGroupID, v)
@@ -1600,6 +1641,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne {
})
}
// SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetUpstreamModel(v)
})
}
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateUpstreamModel() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateUpstreamModel()
})
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearUpstreamModel()
})
}
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@@ -2434,6 +2496,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk {
})
}
// SetUpstreamModel sets the "upstream_model" field.
func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetUpstreamModel(v)
})
}
// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateUpstreamModel() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateUpstreamModel()
})
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearUpstreamModel()
})
}
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {

View File

@@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate {
return _u
}
// SetUpstreamModel sets the "upstream_model" field.
func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate {
_u.mutation.SetUpstreamModel(v)
return _u
}
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableUpstreamModel(v *string) *UsageLogUpdate {
if v != nil {
_u.SetUpstreamModel(*v)
}
return _u
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate {
_u.mutation.ClearUpstreamModel()
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
_u.mutation.SetGroupID(v)
@@ -745,6 +765,11 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
}
}
if v, ok := _u.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@@ -795,6 +820,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Model(); ok {
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
}
if value, ok := _u.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
}
if _u.mutation.UpstreamModelCleared() {
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
}
if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
}
@@ -1177,6 +1208,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne {
return _u
}
// SetUpstreamModel sets the "upstream_model" field.
func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne {
_u.mutation.SetUpstreamModel(v)
return _u
}
// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableUpstreamModel(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetUpstreamModel(*v)
}
return _u
}
// ClearUpstreamModel clears the value of the "upstream_model" field.
func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne {
_u.mutation.ClearUpstreamModel()
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
_u.mutation.SetGroupID(v)
@@ -1833,6 +1884,11 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)}
}
}
if v, ok := _u.mutation.UpstreamModel(); ok {
if err := usagelog.UpstreamModelValidator(v); err != nil {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@@ -1900,6 +1956,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if value, ok := _u.mutation.Model(); ok {
_spec.SetField(usagelog.FieldModel, field.TypeString, value)
}
if value, ok := _u.mutation.UpstreamModel(); ok {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
}
if _u.mutation.UpstreamModelCleared() {
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
}
if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
}

View File

@@ -22,8 +22,6 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
@@ -60,8 +58,6 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWA
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
@@ -98,10 +94,6 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
@@ -238,8 +230,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -273,8 +263,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -326,8 +314,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=

View File

@@ -82,8 +82,8 @@ var DefaultAntigravityModelMapping = map[string]string{
"claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
// Claude Haiku → Sonnet无 Haiku 支持)
"claude-haiku-4-5": "claude-sonnet-4-5",
"claude-haiku-4-5-20251001": "claude-sonnet-4-5",
"claude-haiku-4-5": "claude-sonnet-4-6",
"claude-haiku-4-5-20251001": "claude-sonnet-4-6",
// Gemini 2.5 白名单
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-flash-image": "gemini-2.5-flash-image",

View File

@@ -17,7 +17,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
adminSvc := newStubAdminService()
userHandler := NewUserHandler(adminSvc, nil)
groupHandler := NewGroupHandler(adminSvc)
groupHandler := NewGroupHandler(adminSvc, nil, nil)
proxyHandler := NewProxyHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc, nil)

View File

@@ -98,12 +98,12 @@ func (h *BackupHandler) CreateBackup(c *gin.Context) {
expireDays = *req.ExpireDays
}
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
record, err := h.backupService.StartBackup(c.Request.Context(), "manual", expireDays)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, record)
response.Accepted(c, record)
}
func (h *BackupHandler) ListBackups(c *gin.Context) {
@@ -196,9 +196,10 @@ func (h *BackupHandler) RestoreBackup(c *gin.Context) {
return
}
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
record, err := h.backupService.StartRestore(c.Request.Context(), backupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"restored": true})
response.Accepted(c, record)
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -272,6 +273,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
modelSource := usagestats.ModelSourceRequested
var requestType *int16
var stream *bool
var billingType *int8
@@ -296,6 +298,13 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
groupID = id
}
}
if rawModelSource := strings.TrimSpace(c.Query("model_source")); rawModelSource != "" {
if !usagestats.IsValidModelSource(rawModelSource) {
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
return
}
modelSource = rawModelSource
}
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
@@ -322,7 +331,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
}
}
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, modelSource, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
@@ -512,6 +521,8 @@ func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
payload := gin.H{
"ranking": ranking.Ranking,
"total_actual_cost": ranking.TotalActualCost,
"total_requests": ranking.TotalRequests,
"total_tokens": ranking.TotalTokens,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
}
@@ -602,3 +613,47 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
// GetUserBreakdown handles getting per-user usage breakdown within a dimension.
// GET /api/v1/admin/dashboard/user-breakdown
// Query params: start_date, end_date, group_id, model, endpoint, endpoint_type, limit
func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
dim := usagestats.UserBreakdownDimension{}
if v := c.Query("group_id"); v != "" {
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
dim.GroupID = id
}
}
dim.Model = c.Query("model")
rawModelSource := strings.TrimSpace(c.DefaultQuery("model_source", usagestats.ModelSourceRequested))
if !usagestats.IsValidModelSource(rawModelSource) {
response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping")
return
}
dim.ModelType = rawModelSource
dim.Endpoint = c.Query("endpoint")
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
limit := 50
if v := c.Query("limit"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
limit = n
}
}
stats, err := h.dashboardService.GetUserBreakdownStats(
c.Request.Context(), startTime, endTime, dim, limit,
)
if err != nil {
response.Error(c, 500, "Failed to get user breakdown stats")
return
}
response.Success(c, gin.H{
"users": stats,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
})
}

View File

@@ -61,6 +61,8 @@ func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
return &usagestats.UserSpendingRankingResponse{
Ranking: s.ranking,
TotalActualCost: s.rankingTotal,
TotalRequests: 44,
TotalTokens: 1234,
}, nil
}
@@ -147,6 +149,28 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsInvalidModelSource(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=invalid", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsValidModelSource(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=upstream", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
repo := &dashboardUsageRepoCapture{
@@ -164,6 +188,8 @@ func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 50, repo.rankingLimit)
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
require.Contains(t, rec.Body.String(), "\"total_requests\":44")
require.Contains(t, rec.Body.String(), "\"total_tokens\":1234")
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)

View File

@@ -0,0 +1,229 @@
package admin
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// --- mock repo ---
type userBreakdownRepoCapture struct {
service.UsageLogRepository
capturedDim usagestats.UserBreakdownDimension
capturedLimit int
result []usagestats.UserBreakdownItem
}
func (r *userBreakdownRepoCapture) GetUserBreakdownStats(
_ context.Context, _, _ time.Time,
dim usagestats.UserBreakdownDimension, limit int,
) ([]usagestats.UserBreakdownItem, error) {
r.capturedDim = dim
r.capturedLimit = limit
if r.result != nil {
return r.result, nil
}
return []usagestats.UserBreakdownItem{}, nil
}
func newUserBreakdownRouter(repo *userBreakdownRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
svc := service.NewDashboardService(repo, nil, nil, nil)
h := NewDashboardHandler(svc, nil)
router := gin.New()
router.GET("/admin/dashboard/user-breakdown", h.GetUserBreakdown)
return router
}
// --- tests ---
func TestGetUserBreakdown_GroupIDFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=42", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, int64(42), repo.capturedDim.GroupID)
require.Empty(t, repo.capturedDim.Model)
require.Empty(t, repo.capturedDim.Endpoint)
require.Equal(t, 50, repo.capturedLimit) // default limit
}
func TestGetUserBreakdown_ModelFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model)
require.Equal(t, usagestats.ModelSourceRequested, repo.capturedDim.ModelType)
require.Equal(t, int64(0), repo.capturedDim.GroupID)
}
func TestGetUserBreakdown_ModelSourceFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6&model_source=upstream", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, usagestats.ModelSourceUpstream, repo.capturedDim.ModelType)
}
func TestGetUserBreakdown_InvalidModelSource(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model_source=foobar", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code)
}
func TestGetUserBreakdown_EndpointFilter(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/v1/messages&endpoint_type=upstream", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, "/v1/messages", repo.capturedDim.Endpoint)
require.Equal(t, "upstream", repo.capturedDim.EndpointType)
}
func TestGetUserBreakdown_DefaultEndpointType(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&endpoint=/chat", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, "inbound", repo.capturedDim.EndpointType)
}
func TestGetUserBreakdown_CustomLimit(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=100", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, 100, repo.capturedLimit)
}
func TestGetUserBreakdown_LimitClamped(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
// limit > 200 should fall back to default 50
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=test&limit=999", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, 50, repo.capturedLimit)
}
func TestGetUserBreakdown_ResponseFormat(t *testing.T) {
repo := &userBreakdownRepoCapture{
result: []usagestats.UserBreakdownItem{
{UserID: 1, Email: "alice@test.com", Requests: 100, TotalTokens: 50000, Cost: 1.5, ActualCost: 1.2},
{UserID: 2, Email: "bob@test.com", Requests: 50, TotalTokens: 25000, Cost: 0.8, ActualCost: 0.6},
},
}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=1", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
var resp struct {
Code int `json:"code"`
Data struct {
Users []usagestats.UserBreakdownItem `json:"users"`
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
} `json:"data"`
}
err := json.Unmarshal(w.Body.Bytes(), &resp)
require.NoError(t, err)
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Users, 2)
require.Equal(t, int64(1), resp.Data.Users[0].UserID)
require.Equal(t, "alice@test.com", resp.Data.Users[0].Email)
require.Equal(t, int64(100), resp.Data.Users[0].Requests)
require.InDelta(t, 1.2, resp.Data.Users[0].ActualCost, 0.001)
require.Equal(t, "2026-03-01", resp.Data.StartDate)
require.Equal(t, "2026-03-16", resp.Data.EndDate)
}
func TestGetUserBreakdown_EmptyResult(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&group_id=999", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
var resp struct {
Data struct {
Users []usagestats.UserBreakdownItem `json:"users"`
} `json:"data"`
}
err := json.Unmarshal(w.Body.Bytes(), &resp)
require.NoError(t, err)
require.Empty(t, resp.Data.Users)
}
func TestGetUserBreakdown_NoFilters(t *testing.T) {
repo := &userBreakdownRepoCapture{}
router := newUserBreakdownRouter(repo)
req := httptest.NewRequest(http.MethodGet,
"/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, int64(0), repo.capturedDim.GroupID)
require.Empty(t, repo.capturedDim.Model)
require.Empty(t, repo.capturedDim.Endpoint)
}

View File

@@ -38,6 +38,7 @@ type dashboardModelGroupCacheKey struct {
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
ModelSource string `json:"model_source,omitempty"`
RequestType *int16 `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
@@ -111,6 +112,7 @@ func (h *DashboardHandler) getModelStatsCached(
ctx context.Context,
startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64,
modelSource string,
requestType *int16,
stream *bool,
billingType *int8,
@@ -122,12 +124,13 @@ func (h *DashboardHandler) getModelStatsCached(
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
ModelSource: usagestats.NormalizeModelSource(modelSource),
RequestType: requestType,
Stream: stream,
BillingType: billingType,
})
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
return h.dashboardService.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, modelSource)
})
if err != nil {
return nil, hit, err

View File

@@ -200,6 +200,7 @@ func (h *DashboardHandler) buildSnapshotV2Response(
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
usagestats.ModelSourceRequested,
filters.RequestType,
filters.Stream,
filters.BillingType,

View File

@@ -1,11 +1,15 @@
package admin
import (
"bytes"
"encoding/json"
"fmt"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -13,27 +17,80 @@ import (
// GroupHandler handles admin group management
type GroupHandler struct {
adminService service.AdminService
adminService service.AdminService
dashboardService *service.DashboardService
groupCapacityService *service.GroupCapacityService
}
type optionalLimitField struct {
set bool
value *float64
}
func (f *optionalLimitField) UnmarshalJSON(data []byte) error {
f.set = true
trimmed := bytes.TrimSpace(data)
if bytes.Equal(trimmed, []byte("null")) {
f.value = nil
return nil
}
var number float64
if err := json.Unmarshal(trimmed, &number); err == nil {
f.value = &number
return nil
}
var text string
if err := json.Unmarshal(trimmed, &text); err == nil {
text = strings.TrimSpace(text)
if text == "" {
f.value = nil
return nil
}
number, err = strconv.ParseFloat(text, 64)
if err != nil {
return fmt.Errorf("invalid numeric limit value %q: %w", text, err)
}
f.value = &number
return nil
}
return fmt.Errorf("invalid limit value: %s", string(trimmed))
}
func (f optionalLimitField) ToServiceInput() *float64 {
if !f.set {
return nil
}
if f.value != nil {
return f.value
}
zero := 0.0
return &zero
}
// NewGroupHandler creates a new admin group handler
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
func NewGroupHandler(adminService service.AdminService, dashboardService *service.DashboardService, groupCapacityService *service.GroupCapacityService) *GroupHandler {
return &GroupHandler{
adminService: adminService,
adminService: adminService,
dashboardService: dashboardService,
groupCapacityService: groupCapacityService,
}
}
// CreateGroupRequest represents create group request
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
@@ -62,16 +119,16 @@ type CreateGroupRequest struct {
// UpdateGroupRequest represents update group request
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
@@ -191,9 +248,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
@@ -244,9 +301,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
IsExclusive: req.IsExclusive,
Status: req.Status,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
@@ -311,6 +368,33 @@ func (h *GroupHandler) GetStats(c *gin.Context) {
_ = groupID // TODO: implement actual stats
}
// GetUsageSummary returns today's and cumulative cost for all groups.
// GET /api/v1/admin/groups/usage-summary?timezone=Asia/Shanghai
func (h *GroupHandler) GetUsageSummary(c *gin.Context) {
userTZ := c.Query("timezone")
now := timezone.NowInUserLocation(userTZ)
todayStart := timezone.StartOfDayInUserLocation(now, userTZ)
results, err := h.dashboardService.GetGroupUsageSummary(c.Request.Context(), todayStart)
if err != nil {
response.Error(c, 500, "Failed to get group usage summary")
return
}
response.Success(c, results)
}
// GetCapacitySummary returns aggregated capacity (concurrency/sessions/RPM) for all active groups.
// GET /api/v1/admin/groups/capacity-summary
func (h *GroupHandler) GetCapacitySummary(c *gin.Context) {
results, err := h.groupCapacityService.GetAllGroupCapacity(c.Request.Context())
if err != nil {
response.Error(c, 500, "Failed to get group capacity summary")
return
}
response.Success(c, results)
}
// GetGroupAPIKeys handles getting API keys in a group
// GET /api/v1/admin/groups/:id/api-keys
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {

View File

@@ -80,6 +80,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
FrontendURL: settings.FrontendURL,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
@@ -137,6 +138,7 @@ type UpdateSettingsRequest struct {
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
FrontendURL string `json:"frontend_url"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
@@ -326,6 +328,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// Frontend URL 验证
req.FrontendURL = strings.TrimSpace(req.FrontendURL)
if req.FrontendURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.FrontendURL); err != nil {
response.BadRequest(c, "Frontend URL must be an absolute http(s) URL")
return
}
}
// 自定义菜单项验证
const (
maxCustomMenuItems = 20
@@ -437,6 +448,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
FrontendURL: req.FrontendURL,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost,
@@ -531,6 +543,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
FrontendURL: updatedSettings.FrontendURL,
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
TotpEnabled: updatedSettings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
@@ -614,6 +627,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled")
}
if before.FrontendURL != after.FrontendURL {
changed = append(changed, "frontend_url")
}
if before.TotpEnabled != after.TotpEnabled {
changed = append(changed, "totp_enabled")
}
@@ -961,6 +977,58 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
response.Success(c, gin.H{"message": "Admin API key deleted"})
}
// GetOverloadCooldownSettings 获取529过载冷却配置
// GET /api/v1/admin/settings/overload-cooldown
func (h *SettingHandler) GetOverloadCooldownSettings(c *gin.Context) {
settings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.OverloadCooldownSettings{
Enabled: settings.Enabled,
CooldownMinutes: settings.CooldownMinutes,
})
}
// UpdateOverloadCooldownSettingsRequest 更新529过载冷却配置请求
type UpdateOverloadCooldownSettingsRequest struct {
Enabled bool `json:"enabled"`
CooldownMinutes int `json:"cooldown_minutes"`
}
// UpdateOverloadCooldownSettings 更新529过载冷却配置
// PUT /api/v1/admin/settings/overload-cooldown
func (h *SettingHandler) UpdateOverloadCooldownSettings(c *gin.Context) {
var req UpdateOverloadCooldownSettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
settings := &service.OverloadCooldownSettings{
Enabled: req.Enabled,
CooldownMinutes: req.CooldownMinutes,
}
if err := h.settingService.SetOverloadCooldownSettings(c.Request.Context(), settings); err != nil {
response.BadRequest(c, err.Error())
return
}
updatedSettings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.OverloadCooldownSettings{
Enabled: updatedSettings.Enabled,
CooldownMinutes: updatedSettings.CooldownMinutes,
})
}
// GetStreamTimeoutSettings 获取流超时处理配置
// GET /api/v1/admin/settings/stream-timeout
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {

View File

@@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
}
}
status := c.Query("status")
platform := c.Query("platform")
// Parse sorting parameters
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, platform, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -159,8 +159,8 @@ func (h *UsageHandler) List(c *gin.Context) {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
// Set end time to end of day
t = t.Add(24*time.Hour - time.Nanosecond)
// Use half-open range [start, end), move to next calendar day start (DST-safe).
t = t.AddDate(0, 0, 1)
endTime = &t
}
@@ -285,7 +285,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界DST-safe
endTime = endTime.AddDate(0, 0, 1)
} else {
period := c.DefaultQuery("period", "today")
switch period {

View File

@@ -459,9 +459,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
return
}
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
frontendBaseURL := strings.TrimSpace(h.settingSvc.GetFrontendURL(c.Request.Context()))
if frontendBaseURL == "" {
slog.Error("server.frontend_url not configured; cannot build password reset link")
slog.Error("frontend_url not configured in settings or config; cannot build password reset link")
response.InternalError(c, "Password reset is not configured")
return
}

View File

@@ -135,14 +135,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
return nil
}
out := &AdminGroup{
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
SortOrder: g.SortOrder,
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
ActiveAccountCount: g.ActiveAccountCount,
RateLimitedAccountCount: g.RateLimitedAccountCount,
SortOrder: g.SortOrder,
}
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
@@ -521,8 +523,11 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
AccountID: l.AccountID,
RequestID: l.RequestID,
Model: l.Model,
UpstreamModel: l.UpstreamModel,
ServiceTier: l.ServiceTier,
ReasoningEffort: l.ReasoningEffort,
InboundEndpoint: l.InboundEndpoint,
UpstreamEndpoint: l.UpstreamEndpoint,
GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID,
InputTokens: l.InputTokens,

View File

@@ -76,10 +76,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
t.Parallel()
serviceTier := "priority"
inboundEndpoint := "/v1/chat/completions"
upstreamEndpoint := "/v1/responses"
log := &service.UsageLog{
RequestID: "req_3",
Model: "gpt-5.4",
ServiceTier: &serviceTier,
InboundEndpoint: &inboundEndpoint,
UpstreamEndpoint: &upstreamEndpoint,
AccountRateMultiplier: f64Ptr(1.5),
}
@@ -88,8 +92,16 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
require.NotNil(t, userDTO.ServiceTier)
require.Equal(t, serviceTier, *userDTO.ServiceTier)
require.NotNil(t, userDTO.InboundEndpoint)
require.Equal(t, inboundEndpoint, *userDTO.InboundEndpoint)
require.NotNil(t, userDTO.UpstreamEndpoint)
require.Equal(t, upstreamEndpoint, *userDTO.UpstreamEndpoint)
require.NotNil(t, adminDTO.ServiceTier)
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
require.NotNil(t, adminDTO.InboundEndpoint)
require.Equal(t, inboundEndpoint, *adminDTO.InboundEndpoint)
require.NotNil(t, adminDTO.UpstreamEndpoint)
require.Equal(t, upstreamEndpoint, *adminDTO.UpstreamEndpoint)
require.NotNil(t, adminDTO.AccountRateMultiplier)
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
}

View File

@@ -22,6 +22,7 @@ type SystemSettings struct {
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
FrontendURL string `json:"frontend_url"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
@@ -156,6 +157,12 @@ type ListSoraS3ProfilesResponse struct {
Items []SoraS3Profile `json:"items"`
}
// OverloadCooldownSettings 529过载冷却配置 DTO
type OverloadCooldownSettings struct {
Enabled bool `json:"enabled"`
CooldownMinutes int `json:"cooldown_minutes"`
}
// StreamTimeoutSettings 流超时处理配置 DTO
type StreamTimeoutSettings struct {
Enabled bool `json:"enabled"`

View File

@@ -122,9 +122,11 @@ type AdminGroup struct {
DefaultMappedModel string `json:"default_mapped_model"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
SupportedModelScopes []string `json:"supported_model_scopes"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
ActiveAccountCount int64 `json:"active_account_count,omitempty"`
RateLimitedAccountCount int64 `json:"rate_limited_account_count,omitempty"`
// 分组排序
SortOrder int `json:"sort_order"`
@@ -332,11 +334,18 @@ type UsageLog struct {
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
// UpstreamModel is the actual model sent to the upstream provider after mapping.
// Omitted when no mapping was applied (requested model was used as-is).
UpstreamModel *string `json:"upstream_model,omitempty"`
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string `json:"service_tier,omitempty"`
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
// nil means not provided / not applicable.
// ReasoningEffort is the request's reasoning effort level.
// OpenAI: "low"/"medium"/"high"/"xhigh"; Claude: "low"/"medium"/"high"/"max".
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
InboundEndpoint *string `json:"inbound_endpoint,omitempty"`
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
UpstreamEndpoint *string `json:"upstream_endpoint,omitempty"`
GroupID *int64 `json:"group_id"`
SubscriptionID *int64 `json:"subscription_id"`

View File

@@ -0,0 +1,174 @@
package handler
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ──────────────────────────────────────────────────────────
// Canonical inbound / upstream endpoint paths.
// All normalization and derivation reference this single set
// of constants — add new paths HERE when a new API surface
// is introduced.
// ──────────────────────────────────────────────────────────
const (
EndpointMessages = "/v1/messages"
EndpointChatCompletions = "/v1/chat/completions"
EndpointResponses = "/v1/responses"
EndpointGeminiModels = "/v1beta/models"
)
// gin.Context keys used by the middleware and helpers below.
const (
ctxKeyInboundEndpoint = "_gateway_inbound_endpoint"
)
// ──────────────────────────────────────────────────────────
// Normalization functions
// ──────────────────────────────────────────────────────────
// NormalizeInboundEndpoint maps a raw request path (which may carry
// prefixes like /antigravity, /openai, /sora) to its canonical form.
//
// "/antigravity/v1/messages" → "/v1/messages"
// "/v1/chat/completions" → "/v1/chat/completions"
// "/openai/v1/responses/foo" → "/v1/responses"
// "/v1beta/models/gemini:gen" → "/v1beta/models"
func NormalizeInboundEndpoint(path string) string {
path = strings.TrimSpace(path)
switch {
case strings.Contains(path, EndpointChatCompletions):
return EndpointChatCompletions
case strings.Contains(path, EndpointMessages):
return EndpointMessages
case strings.Contains(path, EndpointResponses):
return EndpointResponses
case strings.Contains(path, EndpointGeminiModels):
return EndpointGeminiModels
default:
return path
}
}
// DeriveUpstreamEndpoint determines the upstream endpoint from the
// account platform and the normalized inbound endpoint.
//
// Platform-specific rules:
// - OpenAI always forwards to /v1/responses (with optional subpath
// such as /v1/responses/compact preserved from the raw URL).
// - Anthropic → /v1/messages
// - Gemini → /v1beta/models
// - Sora → /v1/chat/completions
// - Antigravity routes may target either Claude or Gemini, so the
// inbound endpoint is used to distinguish.
func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
inbound = strings.TrimSpace(inbound)
switch platform {
case service.PlatformOpenAI:
// OpenAI forwards everything to the Responses API.
// Preserve subresource suffix (e.g. /v1/responses/compact).
if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" {
return EndpointResponses + suffix
}
return EndpointResponses
case service.PlatformAnthropic:
return EndpointMessages
case service.PlatformGemini:
return EndpointGeminiModels
case service.PlatformSora:
return EndpointChatCompletions
case service.PlatformAntigravity:
// Antigravity accounts serve both Claude and Gemini.
if inbound == EndpointGeminiModels {
return EndpointGeminiModels
}
return EndpointMessages
}
// Unknown platform — fall back to inbound.
return inbound
}
// responsesSubpathSuffix extracts the part after "/responses" in a raw
// request path, e.g. "/openai/v1/responses/compact" → "/compact".
// Returns "" when there is no meaningful suffix.
func responsesSubpathSuffix(rawPath string) string {
trimmed := strings.TrimRight(strings.TrimSpace(rawPath), "/")
idx := strings.LastIndex(trimmed, "/responses")
if idx < 0 {
return ""
}
suffix := trimmed[idx+len("/responses"):]
if suffix == "" || suffix == "/" {
return ""
}
if !strings.HasPrefix(suffix, "/") {
return ""
}
return suffix
}
// ──────────────────────────────────────────────────────────
// Middleware
// ──────────────────────────────────────────────────────────
// InboundEndpointMiddleware normalizes the request path and stores the
// canonical inbound endpoint in gin.Context so that every handler in
// the chain can read it via GetInboundEndpoint.
//
// Apply this middleware to all gateway route groups.
func InboundEndpointMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
path := c.FullPath()
if path == "" && c.Request != nil && c.Request.URL != nil {
path = c.Request.URL.Path
}
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(path))
c.Next()
}
}
// ──────────────────────────────────────────────────────────
// Context helpers — used by handlers before building
// RecordUsageInput / RecordUsageLongContextInput.
// ──────────────────────────────────────────────────────────
// GetInboundEndpoint returns the canonical inbound endpoint stored by
// InboundEndpointMiddleware. If the middleware did not run (e.g. in
// tests), it falls back to normalizing c.FullPath() on the fly.
func GetInboundEndpoint(c *gin.Context) string {
if v, ok := c.Get(ctxKeyInboundEndpoint); ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
// Fallback: normalize on the fly.
path := ""
if c != nil {
path = c.FullPath()
if path == "" && c.Request != nil && c.Request.URL != nil {
path = c.Request.URL.Path
}
}
return NormalizeInboundEndpoint(path)
}
// GetUpstreamEndpoint derives the upstream endpoint from the context
// and the account platform. Handlers call this after scheduling an
// account, passing account.Platform.
func GetUpstreamEndpoint(c *gin.Context, platform string) string {
inbound := GetInboundEndpoint(c)
rawPath := ""
if c != nil && c.Request != nil && c.Request.URL != nil {
rawPath = c.Request.URL.Path
}
return DeriveUpstreamEndpoint(inbound, rawPath, platform)
}

View File

@@ -0,0 +1,159 @@
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func init() { gin.SetMode(gin.TestMode) }
// ──────────────────────────────────────────────────────────
// NormalizeInboundEndpoint
// ──────────────────────────────────────────────────────────
func TestNormalizeInboundEndpoint(t *testing.T) {
tests := []struct {
path string
want string
}{
// Direct canonical paths.
{"/v1/messages", EndpointMessages},
{"/v1/chat/completions", EndpointChatCompletions},
{"/v1/responses", EndpointResponses},
{"/v1beta/models", EndpointGeminiModels},
// Prefixed paths (antigravity, openai, sora).
{"/antigravity/v1/messages", EndpointMessages},
{"/openai/v1/responses", EndpointResponses},
{"/openai/v1/responses/compact", EndpointResponses},
{"/sora/v1/chat/completions", EndpointChatCompletions},
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
// Gin route patterns with wildcards.
{"/v1beta/models/*modelAction", EndpointGeminiModels},
{"/v1/responses/*subpath", EndpointResponses},
// Unknown path is returned as-is.
{"/v1/embeddings", "/v1/embeddings"},
{"", ""},
{" /v1/messages ", EndpointMessages},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
require.Equal(t, tt.want, NormalizeInboundEndpoint(tt.path))
})
}
}
// ──────────────────────────────────────────────────────────
// DeriveUpstreamEndpoint
// ──────────────────────────────────────────────────────────
func TestDeriveUpstreamEndpoint(t *testing.T) {
tests := []struct {
name string
inbound string
rawPath string
platform string
want string
}{
// Anthropic.
{"anthropic messages", EndpointMessages, "/v1/messages", service.PlatformAnthropic, EndpointMessages},
// Gemini.
{"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels},
// Sora.
{"sora completions", EndpointChatCompletions, "/sora/v1/chat/completions", service.PlatformSora, EndpointChatCompletions},
// OpenAI — always /v1/responses.
{"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses},
{"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"},
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
// Antigravity — uses inbound to pick Claude vs Gemini upstream.
{"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages},
{"antigravity gemini", EndpointGeminiModels, "/antigravity/v1beta/models", service.PlatformAntigravity, EndpointGeminiModels},
// Unknown platform — passthrough.
{"unknown platform", "/v1/embeddings", "/v1/embeddings", "unknown", "/v1/embeddings"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, DeriveUpstreamEndpoint(tt.inbound, tt.rawPath, tt.platform))
})
}
}
// ──────────────────────────────────────────────────────────
// responsesSubpathSuffix
// ──────────────────────────────────────────────────────────
func TestResponsesSubpathSuffix(t *testing.T) {
tests := []struct {
raw string
want string
}{
{"/v1/responses", ""},
{"/v1/responses/", ""},
{"/v1/responses/compact", "/compact"},
{"/openai/v1/responses/compact/detail", "/compact/detail"},
{"/v1/messages", ""},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.raw, func(t *testing.T) {
require.Equal(t, tt.want, responsesSubpathSuffix(tt.raw))
})
}
}
// ──────────────────────────────────────────────────────────
// InboundEndpointMiddleware + context helpers
// ──────────────────────────────────────────────────────────
func TestInboundEndpointMiddleware(t *testing.T) {
router := gin.New()
router.Use(InboundEndpointMiddleware())
var captured string
router.POST("/v1/messages", func(c *gin.Context) {
captured = GetInboundEndpoint(c)
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, EndpointMessages, captured)
}
func TestGetInboundEndpoint_FallbackWithoutMiddleware(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/antigravity/v1/messages", nil)
// Middleware did not run — fallback to normalizing c.Request.URL.Path.
got := GetInboundEndpoint(c)
require.Equal(t, EndpointMessages, got)
}
func TestGetUpstreamEndpoint_FullFlow(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses/compact", nil)
// Simulate middleware.
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(c.Request.URL.Path))
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
require.Equal(t, "/v1/responses/compact", got)
}

View File

@@ -391,6 +391,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if fs.SwitchCount > 0 {
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
}
// 记录 Forward 前已写入字节数Forward 后若增加则说明 SSE 内容已发,禁止 failover
writerSizeBeforeForward := c.Writer.Size()
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
} else {
@@ -402,6 +404,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
if c.Writer.Size() != writerSizeBeforeForward {
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
@@ -435,6 +442,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
if result.ReasoningEffort == nil {
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
@@ -444,6 +457,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
@@ -637,6 +652,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if fs.SwitchCount > 0 {
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
}
// 记录 Forward 前已写入字节数Forward 后若增加则说明 SSE 内容已发,禁止 failover
writerSizeBeforeForward := c.Writer.Size()
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
} else {
@@ -706,6 +723,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
if c.Writer.Size() != writerSizeBeforeForward {
h.handleFailoverExhausted(c, failoverErr, account.Platform, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
@@ -739,6 +761,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
if result.ReasoningEffort == nil {
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
@@ -748,6 +776,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
User: currentAPIKey.User,
Account: account,
Subscription: currentSubscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
@@ -913,7 +943,7 @@ func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Ti
}
if s := c.Query("end_date"); s != "" {
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
endTime = t.Add(24*time.Hour - time.Second) // end of day
endTime = t.AddDate(0, 0, 1) // half-open range upper bound
}
}
return startTime, endTime

View File

@@ -0,0 +1,122 @@
package handler
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// partialMessageStartSSE 模拟 handleStreamingResponse 已写入的首批 SSE 事件。
const partialMessageStartSSE = "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_01\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-5\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":1}}}\n\n" +
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"
// TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten 验证:
// 当 Forward 在返回 UpstreamFailoverError 前已向客户端写入 SSE 内容时,
// 故障转移保护逻辑必须终止循环并发送 SSE 错误事件,而不是进行下一次 Forward。
// 具体验证:
// 1. c.Writer.Size() 检测条件正确触发(字节数已增加)
// 2. handleFailoverExhausted 以 streamStarted=true 调用后,响应体以 SSE 错误事件结尾
// 3. 响应体中只出现一个 message_start不存在第二个防止流拼接腐化
func TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// 步骤 1记录 Forward 前的 writer size模拟 writerSizeBeforeForward := c.Writer.Size()
sizeBeforeForward := c.Writer.Size()
require.Equal(t, -1, sizeBeforeForward, "gin writer 初始 Size 应为 -1未写入任何字节")
// 步骤 2模拟 Forward 已向客户端写入部分 SSE 内容message_start + content_block_start
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
require.NoError(t, err)
// 步骤 3验证守卫条件成立c.Writer.Size() != sizeBeforeForward
require.NotEqual(t, sizeBeforeForward, c.Writer.Size(),
"写入 SSE 内容后 writer size 必须增加,守卫条件应为 true")
// 步骤 4模拟 UpstreamFailoverError上游在流中途返回 403
failoverErr := &service.UpstreamFailoverError{
StatusCode: http.StatusForbidden,
ResponseBody: []byte(`{"error":{"type":"permission_error","message":"forbidden"}}`),
}
// 步骤 5守卫触发 → 调用 handleFailoverExhaustedstreamStarted=true
h := &GatewayHandler{}
h.handleFailoverExhausted(c, failoverErr, service.PlatformAnthropic, true)
body := w.Body.String()
// 断言 A响应体中包含最初写入的 message_start SSE 事件行
require.Contains(t, body, "event: message_start", "响应体应包含已写入的 message_start SSE 事件")
// 断言 B响应体以 SSE 错误事件结尾data: {"type":"error",...}\n\n
require.True(t, strings.HasSuffix(strings.TrimRight(body, "\n"), "}"),
"响应体应以 JSON 对象结尾SSE error event 的 data 字段)")
require.Contains(t, body, `"type":"error"`, "响应体末尾必须包含 SSE 错误事件")
// 断言 CSSE event 行 "event: message_start" 只出现一次(防止双 message_start 拼接腐化)
firstIdx := strings.Index(body, "event: message_start")
lastIdx := strings.LastIndex(body, "event: message_start")
assert.Equal(t, firstIdx, lastIdx,
"响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次")
}
// TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten 与上述测试相同,
// 验证 Gemini 路径使用 service.PlatformGemini而非 account.Platform时行为一致。
func TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.0-flash:streamGenerateContent", nil)
sizeBeforeForward := c.Writer.Size()
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
require.NoError(t, err)
require.NotEqual(t, sizeBeforeForward, c.Writer.Size())
failoverErr := &service.UpstreamFailoverError{
StatusCode: http.StatusForbidden,
}
h := &GatewayHandler{}
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
body := w.Body.String()
require.Contains(t, body, "event: message_start")
require.Contains(t, body, `"type":"error"`)
firstIdx := strings.Index(body, "event: message_start")
lastIdx := strings.LastIndex(body, "event: message_start")
assert.Equal(t, firstIdx, lastIdx, "Gemini 路径不得出现双 message_start")
}
// TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered 验证反向场景:
// 当 Forward 返回 UpstreamFailoverError 时若未向客户端写入任何 SSE 内容,
// 守卫条件c.Writer.Size() != sizeBeforeForward为 false不应中止 failover。
func TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// 模拟 writerSizeBeforeForward初始为 -1
sizeBeforeForward := c.Writer.Size()
// Forward 未写入任何字节直接返回错误(例如 401 发生在连接建立前)
// c.Writer.Size() 仍为 -1
// 守卫条件sizeBeforeForward == c.Writer.Size() → 不触发
guardTriggered := c.Writer.Size() != sizeBeforeForward
require.False(t, guardTriggered,
"未写入任何字节时,守卫条件必须为 false应允许正常 failover 继续")
}

View File

@@ -76,7 +76,7 @@ func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service
return nil, nil
}
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) { return 0, 0, nil }
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
return 0, nil
}

View File

@@ -136,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte {
return []byte(`{
"model":"claude-3-5-sonnet-20241022",
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}
}`)
}
@@ -190,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
System: []any{
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
},
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
}
// body 非法 JSON如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
@@ -209,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing
"system": []any{
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
},
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
})
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)

View File

@@ -504,6 +504,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
@@ -511,6 +513,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,

View File

@@ -256,14 +256,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.chat_completions"),

View File

@@ -0,0 +1,56 @@
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint verifies that the
// unified GetUpstreamEndpoint helper produces the same results as the
// former normalizedOpenAIUpstreamEndpoint for OpenAI platform requests.
func TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
path string
want string
}{
{
name: "responses root maps to responses upstream",
path: "/v1/responses",
want: EndpointResponses,
},
{
name: "responses compact keeps compact suffix",
path: "/openai/v1/responses/compact",
want: "/v1/responses/compact",
},
{
name: "responses nested suffix preserved",
path: "/openai/v1/responses/compact/detail",
want: "/v1/responses/compact/detail",
},
{
name: "non responses path uses platform fallback",
path: "/v1/messages",
want: EndpointResponses,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
require.Equal(t, tt.want, got)
})
}
}

View File

@@ -362,6 +362,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
@@ -738,6 +740,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
@@ -1235,6 +1239,8 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),

View File

@@ -26,6 +26,22 @@ const (
opsStreamKey = "ops_stream"
opsRequestBodyKey = "ops_request_body"
opsAccountIDKey = "ops_account_id"
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
opsErrContextCanceled = "context canceled"
opsErrNoAvailableAccounts = "no available accounts"
opsErrInvalidAPIKey = "invalid_api_key"
opsErrAPIKeyRequired = "api_key_required"
opsErrInsufficientBalance = "insufficient balance"
opsErrInsufficientAccountBalance = "insufficient account balance"
opsErrInsufficientQuota = "insufficient_quota"
// 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited)
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
opsCodeUserInactive = "USER_INACTIVE"
)
const (
@@ -1024,9 +1040,9 @@ func normalizeOpsErrorType(errType string, code string) string {
return errType
}
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE":
case opsCodeInsufficientBalance:
return "billing_error"
case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
case opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
return "subscription_error"
default:
return "api_error"
@@ -1038,7 +1054,7 @@ func classifyOpsPhase(errType, message, code string) string {
// Standardized phases: request|auth|routing|upstream|network|internal
// Map billing/concurrency/response => request; scheduling => routing.
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
return "request"
}
@@ -1057,7 +1073,7 @@ func classifyOpsPhase(errType, message, code string) string {
case "upstream_error", "overloaded_error":
return "upstream"
case "api_error":
if strings.Contains(msg, "no available accounts") {
if strings.Contains(msg, opsErrNoAvailableAccounts) {
return "routing"
}
return "internal"
@@ -1103,7 +1119,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
return true
}
if phase == "billing" || phase == "concurrency" {
@@ -1197,21 +1213,30 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
// Check if context canceled errors should be ignored (client disconnects)
if settings.IgnoreContextCanceled {
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
if strings.Contains(msgLower, opsErrContextCanceled) || strings.Contains(bodyLower, opsErrContextCanceled) {
return true
}
}
// Check if "no available accounts" errors should be ignored
if settings.IgnoreNoAvailableAccounts {
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
if strings.Contains(msgLower, opsErrNoAvailableAccounts) || strings.Contains(bodyLower, opsErrNoAvailableAccounts) {
return true
}
}
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
if settings.IgnoreInvalidApiKeyErrors {
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
if strings.Contains(bodyLower, opsErrInvalidAPIKey) || strings.Contains(bodyLower, opsErrAPIKeyRequired) {
return true
}
}
// Check if insufficient balance errors should be ignored
if settings.IgnoreInsufficientBalanceErrors {
if strings.Contains(bodyLower, opsErrInsufficientBalance) || strings.Contains(bodyLower, opsErrInsufficientAccountBalance) ||
strings.Contains(bodyLower, opsErrInsufficientQuota) ||
strings.Contains(msgLower, opsErrInsufficientBalance) || strings.Contains(msgLower, opsErrInsufficientAccountBalance) {
return true
}
}

View File

@@ -400,6 +400,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
@@ -409,6 +411,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,

View File

@@ -273,8 +273,8 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, nil
}
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
@@ -334,9 +334,23 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return []usagestats.EndpointStat{}, nil
}
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return []usagestats.EndpointStat{}, nil
}
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, nil
}

View File

@@ -114,8 +114,8 @@ func (h *UsageHandler) List(c *gin.Context) {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
// Set end time to end of day
t = t.Add(24*time.Hour - time.Nanosecond)
// Use half-open range [start, end), move to next calendar day start (DST-safe).
t = t.AddDate(0, 0, 1)
endTime = &t
}
@@ -227,8 +227,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
// 设置结束时间为当天结束
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界DST-safe
endTime = endTime.AddDate(0, 0, 1)
} else {
// 使用 period 参数
period := c.DefaultQuery("period", "today")

View File

@@ -124,10 +124,68 @@ type IneligibleTier struct {
type LoadCodeAssistResponse struct {
CloudAICompanionProject string `json:"cloudaicompanionProject"`
CurrentTier *TierInfo `json:"currentTier,omitempty"`
PaidTier *TierInfo `json:"paidTier,omitempty"`
PaidTier *PaidTierInfo `json:"paidTier,omitempty"`
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
}
// PaidTierInfo 付费等级信息,包含 AI Credits 余额。
type PaidTierInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
AvailableCredits []AvailableCredit `json:"availableCredits,omitempty"`
}
// UnmarshalJSON 兼容 paidTier 既可能是字符串也可能是对象的情况。
func (p *PaidTierInfo) UnmarshalJSON(data []byte) error {
data = bytes.TrimSpace(data)
if len(data) == 0 || string(data) == "null" {
return nil
}
if data[0] == '"' {
var id string
if err := json.Unmarshal(data, &id); err != nil {
return err
}
p.ID = id
return nil
}
type alias PaidTierInfo
var raw alias
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
*p = PaidTierInfo(raw)
return nil
}
// AvailableCredit 表示一条 AI Credits 余额记录。
type AvailableCredit struct {
CreditType string `json:"creditType,omitempty"`
CreditAmount string `json:"creditAmount,omitempty"`
MinimumCreditAmountForUsage string `json:"minimumCreditAmountForUsage,omitempty"`
}
// GetAmount 将 creditAmount 解析为浮点数。
func (c *AvailableCredit) GetAmount() float64 {
if c.CreditAmount == "" {
return 0
}
var value float64
_, _ = fmt.Sscanf(c.CreditAmount, "%f", &value)
return value
}
// GetMinimumAmount 将 minimumCreditAmountForUsage 解析为浮点数。
func (c *AvailableCredit) GetMinimumAmount() float64 {
if c.MinimumCreditAmountForUsage == "" {
return 0
}
var value float64
_, _ = fmt.Sscanf(c.MinimumCreditAmountForUsage, "%f", &value)
return value
}
// OnboardUserRequest onboardUser 请求
type OnboardUserRequest struct {
TierID string `json:"tierId"`
@@ -157,6 +215,14 @@ func (r *LoadCodeAssistResponse) GetTier() string {
return ""
}
// GetAvailableCredits 返回 paid tier 中的 AI Credits 余额列表。
func (r *LoadCodeAssistResponse) GetAvailableCredits() []AvailableCredit {
if r.PaidTier == nil {
return nil
}
return r.PaidTier.AvailableCredits
}
// Client Antigravity API 客户端
type Client struct {
httpClient *http.Client

View File

@@ -190,7 +190,7 @@ func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) {
func TestGetTier_PaidTier优先(t *testing.T) {
resp := &LoadCodeAssistResponse{
CurrentTier: &TierInfo{ID: "free-tier"},
PaidTier: &TierInfo{ID: "g1-pro-tier"},
PaidTier: &PaidTierInfo{ID: "g1-pro-tier"},
}
if got := resp.GetTier(); got != "g1-pro-tier" {
t.Errorf("应返回 paidTier: got %s", got)
@@ -209,7 +209,7 @@ func TestGetTier_回退到CurrentTier(t *testing.T) {
func TestGetTier_PaidTier为空ID(t *testing.T) {
resp := &LoadCodeAssistResponse{
CurrentTier: &TierInfo{ID: "free-tier"},
PaidTier: &TierInfo{ID: ""},
PaidTier: &PaidTierInfo{ID: ""},
}
// paidTier.ID 为空时应回退到 currentTier
if got := resp.GetTier(); got != "free-tier" {
@@ -217,6 +217,32 @@ func TestGetTier_PaidTier为空ID(t *testing.T) {
}
}
func TestGetAvailableCredits(t *testing.T) {
resp := &LoadCodeAssistResponse{
PaidTier: &PaidTierInfo{
ID: "g1-pro-tier",
AvailableCredits: []AvailableCredit{
{
CreditType: "GOOGLE_ONE_AI",
CreditAmount: "25",
MinimumCreditAmountForUsage: "5",
},
},
},
}
credits := resp.GetAvailableCredits()
if len(credits) != 1 {
t.Fatalf("AI Credits 数量不匹配: got %d", len(credits))
}
if credits[0].GetAmount() != 25 {
t.Errorf("CreditAmount 解析不正确: got %v", credits[0].GetAmount())
}
if credits[0].GetMinimumAmount() != 5 {
t.Errorf("MinimumCreditAmountForUsage 解析不正确: got %v", credits[0].GetMinimumAmount())
}
}
func TestGetTier_两者都为nil(t *testing.T) {
resp := &LoadCodeAssistResponse{}
if got := resp.GetTier(); got != "" {

View File

@@ -49,8 +49,8 @@ const (
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
)
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.4
var defaultUserAgentVersion = "1.20.4"
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
var defaultUserAgentVersion = "1.20.5"
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
var defaultClientSecret = "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"

View File

@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
if RedirectURI != "http://localhost:8085/callback" {
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
}
if GetUserAgent() != "antigravity/1.20.4 windows/amd64" {
if GetUserAgent() != "antigravity/1.20.5 windows/amd64" {
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
}
if SessionTTL != 30*time.Minute {

View File

@@ -47,6 +47,15 @@ func Created(c *gin.Context, data any) {
})
}
// Accepted 返回异步接受响应 (HTTP 202)
func Accepted(c *gin.Context, data any) {
c.JSON(http.StatusAccepted, Response{
Code: 0,
Message: "accepted",
Data: data,
})
}
// Error 返回错误响应
func Error(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, Response{

View File

@@ -3,6 +3,28 @@ package usagestats
import "time"
const (
ModelSourceRequested = "requested"
ModelSourceUpstream = "upstream"
ModelSourceMapping = "mapping"
)
func IsValidModelSource(source string) bool {
switch source {
case ModelSourceRequested, ModelSourceUpstream, ModelSourceMapping:
return true
default:
return false
}
}
func NormalizeModelSource(source string) string {
if IsValidModelSource(source) {
return source
}
return ModelSourceRequested
}
// DashboardStats 仪表盘统计
type DashboardStats struct {
// 用户统计
@@ -81,6 +103,22 @@ type ModelStat struct {
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// EndpointStat represents usage statistics for a single request endpoint.
type EndpointStat struct {
Endpoint string `json:"endpoint"`
Requests int64 `json:"requests"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// GroupUsageSummary represents today's and cumulative cost for a single group.
type GroupUsageSummary struct {
GroupID int64 `json:"group_id"`
TodayCost float64 `json:"today_cost"`
TotalCost float64 `json:"total_cost"`
}
// GroupStat represents usage statistics for a single group
type GroupStat struct {
GroupID int64 `json:"group_id"`
@@ -116,6 +154,27 @@ type UserSpendingRankingItem struct {
type UserSpendingRankingResponse struct {
Ranking []UserSpendingRankingItem `json:"ranking"`
TotalActualCost float64 `json:"total_actual_cost"`
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
}
// UserBreakdownItem represents per-user usage breakdown within a dimension (group, model, endpoint).
type UserBreakdownItem struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Requests int64 `json:"requests"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserBreakdownDimension specifies the dimension to filter for user breakdown.
type UserBreakdownDimension struct {
GroupID int64 // filter by group_id (>0 to enable)
Model string // filter by model name (non-empty to enable)
ModelType string // "requested", "upstream", or "mapping"
Endpoint string // filter by endpoint value (non-empty to enable)
EndpointType string // "inbound", "upstream", or "path"
}
// APIKeyUsageTrendPoint represents API key usage trend data point
@@ -179,15 +238,18 @@ type UsageLogFilters struct {
// UsageStats represents usage statistics
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
AverageDurationMs float64 `json:"average_duration_ms"`
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
AverageDurationMs float64 `json:"average_duration_ms"`
Endpoints []EndpointStat `json:"endpoints,omitempty"`
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints,omitempty"`
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
}
// BatchUserUsageStats represents usage stats for a single user
@@ -254,7 +316,9 @@ type AccountUsageSummary struct {
// AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse struct {
History []AccountUsageHistory `json:"history"`
Summary AccountUsageSummary `json:"summary"`
Models []ModelStat `json:"models"`
History []AccountUsageHistory `json:"history"`
Summary AccountUsageSummary `json:"summary"`
Models []ModelStat `json:"models"`
Endpoints []EndpointStat `json:"endpoints"`
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints"`
}

View File

@@ -0,0 +1,47 @@
package usagestats
import "testing"
func TestIsValidModelSource(t *testing.T) {
tests := []struct {
name string
source string
want bool
}{
{name: "requested", source: ModelSourceRequested, want: true},
{name: "upstream", source: ModelSourceUpstream, want: true},
{name: "mapping", source: ModelSourceMapping, want: true},
{name: "invalid", source: "foobar", want: false},
{name: "empty", source: "", want: false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := IsValidModelSource(tc.source); got != tc.want {
t.Fatalf("IsValidModelSource(%q)=%v want %v", tc.source, got, tc.want)
}
})
}
}
func TestNormalizeModelSource(t *testing.T) {
tests := []struct {
name string
source string
want string
}{
{name: "requested", source: ModelSourceRequested, want: ModelSourceRequested},
{name: "upstream", source: ModelSourceUpstream, want: ModelSourceUpstream},
{name: "mapping", source: ModelSourceMapping, want: ModelSourceMapping},
{name: "invalid falls back", source: "foobar", want: ModelSourceRequested},
{name: "empty falls back", source: "", want: ModelSourceRequested},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := NormalizeModelSource(tc.source); got != tc.want {
t.Fatalf("NormalizeModelSource(%q)=%q want %q", tc.source, got, tc.want)
}
})
}
}

View File

@@ -57,6 +57,7 @@ func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
// 读取全部内容以获取大小S3 PutObject 需要知道内容长度)
// 注意:阿里云 OSS 不兼容 s3manager 分片上传的签名方式,因此使用 PutObject
data, err := io.ReadAll(body)
if err != nil {
return 0, fmt.Errorf("read body: %w", err)

View File

@@ -20,6 +20,11 @@ const (
billingCacheTTL = 5 * time.Minute
billingCacheJitter = 30 * time.Second
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
// Rate limit window durations — must match service.RateLimitWindow* constants.
rateLimitWindow5h = 5 * time.Hour
rateLimitWindow1d = 24 * time.Hour
rateLimitWindow7d = 7 * 24 * time.Hour
)
// jitteredTTL 返回带随机抖动的 TTL防止缓存雪崩
@@ -90,17 +95,40 @@ var (
return 1
`)
// updateRateLimitUsageScript atomically increments all three rate limit usage counters.
// Returns 0 if the key doesn't exist (cache miss), 1 on success.
// updateRateLimitUsageScript atomically increments all three rate limit usage counters
// with window expiration checking. If a window has expired, its usage is reset to cost
// (instead of accumulated) and the window timestamp is updated, matching the DB-side
// IncrementRateLimitUsage semantics.
//
// ARGV: [1]=cost, [2]=ttl_seconds, [3]=now_unix, [4]=window_5h_seconds, [5]=window_1d_seconds, [6]=window_7d_seconds
updateRateLimitUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost)
local now = tonumber(ARGV[3])
local win5h = tonumber(ARGV[4])
local win1d = tonumber(ARGV[5])
local win7d = tonumber(ARGV[6])
-- Helper: check if window is expired and update usage + window accordingly
-- Returns nothing, modifies the hash in-place.
local function update_window(usage_field, window_field, window_duration)
local w = tonumber(redis.call('HGET', KEYS[1], window_field) or 0)
if w == 0 or (now - w) >= window_duration then
-- Window expired or never started: reset usage to cost, start new window
redis.call('HSET', KEYS[1], usage_field, tostring(cost))
redis.call('HSET', KEYS[1], window_field, tostring(now))
else
-- Window still valid: accumulate
redis.call('HINCRBYFLOAT', KEYS[1], usage_field, cost)
end
end
update_window('usage_5h', 'window_5h', win5h)
update_window('usage_1d', 'window_1d', win1d)
update_window('usage_7d', 'window_7d', win7d)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
@@ -280,7 +308,15 @@ func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
key := billingRateLimitKey(keyID)
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result()
now := time.Now().Unix()
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key},
cost,
int(rateLimitCacheTTL.Seconds()),
now,
int(rateLimitWindow5h.Seconds()),
int(rateLimitWindow1d.Seconds()),
int(rateLimitWindow7d.Seconds()),
).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
return err

View File

@@ -88,8 +88,9 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
if err != nil {
return nil, err
}
count, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count
total, active, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = total
out.ActiveAccountCount = active
return out, nil
}
@@ -256,7 +257,10 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil {
for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID]
c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
}
}
@@ -283,7 +287,10 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil {
for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID]
c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
}
}
@@ -310,7 +317,10 @@ func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform str
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil {
for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID]
c := counts[outGroups[i].ID]
outGroups[i].AccountCount = c.Total
outGroups[i].ActiveAccountCount = c.Active
outGroups[i].RateLimitedAccountCount = c.RateLimited
}
}
@@ -369,12 +379,20 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int
return result, nil
}
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
return 0, err
}
return count, nil
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) {
var rateLimited int64
err = scanSingleRow(ctx, r.sql,
`SELECT COUNT(*),
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true),
COUNT(*) FILTER (WHERE a.status = 'active' AND (
a.rate_limit_reset_at > NOW() OR
a.overload_until > NOW() OR
a.temp_unschedulable_until > NOW()
))
FROM account_groups ag JOIN accounts a ON a.id = ag.account_id
WHERE ag.group_id = $1`,
[]any{groupID}, &total, &active, &rateLimited)
return
}
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
@@ -500,15 +518,32 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return affectedUserIDs, nil
}
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) {
counts = make(map[int64]int64, len(groupIDs))
type groupAccountCounts struct {
Total int64
Active int64
RateLimited int64
}
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) {
counts = make(map[int64]groupAccountCounts, len(groupIDs))
if len(groupIDs) == 0 {
return counts, nil
}
rows, err := r.sql.QueryContext(
ctx,
"SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id",
`SELECT ag.group_id,
COUNT(*) AS total,
COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active,
COUNT(*) FILTER (WHERE a.status = 'active' AND (
a.rate_limit_reset_at > NOW() OR
a.overload_until > NOW() OR
a.temp_unschedulable_until > NOW()
)) AS rate_limited
FROM account_groups ag
JOIN accounts a ON a.id = ag.account_id
WHERE ag.group_id = ANY($1)
GROUP BY ag.group_id`,
pq.Array(groupIDs),
)
if err != nil {
@@ -523,11 +558,11 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
for rows.Next() {
var groupID int64
var count int64
if err = rows.Scan(&groupID, &count); err != nil {
var c groupAccountCounts
if err = rows.Scan(&groupID, &c.Total, &c.Active, &c.RateLimited); err != nil {
return nil, err
}
counts[groupID] = count
counts[groupID] = c
}
if err = rows.Err(); err != nil {
return nil, err

View File

@@ -603,7 +603,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
s.Require().NoError(err)
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(2), count)
}
@@ -619,7 +619,7 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
}
s.Require().NoError(s.repo.Create(s.ctx, group))
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
count, _, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Zero(count)
}
@@ -651,7 +651,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
count, err := s.repo.GetAccountCount(s.ctx, g.ID)
count, _, err := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(0), count, "expected 0 account groups")
}
@@ -692,7 +692,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
s.Require().NoError(err)
s.Require().Equal(int64(3), affected)
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
count, _, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count)
}

View File

@@ -132,7 +132,7 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
}
}
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) {
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
return err
}

View File

@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
var usageLogInsertArgTypes = [...]string{
"bigint",
@@ -36,6 +36,7 @@ var usageLogInsertArgTypes = [...]string{
"bigint",
"text",
"text",
"text",
"bigint",
"bigint",
"integer",
@@ -65,6 +66,8 @@ var usageLogInsertArgTypes = [...]string{
"text",
"text",
"text",
"text",
"text",
"boolean",
"timestamptz",
}
@@ -275,6 +278,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -304,15 +308,17 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
$1, $2, $3, $4, $5, $6,
$7, $8,
$9, $10, $11, $12,
$13, $14,
$15, $16, $17, $18, $19, $20,
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -703,6 +709,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -732,11 +739,13 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
) AS (VALUES `)
args := make([]any, 0, len(keys)*37)
args := make([]any, 0, len(keys)*39)
argPos := 1
for idx, key := range keys {
if idx > 0 {
@@ -770,6 +779,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -799,6 +809,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
)
@@ -808,6 +820,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -837,6 +850,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
FROM input
@@ -886,6 +901,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -915,11 +931,13 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*36)
args := make([]any, 0, len(preparedList)*39)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
@@ -950,6 +968,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -979,6 +998,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
)
@@ -988,6 +1009,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -1017,6 +1039,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
FROM input
@@ -1034,6 +1058,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
account_id,
request_id,
model,
upstream_model,
group_id,
subscription_id,
input_tokens,
@@ -1063,15 +1088,17 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
$1, $2, $3, $4, $5, $6,
$7, $8,
$9, $10, $11, $12,
$13, $14,
$15, $16, $17, $18, $19, $20,
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
@@ -1101,6 +1128,9 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
mediaType := nullString(log.MediaType)
serviceTier := nullString(log.ServiceTier)
reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint)
upstreamModel := nullString(log.UpstreamModel)
var requestIDArg any
if requestID != "" {
@@ -1118,6 +1148,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.AccountID,
requestIDArg,
log.Model,
upstreamModel,
groupID,
subscriptionID,
log.InputTokens,
@@ -1147,6 +1178,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
mediaType,
serviceTier,
reasoningEffort,
inboundEndpoint,
upstreamEndpoint,
log.CacheTTLOverridden,
createdAt,
},
@@ -2139,7 +2172,9 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
actual_cost,
requests,
tokens,
COALESCE(SUM(actual_cost) OVER (), 0) as total_actual_cost
COALESCE(SUM(actual_cost) OVER (), 0) as total_actual_cost,
COALESCE(SUM(requests) OVER (), 0) as total_requests,
COALESCE(SUM(tokens) OVER (), 0) as total_tokens
FROM user_spend
ORDER BY actual_cost DESC, tokens DESC, user_id ASC
LIMIT $3
@@ -2150,7 +2185,9 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
actual_cost,
requests,
tokens,
total_actual_cost
total_actual_cost,
total_requests,
total_tokens
FROM ranked
ORDER BY actual_cost DESC, tokens DESC, user_id ASC
`
@@ -2168,9 +2205,11 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
ranking := make([]UserSpendingRankingItem, 0)
totalActualCost := 0.0
totalRequests := int64(0)
totalTokens := int64(0)
for rows.Next() {
var row UserSpendingRankingItem
if err = rows.Scan(&row.UserID, &row.Email, &row.ActualCost, &row.Requests, &row.Tokens, &totalActualCost); err != nil {
if err = rows.Scan(&row.UserID, &row.Email, &row.ActualCost, &row.Requests, &row.Tokens, &totalActualCost, &totalRequests, &totalTokens); err != nil {
return nil, err
}
ranking = append(ranking, row)
@@ -2182,6 +2221,8 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
return &UserSpendingRankingResponse{
Ranking: ranking,
TotalActualCost: totalActualCost,
TotalRequests: totalRequests,
TotalTokens: totalTokens,
}, nil
}
@@ -2505,7 +2546,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
args = append(args, *filters.StartTime)
}
if filters.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1))
args = append(args, *filters.EndTime)
}
@@ -2834,15 +2875,26 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st
// GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, usagestats.ModelSourceRequested)
}
// GetModelStatsWithFiltersBySource returns model statistics with optional filters and model source dimension.
// source: requested | upstream | mapping.
func (r *usageLogRepository) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) {
return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, source)
}
func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
modelExpr := resolveModelDimensionExpression(source)
query := fmt.Sprintf(`
SELECT
model,
%s as model,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
@@ -2853,7 +2905,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`, actualCostExpr)
`, modelExpr, actualCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
@@ -2877,7 +2929,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY model ORDER BY total_tokens DESC"
query += fmt.Sprintf(" GROUP BY %s ORDER BY total_tokens DESC", modelExpr)
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
@@ -2970,6 +3022,132 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
return results, nil
}
// GetUserBreakdownStats returns per-user usage breakdown within a specific dimension.
func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) (results []usagestats.UserBreakdownItem, err error) {
query := `
SELECT
COALESCE(ul.user_id, 0) as user_id,
COALESCE(u.email, '') as email,
COUNT(*) as requests,
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(ul.total_cost), 0) as cost,
COALESCE(SUM(ul.actual_cost), 0) as actual_cost
FROM usage_logs ul
LEFT JOIN users u ON u.id = ul.user_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
`
args := []any{startTime, endTime}
if dim.GroupID > 0 {
query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1)
args = append(args, dim.GroupID)
}
if dim.Model != "" {
query += fmt.Sprintf(" AND %s = $%d", resolveModelDimensionExpression(dim.ModelType), len(args)+1)
args = append(args, dim.Model)
}
if dim.Endpoint != "" {
col := resolveEndpointColumn(dim.EndpointType)
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
args = append(args, dim.Endpoint)
}
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
if limit > 0 {
query += fmt.Sprintf(" LIMIT %d", limit)
}
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results = make([]usagestats.UserBreakdownItem, 0)
for rows.Next() {
var row usagestats.UserBreakdownItem
if err := rows.Scan(
&row.UserID,
&row.Email,
&row.Requests,
&row.TotalTokens,
&row.Cost,
&row.ActualCost,
); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
// GetAllGroupUsageSummary returns today's and cumulative actual_cost for every group.
// todayStart is the start-of-day in the caller's timezone (UTC-based).
// TODO(perf): This query scans ALL usage_logs rows for total_cost aggregation.
// When usage_logs exceeds ~1M rows, consider adding a short-lived cache (30s)
// or a materialized view / pre-aggregation table for cumulative costs.
func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
query := `
SELECT
g.id AS group_id,
COALESCE(SUM(ul.actual_cost), 0) AS total_cost,
COALESCE(SUM(CASE WHEN ul.created_at >= $1 THEN ul.actual_cost ELSE 0 END), 0) AS today_cost
FROM groups g
LEFT JOIN usage_logs ul ON ul.group_id = g.id
GROUP BY g.id
`
rows, err := r.sql.QueryContext(ctx, query, todayStart)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var results []usagestats.GroupUsageSummary
for rows.Next() {
var row usagestats.GroupUsageSummary
if err := rows.Scan(&row.GroupID, &row.TotalCost, &row.TodayCost); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
// resolveModelDimensionExpression maps model source type to a safe SQL expression.
func resolveModelDimensionExpression(modelType string) string {
switch usagestats.NormalizeModelSource(modelType) {
case usagestats.ModelSourceUpstream:
return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"
case usagestats.ModelSourceMapping:
return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"
default:
return "model"
}
}
// resolveEndpointColumn maps endpoint type to the corresponding DB column name.
func resolveEndpointColumn(endpointType string) string {
switch endpointType {
case "upstream":
return "ul.upstream_endpoint"
case "path":
return "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"
default:
return "ul.inbound_endpoint"
}
}
// GetGlobalStats gets usage statistics for all users within a time range
func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
query := `
@@ -2982,7 +3160,7 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
WHERE created_at >= $1 AND created_at <= $2
WHERE created_at >= $1 AND created_at < $2
`
stats := &UsageStats{}
@@ -3040,7 +3218,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
args = append(args, *filters.StartTime)
}
if filters.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1))
args = append(args, *filters.EndTime)
}
@@ -3080,6 +3258,35 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
stats.TotalAccountCost = &totalAccountCost
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
start := time.Unix(0, 0).UTC()
if filters.StartTime != nil {
start = *filters.StartTime
}
end := time.Now().UTC()
if filters.EndTime != nil {
end = *filters.EndTime
}
endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
if endpointErr != nil {
logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetStatsWithFilters: %v", endpointErr)
endpoints = []EndpointStat{}
}
upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
if upstreamEndpointErr != nil {
logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetStatsWithFilters: %v", upstreamEndpointErr)
upstreamEndpoints = []EndpointStat{}
}
endpointPaths, endpointPathErr := r.getEndpointPathStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
if endpointPathErr != nil {
logger.LegacyPrintf("repository.usage_log", "getEndpointPathStatsWithFilters failed in GetStatsWithFilters: %v", endpointPathErr)
endpointPaths = []EndpointStat{}
}
stats.Endpoints = endpoints
stats.UpstreamEndpoints = upstreamEndpoints
stats.EndpointPaths = endpointPaths
return stats, nil
}
@@ -3092,6 +3299,163 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
// AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
// EndpointStat represents endpoint usage statistics row.
type EndpointStat = usagestats.EndpointStat
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
SELECT
COALESCE(NULLIF(TRIM(%s), ''), 'unknown') AS endpoint,
COUNT(*) AS requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`, endpointColumn, actualCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
args = append(args, userID)
}
if apiKeyID > 0 {
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
args = append(args, apiKeyID)
}
if accountID > 0 {
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID)
}
if groupID > 0 {
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if model != "" {
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY endpoint ORDER BY requests DESC"
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results = make([]EndpointStat, 0)
for rows.Next() {
var row EndpointStat
if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
SELECT
CONCAT(
COALESCE(NULLIF(TRIM(inbound_endpoint), ''), 'unknown'),
' -> ',
COALESCE(NULLIF(TRIM(upstream_endpoint), ''), 'unknown')
) AS endpoint,
COUNT(*) AS requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`, actualCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
args = append(args, userID)
}
if apiKeyID > 0 {
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
args = append(args, apiKeyID)
}
if accountID > 0 {
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID)
}
if groupID > 0 {
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if model != "" {
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY endpoint ORDER BY requests DESC"
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results = make([]EndpointStat, 0)
for rows.Next() {
var row EndpointStat
if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
// GetEndpointStatsWithFilters returns inbound endpoint statistics with optional filters.
func (r *usageLogRepository) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) {
return r.getEndpointStatsByColumnWithFilters(ctx, "inbound_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
}
// GetUpstreamEndpointStatsWithFilters returns upstream endpoint statistics with optional filters.
func (r *usageLogRepository) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) {
return r.getEndpointStatsByColumnWithFilters(ctx, "upstream_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
}
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) {
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
@@ -3254,11 +3618,23 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
if err != nil {
models = []ModelStat{}
}
endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil)
if endpointErr != nil {
logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetAccountUsageStats: %v", endpointErr)
endpoints = []EndpointStat{}
}
upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil)
if upstreamEndpointErr != nil {
logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetAccountUsageStats: %v", upstreamEndpointErr)
upstreamEndpoints = []EndpointStat{}
}
resp = &AccountUsageStatsResponse{
History: history,
Summary: summary,
Models: models,
History: history,
Summary: summary,
Models: models,
Endpoints: endpoints,
UpstreamEndpoints: upstreamEndpoints,
}
return resp, nil
}
@@ -3512,6 +3888,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
accountID int64
requestID sql.NullString
model string
upstreamModel sql.NullString
groupID sql.NullInt64
subscriptionID sql.NullInt64
inputTokens int
@@ -3541,6 +3918,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
mediaType sql.NullString
serviceTier sql.NullString
reasoningEffort sql.NullString
inboundEndpoint sql.NullString
upstreamEndpoint sql.NullString
cacheTTLOverridden bool
createdAt time.Time
)
@@ -3552,6 +3931,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&accountID,
&requestID,
&model,
&upstreamModel,
&groupID,
&subscriptionID,
&inputTokens,
@@ -3581,6 +3961,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&mediaType,
&serviceTier,
&reasoningEffort,
&inboundEndpoint,
&upstreamEndpoint,
&cacheTTLOverridden,
&createdAt,
); err != nil {
@@ -3656,6 +4038,15 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if reasoningEffort.Valid {
log.ReasoningEffort = &reasoningEffort.String
}
if inboundEndpoint.Valid {
log.InboundEndpoint = &inboundEndpoint.String
}
if upstreamEndpoint.Valid {
log.UpstreamEndpoint = &upstreamEndpoint.String
}
if upstreamModel.Valid {
log.UpstreamModel = &upstreamModel.String
}
return log, nil
}

View File

@@ -0,0 +1,50 @@
//go:build unit
package repository
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/stretchr/testify/require"
)
func TestResolveEndpointColumn(t *testing.T) {
tests := []struct {
endpointType string
want string
}{
{"inbound", "ul.inbound_endpoint"},
{"upstream", "ul.upstream_endpoint"},
{"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"},
{"", "ul.inbound_endpoint"}, // default
{"unknown", "ul.inbound_endpoint"}, // fallback
}
for _, tc := range tests {
t.Run(tc.endpointType, func(t *testing.T) {
got := resolveEndpointColumn(tc.endpointType)
require.Equal(t, tc.want, got)
})
}
}
func TestResolveModelDimensionExpression(t *testing.T) {
tests := []struct {
modelType string
want string
}{
{usagestats.ModelSourceRequested, "model"},
{usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"},
{usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"},
{"", "model"},
{"invalid", "model"},
}
for _, tc := range tests {
t.Run(tc.modelType, func(t *testing.T) {
got := resolveModelDimensionExpression(tc.modelType)
require.Equal(t, tc.want, got)
})
}
}

View File

@@ -44,6 +44,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.AccountID,
log.RequestID,
log.Model,
sqlmock.AnyArg(), // upstream_model
sqlmock.AnyArg(), // group_id
sqlmock.AnyArg(), // subscription_id
log.InputTokens,
@@ -73,6 +74,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // media_type
sqlmock.AnyArg(), // service_tier
sqlmock.AnyArg(), // reasoning_effort
sqlmock.AnyArg(), // inbound_endpoint
sqlmock.AnyArg(), // upstream_endpoint
log.CacheTTLOverridden,
createdAt,
).
@@ -114,6 +117,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.Model,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.InputTokens,
log.OutputTokens,
log.CacheCreationTokens,
@@ -141,6 +145,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(),
serviceTier,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.CacheTTLOverridden,
createdAt,
).
@@ -255,10 +261,10 @@ func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}).
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0).
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0).
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0)
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost", "total_requests", "total_tokens"}).
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0, int64(30), int64(2600)).
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0, int64(30), int64(2600)).
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0, int64(30), int64(2600))
mock.ExpectQuery("WITH user_spend AS \\(").
WithArgs(start, end, 12).
@@ -273,6 +279,8 @@ func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
{UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300},
},
TotalActualCost: 40.0,
TotalRequests: 30,
TotalTokens: 2600,
}, got)
require.NoError(t, mock.ExpectationsWereMet())
}
@@ -347,6 +355,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(30), // account_id
sql.NullString{Valid: true, String: "req-1"},
"gpt-5", // model
sql.NullString{}, // upstream_model
sql.NullInt64{}, // group_id
sql.NullInt64{}, // subscription_id
1, // input_tokens
@@ -376,6 +385,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
now,
}})
@@ -396,6 +407,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(31),
sql.NullString{Valid: true, String: "req-2"},
"gpt-5",
sql.NullString{},
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
@@ -415,6 +427,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{Valid: true, String: "flex"},
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
now,
}})
@@ -435,6 +449,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
int64(32),
sql.NullString{Valid: true, String: "req-3"},
"gpt-5.4",
sql.NullString{},
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
@@ -454,6 +469,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
now,
}})

View File

@@ -5,6 +5,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -190,7 +191,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
}
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query()
if userID != nil {
@@ -199,6 +200,9 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if groupID != nil {
q = q.Where(usersubscription.GroupIDEQ(*groupID))
}
if platform != "" {
q = q.Where(usersubscription.HasGroupWith(group.PlatformEQ(platform)))
}
// Status filtering with real-time expiration check
now := time.Now()

View File

@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group := s.mustCreateGroup("g-list")
s.mustCreateSubscription(user.ID, group.ID, nil)
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "", "")
s.Require().NoError(err, "List")
s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total)
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID)
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID)
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)

View File

@@ -493,6 +493,7 @@ func TestAPIContracts(t *testing.T) {
"registration_email_suffix_whitelist": [],
"promo_code_enabled": true,
"password_reset_enabled": false,
"frontend_url": "",
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com",
@@ -923,8 +924,8 @@ func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error
return false, errors.New("not implemented")
}
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
return 0, 0, errors.New("not implemented")
}
func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
@@ -1288,7 +1289,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
@@ -1624,10 +1625,22 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, errors.New("not implemented")
}
@@ -1773,6 +1786,9 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i
func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
return nil, errors.New("not implemented")
}
type stubSettingRepo struct {
all map[string]string

View File

@@ -135,7 +135,7 @@ func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, user
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {

View File

@@ -646,7 +646,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return nil, nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}

View File

@@ -198,6 +198,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
dashboard.GET("/user-breakdown", h.Admin.Dashboard.GetUserBreakdown)
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
}
}
@@ -226,6 +227,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/usage-summary", h.Admin.Group.GetUsageSummary)
groups.GET("/capacity-summary", h.Admin.Group.GetCapacitySummary)
groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
@@ -399,6 +402,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
// 529过载冷却配置
adminSettings.GET("/overload-cooldown", h.Admin.Setting.GetOverloadCooldownSettings)
adminSettings.PUT("/overload-cooldown", h.Admin.Setting.UpdateOverloadCooldownSettings)
// 流超时处理配置
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)

View File

@@ -30,6 +30,7 @@ func RegisterGatewayRoutes(
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
endpointNorm := handler.InboundEndpointMiddleware()
// 未分组 Key 拦截中间件(按协议格式区分错误响应)
requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter)
@@ -40,6 +41,7 @@ func RegisterGatewayRoutes(
gateway.Use(bodyLimit)
gateway.Use(clientRequestID)
gateway.Use(opsErrorLogger)
gateway.Use(endpointNorm)
gateway.Use(gin.HandlerFunc(apiKeyAuth))
gateway.Use(requireGroupAnthropic)
{
@@ -80,6 +82,7 @@ func RegisterGatewayRoutes(
gemini.Use(bodyLimit)
gemini.Use(clientRequestID)
gemini.Use(opsErrorLogger)
gemini.Use(endpointNorm)
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
gemini.Use(requireGroupGoogle)
{
@@ -90,11 +93,11 @@ func RegisterGatewayRoutes(
}
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
// OpenAI Chat Completions API不带v1前缀的别名
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
@@ -104,6 +107,7 @@ func RegisterGatewayRoutes(
antigravityV1.Use(bodyLimit)
antigravityV1.Use(clientRequestID)
antigravityV1.Use(opsErrorLogger)
antigravityV1.Use(endpointNorm)
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
antigravityV1.Use(requireGroupAnthropic)
@@ -118,6 +122,7 @@ func RegisterGatewayRoutes(
antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(clientRequestID)
antigravityV1Beta.Use(opsErrorLogger)
antigravityV1Beta.Use(endpointNorm)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
antigravityV1Beta.Use(requireGroupGoogle)
@@ -132,6 +137,7 @@ func RegisterGatewayRoutes(
soraV1.Use(soraBodyLimit)
soraV1.Use(clientRequestID)
soraV1.Use(opsErrorLogger)
soraV1.Use(endpointNorm)
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
soraV1.Use(requireGroupAnthropic)

View File

@@ -901,6 +901,22 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
return false
}
// IsOveragesEnabled 检查 Antigravity 账号是否启用 AI Credits 超量请求。
func (a *Account) IsOveragesEnabled() bool {
if a.Platform != PlatformAntigravity {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["allow_overages"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
//
// 新字段accounts.extra.openai_passthrough。

View File

@@ -113,15 +113,18 @@ func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error)
return normalized, nil
}
// generateSessionString generates a Claude Code style session string
// generateSessionString generates a Claude Code style session string.
// The output format is determined by the UA version in claude.DefaultHeaders,
// ensuring consistency between the user_id format and the UA sent to upstream.
func generateSessionString() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
hex64 := hex.EncodeToString(bytes)
hex64 := hex.EncodeToString(b)
sessionUUID := uuid.New().String()
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
uaVersion := ExtractCLIVersion(claude.DefaultHeaders["User-Agent"])
return FormatMetadataUserID(hex64, "", sessionUUID, uaVersion), nil
}
// createTestPayload creates a Claude Code style test request payload

View File

@@ -45,7 +45,11 @@ type UsageLogRepository interface {
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error)
GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
@@ -164,6 +168,13 @@ type AntigravityModelDetail struct {
SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"`
}
// AICredit 表示 Antigravity 账号的 AI Credits 余额信息。
type AICredit struct {
CreditType string `json:"credit_type,omitempty"`
Amount float64 `json:"amount,omitempty"`
MinimumBalance float64 `json:"minimum_balance,omitempty"`
}
// UsageInfo 账号使用量信息
type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
@@ -187,6 +198,9 @@ type UsageInfo struct {
// Antigravity 模型详细能力信息(与 antigravity_quota 同 key
AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"`
// Antigravity AI Credits 余额
AICredits []AICredit `json:"ai_credits,omitempty"`
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
@@ -434,23 +448,17 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
}
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-5*time.Hour)); err == nil {
windowStats := windowStatsFromAccountStats(stats)
if hasMeaningfulWindowStats(windowStats) {
if usage.FiveHour == nil {
usage.FiveHour = &UsageProgress{Utilization: 0}
}
usage.FiveHour.WindowStats = windowStats
if usage.FiveHour == nil {
usage.FiveHour = &UsageProgress{Utilization: 0}
}
usage.FiveHour.WindowStats = windowStatsFromAccountStats(stats)
}
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-7*24*time.Hour)); err == nil {
windowStats := windowStatsFromAccountStats(stats)
if hasMeaningfulWindowStats(windowStats) {
if usage.SevenDay == nil {
usage.SevenDay = &UsageProgress{Utilization: 0}
}
usage.SevenDay.WindowStats = windowStats
if usage.SevenDay == nil {
usage.SevenDay = &UsageProgress{Utilization: 0}
}
usage.SevenDay.WindowStats = windowStatsFromAccountStats(stats)
}
return usage, nil
@@ -980,13 +988,6 @@ func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
}
}
func hasMeaningfulWindowStats(stats *WindowStats) bool {
if stats == nil {
return false
}
return stats.Requests > 0 || stats.Tokens > 0 || stats.Cost > 0 || stats.StandardCost > 0 || stats.UserCost > 0
}
func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now time.Time) *UsageProgress {
if len(extra) == 0 {
return nil
@@ -1043,6 +1044,11 @@ func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now t
}
}
// 窗口已过期resetAt 在 now 之前)→ 额度已重置,归零
if progress.ResetsAt != nil && !now.Before(*progress.ResetsAt) {
progress.Utilization = 0
}
return progress
}

View File

@@ -148,3 +148,54 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
t.Fatal("waiting for codex probe rate limit persistence timed out")
}
}
func TestBuildCodexUsageProgressFromExtra_ZerosExpiredWindow(t *testing.T) {
t.Parallel()
now := time.Date(2026, 3, 16, 12, 0, 0, 0, time.UTC)
t.Run("expired 5h window zeroes utilization", func(t *testing.T) {
extra := map[string]any{
"codex_5h_used_percent": 42.0,
"codex_5h_reset_at": "2026-03-16T10:00:00Z", // 2h ago
}
progress := buildCodexUsageProgressFromExtra(extra, "5h", now)
if progress == nil {
t.Fatal("expected non-nil progress")
}
if progress.Utilization != 0 {
t.Fatalf("expected Utilization=0 for expired window, got %v", progress.Utilization)
}
if progress.RemainingSeconds != 0 {
t.Fatalf("expected RemainingSeconds=0, got %v", progress.RemainingSeconds)
}
})
t.Run("active 5h window keeps utilization", func(t *testing.T) {
resetAt := now.Add(2 * time.Hour).Format(time.RFC3339)
extra := map[string]any{
"codex_5h_used_percent": 42.0,
"codex_5h_reset_at": resetAt,
}
progress := buildCodexUsageProgressFromExtra(extra, "5h", now)
if progress == nil {
t.Fatal("expected non-nil progress")
}
if progress.Utilization != 42.0 {
t.Fatalf("expected Utilization=42, got %v", progress.Utilization)
}
})
t.Run("expired 7d window zeroes utilization", func(t *testing.T) {
extra := map[string]any{
"codex_7d_used_percent": 88.0,
"codex_7d_reset_at": "2026-03-15T00:00:00Z", // yesterday
}
progress := buildCodexUsageProgressFromExtra(extra, "7d", now)
if progress == nil {
t.Fatal("expected non-nil progress")
}
if progress.Utilization != 0 {
t.Fatalf("expected Utilization=0 for expired 7d window, got %v", progress.Utilization)
}
})
}

View File

@@ -368,6 +368,10 @@ type ProxyExitInfoProber interface {
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
}
type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
type proxyQualityTarget struct {
Target string
URL string
@@ -445,10 +449,6 @@ type userGroupRateBatchReader interface {
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
}
type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAdminService creates a new AdminService
func NewAdminService(
userRepo UserRepository,
@@ -832,7 +832,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
subscriptionType = SubscriptionTypeStandard
}
// 限额字段:0 和 nil 都表示"无限制"
// 限额字段:nil/负数 表示"无限制"0 表示"不允许用量",正数表示具体限额
dailyLimit := normalizeLimit(input.DailyLimitUSD)
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
@@ -944,9 +944,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return group, nil
}
// normalizeLimit 将 0 或负数转换为 nil表示无限制
// normalizeLimit 将负数转换为 nil表示无限制0 保留(表示限额为零)
func normalizeLimit(limit *float64) *float64 {
if limit == nil || *limit <= 0 {
if limit == nil || *limit < 0 {
return nil
}
return limit
@@ -1058,16 +1058,11 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.SubscriptionType != "" {
group.SubscriptionType = input.SubscriptionType
}
// 限额字段:0 和 nil 都表示"无限制",正数表示具体限额
if input.DailyLimitUSD != nil {
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
}
if input.WeeklyLimitUSD != nil {
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
}
if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
}
// 限额字段:nil/负数 表示"无限制"0 表示"不允许用量",正数表示具体限额
// 前端始终发送这三个字段,无需 nil 守卫
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
// 图片生成计费配置:负数表示清除(使用默认价格)
if input.ImagePrice1K != nil {
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
@@ -1521,6 +1516,7 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if err != nil {
return nil, err
}
wasOveragesEnabled := account.IsOveragesEnabled()
if input.Name != "" {
account.Name = input.Name
@@ -1534,7 +1530,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if len(input.Credentials) > 0 {
account.Credentials = input.Credentials
}
if len(input.Extra) > 0 {
// Extra 使用 map需要区分“未提供(nil)”与“显式清空({})”。
// 关闭配额限制时前端会删除 quota_* 键并提交 extra:{},此时也必须落库。
if input.Extra != nil {
// 保留配额用量字段,防止编辑账号时意外重置
for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} {
if v, ok := account.Extra[key]; ok {
@@ -1542,6 +1540,17 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
}
}
account.Extra = input.Extra
if account.Platform == PlatformAntigravity && wasOveragesEnabled && !account.IsOveragesEnabled() {
delete(account.Extra, "antigravity_credits_overages") // 清理旧版 overages 运行态
// 清除 AICredits 限流 key
if rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any); ok {
delete(rawLimits, creditsExhaustedKey)
}
}
if account.Platform == PlatformAntigravity && !wasOveragesEnabled && account.IsOveragesEnabled() {
delete(account.Extra, modelRateLimitsKey)
delete(account.Extra, "antigravity_credits_overages") // 清理旧版 overages 运行态
}
// 校验并预计算固定时间重置的下次重置时间
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
return nil, err

View File

@@ -194,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {

View File

@@ -160,7 +160,7 @@ func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, er
panic("unexpected ExistsByName call")
}
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}

View File

@@ -100,7 +100,7 @@ func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool,
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) {
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}
@@ -383,7 +383,7 @@ func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}
@@ -458,7 +458,7 @@ func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context,
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) {
panic("unexpected GetAccountCount call")
}

View File

@@ -0,0 +1,155 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type updateAccountOveragesRepoStub struct {
mockAccountRepoForGemini
account *Account
updateCalls int
}
func (r *updateAccountOveragesRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
return r.account, nil
}
func (r *updateAccountOveragesRepoStub) Update(ctx context.Context, account *Account) error {
r.updateCalls++
r.account = account
return nil
}
func TestUpdateAccount_DisableOveragesClearsAICreditsKey(t *testing.T) {
accountID := int64(101)
repo := &updateAccountOveragesRepoStub{
account: &Account{
ID: accountID,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Extra: map[string]any{
"allow_overages": true,
"mixed_scheduling": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
},
creditsExhaustedKey: map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
},
}
svc := &adminServiceImpl{accountRepo: repo}
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
Extra: map[string]any{
"mixed_scheduling": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
},
creditsExhaustedKey: map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
})
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, 1, repo.updateCalls)
require.False(t, updated.IsOveragesEnabled())
// 关闭 overages 后AICredits key 应被清除
rawLimits, ok := repo.account.Extra[modelRateLimitsKey].(map[string]any)
if ok {
_, exists := rawLimits[creditsExhaustedKey]
require.False(t, exists, "关闭 overages 时应清除 AICredits 限流 key")
}
// 普通模型限流应保留
require.True(t, ok)
_, exists := rawLimits["claude-sonnet-4-5"]
require.True(t, exists, "普通模型限流应保留")
}
func TestUpdateAccount_EnableOveragesClearsModelRateLimitsBeforePersist(t *testing.T) {
accountID := int64(102)
repo := &updateAccountOveragesRepoStub{
account: &Account{
ID: accountID,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Extra: map[string]any{
"mixed_scheduling": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
},
},
},
},
}
svc := &adminServiceImpl{accountRepo: repo}
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
Extra: map[string]any{
"mixed_scheduling": true,
"allow_overages": true,
},
})
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, 1, repo.updateCalls)
require.True(t, updated.IsOveragesEnabled())
_, exists := repo.account.Extra[modelRateLimitsKey]
require.False(t, exists, "开启 overages 时应在持久化前清掉旧模型限流")
}
func TestUpdateAccount_EmptyExtraPayloadCanClearQuotaLimits(t *testing.T) {
accountID := int64(103)
repo := &updateAccountOveragesRepoStub{
account: &Account{
ID: accountID,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Status: StatusActive,
Extra: map[string]any{
"quota_limit": 100.0,
"quota_daily_limit": 10.0,
"quota_weekly_limit": 40.0,
},
},
}
svc := &adminServiceImpl{accountRepo: repo}
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
// 显式空对象:语义是“清空 extra 中的可配置键”(例如关闭配额限制)
Extra: map[string]any{},
})
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, 1, repo.updateCalls)
require.NotNil(t, repo.account.Extra)
require.NotContains(t, repo.account.Extra, "quota_limit")
require.NotContains(t, repo.account.Extra, "quota_daily_limit")
require.NotContains(t, repo.account.Extra, "quota_weekly_limit")
require.Len(t, repo.account.Extra, 0)
}

View File

@@ -0,0 +1,234 @@
package service
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
const (
// creditsExhaustedKey 是 model_rate_limits 中标记积分耗尽的特殊 key。
// 与普通模型限流完全同构:通过 SetModelRateLimit / isRateLimitActiveForKey 读写。
creditsExhaustedKey = "AICredits"
creditsExhaustedDuration = 5 * time.Hour
)
type antigravity429Category string
const (
antigravity429Unknown antigravity429Category = "unknown"
antigravity429RateLimited antigravity429Category = "rate_limited"
antigravity429QuotaExhausted antigravity429Category = "quota_exhausted"
)
var (
antigravityQuotaExhaustedKeywords = []string{
"quota_exhausted",
"quota exhausted",
}
creditsExhaustedKeywords = []string{
"google_one_ai",
"insufficient credit",
"insufficient credits",
"not enough credit",
"not enough credits",
"credit exhausted",
"credits exhausted",
"credit balance",
"minimumcreditamountforusage",
"minimum credit amount for usage",
"minimum credit",
}
)
// isCreditsExhausted 检查账号的 AICredits 限流 key 是否生效(积分是否耗尽)。
func (a *Account) isCreditsExhausted() bool {
if a == nil {
return false
}
return a.isRateLimitActiveForKey(creditsExhaustedKey)
}
// setCreditsExhausted 标记账号积分耗尽:写入 model_rate_limits["AICredits"] + 更新缓存。
func (s *AntigravityGatewayService) setCreditsExhausted(ctx context.Context, account *Account) {
if account == nil || account.ID == 0 {
return
}
resetAt := time.Now().Add(creditsExhaustedDuration)
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, creditsExhaustedKey, resetAt); err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "set credits exhausted failed: account=%d err=%v", account.ID, err)
return
}
s.updateAccountModelRateLimitInCache(ctx, account, creditsExhaustedKey, resetAt)
logger.LegacyPrintf("service.antigravity_gateway", "credits_exhausted_marked account=%d reset_at=%s",
account.ID, resetAt.UTC().Format(time.RFC3339))
}
// clearCreditsExhausted 清除账号的 AICredits 限流 key。
func (s *AntigravityGatewayService) clearCreditsExhausted(ctx context.Context, account *Account) {
if account == nil || account.ID == 0 || account.Extra == nil {
return
}
rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any)
if !ok {
return
}
if _, exists := rawLimits[creditsExhaustedKey]; !exists {
return
}
delete(rawLimits, creditsExhaustedKey)
account.Extra[modelRateLimitsKey] = rawLimits
if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{
modelRateLimitsKey: rawLimits,
}); err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "clear credits exhausted failed: account=%d err=%v", account.ID, err)
}
}
// classifyAntigravity429 将 Antigravity 的 429 响应归类为配额耗尽、限流或未知。
func classifyAntigravity429(body []byte) antigravity429Category {
if len(body) == 0 {
return antigravity429Unknown
}
lowerBody := strings.ToLower(string(body))
for _, keyword := range antigravityQuotaExhaustedKeywords {
if strings.Contains(lowerBody, keyword) {
return antigravity429QuotaExhausted
}
}
if info := parseAntigravitySmartRetryInfo(body); info != nil && !info.IsModelCapacityExhausted {
return antigravity429RateLimited
}
return antigravity429Unknown
}
// injectEnabledCreditTypes 在已序列化的 v1internal JSON body 中注入 AI Credits 类型。
func injectEnabledCreditTypes(body []byte) []byte {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil
}
payload["enabledCreditTypes"] = []string{"GOOGLE_ONE_AI"}
result, err := json.Marshal(payload)
if err != nil {
return nil
}
return result
}
// resolveCreditsOveragesModelKey 解析当前请求对应的 overages 状态模型 key。
func resolveCreditsOveragesModelKey(ctx context.Context, account *Account, upstreamModelName, requestedModel string) string {
modelKey := strings.TrimSpace(upstreamModelName)
if modelKey != "" {
return modelKey
}
if account == nil {
return ""
}
modelKey = resolveFinalAntigravityModelKey(ctx, account, requestedModel)
if strings.TrimSpace(modelKey) != "" {
return modelKey
}
return resolveAntigravityModelKey(requestedModel)
}
// shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。
func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool {
if reqErr != nil || resp == nil {
return false
}
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout {
return false
}
if isURLLevelRateLimit(respBody) {
return false
}
if info := parseAntigravitySmartRetryInfo(respBody); info != nil {
return false
}
bodyLower := strings.ToLower(string(respBody))
for _, keyword := range creditsExhaustedKeywords {
if strings.Contains(bodyLower, keyword) {
return true
}
}
return false
}
type creditsOveragesRetryResult struct {
handled bool
resp *http.Response
}
// attemptCreditsOveragesRetry 在确认免费配额耗尽后,尝试注入 AI Credits 继续请求。
func (s *AntigravityGatewayService) attemptCreditsOveragesRetry(
p antigravityRetryLoopParams,
baseURL string,
modelName string,
waitDuration time.Duration,
originalStatusCode int,
respBody []byte,
) *creditsOveragesRetryResult {
creditsBody := injectEnabledCreditTypes(p.body)
if creditsBody == nil {
return &creditsOveragesRetryResult{handled: false}
}
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 credit_overages_retry model=%s account=%d (injecting enabledCreditTypes)",
p.prefix, modelKey, p.account.ID)
creditsReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, creditsBody)
if err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d build_request_err=%v",
p.prefix, modelKey, p.account.ID, err)
return &creditsOveragesRetryResult{handled: true}
}
creditsResp, err := p.httpUpstream.Do(creditsReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if err == nil && creditsResp != nil && creditsResp.StatusCode < 400 {
s.clearCreditsExhausted(p.ctx, p.account)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d credit_overages_success model=%s account=%d",
p.prefix, creditsResp.StatusCode, modelKey, p.account.ID)
return &creditsOveragesRetryResult{handled: true, resp: creditsResp}
}
s.handleCreditsRetryFailure(p.ctx, p.prefix, modelKey, p.account, creditsResp, err)
return &creditsOveragesRetryResult{handled: true}
}
func (s *AntigravityGatewayService) handleCreditsRetryFailure(
ctx context.Context,
prefix string,
modelKey string,
account *Account,
creditsResp *http.Response,
reqErr error,
) {
var creditsRespBody []byte
creditsStatusCode := 0
if creditsResp != nil {
creditsStatusCode = creditsResp.StatusCode
if creditsResp.Body != nil {
creditsRespBody, _ = io.ReadAll(io.LimitReader(creditsResp.Body, 64<<10))
_ = creditsResp.Body.Close()
}
}
if shouldMarkCreditsExhausted(creditsResp, creditsRespBody, reqErr) && account != nil {
s.setCreditsExhausted(ctx, account)
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=true status=%d body=%s",
prefix, modelKey, account.ID, creditsStatusCode, truncateForLog(creditsRespBody, 200))
return
}
if account != nil {
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=false status=%d err=%v body=%s",
prefix, modelKey, account.ID, creditsStatusCode, reqErr, truncateForLog(creditsRespBody, 200))
}
}

View File

@@ -0,0 +1,538 @@
//go:build unit
package service
import (
"bytes"
"context"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
func TestClassifyAntigravity429(t *testing.T) {
t.Run("明确配额耗尽", func(t *testing.T) {
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body))
})
t.Run("结构化限流", func(t *testing.T) {
body := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
require.Equal(t, antigravity429RateLimited, classifyAntigravity429(body))
})
t.Run("未知429", func(t *testing.T) {
body := []byte(`{"error":{"message":"too many requests"}}`)
require.Equal(t, antigravity429Unknown, classifyAntigravity429(body))
})
}
func TestIsCreditsExhausted_UsesAICreditsKey(t *testing.T) {
t.Run("无 AICredits key 则积分可用", func(t *testing.T) {
account := &Account{
ID: 1,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
},
}
require.False(t, account.isCreditsExhausted())
})
t.Run("AICredits key 生效则积分耗尽", func(t *testing.T) {
account := &Account{
ID: 2,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
modelRateLimitsKey: map[string]any{
creditsExhaustedKey: map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
}
require.True(t, account.isCreditsExhausted())
})
t.Run("AICredits key 过期则积分可用", func(t *testing.T) {
account := &Account{
ID: 3,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
modelRateLimitsKey: map[string]any{
creditsExhaustedKey: map[string]any{
"rate_limited_at": time.Now().Add(-6 * time.Hour).UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
}
require.False(t, account.isCreditsExhausted())
})
}
func TestHandleSmartRetry_QuotaExhausted_UsesCreditsAndStoresIndependentState(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 101,
Name: "acc-101",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
},
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-opus-4-6": "claude-sonnet-4-5",
},
},
}
respBody := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-opus-4-6","request":{}}`),
httpUpstream: upstream,
accountRepo: repo,
requestedModel: "claude-opus-4-6",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp)
require.Nil(t, result.switchError)
require.Len(t, upstream.requestBodies, 1)
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
require.Empty(t, repo.modelRateLimitCalls, "overages 成功后不应写入普通 model_rate_limits")
}
func TestHandleSmartRetry_RateLimited_DoesNotUseCredits(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 102,
Name: "acc-102",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
},
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp)
require.Len(t, upstream.requestBodies, 1)
require.NotContains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
require.Empty(t, repo.extraUpdateCalls)
require.Empty(t, repo.modelRateLimitCalls)
}
func TestAntigravityRetryLoop_ModelRateLimited_InjectsCredits(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
defer func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvailability
}()
antigravity.BaseURLs = []string{"https://ag-1.test"}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
},
},
errors: []error{nil},
}
// 模型已限流 + overages 启用 + 无 AICredits key → 应直接注入积分
account := &Account{
ID: 103,
Name: "acc-103",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"allow_overages": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.NoError(t, err)
require.NotNil(t, result)
require.Len(t, upstream.requestBodies, 1)
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
}
func TestAntigravityRetryLoop_CreditsExhausted_DoesNotInject(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
defer func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvailability
}()
antigravity.BaseURLs = []string{"https://ag-1.test"}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
// 模型限流 + overages 启用 + AICredits key 生效 → 不应注入积分,应切号
account := &Account{
ID: 104,
Name: "acc-104",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"allow_overages": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
},
creditsExhaustedKey: map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{}
_, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
// 模型限流 + 积分耗尽 → 应触发切号错误
require.Error(t, err)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
}
func TestAntigravityRetryLoop_CreditErrorMarksExhausted(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
defer func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvailability
}()
antigravity.BaseURLs = []string{"https://ag-1.test"}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
repo := &stubAntigravityAccountRepo{}
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusForbidden,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Insufficient GOOGLE_ONE_AI credits"}}`)),
},
},
errors: []error{nil},
}
// 模型限流 + overages 启用 + 积分可用 → 注入积分但上游返回积分不足
account := &Account{
ID: 105,
Name: "acc-105",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"allow_overages": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{accountRepo: repo}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
httpUpstream: upstream,
accountRepo: repo,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.NoError(t, err)
require.NotNil(t, result)
// 验证 AICredits key 已通过 SetModelRateLimit 写入数据库
require.Len(t, repo.modelRateLimitCalls, 1, "应通过 SetModelRateLimit 写入 AICredits key")
require.Equal(t, creditsExhaustedKey, repo.modelRateLimitCalls[0].modelKey)
}
func TestShouldMarkCreditsExhausted(t *testing.T) {
t.Run("reqErr 不为 nil 时不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusForbidden}
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), io.ErrUnexpectedEOF))
})
t.Run("resp 为 nil 时不标记", func(t *testing.T) {
require.False(t, shouldMarkCreditsExhausted(nil, []byte(`{"error":"Insufficient credits"}`), nil))
})
t.Run("5xx 响应不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusInternalServerError}
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil))
})
t.Run("408 RequestTimeout 不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusRequestTimeout}
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil))
})
t.Run("URL 级限流不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
body := []byte(`{"error":{"message":"Resource has been exhausted"}}`)
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
})
t.Run("结构化限流不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`)
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
})
t.Run("含 credits 关键词时标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusForbidden}
for _, keyword := range []string{
"Insufficient GOOGLE_ONE_AI credits",
"insufficient credit balance",
"not enough credits for this request",
"Credits exhausted",
"minimumCreditAmountForUsage requirement not met",
} {
body := []byte(`{"error":{"message":"` + keyword + `"}}`)
require.True(t, shouldMarkCreditsExhausted(resp, body, nil), "should mark for keyword: %s", keyword)
}
})
t.Run("无 credits 关键词时不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusForbidden}
body := []byte(`{"error":{"message":"permission denied"}}`)
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
})
}
func TestInjectEnabledCreditTypes(t *testing.T) {
t.Run("正常 JSON 注入成功", func(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-5","request":{}}`)
result := injectEnabledCreditTypes(body)
require.NotNil(t, result)
require.Contains(t, string(result), `"enabledCreditTypes"`)
require.Contains(t, string(result), `GOOGLE_ONE_AI`)
})
t.Run("非法 JSON 返回 nil", func(t *testing.T) {
require.Nil(t, injectEnabledCreditTypes([]byte(`not json`)))
})
t.Run("空 body 返回 nil", func(t *testing.T) {
require.Nil(t, injectEnabledCreditTypes([]byte{}))
})
t.Run("已有 enabledCreditTypes 会被覆盖", func(t *testing.T) {
body := []byte(`{"enabledCreditTypes":["OLD"],"model":"test"}`)
result := injectEnabledCreditTypes(body)
require.NotNil(t, result)
require.Contains(t, string(result), `GOOGLE_ONE_AI`)
require.NotContains(t, string(result), `OLD`)
})
}
func TestClearCreditsExhausted(t *testing.T) {
t.Run("account 为 nil 不操作", func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
svc.clearCreditsExhausted(context.Background(), nil)
require.Empty(t, repo.extraUpdateCalls)
})
t.Run("Extra 为 nil 不操作", func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
svc.clearCreditsExhausted(context.Background(), &Account{ID: 1})
require.Empty(t, repo.extraUpdateCalls)
})
t.Run("无 modelRateLimitsKey 不操作", func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
svc.clearCreditsExhausted(context.Background(), &Account{
ID: 1,
Extra: map[string]any{"some_key": "value"},
})
require.Empty(t, repo.extraUpdateCalls)
})
t.Run("无 AICredits key 不操作", func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
svc.clearCreditsExhausted(context.Background(), &Account{
ID: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
},
},
},
})
require.Empty(t, repo.extraUpdateCalls)
})
t.Run("有 AICredits key 时删除并调用 UpdateExtra", func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{
ID: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
},
creditsExhaustedKey: map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
}
svc.clearCreditsExhausted(context.Background(), account)
require.Len(t, repo.extraUpdateCalls, 1)
// AICredits key 应被删除
rawLimits := account.Extra[modelRateLimitsKey].(map[string]any)
_, exists := rawLimits[creditsExhaustedKey]
require.False(t, exists, "AICredits key 应被删除")
// 普通模型限流应保留
_, exists = rawLimits["claude-sonnet-4-5"]
require.True(t, exists, "普通模型限流应保留")
})
}

View File

@@ -188,9 +188,29 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
return &smartRetryResult{action: smartRetryActionContinueURL}
}
category := antigravity429Unknown
if resp.StatusCode == http.StatusTooManyRequests {
category = classifyAntigravity429(respBody)
}
// 判断是否触发智能重试
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(p.account, respBody)
// AI Credits 超量请求:
// 仅在上游明确返回免费配额耗尽时才允许切换到 credits。
if resp.StatusCode == http.StatusTooManyRequests &&
category == antigravity429QuotaExhausted &&
p.account.IsOveragesEnabled() &&
!p.account.isCreditsExhausted() {
result := s.attemptCreditsOveragesRetry(p, baseURL, modelName, waitDuration, resp.StatusCode, respBody)
if result.handled && result.resp != nil {
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
resp: result.resp,
}
}
}
// 情况1: retryDelay >= 阈值,限流模型并切换账号
if shouldRateLimitModel {
// 单账号 503 退避重试模式:不设限流、不切换账号,改为原地等待+重试
@@ -532,14 +552,31 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
// antigravityRetryLoop 执行带 URL fallback 的重试循环
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
// 预检查:模型限流 + overages 启用 + 积分未耗尽 → 直接注入 AI Credits
overagesInjected := false
if p.requestedModel != "" && p.account.Platform == PlatformAntigravity &&
p.account.IsOveragesEnabled() && !p.account.isCreditsExhausted() &&
p.account.isModelRateLimitedWithContext(p.ctx, p.requestedModel) {
if creditsBody := injectEnabledCreditTypes(p.body); creditsBody != nil {
p.body = creditsBody
overagesInjected = true
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: model_rate_limited_credits_inject model=%s account=%d (injecting enabledCreditTypes)",
p.prefix, p.requestedModel, p.account.ID)
}
}
// 预检查:如果账号已限流,直接返回切换信号
if p.requestedModel != "" {
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
// 单账号 503 退避重试模式:跳过限流预检查,直接发请求
// 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。
// 如果上游确实还不可用handleSmartRetry → handleSingleAccountRetryInPlace
// 会在 Service 层原地等待+重试,不需要在预检查这里等。
if isSingleAccountRetry(p.ctx) {
// 已注入积分的请求不再受普通模型限流预检查阻断
if overagesInjected {
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: credits_injected_ignore_rate_limit remaining=%v model=%s account=%d",
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
} else if isSingleAccountRetry(p.ctx) {
// 单账号 503 退避重试模式:跳过限流预检查,直接发请求。
// 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。
// 如果上游确实还不可用handleSmartRetry → handleSingleAccountRetryInPlace
// 会在 Service 层原地等待+重试,不需要在预检查这里等。
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)",
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
} else {
@@ -631,6 +668,15 @@ urlFallbackLoop:
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if overagesInjected && shouldMarkCreditsExhausted(resp, respBody, nil) {
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, "", p.requestedModel)
s.handleCreditsRetryFailure(p.ctx, p.prefix, modelKey, p.account, &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}, nil)
}
// ★ 统一入口:自定义错误码 + 临时不可调度
if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
if policyErr != nil {
@@ -884,7 +930,7 @@ func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParam
case ErrorPolicyTempUnscheduled:
slog.Info("temp_unschedulable_matched",
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, RateLimitedModel: p.requestedModel, IsStickySession: p.isStickySession}
}
return false, statusCode, nil
}
@@ -955,8 +1001,9 @@ type TestConnectionResult struct {
MappedModel string // 实际使用的模型
}
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
// TestConnection 测试 Antigravity 账号连接
// 复用 antigravityRetryLoop 的完整重试 / credits overages / 智能重试逻辑,
// 与真实调度行为一致。差异:不做账号切换(测试指定账号)、不记录 ops 错误。
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
// 获取 token
@@ -980,10 +1027,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
// 构建请求体
var requestBody []byte
if strings.HasPrefix(modelID, "gemini-") {
// Gemini 模型:直接使用 Gemini 格式
requestBody, err = s.buildGeminiTestRequest(projectID, mappedModel)
} else {
// Claude 模型:使用协议转换
requestBody, err = s.buildClaudeTestRequest(projectID, mappedModel)
}
if err != nil {
@@ -996,64 +1041,63 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
proxyURL = account.Proxy.URL()
}
baseURL := resolveAntigravityForwardBaseURL()
if baseURL == "" {
return nil, errors.New("no antigravity forward base url configured")
}
availableURLs := []string{baseURL}
var lastErr error
for urlIdx, baseURL := range availableURLs {
// 构建 HTTP 请求(总是使用流式 endpoint与官方客户端一致
req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody)
if err != nil {
lastErr = err
continue
}
// 调试日志Test 请求信息
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String())
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
lastErr = fmt.Errorf("请求失败: %w", err)
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
return nil, lastErr
}
// 读取响应
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
// 检查是否需要 URL 降级
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
}
// 解析流式响应,提取文本
text := extractTextFromSSEResponse(respBody)
// 标记成功的 URL下次优先使用
antigravity.DefaultURLAvailability.MarkSuccess(baseURL)
return &TestConnectionResult{
Text: text,
MappedModel: mappedModel,
}, nil
// 复用 antigravityRetryLoop完整的重试 / credits overages / 智能重试
prefix := fmt.Sprintf("[antigravity-Test] account=%d(%s)", account.ID, account.Name)
p := antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
account: account,
proxyURL: proxyURL,
accessToken: accessToken,
action: "streamGenerateContent",
body: requestBody,
c: nil, // 无 gin.Context → 跳过 ops 追踪
httpUpstream: s.httpUpstream,
settingService: s.settingService,
accountRepo: s.accountRepo,
requestedModel: modelID,
handleError: testConnectionHandleError,
}
return nil, lastErr
result, err := s.antigravityRetryLoop(p)
if err != nil {
// AccountSwitchError → 测试时不切换账号,返回友好提示
var switchErr *AntigravityAccountSwitchError
if errors.As(err, &switchErr) {
return nil, fmt.Errorf("该账号模型 %s 当前限流中,请稍后重试", switchErr.RateLimitedModel)
}
return nil, err
}
if result == nil || result.resp == nil {
return nil, errors.New("upstream returned empty response")
}
defer func() { _ = result.resp.Body.Close() }()
respBody, err := io.ReadAll(io.LimitReader(result.resp.Body, 2<<20))
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if result.resp.StatusCode >= 400 {
return nil, fmt.Errorf("API 返回 %d: %s", result.resp.StatusCode, string(respBody))
}
text := extractTextFromSSEResponse(respBody)
return &TestConnectionResult{Text: text, MappedModel: mappedModel}, nil
}
// testConnectionHandleError 是 TestConnection 使用的轻量 handleError 回调。
// 仅记录日志,不做 ops 错误追踪或粘性会话清除。
func testConnectionHandleError(
_ context.Context, prefix string, account *Account,
statusCode int, _ http.Header, body []byte,
requestedModel string, _ int64, _ string, _ bool,
) *handleModelRateLimitResult {
logger.LegacyPrintf("service.antigravity_gateway",
"%s test_handle_error status=%d model=%s account=%d body=%s",
prefix, statusCode, requestedModel, account.ID, truncateForLog(body, 200))
return nil
}
// buildGeminiTestRequest 构建 Gemini 格式测试请求
@@ -3033,6 +3077,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
intervalCh = intervalTicker.C
}
// 下游 keepalive防止代理/Cloudflare Tunnel 因连接空闲而断开
keepaliveInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini")
// 仅发送一次错误事件,避免多次写入导致协议混乱
@@ -3065,6 +3125,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
return nil, ev.err
}
lastDataAt = time.Now()
line := ev.line
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
@@ -3124,6 +3186,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if cw.Disconnected() {
continue
}
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// SSE ping/keepalive保持连接活跃防止 Cloudflare Tunnel 等代理断开
if !cw.Fprintf(":\n\n") {
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity gemini), continuing to drain upstream for billing")
continue
}
}
}
}
@@ -3849,6 +3924,22 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
intervalCh = intervalTicker.C
}
// 下游 keepalive防止代理/Cloudflare Tunnel 因连接空闲而断开
keepaliveInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude")
// 仅发送一次错误事件,避免多次写入导致协议混乱
@@ -3901,6 +3992,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
return nil, fmt.Errorf("stream read error: %w", ev.err)
}
lastDataAt = time.Now()
// 处理 SSE 行,转换为 Claude 格式
claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n"))
if len(claudeEvents) > 0 {
@@ -3923,6 +4016,20 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if cw.Disconnected() {
continue
}
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// SSE ping 事件Anthropic 原生格式,客户端会正确处理,
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
if !cw.Fprintf("event: ping\ndata: {\"type\": \"ping\"}\n\n") {
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity claude), continuing to drain upstream for billing")
continue
}
}
}
}
@@ -4253,6 +4360,22 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
intervalCh = intervalTicker.C
}
// 下游 keepalive防止代理/Cloudflare Tunnel 因连接空闲而断开
keepaliveInterval := time.Duration(0)
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.settingService.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
flusher, _ := c.Writer.(http.Flusher)
cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream")
@@ -4270,6 +4393,8 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
}
lastDataAt = time.Now()
line := ev.line
// 记录首 token 时间
@@ -4295,6 +4420,20 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
}
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity upstream)")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}
case <-keepaliveCh:
if cw.Disconnected() {
continue
}
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// SSE ping 事件Anthropic 原生格式,客户端会正确处理,
// 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开
if !cw.Fprintf("event: ping\ndata: {\"type\": \"ping\"}\n\n") {
logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during keepalive ping (antigravity upstream), continuing to drain upstream for billing")
continue
}
}
}
}

View File

@@ -57,16 +57,16 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "claude-opus-4-6-thinking",
},
{
name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-5",
name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-6",
requestedModel: "claude-haiku-4-5",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-sonnet-4-6",
},
{
name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-6",
requestedModel: "claude-haiku-4-5-20251001",
accountMapping: nil,
expected: "claude-sonnet-4-5",
expected: "claude-sonnet-4-6",
},
{
name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5",

View File

@@ -78,11 +78,11 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
return nil, err
}
// 调用 LoadCodeAssist 获取订阅等级(非关键路径,失败不影响主流程)
tierRaw, tierNormalized := f.fetchSubscriptionTier(ctx, client, accessToken)
// 调用 LoadCodeAssist 获取订阅等级和 AI Credits 余额(非关键路径,失败不影响主流程)
tierRaw, tierNormalized, loadResp := f.fetchSubscriptionTier(ctx, client, accessToken)
// 转换为 UsageInfo
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized)
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized, loadResp)
return &QuotaResult{
UsageInfo: usageInfo,
@@ -90,20 +90,21 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
}, nil
}
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string) {
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串
// 同时返回 LoadCodeAssistResponse以便提取 AI Credits 余额。
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string, loadResp *antigravity.LoadCodeAssistResponse) {
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
if err != nil {
slog.Warn("failed to fetch subscription tier", "error", err)
return "", ""
return "", "", nil
}
if loadResp == nil {
return "", ""
return "", "", nil
}
raw = loadResp.GetTier() // 已有方法paidTier > currentTier
normalized = normalizeTier(raw)
return raw, normalized
return raw, normalized, loadResp
}
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
@@ -124,8 +125,8 @@ func normalizeTier(raw string) string {
}
}
// buildUsageInfo 将 API 响应转换为 UsageInfo
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string) *UsageInfo {
// buildUsageInfo 将 API 响应转换为 UsageInfo
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string, loadResp *antigravity.LoadCodeAssistResponse) *UsageInfo {
now := time.Now()
info := &UsageInfo{
UpdatedAt: &now,
@@ -190,6 +191,16 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv
}
}
if loadResp != nil {
for _, credit := range loadResp.GetAvailableCredits() {
info.AICredits = append(info.AICredits, AICredit{
CreditType: credit.CreditType,
Amount: credit.GetAmount(),
MinimumBalance: credit.GetMinimumAmount(),
})
}
}
return info
}

View File

@@ -81,7 +81,7 @@ func TestBuildUsageInfo_BasicModels(t *testing.T) {
},
}
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO")
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", nil)
// 基本字段
require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set")
@@ -141,7 +141,7 @@ func TestBuildUsageInfo_DeprecatedModels(t *testing.T) {
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.Len(t, info.ModelForwardingRules, 2)
require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"])
@@ -159,7 +159,7 @@ func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) {
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models")
}
@@ -171,7 +171,7 @@ func TestBuildUsageInfo_EmptyModels(t *testing.T) {
Models: map[string]antigravity.ModelInfo{},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info)
require.NotNil(t, info.AntigravityQuota)
@@ -193,7 +193,7 @@ func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) {
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info)
require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped")
@@ -222,7 +222,7 @@ func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) {
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists")
// claude-sonnet-4-20250514 is first in priority list, so it should be used
@@ -251,7 +251,7 @@ func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) {
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info.FiveHour)
expectedUtilization := (1.0 - 0.60) * 100 // 40
@@ -277,7 +277,7 @@ func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) {
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info.FiveHour)
expectedUtilization := (1.0 - 0.30) * 100 // 70
@@ -298,7 +298,7 @@ func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) {
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
}
@@ -317,7 +317,7 @@ func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) {
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info.FiveHour)
require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty")
@@ -338,7 +338,7 @@ func TestBuildUsageInfo_FullUtilization(t *testing.T) {
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
require.NotNil(t, quota)
@@ -358,13 +358,38 @@ func TestBuildUsageInfo_ZeroUtilization(t *testing.T) {
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
require.NotNil(t, quota)
require.Equal(t, 0, quota.Utilization)
}
func TestBuildUsageInfo_AICredits(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{},
}
loadResp := &antigravity.LoadCodeAssistResponse{
PaidTier: &antigravity.PaidTierInfo{
ID: "g1-pro-tier",
AvailableCredits: []antigravity.AvailableCredit{
{
CreditType: "GOOGLE_ONE_AI",
CreditAmount: "25",
MinimumCreditAmountForUsage: "5",
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", loadResp)
require.Len(t, info.AICredits, 1)
require.Equal(t, "GOOGLE_ONE_AI", info.AICredits[0].CreditType)
require.Equal(t, 25.0, info.AICredits[0].Amount)
require.Equal(t, 5.0, info.AICredits[0].MinimumBalance)
}
func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) {
// 模拟 FetchQuota 遇到 403 时的行为:
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true

View File

@@ -32,6 +32,10 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste
return false
}
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
// Antigravity + overages 启用 + 积分未耗尽 → 放行(有积分可用)
if a.Platform == PlatformAntigravity && a.IsOveragesEnabled() && !a.isCreditsExhausted() {
return true
}
return false
}
return true

View File

@@ -76,10 +76,16 @@ type modelRateLimitCall struct {
resetAt time.Time
}
type extraUpdateCall struct {
accountID int64
updates map[string]any
}
type stubAntigravityAccountRepo struct {
AccountRepository
rateCalls []rateLimitCall
modelRateLimitCalls []modelRateLimitCall
extraUpdateCalls []extraUpdateCall
}
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
@@ -92,6 +98,11 @@ func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id i
return nil
}
func (s *stubAntigravityAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
s.extraUpdateCalls = append(s.extraUpdateCalls, extraUpdateCall{accountID: id, updates: updates})
return nil
}
func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T) {
t.Setenv(antigravityForwardBaseURLEnv, "")

View File

@@ -260,14 +260,15 @@ func TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches(t *test
// TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit
// 503 + retryDelay < 7s + SingleAccountRetry → 智能重试耗尽后直接返回 503不设限流
// 使用 RATE_LIMIT_EXCEEDED走 1 次智能重试),避免 MODEL_CAPACITY_EXHAUSTED 的 60 次重试导致测试超时
func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testing.T) {
// 智能重试也返回 503
failRespBody := `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
@@ -278,8 +279,9 @@ func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testi
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
responses: []*http.Response{failResp},
errors: []error{nil},
repeatLast: true,
}
repo := &stubAntigravityAccountRepo{}
@@ -294,9 +296,9 @@ func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testi
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
@@ -569,8 +571,9 @@ func TestHandleSingleAccountRetryInPlace_WaitDurationClamped(t *testing.T) {
svc := &AntigravityGatewayService{}
// 等待时间过大应被 clamp 到 antigravitySingleAccountSmartRetryMaxWait
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 999*time.Second, "gemini-3-pro")
// waitDuration=0 会被 clamp 到 antigravitySmartRetryMinWait=1s。
// 首次重试即成功200总耗时 ~1s。
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 0, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp)

View File

@@ -32,20 +32,65 @@ func (c *stubSmartRetryCache) DeleteSessionAccountID(_ context.Context, groupID
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
type mockSmartRetryUpstream struct {
responses []*http.Response
errors []error
callIdx int
calls []string
responses []*http.Response
responseBodies [][]byte // 缓存的 response body 字节(用于 repeatLast 重建)
errors []error
callIdx int
calls []string
requestBodies [][]byte
repeatLast bool // 超出范围时重复最后一个响应
}
func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
idx := m.callIdx
m.calls = append(m.calls, req.URL.String())
m.callIdx++
if idx < len(m.responses) {
return m.responses[idx], m.errors[idx]
if req != nil && req.Body != nil {
body, _ := io.ReadAll(req.Body)
m.requestBodies = append(m.requestBodies, body)
req.Body = io.NopCloser(bytes.NewReader(body))
} else {
m.requestBodies = append(m.requestBodies, nil)
}
return nil, nil
m.callIdx++
// 确定使用哪个索引
respIdx := idx
if respIdx >= len(m.responses) {
if !m.repeatLast || len(m.responses) == 0 {
return nil, nil
}
respIdx = len(m.responses) - 1
}
resp := m.responses[respIdx]
respErr := m.errors[respIdx]
if resp == nil {
return nil, respErr
}
// 首次调用时缓存 body 字节
if respIdx >= len(m.responseBodies) {
for len(m.responseBodies) <= respIdx {
m.responseBodies = append(m.responseBodies, nil)
}
}
if m.responseBodies[respIdx] == nil && resp.Body != nil {
bodyBytes, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
m.responseBodies[respIdx] = bodyBytes
}
// 用缓存的 body 字节重建新的 reader
var body io.ReadCloser
if m.responseBodies[respIdx] != nil {
body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx]))
}
return &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: body,
}, respErr
}
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {

View File

@@ -3,7 +3,6 @@ package service
import (
"context"
"errors"
"log"
"log/slog"
"strconv"
"strings"
@@ -17,15 +16,18 @@ const (
antigravityBackfillCooldown = 5 * time.Minute
)
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
// AntigravityTokenCache token cache interface.
type AntigravityTokenCache = GeminiTokenCache
// AntigravityTokenProvider 管理 Antigravity 账户的 access_token
// AntigravityTokenProvider manages access_token for antigravity accounts.
type AntigravityTokenProvider struct {
accountRepo AccountRepository
tokenCache AntigravityTokenCache
antigravityOAuthService *AntigravityOAuthService
backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time
backfillCooldown sync.Map // key: accountID -> last attempt time
refreshAPI *OAuthRefreshAPI
executor OAuthRefreshExecutor
refreshPolicy ProviderRefreshPolicy
}
func NewAntigravityTokenProvider(
@@ -37,10 +39,22 @@ func NewAntigravityTokenProvider(
accountRepo: accountRepo,
tokenCache: tokenCache,
antigravityOAuthService: antigravityOAuthService,
refreshPolicy: AntigravityProviderRefreshPolicy(),
}
}
// GetAccessToken 获取有效的 access_token
// SetRefreshAPI injects unified OAuth refresh API and executor.
func (p *AntigravityTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
p.refreshAPI = api
p.executor = executor
}
// SetRefreshPolicy injects caller-side refresh policy.
func (p *AntigravityTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
p.refreshPolicy = policy
}
// GetAccessToken returns a valid access_token.
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
@@ -48,7 +62,8 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if account.Platform != PlatformAntigravity {
return "", errors.New("not an antigravity account")
}
// upstream 类型:直接从 credentials 读取 api_key不走 OAuth 刷新流程
// upstream accounts use static api_key and never refresh oauth token.
if account.Type == AccountTypeUpstream {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
@@ -62,46 +77,38 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
cacheKey := AntigravityTokenCacheKey(account)
// 1. 先尝试缓存
// 1) Try cache first.
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
// 2. 如果即将过期则刷新
// 2) Refresh if needed (pre-expiry skew).
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
if needsRefresh && p.tokenCache != nil {
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, antigravityTokenRefreshSkew)
if err != nil {
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
}
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
// default policy: continue with existing token.
} else {
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
}
} else if needsRefresh && p.tokenCache != nil {
// Backward-compatible test path when refreshAPI is not injected.
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
if p.antigravityOAuthService == nil {
return "", errors.New("antigravity oauth service not configured")
}
tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return "", err
}
p.mergeCredentials(account, tokenInfo)
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
@@ -110,32 +117,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("access_token not found in credentials")
}
// 如果账号还没有 project_id,尝试在线补齐,避免请求 daily/sandbox 时出现
// "Invalid project resource name projects/"。
// 仅调用 loadProjectIDWithRetry不刷新 OAuth token带冷却机制防止频繁重试。
// Backfill project_id online when missing, with cooldown to avoid hammering.
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
if p.shouldAttemptBackfill(account.ID) {
p.markBackfillAttempted(account.ID)
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
account.Credentials["project_id"] = projectID
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr)
slog.Warn("antigravity_project_id_backfill_persist_failed",
"account_id", account.ID,
"error", updateErr,
)
}
}
}
}
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
// 3) Populate cache with TTL.
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
@@ -156,18 +162,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return accessToken, nil
}
// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段
func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) {
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
}
// shouldAttemptBackfill 检查是否应该尝试补齐 project_id冷却期内不重复尝试
// shouldAttemptBackfill checks backfill cooldown.
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
if v, ok := p.backfillCooldown.Load(accountID); ok {
if lastAttempt, ok := v.(time.Time); ok {

View File

@@ -25,6 +25,11 @@ func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthServi
}
}
// CacheKey 返回用于分布式锁的缓存键
func (r *AntigravityTokenRefresher) CacheKey(account *Account) string {
return AntigravityTokenCacheKey(account)
}
// CanRefresh 检查是否可以刷新此账户
func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth
@@ -58,11 +63,7 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
// 合并旧的 credentials保留新 credentials 中不存在的字段
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
newCredentials = MergeCredentials(account.Credentials, newCredentials)
// 特殊处理 project_id如果新值为空但旧值非空保留旧值
// 这确保了即使 LoadCodeAssist 失败project_id 也不会丢失

View File

@@ -22,8 +22,9 @@ const (
)
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
// A nil windowStart is treated as expired — no initialized window means any accumulated usage is stale.
func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool {
return windowStart != nil && time.Since(*windowStart) >= duration
return windowStart == nil || time.Since(*windowStart) >= duration
}
type APIKey struct {

View File

@@ -15,10 +15,10 @@ func TestIsWindowExpired(t *testing.T) {
want bool
}{
{
name: "nil window start",
name: "nil window start (treated as expired)",
start: nil,
duration: RateLimitWindow5h,
want: false,
want: true,
},
{
name: "active window (started 1h ago, 5h window)",
@@ -113,7 +113,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
want7d: 0,
},
{
name: "nil window starts return raw usage",
name: "nil window starts return 0 (stale usage reset)",
key: APIKey{
Usage5h: 5.0,
Usage1d: 10.0,
@@ -122,9 +122,9 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
Window1dStart: nil,
Window7dStart: nil,
},
want5h: 5.0,
want1d: 10.0,
want7d: 50.0,
want5h: 0,
want1d: 0,
want7d: 0,
},
{
name: "mixed: 5h expired, 1d active, 7d nil",
@@ -138,7 +138,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
},
want5h: 0,
want1d: 10.0,
want7d: 50.0,
want7d: 0,
},
{
name: "zero usage with active windows",
@@ -210,7 +210,7 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
want7d: 0,
},
{
name: "nil window starts return raw usage",
name: "nil window starts return 0 (stale usage reset)",
data: APIKeyRateLimitData{
Usage5h: 3.0,
Usage1d: 8.0,
@@ -219,9 +219,9 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
Window1dStart: nil,
Window7dStart: nil,
},
want5h: 3.0,
want1d: 8.0,
want7d: 40.0,
want5h: 0,
want1d: 0,
want7d: 0,
},
}

View File

@@ -4,11 +4,13 @@ import (
"compress/gzip"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
@@ -84,17 +86,21 @@ type BackupScheduleConfig struct {
// BackupRecord 备份记录
type BackupRecord struct {
ID string `json:"id"`
Status string `json:"status"` // pending, running, completed, failed
BackupType string `json:"backup_type"` // postgres
FileName string `json:"file_name"`
S3Key string `json:"s3_key"`
SizeBytes int64 `json:"size_bytes"`
TriggeredBy string `json:"triggered_by"` // manual, scheduled
ErrorMsg string `json:"error_message,omitempty"`
StartedAt string `json:"started_at"`
FinishedAt string `json:"finished_at,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
ID string `json:"id"`
Status string `json:"status"` // pending, running, completed, failed
BackupType string `json:"backup_type"` // postgres
FileName string `json:"file_name"`
S3Key string `json:"s3_key"`
SizeBytes int64 `json:"size_bytes"`
TriggeredBy string `json:"triggered_by"` // manual, scheduled
ErrorMsg string `json:"error_message,omitempty"`
StartedAt string `json:"started_at"`
FinishedAt string `json:"finished_at,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
Progress string `json:"progress,omitempty"` // "dumping", "uploading", ""
RestoreStatus string `json:"restore_status,omitempty"` // "", "running", "completed", "failed"
RestoreError string `json:"restore_error,omitempty"`
RestoredAt string `json:"restored_at,omitempty"`
}
// BackupService 数据库备份恢复服务
@@ -105,17 +111,24 @@ type BackupService struct {
storeFactory BackupObjectStoreFactory
dumper DBDumper
mu sync.Mutex
store BackupObjectStore
s3Cfg *BackupS3Config
opMu sync.Mutex // 保护 backingUp/restoring 标志
backingUp bool
restoring bool
storeMu sync.Mutex // 保护 store/s3Cfg 缓存
store BackupObjectStore
s3Cfg *BackupS3Config
recordsMu sync.Mutex // 保护 records 的 load/save 操作
cronMu sync.Mutex
cronSched *cron.Cron
cronEntryID cron.EntryID
wg sync.WaitGroup // 追踪活跃的备份/恢复 goroutine
shuttingDown atomic.Bool // 阻止新备份启动
bgCtx context.Context // 所有后台操作的 parent context
bgCancel context.CancelFunc // 取消所有活跃后台操作
}
func NewBackupService(
@@ -125,20 +138,26 @@ func NewBackupService(
storeFactory BackupObjectStoreFactory,
dumper DBDumper,
) *BackupService {
bgCtx, bgCancel := context.WithCancel(context.Background())
return &BackupService{
settingRepo: settingRepo,
dbCfg: &cfg.Database,
encryptor: encryptor,
storeFactory: storeFactory,
dumper: dumper,
bgCtx: bgCtx,
bgCancel: bgCancel,
}
}
// Start 启动定时备份调度器
// Start 启动定时备份调度器并清理孤立记录
func (s *BackupService) Start() {
s.cronSched = cron.New()
s.cronSched.Start()
// 清理重启后孤立的 running 记录
s.recoverStaleRecords()
// 加载已有的定时配置
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
@@ -154,13 +173,65 @@ func (s *BackupService) Start() {
}
}
// Stop 停止定时备份
// recoverStaleRecords 启动时将孤立的 running 记录标记为 failed
func (s *BackupService) recoverStaleRecords() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
records, err := s.loadRecords(ctx)
if err != nil {
return
}
for i := range records {
if records[i].Status == "running" {
records[i].Status = "failed"
records[i].ErrorMsg = "interrupted by server restart"
records[i].Progress = ""
records[i].FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(ctx, &records[i])
logger.LegacyPrintf("service.backup", "[Backup] recovered stale running record: %s", records[i].ID)
}
if records[i].RestoreStatus == "running" {
records[i].RestoreStatus = "failed"
records[i].RestoreError = "interrupted by server restart"
_ = s.saveRecord(ctx, &records[i])
logger.LegacyPrintf("service.backup", "[Backup] recovered stale restoring record: %s", records[i].ID)
}
}
}
// Stop 停止定时备份并等待活跃操作完成
func (s *BackupService) Stop() {
s.shuttingDown.Store(true)
s.cronMu.Lock()
defer s.cronMu.Unlock()
if s.cronSched != nil {
s.cronSched.Stop()
}
s.cronMu.Unlock()
// 等待活跃备份/恢复完成(最多 5 分钟)
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
logger.LegacyPrintf("service.backup", "[Backup] all active operations finished")
case <-time.After(5 * time.Minute):
logger.LegacyPrintf("service.backup", "[Backup] shutdown timeout after 5min, cancelling active operations")
if s.bgCancel != nil {
s.bgCancel() // 取消所有后台操作
}
// 给 goroutine 时间响应取消并完成清理
select {
case <-done:
logger.LegacyPrintf("service.backup", "[Backup] active operations cancelled and cleaned up")
case <-time.After(10 * time.Second):
logger.LegacyPrintf("service.backup", "[Backup] goroutine cleanup timed out")
}
}
}
// ─── S3 配置管理 ───
@@ -203,10 +274,10 @@ func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config)
}
// 清除缓存的 S3 客户端
s.mu.Lock()
s.storeMu.Lock()
s.store = nil
s.s3Cfg = nil
s.mu.Unlock()
s.storeMu.Unlock()
cfg.SecretAccessKey = ""
return &cfg, nil
@@ -314,7 +385,10 @@ func (s *BackupService) removeCronSchedule() {
}
func (s *BackupService) runScheduledBackup() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
s.wg.Add(1)
defer s.wg.Done()
ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute)
defer cancel()
// 读取定时备份配置中的过期天数
@@ -327,7 +401,11 @@ func (s *BackupService) runScheduledBackup() {
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
if err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
if errors.Is(err, ErrBackupInProgress) {
logger.LegacyPrintf("service.backup", "[Backup] 定时备份跳过: 已有备份正在进行中")
} else {
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
}
return
}
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
@@ -346,17 +424,21 @@ func (s *BackupService) runScheduledBackup() {
// CreateBackup 创建全量数据库备份并上传到 S3流式处理
// expireDays: 备份过期天数0=永不过期默认14天
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
s.mu.Lock()
if s.shuttingDown.Load() {
return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down")
}
s.opMu.Lock()
if s.backingUp {
s.mu.Unlock()
s.opMu.Unlock()
return nil, ErrBackupInProgress
}
s.backingUp = true
s.mu.Unlock()
s.opMu.Unlock()
defer func() {
s.mu.Lock()
s.opMu.Lock()
s.backingUp = false
s.mu.Unlock()
s.opMu.Unlock()
}()
s3Cfg, err := s.loadS3Config(ctx)
@@ -405,36 +487,47 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
pr, pw := io.Pipe()
var gzipErr error
gzipDone := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
pw.CloseWithError(fmt.Errorf("gzip goroutine panic: %v", r)) //nolint:errcheck
gzipDone <- fmt.Errorf("gzip goroutine panic: %v", r)
}
}()
gzWriter := gzip.NewWriter(pw)
_, gzipErr = io.Copy(gzWriter, dumpReader)
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
gzipErr = closeErr
var gzErr error
_, gzErr = io.Copy(gzWriter, dumpReader)
if closeErr := gzWriter.Close(); closeErr != nil && gzErr == nil {
gzErr = closeErr
}
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
gzipErr = closeErr
if closeErr := dumpReader.Close(); closeErr != nil && gzErr == nil {
gzErr = closeErr
}
if gzipErr != nil {
_ = pw.CloseWithError(gzipErr)
if gzErr != nil {
_ = pw.CloseWithError(gzErr)
} else {
_ = pw.Close()
}
gzipDone <- gzErr
}()
contentType := "application/gzip"
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
if err != nil {
_ = pr.CloseWithError(err) // 确保 gzip goroutine 不会悬挂
gzErr := <-gzipDone // 安全等待 gzip goroutine 完成
record.Status = "failed"
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
if gzipErr != nil {
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
if gzErr != nil {
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzErr)
}
record.ErrorMsg = errMsg
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(ctx, record)
return record, fmt.Errorf("backup upload: %w", err)
}
<-gzipDone // 确保 gzip goroutine 已退出
record.SizeBytes = sizeBytes
record.Status = "completed"
@@ -446,19 +539,187 @@ func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, ex
return record, nil
}
// StartBackup 异步创建备份,立即返回 running 状态的记录
func (s *BackupService) StartBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
if s.shuttingDown.Load() {
return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down")
}
s.opMu.Lock()
if s.backingUp {
s.opMu.Unlock()
return nil, ErrBackupInProgress
}
s.backingUp = true
s.opMu.Unlock()
// 初始化阶段出错时自动重置标志
launched := false
defer func() {
if !launched {
s.opMu.Lock()
s.backingUp = false
s.opMu.Unlock()
}
}()
// 在返回前加载 S3 配置和创建 store避免 goroutine 中配置被修改
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return nil, err
}
if s3Cfg == nil || !s3Cfg.IsConfigured() {
return nil, ErrBackupS3NotConfigured
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return nil, fmt.Errorf("init object store: %w", err)
}
now := time.Now()
backupID := uuid.New().String()[:8]
fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405"))
s3Key := s.buildS3Key(s3Cfg, fileName)
var expiresAt string
if expireDays > 0 {
expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339)
}
record := &BackupRecord{
ID: backupID,
Status: "running",
BackupType: "postgres",
FileName: fileName,
S3Key: s3Key,
TriggeredBy: triggeredBy,
StartedAt: now.Format(time.RFC3339),
ExpiresAt: expiresAt,
Progress: "pending",
}
if err := s.saveRecord(ctx, record); err != nil {
return nil, fmt.Errorf("save initial record: %w", err)
}
launched = true
// 在启动 goroutine 前完成拷贝,避免数据竞争
result := *record
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() {
s.opMu.Lock()
s.backingUp = false
s.opMu.Unlock()
}()
defer func() {
if r := recover(); r != nil {
logger.LegacyPrintf("service.backup", "[Backup] panic recovered: %v", r)
record.Status = "failed"
record.ErrorMsg = fmt.Sprintf("internal panic: %v", r)
record.Progress = ""
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(context.Background(), record)
}
}()
s.executeBackup(record, objectStore)
}()
return &result, nil
}
// executeBackup 后台执行备份(独立于 HTTP context
func (s *BackupService) executeBackup(record *BackupRecord, objectStore BackupObjectStore) {
ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute)
defer cancel()
// 阶段1: pg_dump
record.Progress = "dumping"
_ = s.saveRecord(ctx, record)
dumpReader, err := s.dumper.Dump(ctx)
if err != nil {
record.Status = "failed"
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
record.Progress = ""
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(context.Background(), record)
return
}
// 阶段2: gzip + upload
record.Progress = "uploading"
_ = s.saveRecord(ctx, record)
pr, pw := io.Pipe()
gzipDone := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
pw.CloseWithError(fmt.Errorf("gzip goroutine panic: %v", r)) //nolint:errcheck
gzipDone <- fmt.Errorf("gzip goroutine panic: %v", r)
}
}()
gzWriter := gzip.NewWriter(pw)
var gzErr error
_, gzErr = io.Copy(gzWriter, dumpReader)
if closeErr := gzWriter.Close(); closeErr != nil && gzErr == nil {
gzErr = closeErr
}
if closeErr := dumpReader.Close(); closeErr != nil && gzErr == nil {
gzErr = closeErr
}
if gzErr != nil {
_ = pw.CloseWithError(gzErr)
} else {
_ = pw.Close()
}
gzipDone <- gzErr
}()
contentType := "application/gzip"
sizeBytes, err := objectStore.Upload(ctx, record.S3Key, pr, contentType)
if err != nil {
_ = pr.CloseWithError(err) // 确保 gzip goroutine 不会悬挂
gzErr := <-gzipDone // 安全等待 gzip goroutine 完成
record.Status = "failed"
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
if gzErr != nil {
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzErr)
}
record.ErrorMsg = errMsg
record.Progress = ""
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(context.Background(), record)
return
}
<-gzipDone // 确保 gzip goroutine 已退出
record.SizeBytes = sizeBytes
record.Status = "completed"
record.Progress = ""
record.FinishedAt = time.Now().Format(time.RFC3339)
if err := s.saveRecord(context.Background(), record); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err)
}
}
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
s.mu.Lock()
s.opMu.Lock()
if s.restoring {
s.mu.Unlock()
s.opMu.Unlock()
return ErrRestoreInProgress
}
s.restoring = true
s.mu.Unlock()
s.opMu.Unlock()
defer func() {
s.mu.Lock()
s.opMu.Lock()
s.restoring = false
s.mu.Unlock()
s.opMu.Unlock()
}()
record, err := s.GetBackupRecord(ctx, backupID)
@@ -500,6 +761,112 @@ func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) erro
return nil
}
// StartRestore 异步恢复备份,立即返回
func (s *BackupService) StartRestore(ctx context.Context, backupID string) (*BackupRecord, error) {
if s.shuttingDown.Load() {
return nil, infraerrors.ServiceUnavailable("SERVER_SHUTTING_DOWN", "server is shutting down")
}
s.opMu.Lock()
if s.restoring {
s.opMu.Unlock()
return nil, ErrRestoreInProgress
}
s.restoring = true
s.opMu.Unlock()
// 初始化阶段出错时自动重置标志
launched := false
defer func() {
if !launched {
s.opMu.Lock()
s.restoring = false
s.opMu.Unlock()
}
}()
record, err := s.GetBackupRecord(ctx, backupID)
if err != nil {
return nil, err
}
if record.Status != "completed" {
return nil, infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup")
}
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return nil, err
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return nil, fmt.Errorf("init object store: %w", err)
}
record.RestoreStatus = "running"
_ = s.saveRecord(ctx, record)
launched = true
result := *record
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() {
s.opMu.Lock()
s.restoring = false
s.opMu.Unlock()
}()
defer func() {
if r := recover(); r != nil {
logger.LegacyPrintf("service.backup", "[Backup] restore panic recovered: %v", r)
record.RestoreStatus = "failed"
record.RestoreError = fmt.Sprintf("internal panic: %v", r)
_ = s.saveRecord(context.Background(), record)
}
}()
s.executeRestore(record, objectStore)
}()
return &result, nil
}
// executeRestore 后台执行恢复
func (s *BackupService) executeRestore(record *BackupRecord, objectStore BackupObjectStore) {
ctx, cancel := context.WithTimeout(s.bgCtx, 30*time.Minute)
defer cancel()
body, err := objectStore.Download(ctx, record.S3Key)
if err != nil {
record.RestoreStatus = "failed"
record.RestoreError = fmt.Sprintf("S3 download failed: %v", err)
_ = s.saveRecord(context.Background(), record)
return
}
defer func() { _ = body.Close() }()
gzReader, err := gzip.NewReader(body)
if err != nil {
record.RestoreStatus = "failed"
record.RestoreError = fmt.Sprintf("gzip reader: %v", err)
_ = s.saveRecord(context.Background(), record)
return
}
defer func() { _ = gzReader.Close() }()
if err := s.dumper.Restore(ctx, gzReader); err != nil {
record.RestoreStatus = "failed"
record.RestoreError = fmt.Sprintf("pg restore: %v", err)
_ = s.saveRecord(context.Background(), record)
return
}
record.RestoreStatus = "completed"
record.RestoredAt = time.Now().Format(time.RFC3339)
if err := s.saveRecord(context.Background(), record); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 保存恢复记录失败: %v", err)
}
}
// ─── 备份记录管理 ───
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
@@ -614,8 +981,8 @@ func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, erro
}
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.storeMu.Lock()
defer s.storeMu.Unlock()
if s.store != nil && s.s3Cfg != nil {
return s.store, nil

View File

@@ -134,6 +134,30 @@ func (m *mockDumper) Restore(_ context.Context, data io.Reader) error {
return nil
}
// blockingDumper 可控延迟的 dumper用于测试异步行为
type blockingDumper struct {
blockCh chan struct{}
data []byte
restErr error
}
func (d *blockingDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
select {
case <-d.blockCh:
case <-ctx.Done():
return nil, ctx.Err()
}
return io.NopCloser(bytes.NewReader(d.data)), nil
}
func (d *blockingDumper) Restore(_ context.Context, data io.Reader) error {
if d.restErr != nil {
return d.restErr
}
_, _ = io.ReadAll(data)
return nil
}
type mockObjectStore struct {
objects map[string][]byte
mu sync.Mutex
@@ -179,7 +203,7 @@ func (m *mockObjectStore) HeadBucket(_ context.Context) error {
return nil
}
func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService {
func newTestBackupService(repo *mockSettingRepo, dumper DBDumper, store *mockObjectStore) *BackupService {
cfg := &config.Config{
Database: config.DatabaseConfig{
Host: "localhost",
@@ -361,9 +385,9 @@ func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) {
svc := newTestBackupService(repo, dumper, store)
// 手动设置 backingUp 标志
svc.mu.Lock()
svc.opMu.Lock()
svc.backingUp = true
svc.mu.Unlock()
svc.opMu.Unlock()
_, err := svc.CreateBackup(context.Background(), "manual", 14)
require.ErrorIs(t, err, ErrBackupInProgress)
@@ -526,3 +550,154 @@ func TestBackupService_LoadS3Config_Corrupted(t *testing.T) {
require.Error(t, err)
require.Nil(t, cfg)
}
// ─── Async Backup Tests ───
func TestStartBackup_ReturnsImmediately(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.StartBackup(context.Background(), "manual", 14)
require.NoError(t, err)
require.Equal(t, "running", record.Status)
require.NotEmpty(t, record.ID)
// 释放 dumper 让后台完成
close(dumper.blockCh)
svc.wg.Wait()
// 验证最终状态
final, err := svc.GetBackupRecord(context.Background(), record.ID)
require.NoError(t, err)
require.Equal(t, "completed", final.Status)
require.Greater(t, final.SizeBytes, int64(0))
}
func TestStartBackup_ConcurrentBlocked(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
// 第一次启动
_, err := svc.StartBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// 第二次应被阻塞
_, err = svc.StartBackup(context.Background(), "manual", 14)
require.ErrorIs(t, err, ErrBackupInProgress)
close(dumper.blockCh)
svc.wg.Wait()
}
func TestStartBackup_ShuttingDown(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
svc := newTestBackupService(repo, &mockDumper{dumpData: []byte("data")}, newMockObjectStore())
svc.shuttingDown.Store(true)
_, err := svc.StartBackup(context.Background(), "manual", 14)
require.Error(t, err)
require.Contains(t, err.Error(), "shutting down")
}
func TestRecoverStaleRecords(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
// 模拟一条孤立的 running 记录
_ = svc.saveRecord(context.Background(), &BackupRecord{
ID: "stale-1",
Status: "running",
StartedAt: time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
})
// 模拟一条孤立的恢复中记录
_ = svc.saveRecord(context.Background(), &BackupRecord{
ID: "stale-2",
Status: "completed",
RestoreStatus: "running",
StartedAt: time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
})
svc.recoverStaleRecords()
r1, _ := svc.GetBackupRecord(context.Background(), "stale-1")
require.Equal(t, "failed", r1.Status)
require.Contains(t, r1.ErrorMsg, "server restart")
r2, _ := svc.GetBackupRecord(context.Background(), "stale-2")
require.Equal(t, "failed", r2.RestoreStatus)
require.Contains(t, r2.RestoreError, "server restart")
}
func TestGracefulShutdown(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &blockingDumper{blockCh: make(chan struct{}), data: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
_, err := svc.StartBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// Stop 应该等待备份完成
done := make(chan struct{})
go func() {
svc.Stop()
close(done)
}()
// 短暂等待确认 Stop 还在等待
select {
case <-done:
t.Fatal("Stop returned before backup finished")
case <-time.After(100 * time.Millisecond):
// 预期Stop 还在等待
}
// 释放备份
close(dumper.blockCh)
// 现在 Stop 应该完成
select {
case <-done:
// 预期
case <-time.After(5 * time.Second):
t.Fatal("Stop did not return after backup finished")
}
}
func TestStartRestore_Async(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
dumper := &mockDumper{dumpData: []byte(dumpContent)}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
// 先创建一个备份(同步方式)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// 异步恢复
restored, err := svc.StartRestore(context.Background(), record.ID)
require.NoError(t, err)
require.Equal(t, "running", restored.RestoreStatus)
svc.wg.Wait()
// 验证最终状态
final, err := svc.GetBackupRecord(context.Background(), record.ID)
require.NoError(t, err)
require.Equal(t, "completed", final.RestoreStatus)
}

View File

@@ -21,9 +21,6 @@ var (
// 带捕获组的版本提取正则
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
systemPromptThreshold = 0.5
)
@@ -124,7 +121,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
return false
}
if !userIDPattern.MatchString(userID) {
if ParseMetadataUserID(userID) == nil {
return false
}
@@ -278,11 +275,7 @@ func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context
// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
if len(matches) >= 2 {
return matches[1]
}
return ""
return ExtractCLIVersion(ua)
}
// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"log/slog"
"strconv"
"strings"
"time"
)
@@ -15,14 +14,17 @@ const (
claudeLockWaitTime = 200 * time.Millisecond
)
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
// ClaudeTokenCache token cache interface.
type ClaudeTokenCache = GeminiTokenCache
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
type ClaudeTokenProvider struct {
accountRepo AccountRepository
tokenCache ClaudeTokenCache
oauthService *OAuthService
accountRepo AccountRepository
tokenCache ClaudeTokenCache
oauthService *OAuthService
refreshAPI *OAuthRefreshAPI
executor OAuthRefreshExecutor
refreshPolicy ProviderRefreshPolicy
}
func NewClaudeTokenProvider(
@@ -31,13 +33,25 @@ func NewClaudeTokenProvider(
oauthService *OAuthService,
) *ClaudeTokenProvider {
return &ClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
oauthService: oauthService,
accountRepo: accountRepo,
tokenCache: tokenCache,
oauthService: oauthService,
refreshPolicy: ClaudeProviderRefreshPolicy(),
}
}
// GetAccessToken 获取有效的 access_token
// SetRefreshAPI injects unified OAuth refresh API and executor.
func (p *ClaudeTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
p.refreshAPI = api
p.executor = executor
}
// SetRefreshPolicy injects caller-side refresh policy.
func (p *ClaudeTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
p.refreshPolicy = policy
}
// GetAccessToken returns a valid access_token.
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
@@ -48,7 +62,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
cacheKey := ClaudeTokenCacheKey(account)
// 1. 先尝试缓存
// 1) Try cache first.
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
@@ -60,114 +74,39 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
// 2. 如果即将过期则刷新
// 2) Refresh if needed (pre-expiry skew).
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, claudeTokenRefreshSkew)
if err != nil {
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
if p.oauthService == nil {
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
// 构建新 credentials保留原有字段
newCredentials := make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if lockErr != nil {
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
// 检查 ctx 是否已取消
if ctx.Err() != nil {
return "", ctx.Err()
}
// 从数据库获取最新账户信息
if p.accountRepo != nil {
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
}
expiresAt = account.GetCredentialAsTime("expires_at")
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
if p.oauthService == nil {
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
refreshFailed = true
} else {
// 构建新 credentials保留原有字段
newCredentials := make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
refreshFailed = true
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
time.Sleep(claudeLockWaitTime)
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
} else {
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
}
} else if needsRefresh && p.tokenCache != nil {
// Backward-compatible test path when refreshAPI is not injected.
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
} else if lockErr != nil {
slog.Warn("claude_token_lock_failed", "account_id", account.ID, "error", lockErr)
} else {
time.Sleep(claudeLockWaitTime)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
@@ -181,22 +120,23 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
// 3) Populate cache with TTL.
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
if p.refreshPolicy.FailureTTL > 0 {
ttl = p.refreshPolicy.FailureTTL
} else {
ttl = time.Minute
}
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)

View File

@@ -140,6 +140,27 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
func (s *DashboardService) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, modelSource string) ([]usagestats.ModelStat, error) {
normalizedSource := usagestats.NormalizeModelSource(modelSource)
if normalizedSource == usagestats.ModelSourceRequested {
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
}
type modelStatsBySourceRepo interface {
GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) ([]usagestats.ModelStat, error)
}
if sourceRepo, ok := s.usageRepo.(modelStatsBySourceRepo); ok {
stats, err := sourceRepo.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, normalizedSource)
if err != nil {
return nil, fmt.Errorf("get model stats with filters by source: %w", err)
}
return stats, nil
}
return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
}
func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
@@ -148,6 +169,15 @@ func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
// GetGroupUsageSummary returns today's and cumulative cost for all groups.
func (s *DashboardService) GetGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) {
results, err := s.usageRepo.GetAllGroupUsageSummary(ctx, todayStart)
if err != nil {
return nil, fmt.Errorf("get group usage summary: %w", err)
}
return results, nil
}
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
data, err := s.cache.GetDashboardStats(ctx)
if err != nil {
@@ -335,6 +365,14 @@ func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime
return ranking, nil
}
func (s *DashboardService) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) {
stats, err := s.usageRepo.GetUserBreakdownStats(ctx, startTime, endTime, dim, limit)
if err != nil {
return nil, fmt.Errorf("get user breakdown stats: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
if err != nil {

Some files were not shown because too many files have changed in this diff Show More