Compare commits

..

108 Commits

Author SHA1 Message Date
Wesley Liddick
d3a9f5bb88 Merge pull request #1027 from touwaeriol/feat/ignore-insufficient-balance-errors
feat(ops): add ignore insufficient balance errors toggle and extract error constants
2026-03-15 19:10:18 +08:00
Wesley Liddick
7eb0415a8a Merge pull request #1028 from IanShaw027/fix/open-issues-cleanup
fix: 修复多个issues - Gemini schema 兼容性、批量编辑白名单、Docker 工具支持和限额字段处理Fix/open issues cleanup
2026-03-15 19:09:49 +08:00
erio
bdbc8fa08f fix(ops): align constant declarations for gofmt compliance 2026-03-15 18:55:14 +08:00
erio
63f3af0f94 fix(ops): match "insufficient account balance" in error filter
The upstream Gemini API returns "Insufficient account balance" which
doesn't contain the substring "insufficient balance". Add explicit
match for the full phrase to ensure the filter works correctly.
2026-03-15 18:45:48 +08:00
IanShaw027
686f890fbf style: 修复 gofmt 格式问题 2026-03-15 18:42:32 +08:00
shaw
220fbe6544 fix: 恢复 UsageProgressBar 中被意外移除的窗口统计数据展示
commit 0debe0a8 在修复 OpenAI WS 用量窗口刷新问题时,意外删除了
UsageProgressBar 中的 window stats 渲染逻辑和格式化函数。

恢复进度条上方的统计行(requests, tokens, account cost, user cost)
及对应的 4 个格式化 computed 属性。
2026-03-15 18:29:23 +08:00
shaw
ae44a94325 fix: 重置密码功能新增UI配置发送邮件域名 2026-03-15 17:52:29 +08:00
IanShaw
3718d6dcd4 Merge branch 'Wei-Shaw:main' into fix/open-issues-cleanup 2026-03-15 17:49:20 +08:00
IanShaw027
90b3838173 fix: 移除 Gemini 不支持的 patternProperties 字段 #795 2026-03-15 17:46:58 +08:00
IanShaw027
19d3ecc76f fix: 修复批量编辑账号时模型白名单显示与实际不一致的问题 #982
修复批量编辑账号时,UI 显示的是 plain 模型名(如 GPT-5),但实际落库的是 dated 模型名的问题。

核心改动:
1. 批量编辑白名单不再使用 BulkEditAccountModal.vue 中手写的过期模型列表
   - 移除了 allModels 和 presetMappings 的硬编码列表(共 200+ 行)
   - 直接复用 ModelWhitelistSelector.vue 组件

2. ModelWhitelistSelector 组件支持多平台联合过滤
   - 新增 platforms 属性支持传入多个平台
   - 添加 normalizedPlatforms 计算属性统一处理单平台和多平台场景
   - availableOptions 根据选中的多个平台动态联合过滤模型列表
   - fillRelated 功能支持一次性填充多个平台的相关模型

3. 模型映射预设改为动态生成
   - filteredPresets 改用 getPresetMappingsByPlatform 从统一模型源按平台动态生成
   - 不再依赖弹窗中的手写预设列表

现在的行为:
- UI 显示什么模型,勾选什么模型,传给后端的就是什么模型
- 彻底解决了批量编辑链路上"显示与实际不一致"的问题
- 模型列表和映射预设始终与系统定义保持同步
2026-03-15 17:46:58 +08:00
IanShaw027
6fba4ebb13 fix: 在 Dockerfile.goreleaser 中添加 pg_dump 和 psql 工具 #1002
为了支持容器内的数据库备份和恢复功能,在运行时镜像中添加 PostgreSQL 客户端工具。

变更内容:
- 使用多阶段构建从 postgres:18-alpine 镜像复制 pg_dump 和 psql 二进制文件
- 添加必要的依赖库(libpq, zstd-libs, lz4-libs, krb5-libs, libldap, libedit)
- 升级基础镜像到 alpine:3.21
- 复制 libpq.so.5 共享库以确保工具正常运行

这样可以在运行时容器中直接执行数据库备份和恢复操作,无需访问 Docker socket。
2026-03-15 17:46:58 +08:00
IanShaw027
c31974c913 fix: 兼容部分限额字段为空的情况 #1021
修复在填写限额时,如果不填写完整的三个限额额度(日限额、周限额、月限额)就会报错的问题。

变更内容:
- 后端:添加 optionalLimitField 类型处理空值和空字符串,兼容部分限额字段为空的情况
- 前端:添加 normalizeOptionalLimit 函数规范化限额输入,将空值、空字符串和无效数字统一处理为 null
2026-03-15 17:46:58 +08:00
erio
6177fa5dd8 fix(i18n): correct insufficient balance error hint text
Remove misleading "upstream" wording - the error is about client API key
user balance, not upstream account balance.
2026-03-15 17:41:51 +08:00
erio
cfe72159d0 feat(ops): add ignore insufficient balance errors toggle and extract error constants
- Add 5th error filter switch IgnoreInsufficientBalanceErrors to suppress
  upstream insufficient balance / insufficient_quota errors from ops log
- Extract hardcoded error strings into package-level constants for
  shouldSkipOpsErrorLog, normalizeOpsErrorType, classifyOpsPhase, and
  classifyOpsIsBusinessLimited
- Define ErrNoAvailableAccounts sentinel error and replace all
  errors.New("no available accounts") call sites
- Update tests to use require.ErrorIs with the sentinel error
2026-03-15 17:26:18 +08:00
Wesley Liddick
8321e4a647 Merge pull request #1023 from YanzheL/fix/claude-output-effort-logging
fix: extract and log Claude output_config.effort in usage records
2026-03-15 16:45:37 +08:00
Wesley Liddick
3084330d0c Merge pull request #1019 from Ethan0x0000/feat/usage-endpoint-distribution
feat: add endpoint metadata and usage endpoint distribution insights
2026-03-15 16:42:03 +08:00
Wesley Liddick
b566649e79 Merge pull request #1025 from touwaeriol/fix/rate-limit-nil-window-reset
fix(billing): treat nil rate limit window as expired to prevent usage accumulation
2026-03-15 16:33:14 +08:00
Wesley Liddick
10a6180e4a Merge pull request #1026 from touwaeriol/fix/group-quota-clear
fix(billing): allow clearing group quota limits and treat 0 as zero-limit
2026-03-15 16:33:00 +08:00
Wesley Liddick
cbe9e78977 Merge pull request #1007 from StarryKira/fix/streaming-failover-corruption
fix(gateway): 防止流式 failover 拼接腐化导致客户端收到双 message_start fix issue #991
2026-03-15 16:29:31 +08:00
Wesley Liddick
74145b1f39 Merge pull request #1017 from SsageParuders/fix/bedrock-account-quota
fix: Bedrock 账户配额限制不生效
2026-03-15 16:28:42 +08:00
Elysia
359e56751b 增加测试 2026-03-15 16:21:49 +08:00
erio
5899784aa4 fix(billing): allow clearing group quota limits and treat 0 as zero-limit
Previously, v-model.number produced "" when input was cleared, causing
JSON decode errors on the backend. Also, normalizeLimit treated 0 as
"unlimited" which prevented setting a zero quota. Now "" is converted
to null (unlimited) in frontend, and 0 is preserved as a valid limit.

Closes Wei-Shaw/sub2api#1021
2026-03-15 16:15:15 +08:00
erio
9e8959c56d fix(billing): treat nil rate limit window as expired to prevent usage accumulation
When Redis cache is populated from DB with a NULL window_1d_start, the
Lua increment script only updates usage counters without setting window
timestamps. IsWindowExpired(nil) previously returned false, so the
accumulated usage was never reset across time windows, effectively
turning usage_1d into a lifetime counter. Once this exceeded
rate_limit_1d the key was incorrectly blocked with "日限额已用完".

Fixes Wei-Shaw/sub2api#1022
2026-03-15 14:04:13 +08:00
YanzheL
1bff2292a6 fix: extract and log Claude output_config.effort in usage records
Claude's output_config.effort parameter (low/medium/high/max) was not
being extracted from requests or logged in the reasoning_effort column
of usage logs. Only the OpenAI path populated this field.

Changes:
- Extract output_config.effort in ParseGatewayRequest
- Add ReasoningEffort field to ForwardResult
- Populate reasoning_effort in both RecordUsage and RecordUsageWithLongContext
- Guard against overwriting service-set effort values in handler
- Update stale comments that described reasoning_effort as OpenAI-only
- Add unit tests for extraction, normalization, and persistence
2026-03-15 12:55:37 +08:00
Ethan0x0000
cf9247754e test: fix usage repo stubs for unit builds 2026-03-15 12:51:34 +08:00
Ethan0x0000
eefab15958 feat: 完善使用记录端点可观测性与分布统计
将入站、上游与路径三类端点分布统一到使用记录页的一致化卡片交互中,并补齐端点元数据与统计链路,提升排障与流量分析效率。
2026-03-15 11:26:42 +08:00
Elysia
0e23732631 fix(gateway): 防止流式 failover 拼接腐化导致客户端收到双 message_start
当上游在 SSE 流中途返回 event:error 时,handleStreamingResponse 已将
部分 SSE 事件写入客户端,但原先的 failover 逻辑仍会切换到下一个账号
并写入完整流,导致客户端收到两个 message_start 进而产生 400 错误。

修复方案:在每次 Forward 调用前记录 c.Writer.Size(),若 Forward 返回
UpstreamFailoverError 后 writer 字节数增加,说明 SSE 内容已不可撤销地
发送给客户端,此时直接调用 handleFailoverExhausted 发送 SSE error 事件
终止流,而非继续 failover。

Ping-only 场景不受影响:slot 等待期的 ping 字节在 Forward 前后相等,
正常 failover 流程照常进行。

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-14 22:49:23 +08:00
SsageParuders
37c044fb4b fix: Bedrock 账户配额限制不生效,配额计数器始终为 $0.00
applyUsageBillingEffects() 中配额更新条件仅检查了 AccountTypeAPIKey,
遗漏了 AccountTypeBedrock,导致 Bedrock 账户的配额计数器永远不递增。
扩展条件以同时支持 APIKey 和 Bedrock 类型。

同时在前端账户筛选下拉框中添加 AWS Bedrock 选项。
2026-03-14 22:47:44 +08:00
shaw
6da5fa01b9 fix(frontend): 修复运维设置对话框保存按钮始终禁用的问题
后端默认 alert.enabled=true 但 recipients 为空,前端验证将其视为
错误并阻断保存按钮。移除该阻断性验证,改为保存时自动禁用无收件人
的邮件通知配置。
2026-03-14 20:39:29 +08:00
shaw
616930f9d3 refactor(frontend): 将备份和数据管理页面合并为设置页的标签页
将独立的 /admin/backup 和 /admin/data-management 页面整合到设置页,
作为「备份」和「Sora 存储」标签页,减少侧边栏条目,集中管理配置。

- 移除 BackupView 和 DataManagementView 的 AppLayout 包装
- 在 SettingsView 中以子组件形式嵌入,使用 v-show 切换标签
- 删除独立路由和侧边栏菜单入口
- 备份/数据标签页下隐藏主保存按钮(各自有独立保存)
- 优化标签栏样式适配7个标签,PC端支持细滚动条
- 清理未使用的图标组件和 i18n 键
2026-03-14 20:22:39 +08:00
Wesley Liddick
b9c31fa7c4 Merge pull request #999 from InCerryGit/fix/enc_coot
fix: handle invalid encrypted content error and retry logic.
2026-03-14 19:29:07 +08:00
Wesley Liddick
17b339972c Merge pull request #1000 from touwaeriol/fix/ops-agg-tuning
fix(ops): tune aggregation constants to prevent PG overload
2026-03-14 19:03:03 +08:00
shaw
39f8bd91b9 fix: remove unused saveRecords method to pass lint 2026-03-14 19:01:27 +08:00
Wesley Liddick
aa4e37d085 Merge pull request #966 from GuangYiDing/feat/db-backup-restore
feat: 数据库定时备份与恢复(S3 兼容存储,支持 Cloudflare R2)
2026-03-14 18:58:56 +08:00
erio
f59b66b7d4 fix(ops): tune aggregation constants to prevent PG overload
Increase MAX(bucket_start) query timeout from 3s to 5s to reduce
timeout-induced fallbacks. Shrink backfill window from 30 days to
1 hour so that fallback recomputation stays lightweight instead of
scanning the entire retention range.
2026-03-14 18:47:37 +08:00
InCerry
8f0ea7a02d Merge branch 'main' into fix/enc_coot 2026-03-14 18:46:33 +08:00
Wesley Liddick
a1dc00890e Merge pull request #944 from miraserver/feat/backend-mode
feat: add Backend Mode toggle to disable user self-service
2026-03-14 17:53:54 +08:00
Wesley Liddick
dfbcc363d1 Merge pull request #969 from wucm667/feat/quota-fixed-reset-mode
feat: 账号配额支持固定时间重置模式
2026-03-14 17:52:56 +08:00
Rose Ding
1047f973d5 fix: 按 review 意见重构数据库备份服务(安全性 + 架构 + 健壮性)
1. S3 凭证加密存储:使用 SecretEncryptor (AES-256-GCM) 加密 SecretAccessKey,
   防止备份文件中泄露 S3 凭证,兼容旧的未加密数据
2. 修复 saveRecord 竞态条件:添加 recordsMu 互斥锁保护 records 的 load/save
3. 恢复操作增加服务端验证:handler 层要求重新输入管理员密码,通过 bcrypt
   校验,前端弹出密码输入框
4. pg_dump/psql/S3 操作抽象为接口:定义 DBDumper 和 BackupObjectStore 接口,
   实现放入 repository 层,遵循项目依赖注入架构规范
5. 改为流式处理避免大数据库 OOM:备份时 pg_dump stdout -> gzip -> io.Pipe ->
   S3 upload;恢复时 S3 download -> gzip reader -> psql stdin,不再全量加载
6. loadRecords 区分"无数据"和"数据损坏"场景:JSON 解析失败返回明确错误
7. 添加 18 个核心逻辑单元测试:覆盖加密、并发、流式备份/恢复、错误处理等

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-14 17:48:21 +08:00
Wesley Liddick
e32977dd73 Merge pull request #997 from SsageParuders/refactor/bedrock-channel-merge
refactor: merge bedrock-apikey into unified bedrock channel with auth_mode
2026-03-14 17:48:01 +08:00
wucm667
b5f78ec1e8 feat: 实现固定时间重置模式的 SQL 表达式,并添加相关单元测试 2026-03-14 17:37:34 +08:00
SsageParuders
e0f290fdc8 style: fix gofmt formatting for account type constants 2026-03-14 17:32:53 +08:00
Wesley Liddick
fc00a4e3b2 Merge pull request #959 from touwaeriol/feat/antigravity-403-detection
feat(antigravity): add 403 forbidden status detection and display
2026-03-14 17:23:22 +08:00
Wesley Liddick
db1f6ded88 Merge pull request #961 from 0xObjc/codex/ops-openai-token-visibility
feat(ops): make OpenAI token stats optional
2026-03-14 17:23:01 +08:00
SsageParuders
4644af2ccc refactor: merge bedrock-apikey into bedrock with auth_mode credential
Consolidate two separate channel types (bedrock + bedrock-apikey) into
a single "AWS Bedrock" channel. Authentication mode is now distinguished
by credentials.auth_mode ("sigv4" | "apikey") instead of separate types.

Backend:
- Remove AccountTypeBedrockAPIKey constant
- IsBedrock() simplified; IsBedrockAPIKey() checks auth_mode
- Add IsAPIKeyOrBedrock() helper to eliminate repeated type checks
- Extend pool mode, quota scheduling, and billing to bedrock
- Add RetryableOnSameAccount to handleBedrockUpstreamErrors
- Add "bedrock" scope to Beta Policy for independent control

Frontend:
- Merge two buttons into one "AWS Bedrock" with auth mode radio
- Badge displays "Anthropic | AWS"
- Pool mode and quota limit UI available for bedrock
- Quota display in account list (usage bars, capacity badges, reset)
- Remove all bedrock-apikey type references
2026-03-14 17:13:30 +08:00
Wesley Liddick
2e3e8687e1 Merge pull request #993 from xvhuan/fix/codex-responses-id-prefix-hotfix-20260314
fix: 止血 Codex/Responses 原生 input id 被误改成 fc_*
2026-03-14 13:54:26 +08:00
ius
ca42a45802 fix: stop rewriting native responses input ids 2026-03-14 13:47:01 +08:00
Wesley Liddick
9350ecb62b Merge pull request #987 from Ethan0x0000/feat-chatcompletions2repsonses-fix
fix: chat compatibility model fallback and reasoning_content output
2026-03-14 13:41:47 +08:00
Wesley Liddick
a4a026e8da Merge pull request #990 from LvyuanW/admin-openai-available-models-fix
fix: respect OpenAI OAuth model mapping in admin available models
2026-03-14 13:33:18 +08:00
Wesley Liddick
342fd03e72 Merge pull request #986 from LvyuanW/openai-model-mapping-fix
fix: honor account model mapping before group fallback
2026-03-14 13:32:26 +08:00
Ethan0x0000
e3f1fd9b63 fix: handle strings.Builder write errors in assistant parsing 2026-03-14 13:12:17 +08:00
InCerry
e4a4dfd038 Merge remote-tracking branch 'origin/main' into fix/enc_coot
# Conflicts:
#	backend/internal/service/openai_gateway_service.go
2026-03-14 13:04:24 +08:00
Wang Lvyuan
a377e99088 fix: remove unused wildcard mapping helper 2026-03-14 12:56:34 +08:00
Wang Lvyuan
1d3d7a3033 fix: respect OpenAI model mapping in admin available models 2026-03-14 12:45:10 +08:00
Wesley Liddick
e7086cb3a3 Merge pull request #988 from LvyuanW/scheduler-snapshot-sync
fix: sync scheduler snapshot on account updates
2026-03-14 12:38:48 +08:00
shaw
4f2a97073e chore: update docs 2026-03-14 12:37:36 +08:00
Wesley Liddick
7407e3b45d Merge pull request #984 from touwaeriol/docs/ecosystem-projects
docs: add iframe integration feature and ecosystem projects section
2026-03-14 12:31:25 +08:00
Wang Lvyuan
01ef7340aa Merge remote-tracking branch 'origin/main' into openai-model-mapping-fix 2026-03-14 12:27:08 +08:00
Wang Lvyuan
1c960d22c1 fix: sync scheduler snapshot on account updates 2026-03-14 12:21:28 +08:00
Ethan0x0000
ece0606fed fix: consolidate chat-completions compatibility fixes
- apply default mapped model only when scheduling fallback is actually used

- preserve reasoning in OpenAI-compatible output via reasoning_content and avoid invalid input function_call ids
2026-03-14 12:12:08 +08:00
InCerry
2666422b99 fix: handle invalid encrypted content error and retry logic. 2026-03-14 11:42:42 +08:00
Wesley Liddick
e6d59216d4 Merge pull request #975 from Ylarod/aws-bedrock
sub2api: add bedrock support
2026-03-14 10:52:24 +08:00
Wang Lvyuan
4e8615f276 fix: honor account model mapping before group fallback 2026-03-14 10:47:31 +08:00
erio
91e4d95660 docs: add iframe integration feature and ecosystem projects section 2026-03-14 03:16:07 +08:00
erio
45456fa24c fix: restore OAuth 401 temp-unschedulable for Gemini, update Antigravity tests
The 403 detection PR changed the 401 handler condition from
`account.Type == AccountTypeOAuth` to
`account.Type == AccountTypeOAuth && account.Platform == PlatformOpenAI`,
which accidentally excluded Gemini OAuth from the temp-unschedulable path.

Fix: use `!= PlatformAntigravity` instead, preserving Gemini behavior
while correctly excluding Antigravity (whose 401 is handled by
applyErrorPolicy's temp_unschedulable_rules).

Update tests to reflect Antigravity's new 401 semantics:
- HandleUpstreamError: Antigravity OAuth 401 now uses SetError
- CheckErrorPolicy: Antigravity 401 second hit stays TempUnscheduled
- DB fallback: split into Gemini (escalates) and Antigravity (stays temp)
2026-03-14 02:21:22 +08:00
Wesley Liddick
4588258d80 Merge pull request #960 from 0xObjc/codex/user-spending-ranking
feat(admin): add user spending ranking dashboard view
2026-03-13 23:06:30 +08:00
Wesley Liddick
c12e48f966 Merge pull request #949 from kunish/fix/remove-done-stop-sequence
fix: remove SSE termination marker from DefaultStopSequences
2026-03-13 22:56:29 +08:00
Wesley Liddick
ec8f50a658 Merge pull request #951 from wanXcode/fix/dashboard-user-trend-label
fix(dashboard): prefer username over email prefix in recent usage chart
2026-03-13 22:56:13 +08:00
Wesley Liddick
99c9191784 Merge pull request #974 from 0xObjc/codex/next-pr-base-20260313
fix(admin): default dashboard date range to today
2026-03-13 22:55:14 +08:00
shaw
6bb02d141f chore: remove accidentally committed PR diagnostic report 2026-03-13 22:52:05 +08:00
Wesley Liddick
07bb2a5f3f Merge pull request #952 from xvhuan/feat/billing-ledger-decouple-usage-log-20260312
feat: 解耦计费正确性与 usage_logs 批量写压
2026-03-13 22:46:09 +08:00
Wesley Liddick
417861a48e Merge pull request #956 from share-wey/main
chore: codex transform fixes and feature compatibility
2026-03-13 22:36:29 +08:00
Wesley Liddick
b7e878de64 Merge pull request #980 from touwaeriol/feat/redeem-subscription-support
feat(redeem): support subscription type in create-and-redeem API
2026-03-13 22:15:33 +08:00
erio
05edb5514b feat(redeem): support subscription type in create-and-redeem API
Add group_id and validity_days fields to CreateAndRedeemCodeRequest,
enabling subscription-type redemption codes to be created and redeemed
in a single API call.

- Type defaults to "balance" when omitted for backward compatibility
- Subscription type requires group_id (non-nil) and validity_days (>0)
- Existing balance/concurrency callers are unaffected
2026-03-13 21:26:46 +08:00
Ylarod
e90ec847b6 fix lint 2026-03-13 19:15:27 +08:00
erio
6344fa2a86 feat(antigravity): add 403 forbidden status detection, classification and display
Backend:
- Detect and classify 403 responses into three types:
  validation (account needs Google verification),
  violation (terms of service / banned),
  forbidden (generic 403)
- Extract verification/appeal URLs from 403 response body
  (structured JSON parsing with regex fallback)
- Add needs_verify, is_banned, needs_reauth, error_code fields
  to UsageInfo (omitempty for zero impact on other platforms)
- Handle 403 in request path: classify and permanently set account error
- Save validation_url in error_message for degraded path recovery
- Enrich usage with account error on both success and degraded paths
- Add singleflight dedup for usage requests with independent context
- Differentiate cache TTL: success/403 → 3min, errors → 1min
- Return degraded UsageInfo instead of HTTP 500 on quota fetch errors

Frontend:
- Display forbidden status badges with color coding (red for banned,
  amber for needs verification, gray for generic)
- Show clickable verification/appeal URL links
- Display needs_reauth and degraded error states in usage cell
- Add Antigravity tier label badge next to platform type

Tests:
- Comprehensive unit tests for classifyForbiddenType (7 cases)
- Unit tests for extractValidationURL (8 cases including unicode escapes)
- Integration test for FetchQuota forbidden path
2026-03-13 18:22:45 +08:00
Connie Borer
7e288acc90 Merge branch 'Wei-Shaw:main' into main 2026-03-13 17:28:14 +08:00
Peter
29b0e4a8a5 feat(ops): allow hiding alert events 2026-03-13 17:18:04 +08:00
Peter
27ff222cfb fix(admin): default dashboard date range to today 2026-03-13 17:02:54 +08:00
Ylarod
11f7b83522 sub2api: add bedrock support 2026-03-13 17:00:16 +08:00
Rose Ding
f7177be3b6 fix: golangci-lint 修复(gofmt 格式化 + errcheck 返回值检查)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-13 12:01:05 +08:00
Rose Ding
875b417fde fix: 补充 wire_gen_test.go 中 provideCleanup 缺少的 backupSvc 参数
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-13 11:53:46 +08:00
wucm667
2573107b32 refactor: 将 ComputeQuotaResetAt 和 ValidateQuotaResetConfig 函数中的 map 类型从 map[string]interface{} 修改为 map[string]any 2026-03-13 11:44:49 +08:00
wucm667
5b85005945 feat: 账号配额支持固定时间重置模式
- 后端新增 rolling/fixed 两种配额重置模式,支持日配额和周配额
- fixed 模式下可配置重置时刻(小时)、重置星期几(周配额)及时区(IANA)
- 在 account_repo.go 中使用 SQL 表达式适配两种模式的过期判断与重置时间推进
- 新增 ComputeQuotaResetAt / ValidateQuotaResetConfig 等辅助函数
- DTO 层新增相关字段并在 mappers 中完整映射
- 前端 QuotaLimitCard 新增 rolling/fixed 切换 UI、时区选择器
- CreateAccountModal / EditAccountModal 透传新配置字段
- i18n(zh/en)同步新增相关翻译词条
2026-03-13 11:12:37 +08:00
Wesley Liddick
1ee984478f Merge pull request #957 from touwaeriol/feat/group-rate-multipliers-modal
feat(groups): add rate multipliers management modal
2026-03-13 11:11:13 +08:00
Wesley Liddick
fd693dc526 Merge pull request #967 from StarryKira/fix/admin-reset-quota-monthly
fix: 管理员重置配额补全 monthly 字段并修复 ristretto 缓存异步问题 fix issue #964
2026-03-13 11:10:47 +08:00
haruka
e73531ce9b fix: 管理员重置配额补全 monthly 字段并修复 ristretto 缓存异步问题
- 后端 handler:ResetSubscriptionQuotaRequest 新增 Monthly 字段,
  验证逻辑扩展为 daily/weekly/monthly 至少一项为 true
- 后端 service:AdminResetQuota 新增 resetMonthly 参数,
  调用 ResetMonthlyUsage;重置后追加 subCacheL1.Wait(),
  保证 ristretto Del() 的异步删除立即生效,消除重置后
  /v1/usage 返回旧用量数据的竞态窗口
- 后端测试:更新存量测试用例匹配新签名,补充
  TestAdminResetQuota_ResetMonthlyOnly /
  TestAdminResetQuota_ResetMonthlyUsageError 两个新用例
- 前端 API:resetQuota options 类型新增 monthly: boolean
- 前端视图:confirmResetQuota 改为同时重置 daily/weekly/monthly
- i18n:中英文确认提示文案更新,提及每月配额

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-13 10:39:35 +08:00
Rose Ding
53ad1645cf feat: 数据库定时备份与恢复(S3 兼容存储,支持 Cloudflare R2)
新增管理员专属的数据库备份与恢复功能:
- 全量 PostgreSQL 备份(pg_dump),gzip 压缩后上传到 S3 兼容存储
- 支持手动备份和 cron 定时备份
- 支持从备份恢复(psql --single-transaction)
- 备份文件自动过期清理(默认 14 天)
- 前端完整管理页面(S3 配置、定时配置、备份列表、恢复/下载/删除)
- 内置 Cloudflare R2 配置教程弹窗
- Dockerfile 从 postgres 镜像多阶段复制 pg_dump/psql,确保版本一致

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-13 10:38:19 +08:00
Wesley Liddick
ecea13757b Merge pull request #955 from DaydreamCoding/feat/gpt-training-off
feat: GPT Private设置数据不用于训练
2026-03-13 09:12:33 +08:00
Peter
af9c4a7dd0 feat(ops): make openai token stats optional 2026-03-13 04:11:58 +08:00
Peter
80d8d6c3bc feat(admin): add user spending ranking dashboard view 2026-03-13 03:43:03 +08:00
erio
d648811233 feat(groups): add rate multipliers management modal
Add a dedicated modal in group management for viewing, adding, editing,
and deleting per-user rate multipliers within a group.

Backend:
- GET /admin/groups/:id/rate-multipliers - list entries with user details
- PUT /admin/groups/:id/rate-multipliers - batch sync (full replace)
- DELETE /admin/groups/:id/rate-multipliers - clear all entries
- Repository: GetByGroupID, SyncGroupRateMultipliers methods on
  user_group_rate_multipliers table (same table as user-side rates)

Frontend:
- New GroupRateMultipliersModal component with:
  - User search and add with email autocomplete
  - Editable rate column with local edit mode (cancel/save)
  - Batch adjust: multiply all rates by a factor
  - Clear all (local operation, requires save to persist)
  - Pagination (10/20/50 per page)
  - Platform icon with brand colors in group info bar
  - Unsaved changes indicator with revert option
- Unit tests for all three backend endpoints
2026-03-12 23:37:36 +08:00
QTom
34695acb85 fix: 移除账号导入时同步调用 disableOpenAITraining,避免网络超时导致导入失败
privacy_mode 改为由 TokenRefreshService 在 token 刷新后异步补设。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-12 22:36:25 +08:00
QTom
a63de12182 feat: GPT 隐私模式 + no-train 前端展示优化 2026-03-12 21:24:01 +08:00
yexueduxing
f16910d616 chore: codex transform fixes and feature compatibility 2026-03-12 20:52:35 +08:00
ius
64b3f3cec1 test: relocate best-effort usage log stub 2026-03-12 18:43:37 +08:00
ius
6a685727d0 fix: harden usage billing idempotency and backpressure 2026-03-12 18:38:09 +08:00
ius
32d25f76fc fix: respect preconfigured usage log batch channels 2026-03-12 17:44:57 +08:00
wanXcode
69cafe8674 fix(dashboard): prefer username in user usage trend 2026-03-12 17:42:41 +08:00
ius
18ba8d9166 fix: stabilize repository integration paths 2026-03-12 17:42:41 +08:00
ius
e97fd7e81c test: align oauth passthrough stream expectations 2026-03-12 17:22:01 +08:00
kunish
cdb64b0d33 fix: remove SSE termination marker from DefaultStopSequences
The SSE stream termination marker string was incorrectly included in
DefaultStopSequences, causing Gemini to prematurely stop generating
output whenever the model produced text containing that marker.

The SSE-level protocol filtering in stream_transformer.go already
handles this marker correctly; it should not be a stop sequence for
the model's text generation.
2026-03-12 17:10:01 +08:00
ius
8d4d3b03bb fix: remove unused gateway usage helpers 2026-03-12 17:08:57 +08:00
ius
addefe79e1 fix: align docker health checks with runtime image 2026-03-12 17:03:21 +08:00
ius
b764d3b8f6 Merge remote-tracking branch 'origin/main' into feat/billing-ledger-decouple-usage-log-20260312 2026-03-12 16:53:28 +08:00
ius
611fd884bd feat: decouple billing correctness from usage log batching 2026-03-12 16:53:18 +08:00
John Doe
6826149a8f feat: add Backend Mode toggle to disable user self-service
Add a system-wide "Backend Mode" that disables user self-registration
and self-service while keeping admin panel and API gateway fully
functional. When enabled, only admin can log in; all user-facing
routes return 403.

Backend:
- New setting key `backend_mode_enabled` with atomic cached reads (60s TTL)
- BackendModeUserGuard middleware blocks non-admin authenticated routes
- BackendModeAuthGuard middleware blocks registration/password-reset auth routes
- Login/Login2FA/RefreshToken handlers reject non-admin when enabled
- TokenPairWithUser struct for role-aware token refresh
- 20 unit tests (middleware + service layer)

Frontend:
- Router guards redirect unauthenticated users to /login
- Admin toggle in Settings page
- Login page hides register link and footer in backend mode
- 9 unit tests for router guard logic
- i18n support (en/zh)

27 files changed, 833 insertions(+), 17 deletions(-)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-12 02:42:57 +03:00
ius
c9debc50b1 Batch usage log writes in repository 2026-03-11 20:29:48 +08:00
207 changed files with 19238 additions and 1280 deletions

View File

@@ -9,6 +9,7 @@
ARG NODE_IMAGE=node:24-alpine
ARG GOLANG_IMAGE=golang:1.26.1-alpine
ARG ALPINE_IMAGE=alpine:3.21
ARG POSTGRES_IMAGE=postgres:18-alpine
ARG GOPROXY=https://goproxy.cn,direct
ARG GOSUMDB=sum.golang.google.cn
@@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \
./cmd/server
# -----------------------------------------------------------------------------
# Stage 3: Final Runtime Image
# Stage 3: PostgreSQL Client (version-matched with docker-compose)
# -----------------------------------------------------------------------------
FROM ${POSTGRES_IMAGE} AS pg-client
# -----------------------------------------------------------------------------
# Stage 4: Final Runtime Image
# -----------------------------------------------------------------------------
FROM ${ALPINE_IMAGE}
@@ -86,8 +92,20 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
RUN apk add --no-cache \
ca-certificates \
tzdata \
libpq \
zstd-libs \
lz4-libs \
krb5-libs \
libldap \
libedit \
&& rm -rf /var/cache/apk/*
# Copy pg_dump and psql from the same postgres image used in docker-compose
# This ensures version consistency between backup tools and the database server
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
# Create non-root user
RUN addgroup -g 1000 sub2api && \
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api

View File

@@ -5,7 +5,12 @@
# It only packages the pre-built binary, no compilation needed.
# =============================================================================
FROM alpine:3.19
ARG ALPINE_IMAGE=alpine:3.21
ARG POSTGRES_IMAGE=postgres:18-alpine
FROM ${POSTGRES_IMAGE} AS pg-client
FROM ${ALPINE_IMAGE}
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
LABEL description="Sub2API - AI API Gateway Platform"
@@ -16,8 +21,20 @@ RUN apk add --no-cache \
ca-certificates \
tzdata \
curl \
libpq \
zstd-libs \
lz4-libs \
krb5-libs \
libldap \
libedit \
&& rm -rf /var/cache/apk/*
# Copy pg_dump and psql from a version-matched PostgreSQL image so backup and
# restore work in the runtime container without requiring Docker socket access.
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
# Create non-root user
RUN addgroup -g 1000 sub2api && \
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api

View File

@@ -39,6 +39,16 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
- **Concurrency Control** - Per-user and per-account concurrency limits
- **Rate Limiting** - Configurable request and token rate limits
- **Admin Dashboard** - Web interface for monitoring and management
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
## Ecosystem
Community projects that extend or integrate with Sub2API:
| Project | Description | Features |
|---------|-------------|----------|
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | Self-service payment system | Self-service top-up and subscription purchase; supports YiPay protocol, WeChat Pay, Alipay, Stripe; embeddable via iframe |
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native |
## Tech Stack

View File

@@ -39,6 +39,16 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
- **并发控制** - 用户级和账号级并发限制
- **速率限制** - 可配置的请求和 Token 速率限制
- **管理后台** - Web 界面进行监控和管理
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
## 生态项目
围绕 Sub2API 的社区扩展与集成项目:
| 项目 | 说明 | 功能 |
|------|------|------|
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | 自助支付系统 | 用户自助充值、自助订阅购买兼容易支付协议、微信官方支付、支付宝官方支付、Stripe支持 iframe 嵌入管理后台 |
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用iOS/Android/Web支持用户管理、账号管理、监控看板、多后端切换基于 Expo + React Native 构建 |
## 技术栈

View File

@@ -41,6 +41,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
// Server layer ProviderSet
server.ProviderSet,
// Privacy client factory for OpenAI training opt-out
providePrivacyClientFactory,
// BuildInfo provider
provideServiceBuildInfo,
@@ -53,6 +56,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
return nil, nil
}
func providePrivacyClientFactory() service.PrivacyClientFactory {
return repository.CreatePrivacyReqClient
}
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
return service.BuildInfo{
Version: buildInfo.Version,
@@ -87,6 +94,7 @@ func provideCleanup(
antigravityOAuth *service.AntigravityOAuthService,
openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -223,6 +231,12 @@ func provideCleanup(
}
return nil
}},
{"BackupService", func() error {
if backupSvc != nil {
backupSvc.Stop()
}
return nil
}},
}
infraSteps := []cleanupStep{

View File

@@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemHandler := handler.NewRedeemHandler(redeemService)
@@ -104,7 +105,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
privacyClientFactory := providePrivacyClientFactory()
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
@@ -144,6 +146,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
dataManagementService := service.NewDataManagementService()
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
backupObjectStoreFactory := repository.NewS3BackupStoreFactory()
dbDumper := repository.NewPgDumper(configConfig)
backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper)
backupHandler := admin.NewBackupHandler(backupService, userService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
@@ -162,9 +168,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
@@ -199,7 +205,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@@ -226,11 +232,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -245,6 +251,10 @@ type Application struct {
Cleanup func()
}
func providePrivacyClientFactory() service.PrivacyClientFactory {
return repository.CreatePrivacyReqClient
}
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
return service.BuildInfo{
Version: buildInfo.Version,
@@ -279,6 +289,7 @@ func provideCleanup(
antigravityOAuth *service.AntigravityOAuthService,
openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -414,6 +425,12 @@ func provideCleanup(
}
return nil
}},
{"BackupService", func() error {
if backupSvc != nil {
backupSvc.Stop()
}
return nil
}},
}
infraSteps := []cleanupStep{

View File

@@ -75,6 +75,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
antigravityOAuthSvc,
nil, // openAIGateway
nil, // scheduledTestRunner
nil, // backupSvc
)
require.NotPanics(t, func() {

View File

@@ -7,7 +7,7 @@ require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/DouDOU-start/go-sora2api v1.1.0
github.com/alitto/pond/v2 v2.6.2
github.com/aws/aws-sdk-go-v2 v1.41.2
github.com/aws/aws-sdk-go-v2 v1.41.3
github.com/aws/aws-sdk-go-v2/config v1.32.10
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
@@ -66,7 +66,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
github.com/aws/smithy-go v1.24.1 // indirect
github.com/aws/smithy-go v1.24.2 // indirect
github.com/bdandy/go-errors v1.2.2 // indirect
github.com/bdandy/go-socks4 v1.2.3 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect

View File

@@ -24,6 +24,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
@@ -60,6 +62,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=

View File

@@ -934,9 +934,10 @@ type DashboardAggregationConfig struct {
// DashboardAggregationRetentionConfig 预聚合保留窗口
type DashboardAggregationRetentionConfig struct {
UsageLogsDays int `mapstructure:"usage_logs_days"`
HourlyDays int `mapstructure:"hourly_days"`
DailyDays int `mapstructure:"daily_days"`
UsageLogsDays int `mapstructure:"usage_logs_days"`
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
HourlyDays int `mapstructure:"hourly_days"`
DailyDays int `mapstructure:"daily_days"`
}
// UsageCleanupConfig 使用记录清理任务配置
@@ -1301,6 +1302,7 @@ func setDefaults() {
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
@@ -1758,6 +1760,12 @@ func (c *Config) Validate() error {
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
}
if c.DashboardAgg.Retention.HourlyDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
}
@@ -1780,6 +1788,14 @@ func (c *Config) Validate() error {
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
}
if c.DashboardAgg.Retention.HourlyDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
}

View File

@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
}
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
}
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
}
@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
wantErr: "dashboard_aggregation.retention.usage_logs_days",
},
{
name: "dashboard aggregation dedup retention",
mutate: func(c *Config) {
c.DashboardAgg.Enabled = true
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
},
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
},
{
name: "dashboard aggregation dedup retention smaller than usage logs",
mutate: func(c *Config) {
c.DashboardAgg.Enabled = true
c.DashboardAgg.Retention.UsageLogsDays = 30
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
},
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
},
{
name: "dashboard aggregation disabled interval",
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },

View File

@@ -31,6 +31,7 @@ const (
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock由 credentials.auth_mode 区分)
)
// Redeem type constants
@@ -113,3 +114,27 @@ var DefaultAntigravityModelMapping = map[string]string{
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview",
}
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
// 注意:此处的 "us." 前缀仅为默认值ResolveBedrockModelID 会根据账号配置的
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
var DefaultBedrockModelMapping = map[string]string{
// Claude Opus
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
"claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0",
"claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0",
"claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0",
// Claude Sonnet
"claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6",
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
"claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0",
// Claude Haiku
"claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
"claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
}

View File

@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
type UpdateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -865,6 +865,9 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv
}
}
// OpenAI OAuth: 刷新成功后检查并设置 privacy_mode
h.adminService.EnsureOpenAIPrivacy(ctx, updatedAccount)
return updatedAccount, "", nil
}
@@ -1715,13 +1718,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
// Handle OpenAI accounts
if account.IsOpenAI() {
// For OAuth accounts: return default OpenAI models
if account.IsOAuth() {
// OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。
if account.IsOpenAIPassthroughEnabled() {
response.Success(c, openai.DefaultModels)
return
}
// For API Key accounts: check model_mapping
mapping := account.GetModelMapping()
if len(mapping) == 0 {
response.Success(c, openai.DefaultModels)

View File

@@ -0,0 +1,105 @@
package admin
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type availableModelsAdminService struct {
*stubAdminService
account service.Account
}
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
if s.account.ID == id {
acc := s.account
return &acc, nil
}
return s.stubAdminService.GetAccount(context.Background(), id)
}
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
return router
}
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 42,
Name: "openai-oauth",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.1",
},
},
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Len(t, resp.Data, 1)
require.Equal(t, "gpt-5", resp.Data[0].ID)
}
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 43,
Name: "openai-oauth-passthrough",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.1",
},
},
Extra: map[string]any{
"openai_passthrough": true,
},
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp.Data)
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
}

View File

@@ -175,6 +175,18 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
return s.apiKeys, int64(len(s.apiKeys)), nil
}
func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) ([]service.UserGroupRateEntry, error) {
return nil, nil
}
func (s *stubAdminService) ClearGroupRateMultipliers(_ context.Context, _ int64) error {
return nil
}
func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int64, _ []service.GroupRateMultiplierInput) error {
return nil
}
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil
}
@@ -429,5 +441,9 @@ func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) erro
return nil
}
func (s *stubAdminService) EnsureOpenAIPrivacy(ctx context.Context, account *service.Account) string {
return ""
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -0,0 +1,204 @@
package admin
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type BackupHandler struct {
backupService *service.BackupService
userService *service.UserService
}
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
return &BackupHandler{
backupService: backupService,
userService: userService,
}
}
// ─── S3 配置 ───
func (h *BackupHandler) GetS3Config(c *gin.Context) {
cfg, err := h.backupService.GetS3Config(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
var req service.BackupS3Config
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
var req service.BackupS3Config
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.backupService.TestS3Connection(c.Request.Context(), req)
if err != nil {
response.Success(c, gin.H{"ok": false, "message": err.Error()})
return
}
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
}
// ─── 定时备份 ───
func (h *BackupHandler) GetSchedule(c *gin.Context) {
cfg, err := h.backupService.GetSchedule(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
var req service.BackupScheduleConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
// ─── 备份操作 ───
type CreateBackupRequest struct {
ExpireDays *int `json:"expire_days"` // nil=使用默认值140=永不过期
}
func (h *BackupHandler) CreateBackup(c *gin.Context) {
var req CreateBackupRequest
_ = c.ShouldBindJSON(&req) // 允许空 body
expireDays := 14 // 默认14天过期
if req.ExpireDays != nil {
expireDays = *req.ExpireDays
}
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, record)
}
func (h *BackupHandler) ListBackups(c *gin.Context) {
records, err := h.backupService.ListBackups(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
if records == nil {
records = []service.BackupRecord{}
}
response.Success(c, gin.H{"items": records})
}
func (h *BackupHandler) GetBackup(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, record)
}
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"url": url})
}
// ─── 恢复操作(需要重新输入管理员密码) ───
type RestoreBackupRequest struct {
Password string `json:"password" binding:"required"`
}
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
var req RestoreBackupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "password is required for restore operation")
return
}
// 从上下文获取当前管理员用户 ID
sub, ok := middleware.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "unauthorized")
return
}
// 获取管理员用户并验证密码
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if !user.CheckPassword(req.Password) {
response.BadRequest(c, "incorrect admin password")
return
}
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"restored": true})
}

View File

@@ -466,9 +466,60 @@ type BatchUsersUsageRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
}
var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
func parseRankingLimit(raw string) int {
limit, err := strconv.Atoi(strings.TrimSpace(raw))
if err != nil || limit <= 0 {
return 12
}
if limit > 50 {
return 50
}
return limit
}
// GetUserSpendingRanking handles getting user spending ranking data.
// GET /api/v1/admin/dashboard/users-ranking
func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
limit := parseRankingLimit(c.DefaultQuery("limit", "12"))
keyRaw, _ := json.Marshal(struct {
Start string `json:"start"`
End string `json:"end"`
Limit int `json:"limit"`
}{
Start: startTime.UTC().Format(time.RFC3339),
End: endTime.UTC().Format(time.RFC3339),
Limit: limit,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit)
if err != nil {
response.Error(c, 500, "Failed to get user spending ranking")
return
}
payload := gin.H{
"ranking": ranking.Ranking,
"total_actual_cost": ranking.TotalActualCost,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
}
dashboardUsersRankingCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
// GetBatchUsersUsage handles getting usage stats for multiple users
// POST /api/v1/admin/dashboard/users-usage
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {

View File

@@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct {
trendStream *bool
modelRequestType *int16
modelStream *bool
rankingLimit int
ranking []usagestats.UserSpendingRankingItem
rankingTotal float64
}
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
@@ -49,6 +52,18 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
return []usagestats.ModelStat{}, nil
}
func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
ctx context.Context,
startTime, endTime time.Time,
limit int,
) (*usagestats.UserSpendingRankingResponse, error) {
s.rankingLimit = limit
return &usagestats.UserSpendingRankingResponse{
Ranking: s.ranking,
TotalActualCost: s.rankingTotal,
}, nil
}
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
@@ -56,6 +71,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng
router := gin.New()
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
router.GET("/admin/dashboard/models", handler.GetModelStats)
router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking)
return router
}
@@ -130,3 +146,30 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
repo := &dashboardUsageRepoCapture{
ranking: []usagestats.UserSpendingRankingItem{
{UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300},
},
rankingTotal: 88.8,
}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 50, repo.rankingLimit)
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
require.Equal(t, http.StatusOK, rec2.Code)
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
}

View File

@@ -1,6 +1,9 @@
package admin
import (
"bytes"
"encoding/json"
"fmt"
"strconv"
"strings"
@@ -16,6 +19,55 @@ type GroupHandler struct {
adminService service.AdminService
}
type optionalLimitField struct {
set bool
value *float64
}
func (f *optionalLimitField) UnmarshalJSON(data []byte) error {
f.set = true
trimmed := bytes.TrimSpace(data)
if bytes.Equal(trimmed, []byte("null")) {
f.value = nil
return nil
}
var number float64
if err := json.Unmarshal(trimmed, &number); err == nil {
f.value = &number
return nil
}
var text string
if err := json.Unmarshal(trimmed, &text); err == nil {
text = strings.TrimSpace(text)
if text == "" {
f.value = nil
return nil
}
number, err = strconv.ParseFloat(text, 64)
if err != nil {
return fmt.Errorf("invalid numeric limit value %q: %w", text, err)
}
f.value = &number
return nil
}
return fmt.Errorf("invalid limit value: %s", string(trimmed))
}
func (f optionalLimitField) ToServiceInput() *float64 {
if !f.set {
return nil
}
if f.value != nil {
return f.value
}
zero := 0.0
return &zero
}
// NewGroupHandler creates a new admin group handler
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
return &GroupHandler{
@@ -25,15 +77,15 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
// CreateGroupRequest represents create group request
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
@@ -62,16 +114,16 @@ type CreateGroupRequest struct {
// UpdateGroupRequest represents update group request
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
@@ -191,9 +243,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
@@ -244,9 +296,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
IsExclusive: req.IsExclusive,
Status: req.Status,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
@@ -335,6 +387,72 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
response.Paginated(c, outKeys, total, page, pageSize)
}
// GetGroupRateMultipliers handles getting rate multipliers for users in a group
// GET /api/v1/admin/groups/:id/rate-multipliers
func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
entries, err := h.adminService.GetGroupRateMultipliers(c.Request.Context(), groupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if entries == nil {
entries = []service.UserGroupRateEntry{}
}
response.Success(c, entries)
}
// ClearGroupRateMultipliers handles clearing all rate multipliers for a group
// DELETE /api/v1/admin/groups/:id/rate-multipliers
func (h *GroupHandler) ClearGroupRateMultipliers(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
if err := h.adminService.ClearGroupRateMultipliers(c.Request.Context(), groupID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Rate multipliers cleared successfully"})
}
// BatchSetGroupRateMultipliersRequest represents batch set rate multipliers request
type BatchSetGroupRateMultipliersRequest struct {
Entries []service.GroupRateMultiplierInput `json:"entries" binding:"required"`
}
// BatchSetGroupRateMultipliers handles batch setting rate multipliers for a group
// PUT /api/v1/admin/groups/:id/rate-multipliers
func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
var req BatchSetGroupRateMultipliersRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.adminService.BatchSetGroupRateMultipliers(c.Request.Context(), groupID, req.Entries); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
}
// UpdateSortOrderRequest represents the request to update group sort orders
type UpdateSortOrderRequest struct {
Updates []struct {

View File

@@ -289,6 +289,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
Platform: platform,
Type: "oauth",
Credentials: credentials,
Extra: nil,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,

View File

@@ -41,12 +41,15 @@ type GenerateRedeemCodesRequest struct {
}
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance
type CreateAndRedeemCodeRequest struct {
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value" binding:"required,gt=0"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
Notes string `json:"notes"`
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance向后兼容
Value float64 `json:"value" binding:"required,gt=0"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
GroupID *int64 `json:"group_id"` // subscription 类型必填
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0
Notes string `json:"notes"`
}
// List handles listing all redeem codes with pagination
@@ -136,6 +139,22 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
return
}
req.Code = strings.TrimSpace(req.Code)
// 向后兼容:旧版调用方(如 Sub2ApiPay不传 type 字段,默认当作 balance 充值处理。
// 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。
if req.Type == "" {
req.Type = "balance"
}
if req.Type == "subscription" {
if req.GroupID == nil {
response.BadRequest(c, "group_id is required for subscription type")
return
}
if req.ValidityDays <= 0 {
response.BadRequest(c, "validity_days must be greater than 0 for subscription type")
return
}
}
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
existing, err := h.redeemService.GetByCode(ctx, req.Code)
@@ -147,11 +166,13 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
}
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
Code: req.Code,
Type: req.Type,
Value: req.Value,
Status: service.StatusUnused,
Notes: req.Notes,
Code: req.Code,
Type: req.Type,
Value: req.Value,
Status: service.StatusUnused,
Notes: req.Notes,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
})
if createErr != nil {
// Unique code race: if code now exists, use idempotent semantics by used_by.

View File

@@ -0,0 +1,135 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal)
// RedeemService so that CreateAndRedeem's nil guard passes and we can test the
// parameter-validation layer that runs before any service call.
func newCreateAndRedeemHandler() *RedeemHandler {
return &RedeemHandler{
adminService: newStubAdminService(),
redeemService: &service.RedeemService{}, // non-nil to pass nil guard
}
}
// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response
// status code. For cases that pass validation and proceed into the service layer,
// a panic may occur (because RedeemService internals are nil); this is expected
// and treated as "validation passed" (returns 0 to indicate panic).
func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
jsonBytes, err := json.Marshal(body)
require.NoError(t, err)
c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes))
c.Request.Header.Set("Content-Type", "application/json")
defer func() {
if r := recover(); r != nil {
// Panic means we passed validation and entered service layer (expected for minimal stub).
code = 0
}
}()
handler.CreateAndRedeem(c)
return w.Code
}
func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) {
// 不传 type 字段时应默认 balance不触发 subscription 校验。
// 验证通过后进入 service 层会 panic返回 0说明默认值生效。
h := newCreateAndRedeemHandler()
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-balance-default",
"value": 10.0,
"user_id": 1,
})
assert.NotEqual(t, http.StatusBadRequest, code,
"omitting type should default to balance and pass validation")
}
func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
h := newCreateAndRedeemHandler()
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-sub-no-group",
"type": "subscription",
"value": 29.9,
"user_id": 1,
"validity_days": 30,
// group_id 缺失
})
assert.Equal(t, http.StatusBadRequest, code)
}
func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) {
groupID := int64(5)
h := newCreateAndRedeemHandler()
cases := []struct {
name string
validityDays int
}{
{"zero", 0},
{"negative", -1},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-sub-bad-days-" + tc.name,
"type": "subscription",
"value": 29.9,
"user_id": 1,
"group_id": groupID,
"validity_days": tc.validityDays,
})
assert.Equal(t, http.StatusBadRequest, code)
})
}
}
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
groupID := int64(5)
h := newCreateAndRedeemHandler()
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-sub-valid",
"type": "subscription",
"value": 29.9,
"user_id": 1,
"group_id": groupID,
"validity_days": 31,
})
assert.NotEqual(t, http.StatusBadRequest, code,
"valid subscription params should pass validation")
}
func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
h := newCreateAndRedeemHandler()
// balance 类型不传 group_id 和 validity_days不应报 400
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-balance-no-extras",
"type": "balance",
"value": 50.0,
"user_id": 1,
})
assert.NotEqual(t, http.StatusBadRequest, code,
"balance type should not require group_id or validity_days")
}

View File

@@ -80,6 +80,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
FrontendURL: settings.FrontendURL,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
@@ -125,6 +126,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
BackendModeEnabled: settings.BackendModeEnabled,
})
}
@@ -136,6 +138,7 @@ type UpdateSettingsRequest struct {
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
FrontendURL string `json:"frontend_url"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
@@ -199,6 +202,9 @@ type UpdateSettingsRequest struct {
// 分组隔离
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
// Backend Mode
BackendModeEnabled bool `json:"backend_mode_enabled"`
}
// UpdateSettings 更新系统设置
@@ -322,6 +328,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// Frontend URL 验证
req.FrontendURL = strings.TrimSpace(req.FrontendURL)
if req.FrontendURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.FrontendURL); err != nil {
response.BadRequest(c, "Frontend URL must be an absolute http(s) URL")
return
}
}
// 自定义菜单项验证
const (
maxCustomMenuItems = 20
@@ -433,6 +448,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
FrontendURL: req.FrontendURL,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost,
@@ -473,6 +489,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
IdentityPatchPrompt: req.IdentityPatchPrompt,
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
BackendModeEnabled: req.BackendModeEnabled,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@@ -526,6 +543,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
FrontendURL: updatedSettings.FrontendURL,
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
TotpEnabled: updatedSettings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
@@ -571,6 +589,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
BackendModeEnabled: updatedSettings.BackendModeEnabled,
})
}
@@ -608,6 +627,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled")
}
if before.FrontendURL != after.FrontendURL {
changed = append(changed, "frontend_url")
}
if before.TotpEnabled != after.TotpEnabled {
changed = append(changed, "totp_enabled")
}
@@ -725,6 +747,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
changed = append(changed, "allow_ungrouped_key_scheduling")
}
if before.BackendModeEnabled != after.BackendModeEnabled {
changed = append(changed, "backend_mode_enabled")
}
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
changed = append(changed, "purchase_subscription_enabled")
}

View File

@@ -218,11 +218,12 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
// ResetSubscriptionQuotaRequest represents the reset quota request
type ResetSubscriptionQuotaRequest struct {
Daily bool `json:"daily"`
Weekly bool `json:"weekly"`
Daily bool `json:"daily"`
Weekly bool `json:"weekly"`
Monthly bool `json:"monthly"`
}
// ResetQuota resets daily and/or weekly usage for a subscription.
// ResetQuota resets daily, weekly, and/or monthly usage for a subscription.
// POST /api/v1/admin/subscriptions/:id/reset-quota
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
@@ -235,11 +236,11 @@ func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !req.Daily && !req.Weekly {
response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true")
if !req.Daily && !req.Weekly && !req.Monthly {
response.BadRequest(c, "At least one of 'daily', 'weekly', or 'monthly' must be true")
return
}
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly)
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly, req.Monthly)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
// Backend mode: only admin can login
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return
}
h.respondWithTokenPair(c, user)
}
@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
return
}
// Delete the login session
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
// Get the user
// Get the user (before session deletion so we can check backend mode)
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Backend mode: only admin can login (check BEFORE deleting session)
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return
}
// Delete the login session (only after all checks pass)
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
h.respondWithTokenPair(c, user)
}
@@ -447,9 +459,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
return
}
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
frontendBaseURL := strings.TrimSpace(h.settingSvc.GetFrontendURL(c.Request.Context()))
if frontendBaseURL == "" {
slog.Error("server.frontend_url not configured; cannot build password reset link")
slog.Error("frontend_url not configured in settings or config; cannot build password reset link")
response.InternalError(c, "Password reset is not configured")
return
}
@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
return
}
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Backend mode: block non-admin token refresh
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return
}
response.Success(c, RefreshTokenResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
ExpiresIn: tokenPair.ExpiresIn,
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
ExpiresIn: result.ExpiresIn,
TokenType: "Bearer",
})
}

View File

@@ -264,8 +264,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
}
}
// 提取 API Key 账号配额限制(apikey 类型有效)
if a.Type == service.AccountTypeAPIKey {
// 提取账号配额限制apikey / bedrock 类型有效)
if a.IsAPIKeyOrBedrock() {
if limit := a.GetQuotaLimit(); limit > 0 {
out.QuotaLimit = &limit
used := a.GetQuotaUsed()
@@ -281,6 +281,31 @@ func AccountFromServiceShallow(a *service.Account) *Account {
used := a.GetQuotaWeeklyUsed()
out.QuotaWeeklyUsed = &used
}
// 固定时间重置配置
if mode := a.GetQuotaDailyResetMode(); mode == "fixed" {
out.QuotaDailyResetMode = &mode
hour := a.GetQuotaDailyResetHour()
out.QuotaDailyResetHour = &hour
}
if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" {
out.QuotaWeeklyResetMode = &mode
day := a.GetQuotaWeeklyResetDay()
out.QuotaWeeklyResetDay = &day
hour := a.GetQuotaWeeklyResetHour()
out.QuotaWeeklyResetHour = &hour
}
if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" {
tz := a.GetQuotaResetTimezone()
out.QuotaResetTimezone = &tz
}
if a.Extra != nil {
if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" {
out.QuotaDailyResetAt = &v
}
if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" {
out.QuotaWeeklyResetAt = &v
}
}
}
return out
@@ -498,6 +523,8 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
Model: l.Model,
ServiceTier: l.ServiceTier,
ReasoningEffort: l.ReasoningEffort,
InboundEndpoint: l.InboundEndpoint,
UpstreamEndpoint: l.UpstreamEndpoint,
GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID,
InputTokens: l.InputTokens,

View File

@@ -76,10 +76,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
t.Parallel()
serviceTier := "priority"
inboundEndpoint := "/v1/chat/completions"
upstreamEndpoint := "/v1/responses"
log := &service.UsageLog{
RequestID: "req_3",
Model: "gpt-5.4",
ServiceTier: &serviceTier,
InboundEndpoint: &inboundEndpoint,
UpstreamEndpoint: &upstreamEndpoint,
AccountRateMultiplier: f64Ptr(1.5),
}
@@ -88,8 +92,16 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
require.NotNil(t, userDTO.ServiceTier)
require.Equal(t, serviceTier, *userDTO.ServiceTier)
require.NotNil(t, userDTO.InboundEndpoint)
require.Equal(t, inboundEndpoint, *userDTO.InboundEndpoint)
require.NotNil(t, userDTO.UpstreamEndpoint)
require.Equal(t, upstreamEndpoint, *userDTO.UpstreamEndpoint)
require.NotNil(t, adminDTO.ServiceTier)
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
require.NotNil(t, adminDTO.InboundEndpoint)
require.Equal(t, inboundEndpoint, *adminDTO.InboundEndpoint)
require.NotNil(t, adminDTO.UpstreamEndpoint)
require.Equal(t, upstreamEndpoint, *adminDTO.UpstreamEndpoint)
require.NotNil(t, adminDTO.AccountRateMultiplier)
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
}

View File

@@ -22,6 +22,7 @@ type SystemSettings struct {
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
FrontendURL string `json:"frontend_url"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
@@ -81,6 +82,9 @@ type SystemSettings struct {
// 分组隔离
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
// Backend Mode
BackendModeEnabled bool `json:"backend_mode_enabled"`
}
type DefaultSubscriptionSetting struct {
@@ -111,6 +115,7 @@ type PublicSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
Version string `json:"version"`
}

View File

@@ -203,6 +203,16 @@ type Account struct {
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
// 配额固定时间重置配置
QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"`
QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"`
QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"`
QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"`
QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"`
QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"`
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
@@ -324,9 +334,13 @@ type UsageLog struct {
Model string `json:"model"`
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string `json:"service_tier,omitempty"`
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
// nil means not provided / not applicable.
// ReasoningEffort is the request's reasoning effort level.
// OpenAI: "low"/"medium"/"high"/"xhigh"; Claude: "low"/"medium"/"high"/"max".
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
InboundEndpoint *string `json:"inbound_endpoint,omitempty"`
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
UpstreamEndpoint *string `json:"upstream_endpoint,omitempty"`
GroupID *int64 `json:"group_id"`
SubscriptionID *int64 `json:"subscription_id"`

View File

@@ -391,6 +391,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if fs.SwitchCount > 0 {
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
}
// 记录 Forward 前已写入字节数Forward 后若增加则说明 SSE 内容已发,禁止 failover
writerSizeBeforeForward := c.Writer.Size()
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
} else {
@@ -402,6 +404,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
if c.Writer.Size() != writerSizeBeforeForward {
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
@@ -434,19 +441,25 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
if result.ReasoningEffort == nil {
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),
@@ -635,6 +648,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if fs.SwitchCount > 0 {
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
}
// 记录 Forward 前已写入字节数Forward 后若增加则说明 SSE 内容已发,禁止 failover
writerSizeBeforeForward := c.Writer.Size()
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
} else {
@@ -704,6 +719,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
if c.Writer.Size() != writerSizeBeforeForward {
h.handleFailoverExhausted(c, failoverErr, account.Platform, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
@@ -736,19 +756,25 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
if result.ReasoningEffort == nil {
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: account,
Subscription: currentSubscription,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: account,
Subscription: currentSubscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),

View File

@@ -0,0 +1,122 @@
package handler
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// partialMessageStartSSE 模拟 handleStreamingResponse 已写入的首批 SSE 事件。
const partialMessageStartSSE = "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_01\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-5\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":1}}}\n\n" +
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"
// TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten 验证:
// 当 Forward 在返回 UpstreamFailoverError 前已向客户端写入 SSE 内容时,
// 故障转移保护逻辑必须终止循环并发送 SSE 错误事件,而不是进行下一次 Forward。
// 具体验证:
// 1. c.Writer.Size() 检测条件正确触发(字节数已增加)
// 2. handleFailoverExhausted 以 streamStarted=true 调用后,响应体以 SSE 错误事件结尾
// 3. 响应体中只出现一个 message_start不存在第二个防止流拼接腐化
func TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// 步骤 1记录 Forward 前的 writer size模拟 writerSizeBeforeForward := c.Writer.Size()
sizeBeforeForward := c.Writer.Size()
require.Equal(t, -1, sizeBeforeForward, "gin writer 初始 Size 应为 -1未写入任何字节")
// 步骤 2模拟 Forward 已向客户端写入部分 SSE 内容message_start + content_block_start
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
require.NoError(t, err)
// 步骤 3验证守卫条件成立c.Writer.Size() != sizeBeforeForward
require.NotEqual(t, sizeBeforeForward, c.Writer.Size(),
"写入 SSE 内容后 writer size 必须增加,守卫条件应为 true")
// 步骤 4模拟 UpstreamFailoverError上游在流中途返回 403
failoverErr := &service.UpstreamFailoverError{
StatusCode: http.StatusForbidden,
ResponseBody: []byte(`{"error":{"type":"permission_error","message":"forbidden"}}`),
}
// 步骤 5守卫触发 → 调用 handleFailoverExhaustedstreamStarted=true
h := &GatewayHandler{}
h.handleFailoverExhausted(c, failoverErr, service.PlatformAnthropic, true)
body := w.Body.String()
// 断言 A响应体中包含最初写入的 message_start SSE 事件行
require.Contains(t, body, "event: message_start", "响应体应包含已写入的 message_start SSE 事件")
// 断言 B响应体以 SSE 错误事件结尾data: {"type":"error",...}\n\n
require.True(t, strings.HasSuffix(strings.TrimRight(body, "\n"), "}"),
"响应体应以 JSON 对象结尾SSE error event 的 data 字段)")
require.Contains(t, body, `"type":"error"`, "响应体末尾必须包含 SSE 错误事件")
// 断言 CSSE event 行 "event: message_start" 只出现一次(防止双 message_start 拼接腐化)
firstIdx := strings.Index(body, "event: message_start")
lastIdx := strings.LastIndex(body, "event: message_start")
assert.Equal(t, firstIdx, lastIdx,
"响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次")
}
// TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten 与上述测试相同,
// 验证 Gemini 路径使用 service.PlatformGemini而非 account.Platform时行为一致。
func TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.0-flash:streamGenerateContent", nil)
sizeBeforeForward := c.Writer.Size()
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
require.NoError(t, err)
require.NotEqual(t, sizeBeforeForward, c.Writer.Size())
failoverErr := &service.UpstreamFailoverError{
StatusCode: http.StatusForbidden,
}
h := &GatewayHandler{}
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
body := w.Body.String()
require.Contains(t, body, "event: message_start")
require.Contains(t, body, `"type":"error"`)
firstIdx := strings.Index(body, "event: message_start")
lastIdx := strings.LastIndex(body, "event: message_start")
assert.Equal(t, firstIdx, lastIdx, "Gemini 路径不得出现双 message_start")
}
// TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered 验证反向场景:
// 当 Forward 返回 UpstreamFailoverError 时若未向客户端写入任何 SSE 内容,
// 守卫条件c.Writer.Size() != sizeBeforeForward为 false不应中止 failover。
func TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// 模拟 writerSizeBeforeForward初始为 -1
sizeBeforeForward := c.Writer.Size()
// Forward 未写入任何字节直接返回错误(例如 401 发生在连接建立前)
// c.Writer.Size() 仍为 -1
// 守卫条件sizeBeforeForward == c.Writer.Size() → 不触发
guardTriggered := c.Writer.Size() != sizeBeforeForward
require.False(t, guardTriggered,
"未写入任何字节时,守卫条件必须为 false应允许正常 failover 继续")
}

View File

@@ -139,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // accountRepo (not used: scheduler snapshot hit)
&fakeGroupRepo{group: group},
nil, // usageLogRepo
nil, // usageBillingRepo
nil, // userRepo
nil, // userSubRepo
nil, // userGroupRateRepo

View File

@@ -503,6 +503,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
requestPayloadHash := service.HashUsageRequestPayload(body)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
@@ -512,6 +513,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fs.ForceCacheBilling,

View File

@@ -12,6 +12,7 @@ type AdminHandlers struct {
Account *admin.AccountHandler
Announcement *admin.AnnouncementHandler
DataManagement *admin.DataManagementHandler
Backup *admin.BackupHandler
OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler

View File

@@ -181,13 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
defaultMappedModel := ""
if apiKey.Group != nil {
defaultMappedModel = apiKey.Group.DefaultMappedModel
}
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
defaultMappedModel = fallbackModel
}
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
@@ -262,14 +256,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointChatCompletions),
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.chat_completions"),

View File

@@ -0,0 +1,57 @@
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestNormalizedOpenAIUpstreamEndpoint(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
path string
fallback string
want string
}{
{
name: "responses root maps to responses upstream",
path: "/v1/responses",
fallback: openAIUpstreamEndpointResponses,
want: "/v1/responses",
},
{
name: "responses compact keeps compact suffix",
path: "/openai/v1/responses/compact",
fallback: openAIUpstreamEndpointResponses,
want: "/v1/responses/compact",
},
{
name: "responses nested suffix preserved",
path: "/openai/v1/responses/compact/detail",
fallback: openAIUpstreamEndpointResponses,
want: "/v1/responses/compact/detail",
},
{
name: "non responses path uses fallback",
path: "/v1/messages",
fallback: openAIUpstreamEndpointResponses,
want: "/v1/responses",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
got := normalizedOpenAIUpstreamEndpoint(c, tt.fallback)
require.Equal(t, tt.want, got)
})
}
}

View File

@@ -37,6 +37,13 @@ type OpenAIGatewayHandler struct {
cfg *config.Config
}
const (
openAIInboundEndpointResponses = "/v1/responses"
openAIInboundEndpointMessages = "/v1/messages"
openAIInboundEndpointChatCompletions = "/v1/chat/completions"
openAIUpstreamEndpointResponses = "/v1/responses"
)
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
@@ -352,18 +359,22 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointResponses),
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
@@ -653,14 +664,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
defaultMappedModel := ""
if apiKey.Group != nil {
defaultMappedModel = apiKey.Group.DefaultMappedModel
}
// 如果使用了降级模型调度,强制使用降级模型
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
defaultMappedModel = fallbackModel
}
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时,
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。
defaultMappedModel := c.GetString("openai_messages_fallback_model")
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
@@ -732,17 +738,21 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointMessages),
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.messages"),
@@ -1231,14 +1241,17 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
h.submitUsageRecordTask(func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointResponses),
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),
@@ -1530,6 +1543,62 @@ func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64)
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
}
func normalizedOpenAIInboundEndpoint(c *gin.Context, fallback string) string {
path := strings.TrimSpace(fallback)
if c != nil {
if fullPath := strings.TrimSpace(c.FullPath()); fullPath != "" {
path = fullPath
} else if c.Request != nil && c.Request.URL != nil {
if requestPath := strings.TrimSpace(c.Request.URL.Path); requestPath != "" {
path = requestPath
}
}
}
switch {
case strings.Contains(path, openAIInboundEndpointChatCompletions):
return openAIInboundEndpointChatCompletions
case strings.Contains(path, openAIInboundEndpointMessages):
return openAIInboundEndpointMessages
case strings.Contains(path, openAIInboundEndpointResponses):
return openAIInboundEndpointResponses
default:
return path
}
}
func normalizedOpenAIUpstreamEndpoint(c *gin.Context, fallback string) string {
base := strings.TrimSpace(fallback)
if base == "" {
base = openAIUpstreamEndpointResponses
}
base = strings.TrimRight(base, "/")
if c == nil || c.Request == nil || c.Request.URL == nil {
return base
}
path := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
if path == "" {
return base
}
idx := strings.LastIndex(path, "/responses")
if idx < 0 {
return base
}
suffix := strings.TrimSpace(path[idx+len("/responses"):])
if suffix == "" || suffix == "/" {
return base
}
if !strings.HasPrefix(suffix, "/") {
return base
}
return base + suffix
}
func isOpenAIWSUpgradeRequest(r *http.Request) bool {
if r == nil {
return false

View File

@@ -26,6 +26,22 @@ const (
opsStreamKey = "ops_stream"
opsRequestBodyKey = "ops_request_body"
opsAccountIDKey = "ops_account_id"
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
opsErrContextCanceled = "context canceled"
opsErrNoAvailableAccounts = "no available accounts"
opsErrInvalidAPIKey = "invalid_api_key"
opsErrAPIKeyRequired = "api_key_required"
opsErrInsufficientBalance = "insufficient balance"
opsErrInsufficientAccountBalance = "insufficient account balance"
opsErrInsufficientQuota = "insufficient_quota"
// 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited)
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
opsCodeUserInactive = "USER_INACTIVE"
)
const (
@@ -1024,9 +1040,9 @@ func normalizeOpsErrorType(errType string, code string) string {
return errType
}
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE":
case opsCodeInsufficientBalance:
return "billing_error"
case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
case opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
return "subscription_error"
default:
return "api_error"
@@ -1038,7 +1054,7 @@ func classifyOpsPhase(errType, message, code string) string {
// Standardized phases: request|auth|routing|upstream|network|internal
// Map billing/concurrency/response => request; scheduling => routing.
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
return "request"
}
@@ -1057,7 +1073,7 @@ func classifyOpsPhase(errType, message, code string) string {
case "upstream_error", "overloaded_error":
return "upstream"
case "api_error":
if strings.Contains(msg, "no available accounts") {
if strings.Contains(msg, opsErrNoAvailableAccounts) {
return "routing"
}
return "internal"
@@ -1103,7 +1119,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
return true
}
if phase == "billing" || phase == "concurrency" {
@@ -1197,21 +1213,30 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
// Check if context canceled errors should be ignored (client disconnects)
if settings.IgnoreContextCanceled {
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
if strings.Contains(msgLower, opsErrContextCanceled) || strings.Contains(bodyLower, opsErrContextCanceled) {
return true
}
}
// Check if "no available accounts" errors should be ignored
if settings.IgnoreNoAvailableAccounts {
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
if strings.Contains(msgLower, opsErrNoAvailableAccounts) || strings.Contains(bodyLower, opsErrNoAvailableAccounts) {
return true
}
}
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
if settings.IgnoreInvalidApiKeyErrors {
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
if strings.Contains(bodyLower, opsErrInvalidAPIKey) || strings.Contains(bodyLower, opsErrAPIKeyRequired) {
return true
}
}
// Check if insufficient balance errors should be ignored
if settings.IgnoreInsufficientBalanceErrors {
if strings.Contains(bodyLower, opsErrInsufficientBalance) || strings.Contains(bodyLower, opsErrInsufficientAccountBalance) ||
strings.Contains(bodyLower, opsErrInsufficientQuota) ||
strings.Contains(msgLower, opsErrInsufficientBalance) || strings.Contains(msgLower, opsErrInsufficientAccountBalance) {
return true
}
}

View File

@@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
SoraClientEnabled: settings.SoraClientEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
Version: h.version,
})
}

View File

@@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService用于测试 SelectAccountForModel
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil,
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
)
}

View File

@@ -399,17 +399,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
}); err != nil {
logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"),

View File

@@ -334,6 +334,14 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return []usagestats.EndpointStat{}, nil
}
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return []usagestats.EndpointStat{}, nil
}
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, nil
}
@@ -343,6 +351,9 @@ func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, e
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, nil
}
@@ -431,6 +442,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
nil,
nil,
nil,
nil,
testutil.StubGatewayCache{},
cfg,
nil,

View File

@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
accountHandler *admin.AccountHandler,
announcementHandler *admin.AnnouncementHandler,
dataManagementHandler *admin.DataManagementHandler,
backupHandler *admin.BackupHandler,
oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler,
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
Account: accountHandler,
Announcement: announcementHandler,
DataManagement: dataManagementHandler,
Backup: backupHandler,
OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler,
@@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet(
admin.NewAccountHandler,
admin.NewAnnouncementHandler,
admin.NewDataManagementHandler,
admin.NewBackupHandler,
admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler,

View File

@@ -19,6 +19,16 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
)
// ForbiddenError 表示上游返回 403 Forbidden
type ForbiddenError struct {
StatusCode int
Body string
}
func (e *ForbiddenError) Error() string {
return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body)
}
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求v1internal 端点)
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
// 构建 URL流式请求添加 ?alt=sse 参数
@@ -514,7 +524,20 @@ type ModelQuotaInfo struct {
// ModelInfo 模型信息
type ModelInfo struct {
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
DisplayName string `json:"displayName,omitempty"`
SupportsImages *bool `json:"supportsImages,omitempty"`
SupportsThinking *bool `json:"supportsThinking,omitempty"`
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
Recommended *bool `json:"recommended,omitempty"`
MaxTokens *int `json:"maxTokens,omitempty"`
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"`
}
// DeprecatedModelInfo 废弃模型转发信息
type DeprecatedModelInfo struct {
NewModelID string `json:"newModelId"`
}
// FetchAvailableModelsRequest fetchAvailableModels 请求
@@ -524,7 +547,8 @@ type FetchAvailableModelsRequest struct {
// FetchAvailableModelsResponse fetchAvailableModels 响应
type FetchAvailableModelsResponse struct {
Models map[string]ModelInfo `json:"models"`
Models map[string]ModelInfo `json:"models"`
DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"`
}
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
@@ -573,6 +597,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
continue
}
if resp.StatusCode == http.StatusForbidden {
return nil, nil, &ForbiddenError{
StatusCode: resp.StatusCode,
Body: string(respBodyBytes),
}
}
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
}

View File

@@ -189,6 +189,5 @@ var DefaultStopSequences = []string{
"<|user|>",
"<|endoftext|>",
"<|end_of_turn|>",
"[DONE]",
"\n\nHuman:",
}

View File

@@ -105,6 +105,7 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
assert.Equal(t, "assistant", items[1].Role)
assert.Equal(t, "function_call", items[2].Type)
assert.Equal(t, "fc_call_1", items[2].CallID)
assert.Empty(t, items[2].ID)
assert.Equal(t, "function_call_output", items[3].Type)
assert.Equal(t, "fc_call_1", items[3].CallID)
assert.Equal(t, "Sunny, 72°F", items[3].Output)

View File

@@ -277,7 +277,6 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
CallID: fcID,
Name: b.Name,
Arguments: args,
ID: fcID,
})
}

View File

@@ -99,6 +99,7 @@ func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
// Check function_call item
assert.Equal(t, "function_call", items[1].Type)
assert.Equal(t, "call_1", items[1].CallID)
assert.Empty(t, items[1].ID)
assert.Equal(t, "ping", items[1].Name)
// Check function_call_output item
@@ -252,6 +253,55 @@ func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T)
assert.Equal(t, "user", items[0].Role)
assert.Equal(t, "assistant", items[1].Role)
assert.Equal(t, "function_call", items[2].Type)
assert.Empty(t, items[2].ID)
}
func TestChatCompletionsToResponses_AssistantArrayContentPreserved(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Hi"`)},
{Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"A"},{"type":"text","text":"B"}]`)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 2)
assert.Equal(t, "assistant", items[1].Role)
var parts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
require.Len(t, parts, 1)
assert.Equal(t, "output_text", parts[0].Type)
assert.Equal(t, "AB", parts[0].Text)
}
func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Hi"`)},
{Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"internal plan"},{"type":"text","text":"final answer"}]`)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 2)
var parts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
require.Len(t, parts, 1)
assert.Equal(t, "output_text", parts[0].Type)
assert.Contains(t, parts[0].Text, "<thinking>internal plan</thinking>")
assert.Contains(t, parts[0].Text, "final answer")
}
// ---------------------------------------------------------------------------
@@ -344,8 +394,8 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
var content string
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
// Reasoning summary is prepended to text
assert.Equal(t, "I thought about it.The answer is 42.", content)
assert.Equal(t, "The answer is 42.", content)
assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
}
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
@@ -582,8 +632,35 @@ func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
Delta: "Thinking...",
}, state)
require.Len(t, chunks, 1)
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.ReasoningContent)
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.reasoning_summary_text.done",
}, state)
require.Len(t, chunks, 0)
}
func TestResponsesEventToChatChunks_ReasoningThenTextAutoCloseTag(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.SentRole = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.reasoning_summary_text.delta",
Delta: "plan",
}, state)
require.Len(t, chunks, 1)
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
assert.Equal(t, "plan", *chunks[0].Choices[0].Delta.ReasoningContent)
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.output_text.delta",
Delta: "answer",
}, state)
require.Len(t, chunks, 1)
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content)
assert.Equal(t, "answer", *chunks[0].Choices[0].Delta.Content)
}
func TestFinalizeResponsesChatStream(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package apicompat
import (
"encoding/json"
"fmt"
"strings"
)
// ChatCompletionsToResponses converts a Chat Completions request into a
@@ -174,8 +175,11 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
// Emit assistant message with output_text if content is non-empty.
if len(m.Content) > 0 {
var s string
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
s, err := parseAssistantContent(m.Content)
if err != nil {
return nil, err
}
if s != "" {
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
partsJSON, err := json.Marshal(parts)
if err != nil {
@@ -196,13 +200,82 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
CallID: tc.ID,
Name: tc.Function.Name,
Arguments: args,
ID: tc.ID,
})
}
return items, nil
}
// parseAssistantContent returns assistant content as plain text.
//
// Supported formats:
// - JSON string
// - JSON array of typed parts (e.g. [{"type":"text","text":"..."}])
//
// For structured thinking/reasoning parts, it preserves semantics by wrapping
// the text in explicit tags so downstream can still distinguish it from normal text.
func parseAssistantContent(raw json.RawMessage) (string, error) {
if len(raw) == 0 {
return "", nil
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return s, nil
}
var parts []map[string]any
if err := json.Unmarshal(raw, &parts); err != nil {
// Keep compatibility with prior behavior: unsupported assistant content
// formats are ignored instead of failing the whole request conversion.
return "", nil
}
var b strings.Builder
write := func(v string) error {
_, err := b.WriteString(v)
return err
}
for _, p := range parts {
typ, _ := p["type"].(string)
text, _ := p["text"].(string)
thinking, _ := p["thinking"].(string)
switch typ {
case "thinking", "reasoning":
if thinking != "" {
if err := write("<thinking>"); err != nil {
return "", err
}
if err := write(thinking); err != nil {
return "", err
}
if err := write("</thinking>"); err != nil {
return "", err
}
} else if text != "" {
if err := write("<thinking>"); err != nil {
return "", err
}
if err := write(text); err != nil {
return "", err
}
if err := write("</thinking>"); err != nil {
return "", err
}
}
default:
if text != "" {
if err := write(text); err != nil {
return "", err
}
}
}
}
return b.String(), nil
}
// chatToolToResponses converts a tool result message (role=tool) into a
// function_call_output item.
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {

View File

@@ -29,6 +29,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
}
var contentText string
var reasoningText string
var toolCalls []ChatToolCall
for _, item := range resp.Output {
@@ -51,7 +52,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
case "reasoning":
for _, s := range item.Summary {
if s.Type == "summary_text" && s.Text != "" {
contentText += s.Text
reasoningText += s.Text
}
}
case "web_search_call":
@@ -67,6 +68,9 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
raw, _ := json.Marshal(contentText)
msg.Content = raw
}
if reasoningText != "" {
msg.ReasoningContent = reasoningText
}
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
@@ -153,6 +157,8 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
return resToChatHandleFuncArgsDelta(evt, state)
case "response.reasoning_summary_text.delta":
return resToChatHandleReasoningDelta(evt, state)
case "response.reasoning_summary_text.done":
return nil
case "response.completed", "response.incomplete", "response.failed":
return resToChatHandleCompleted(evt, state)
default:
@@ -276,8 +282,8 @@ func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEv
if evt.Delta == "" {
return nil
}
content := evt.Delta
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
reasoning := evt.Delta
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ReasoningContent: &reasoning})}
}
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {

View File

@@ -361,11 +361,12 @@ type ChatStreamOptions struct {
// ChatMessage is a single message in the Chat Completions conversation.
type ChatMessage struct {
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
Content json.RawMessage `json:"content,omitempty"`
Name string `json:"name,omitempty"`
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
Content json.RawMessage `json:"content,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
Name string `json:"name,omitempty"`
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
// Legacy function calling
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
@@ -466,9 +467,10 @@ type ChatChunkChoice struct {
// ChatDelta carries incremental content in a streaming chunk.
type ChatDelta struct {
Role string `json:"role,omitempty"`
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
Role string `json:"role,omitempty"`
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
ReasoningContent *string `json:"reasoning_content,omitempty"`
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
}
// ---------------------------------------------------------------------------

View File

@@ -81,6 +81,15 @@ type ModelStat struct {
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// EndpointStat represents usage statistics for a single request endpoint.
type EndpointStat struct {
Endpoint string `json:"endpoint"`
Requests int64 `json:"requests"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// GroupStat represents usage statistics for a single group
type GroupStat struct {
GroupID int64 `json:"group_id"`
@@ -96,12 +105,28 @@ type UserUsageTrendPoint struct {
Date string `json:"date"`
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserSpendingRankingItem represents a user spending ranking row.
type UserSpendingRankingItem struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
ActualCost float64 `json:"actual_cost"` // 实际扣除
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}
// UserSpendingRankingResponse represents ranking rows plus total spend for the time range.
type UserSpendingRankingResponse struct {
Ranking []UserSpendingRankingItem `json:"ranking"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// APIKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint struct {
Date string `json:"date"`
@@ -163,15 +188,18 @@ type UsageLogFilters struct {
// UsageStats represents usage statistics
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
AverageDurationMs float64 `json:"average_duration_ms"`
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
AverageDurationMs float64 `json:"average_duration_ms"`
Endpoints []EndpointStat `json:"endpoints,omitempty"`
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints,omitempty"`
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
}
// BatchUserUsageStats represents usage stats for a single user
@@ -238,7 +266,9 @@ type AccountUsageSummary struct {
// AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse struct {
History []AccountUsageHistory `json:"history"`
Summary AccountUsageSummary `json:"summary"`
Models []ModelStat `json:"models"`
History []AccountUsageHistory `json:"history"`
Summary AccountUsageSummary `json:"summary"`
Models []ModelStat `json:"models"`
Endpoints []EndpointStat `json:"endpoints"`
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints"`
}

View File

@@ -397,9 +397,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
}
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
r.syncSchedulerAccountSnapshot(ctx, account.ID)
}
// 普通账号编辑(如 model_mapping / credentials也需要立即刷新单账号快照
// 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。
r.syncSchedulerAccountSnapshot(ctx, account.ID)
return nil
}
@@ -1727,8 +1727,96 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired.
// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes.
const dailyExpiredExpr = `(
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz)
ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
END
)`
// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired.
const weeklyExpiredExpr = `(
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz)
ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
END
)`
// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs.
// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour.
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
const nextDailyResetAtExpr = `(
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
THEN to_char((
-- Compute today's reset point in the configured timezone, then pick next future one
CASE WHEN NOW() >= (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
-- NOW() is at or past today's reset point → next reset is tomorrow
THEN (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
+ '1 day'::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
-- NOW() is before today's reset point → next reset is today
ELSE (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
END
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
ELSE NULL END
)`
// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs.
// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour.
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
const nextWeeklyResetAtExpr = `(
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
THEN to_char((
-- Compute this week's reset point in the configured timezone
-- Step 1: get today's date at reset hour in configured tz
-- Step 2: compute days forward to target weekday
-- Step 3: if same day but past reset hour, advance 7 days
CASE
WHEN (
-- days_forward = (target_day - current_day + 7) % 7
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
+ 7) % 7
) = 0 AND NOW() >= (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
-- Same weekday and past reset hour → next week
THEN (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
+ '7 days'::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
ELSE (
-- Advance to target weekday this week (or next if days_forward > 0)
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
+ ((
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
+ 7) % 7
) || ' days')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
END
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
ELSE NULL END
)`
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
// 日/周额度在周期过期时自动重置为 0 再递增。
// 支持滚动窗口rolling和固定时间fixed两种重置模式。
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
rows, err := r.sql.QueryContext(ctx,
`UPDATE accounts SET extra = (
@@ -1739,31 +1827,35 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_daily_used',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
CASE WHEN `+dailyExpiredExpr+`
THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
CASE WHEN `+dailyExpiredExpr+`
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
)
-- 固定模式重置时更新下次重置时间
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
ELSE '{}'::jsonb END
ELSE '{}'::jsonb END
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_weekly_used',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
CASE WHEN `+weeklyExpiredExpr+`
THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
CASE WHEN `+weeklyExpiredExpr+`
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
)
-- 固定模式重置时更新下次重置时间
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
ELSE '{}'::jsonb END
ELSE '{}'::jsonb END
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
@@ -1796,12 +1888,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
}
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
// 保留固定重置模式的配置字段quota_daily_reset_mode 等),仅清零用量和窗口起始时间
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx,
`UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb)
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW()
WHERE id = $1 AND deleted_at IS NULL`,
id)
if err != nil {

View File

@@ -142,6 +142,35 @@ func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
}
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnCredentialsChange() {
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "sync-credentials-update",
Status: service.StatusActive,
Schedulable: true,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.1",
},
},
})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
account.Credentials = map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.2",
},
}
err := s.repo.Update(s.ctx, account)
s.Require().NoError(err, "Update")
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
mapping, ok := cacheRecorder.setAccounts[0].Credentials["model_mapping"].(map[string]any)
s.Require().True(ok)
s.Require().Equal("gpt-5.2", mapping["gpt-5"])
}
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})

View File

@@ -0,0 +1,98 @@
package repository
import (
"context"
"fmt"
"io"
"os/exec"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// PgDumper implements service.DBDumper using pg_dump/psql
type PgDumper struct {
cfg *config.DatabaseConfig
}
// NewPgDumper creates a new PgDumper
func NewPgDumper(cfg *config.Config) service.DBDumper {
return &PgDumper{cfg: &cfg.Database}
}
// Dump executes pg_dump and returns a streaming reader of the output
func (d *PgDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
args := []string{
"-h", d.cfg.Host,
"-p", fmt.Sprintf("%d", d.cfg.Port),
"-U", d.cfg.User,
"-d", d.cfg.DBName,
"--no-owner",
"--no-acl",
"--clean",
"--if-exists",
}
cmd := exec.CommandContext(ctx, "pg_dump", args...)
if d.cfg.Password != "" {
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
}
if d.cfg.SSLMode != "" {
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("create stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("start pg_dump: %w", err)
}
// 返回一个 ReadCloser读 stdout关闭时等待进程退出
return &cmdReadCloser{ReadCloser: stdout, cmd: cmd}, nil
}
// Restore executes psql to restore from a streaming reader
func (d *PgDumper) Restore(ctx context.Context, data io.Reader) error {
args := []string{
"-h", d.cfg.Host,
"-p", fmt.Sprintf("%d", d.cfg.Port),
"-U", d.cfg.User,
"-d", d.cfg.DBName,
"--single-transaction",
}
cmd := exec.CommandContext(ctx, "psql", args...)
if d.cfg.Password != "" {
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
}
if d.cfg.SSLMode != "" {
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
}
cmd.Stdin = data
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("%v: %s", err, string(output))
}
return nil
}
// cmdReadCloser wraps a command stdout pipe and waits for the process on Close
type cmdReadCloser struct {
io.ReadCloser
cmd *exec.Cmd
}
func (c *cmdReadCloser) Close() error {
// Close the pipe first
_ = c.ReadCloser.Close()
// Wait for the process to exit
if err := c.cmd.Wait(); err != nil {
return fmt.Errorf("pg_dump exited with error: %w", err)
}
return nil
}

View File

@@ -0,0 +1,116 @@
package repository
import (
"bytes"
"context"
"fmt"
"io"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage
type S3BackupStore struct {
client *s3.Client
bucket string
}
// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores
func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) {
region := cfg.Region
if region == "" {
region = "auto" // Cloudflare R2 默认 region
}
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
awsconfig.WithRegion(region),
awsconfig.WithCredentialsProvider(
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
),
)
if err != nil {
return nil, fmt.Errorf("load aws config: %w", err)
}
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
if cfg.Endpoint != "" {
o.BaseEndpoint = &cfg.Endpoint
}
if cfg.ForcePathStyle {
o.UsePathStyle = true
}
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
})
return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil
}
}
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
// 读取全部内容以获取大小S3 PutObject 需要知道内容长度)
data, err := io.ReadAll(body)
if err != nil {
return 0, fmt.Errorf("read body: %w", err)
}
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
Bucket: &s.bucket,
Key: &key,
Body: bytes.NewReader(data),
ContentType: &contentType,
})
if err != nil {
return 0, fmt.Errorf("S3 PutObject: %w", err)
}
return int64(len(data)), nil
}
func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) {
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
Bucket: &s.bucket,
Key: &key,
})
if err != nil {
return nil, fmt.Errorf("S3 GetObject: %w", err)
}
return result.Body, nil
}
func (s *S3BackupStore) Delete(ctx context.Context, key string) error {
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: &s.bucket,
Key: &key,
})
return err
}
func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) {
presignClient := s3.NewPresignClient(s.client)
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
Bucket: &s.bucket,
Key: &key,
}, s3.WithPresignExpires(expiry))
if err != nil {
return "", fmt.Errorf("presign url: %w", err)
}
return result.URL, nil
}
func (s *S3BackupStore) HeadBucket(ctx context.Context) error {
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
Bucket: &s.bucket,
})
if err != nil {
return fmt.Errorf("S3 HeadBucket failed: %w", err)
}
return nil
}

View File

@@ -17,6 +17,9 @@ type dashboardAggregationRepository struct {
sql sqlExecutor
}
const usageLogsCleanupBatchSize = 10000
const usageBillingDedupCleanupBatchSize = 10000
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
if sqlDB == nil {
@@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool {
}
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
if r == nil || r.sql == nil {
return nil
}
loc := timezone.Location()
startLocal := start.In(loc)
endLocal := end.In(loc)
@@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
dayEnd = dayEnd.Add(24 * time.Hour)
}
if db, ok := r.sql.(*sql.DB); ok {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
}
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
@@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c
if isPartitioned {
return r.dropUsageLogsPartitions(ctx, cutoff)
}
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
return err
for {
res, err := r.sql.ExecContext(ctx, `
WITH victims AS (
SELECT ctid
FROM usage_logs
WHERE created_at < $1
LIMIT $2
)
DELETE FROM usage_logs
WHERE ctid IN (SELECT ctid FROM victims)
`, cutoff.UTC(), usageLogsCleanupBatchSize)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected < usageLogsCleanupBatchSize {
return nil
}
}
}
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
for {
res, err := r.sql.ExecContext(ctx, `
WITH victims AS (
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
FROM usage_billing_dedup
WHERE created_at < $1
LIMIT $2
), archived AS (
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
SELECT request_id, api_key_id, request_fingerprint, created_at
FROM victims
ON CONFLICT (request_id, api_key_id) DO NOTHING
)
DELETE FROM usage_billing_dedup
WHERE ctid IN (SELECT ctid FROM victims)
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected < usageBillingDedupCleanupBatchSize {
return nil
}
}
}
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {

View File

@@ -262,6 +262,42 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se
SetKey(k.Key).
SetName(k.Name).
SetStatus(k.Status)
if k.Quota != 0 {
create.SetQuota(k.Quota)
}
if k.QuotaUsed != 0 {
create.SetQuotaUsed(k.QuotaUsed)
}
if k.RateLimit5h != 0 {
create.SetRateLimit5h(k.RateLimit5h)
}
if k.RateLimit1d != 0 {
create.SetRateLimit1d(k.RateLimit1d)
}
if k.RateLimit7d != 0 {
create.SetRateLimit7d(k.RateLimit7d)
}
if k.Usage5h != 0 {
create.SetUsage5h(k.Usage5h)
}
if k.Usage1d != 0 {
create.SetUsage1d(k.Usage1d)
}
if k.Usage7d != 0 {
create.SetUsage7d(k.Usage7d)
}
if k.Window5hStart != nil {
create.SetWindow5hStart(*k.Window5hStart)
}
if k.Window1dStart != nil {
create.SetWindow1dStart(*k.Window1dStart)
}
if k.Window7dStart != nil {
create.SetWindow7dStart(*k.Window7dStart)
}
if k.ExpiresAt != nil {
create.SetExpiresAt(*k.ExpiresAt)
}
if k.GroupID != nil {
create.SetGroupID(*k.GroupID)
}

View File

@@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
// usage_billing_dedup: billing idempotency narrow table
var usageBillingDedupRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
var usageBillingDedupArchiveRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
// settings table should exist
var settingsRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
@@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
var exists bool
err := tx.QueryRowContext(context.Background(), `
SELECT EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = $1
AND indexname = $2
)
`, table, index).Scan(&exists)
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
require.True(t, exists, "expected index %s on %s", index, table)
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper()

View File

@@ -73,3 +73,14 @@ func buildReqClientKey(opts reqClientOptions) string {
opts.ForceHTTP2,
)
}
// CreatePrivacyReqClient creates an HTTP client for OpenAI privacy settings API
// This is exported for use by OpenAIPrivacyService
// Uses Chrome TLS fingerprint impersonation to bypass Cloudflare checks
func CreatePrivacyReqClient(proxyURL string) (*req.Client, error) {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
Impersonate: true, // Enable Chrome TLS fingerprint impersonation
})
}

View File

@@ -0,0 +1,308 @@
package repository
import (
"context"
"database/sql"
"errors"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type usageBillingRepository struct {
db *sql.DB
}
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
return &usageBillingRepository{db: sqlDB}
}
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
if cmd == nil {
return &service.UsageBillingApplyResult{}, nil
}
if r == nil || r.db == nil {
return nil, errors.New("usage billing repository db is nil")
}
cmd.Normalize()
if cmd.RequestID == "" {
return nil, service.ErrUsageBillingRequestIDRequired
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
if err != nil {
return nil, err
}
if !applied {
return &service.UsageBillingApplyResult{Applied: false}, nil
}
result := &service.UsageBillingApplyResult{Applied: true}
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
tx = nil
return result, nil
}
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
var id int64
err := tx.QueryRowContext(ctx, `
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
VALUES ($1, $2, $3)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
if errors.Is(err, sql.ErrNoRows) {
var existingFingerprint string
if err := tx.QueryRowContext(ctx, `
SELECT request_fingerprint
FROM usage_billing_dedup
WHERE request_id = $1 AND api_key_id = $2
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
return false, err
}
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
return false, service.ErrUsageBillingRequestConflict
}
return false, nil
}
if err != nil {
return false, err
}
var archivedFingerprint string
err = tx.QueryRowContext(ctx, `
SELECT request_fingerprint
FROM usage_billing_dedup_archive
WHERE request_id = $1 AND api_key_id = $2
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
if err == nil {
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
return false, service.ErrUsageBillingRequestConflict
}
return false, nil
}
if !errors.Is(err, sql.ErrNoRows) {
return false, err
}
return true, nil
}
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
return err
}
}
if cmd.BalanceCost > 0 {
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
return err
}
}
if cmd.APIKeyQuotaCost > 0 {
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
if err != nil {
return err
}
result.APIKeyQuotaExhausted = exhausted
}
if cmd.APIKeyRateLimitCost > 0 {
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
return err
}
}
if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) {
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
return err
}
}
return nil
}
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
const updateSQL = `
UPDATE user_subscriptions us
SET
daily_usage_usd = us.daily_usage_usd + $1,
weekly_usage_usd = us.weekly_usage_usd + $1,
monthly_usage_usd = us.monthly_usage_usd + $1,
updated_at = NOW()
FROM groups g
WHERE us.id = $2
AND us.deleted_at IS NULL
AND us.group_id = g.id
AND g.deleted_at IS NULL
`
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
return service.ErrSubscriptionNotFound
}
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
res, err := tx.ExecContext(ctx, `
UPDATE users
SET balance = balance - $1,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
`, amount, userID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
return service.ErrUserNotFound
}
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
var exhausted bool
err := tx.QueryRowContext(ctx, `
UPDATE api_keys
SET quota_used = quota_used + $1,
status = CASE
WHEN quota > 0
AND status = $3
AND quota_used < quota
AND quota_used + $1 >= quota
THEN $4
ELSE status
END,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
if errors.Is(err, sql.ErrNoRows) {
return false, service.ErrAPIKeyNotFound
}
if err != nil {
return false, err
}
return exhausted, nil
}
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
res, err := tx.ExecContext(ctx, `
UPDATE api_keys SET
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
`, cost, apiKeyID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAPIKeyNotFound
}
return nil
}
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
rows, err := tx.QueryContext(ctx,
`UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb)
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_daily_used',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
)
ELSE '{}'::jsonb END
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_weekly_used',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
)
ELSE '{}'::jsonb END
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING
COALESCE((extra->>'quota_used')::numeric, 0),
COALESCE((extra->>'quota_limit')::numeric, 0)`,
amount, accountID)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
var newUsed, limit float64
if rows.Next() {
if err := rows.Scan(&newUsed, &limit); err != nil {
return err
}
} else {
if err := rows.Err(); err != nil {
return err
}
return service.ErrAccountNotFound
}
if err := rows.Err(); err != nil {
return err
}
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
return err
}
}
return nil
}

View File

@@ -0,0 +1,279 @@
//go:build integration
package repository
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-" + uuid.NewString(),
Name: "billing",
Quota: 1,
})
account := mustCreateAccount(t, client, &service.Account{
Name: "usage-billing-account-" + uuid.NewString(),
Type: service.AccountTypeAPIKey,
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: account.ID,
AccountType: service.AccountTypeAPIKey,
BalanceCost: 1.25,
APIKeyQuotaCost: 1.25,
APIKeyRateLimitCost: 1.25,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.NotNil(t, result1)
require.True(t, result1.Applied)
require.True(t, result1.APIKeyQuotaExhausted)
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.NotNil(t, result2)
require.False(t, result2.Applied)
var balance float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
require.InDelta(t, 98.75, balance, 0.000001)
var quotaUsed float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan(&quotaUsed))
require.InDelta(t, 1.25, quotaUsed, 0.000001)
var usage5h float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h))
require.InDelta(t, 1.25, usage5h, 0.000001)
var status string
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status))
require.Equal(t, service.StatusAPIKeyQuotaExhausted, status)
var dedupCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount))
require.Equal(t, 1, dedupCount)
}
func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
})
group := mustCreateGroup(t, client, &service.Group{
Name: "usage-billing-group-" + uuid.NewString(),
Platform: service.PlatformAnthropic,
SubscriptionType: service.SubscriptionTypeSubscription,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
GroupID: &group.ID,
Key: "sk-usage-billing-sub-" + uuid.NewString(),
Name: "billing-sub",
})
subscription := mustCreateSubscription(t, client, &service.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: 0,
SubscriptionID: &subscription.ID,
SubscriptionCost: 2.5,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.True(t, result1.Applied)
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.False(t, result2.Applied)
var dailyUsage float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage))
require.InDelta(t, 2.5, dailyUsage, 0.000001)
}
func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-conflict-" + uuid.NewString(),
Name: "billing-conflict",
})
requestID := uuid.NewString()
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 1.25,
})
require.NoError(t, err)
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 2.50,
})
require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict)
}
func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-account-" + uuid.NewString(),
Name: "billing-account",
})
account := mustCreateAccount(t, client, &service.Account{
Name: "usage-billing-account-quota-" + uuid.NewString(),
Type: service.AccountTypeAPIKey,
Extra: map[string]any{
"quota_limit": 100.0,
},
})
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: uuid.NewString(),
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: account.ID,
AccountType: service.AccountTypeAPIKey,
AccountQuotaCost: 3.5,
})
require.NoError(t, err)
var quotaUsed float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan(&quotaUsed))
require.InDelta(t, 3.5, quotaUsed, 0.000001)
}
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
ctx := context.Background()
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
oldRequestID := "dedup-old-" + uuid.NewString()
newRequestID := "dedup-new-" + uuid.NewString()
oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400)
newCreatedAt := time.Now().UTC().Add(-time.Hour)
_, err := integrationDB.ExecContext(ctx, `
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at)
VALUES ($1, 1, $2, $3), ($4, 1, $5, $6)
`,
oldRequestID, strings.Repeat("a", 64), oldCreatedAt,
newRequestID, strings.Repeat("b", 64), newCreatedAt,
)
require.NoError(t, err)
require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
var oldCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount))
require.Equal(t, 0, oldCount)
var newCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount))
require.Equal(t, 1, newCount)
var archivedCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount))
require.Equal(t, 1, archivedCount)
}
func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-archive-" + uuid.NewString(),
Name: "billing-archive",
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 1.25,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.True(t, result1.Applied)
_, err = integrationDB.ExecContext(ctx, `
UPDATE usage_billing_dedup
SET created_at = $1
WHERE request_id = $2 AND api_key_id = $3
`, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID)
require.NoError(t, err)
require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.False(t, result2.Applied)
var balance float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
require.InDelta(t, 98.75, balance, 0.000001)
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,8 @@ package repository
import (
"context"
"fmt"
"sync"
"testing"
"time"
@@ -14,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@@ -84,6 +87,367 @@ func (s *UsageLogRepoSuite) TestCreate() {
s.Require().NotZero(log.ID)
}
func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()})
const total = 16
results := make([]bool, total)
errs := make([]error, total)
logs := make([]*service.UsageLog, total)
var wg sync.WaitGroup
wg.Add(total)
for i := 0; i < total; i++ {
i := i
logs[i] = &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10 + i,
OutputTokens: 20 + i,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
go func() {
defer wg.Done()
results[i], errs[i] = repo.Create(ctx, logs[i])
}()
}
wg.Wait()
for i := 0; i < total; i++ {
require.NoError(t, errs[i])
require.True(t, results[i])
require.NotZero(t, logs[i].ID)
}
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count))
require.Equal(t, total, count)
}
func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()})
requestID := uuid.NewString()
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
inserted1, err1 := repo.Create(ctx, log1)
inserted2, err2 := repo.Create(ctx, log2)
require.NoError(t, err1)
require.NoError(t, err2)
require.True(t, inserted1)
require.False(t, inserted2)
require.Equal(t, log1.ID, log2.ID)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
require.Equal(t, 1, count)
}
func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()})
requestID := uuid.NewString()
const total = 8
batch := make([]usageLogCreateRequest, 0, total)
logs := make([]*service.UsageLog, 0, total)
for i := 0; i < total; i++ {
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10 + i,
OutputTokens: 20 + i,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
logs = append(logs, log)
batch = append(batch, usageLogCreateRequest{
log: log,
prepared: prepareUsageLogInsert(log),
resultCh: make(chan usageLogCreateResult, 1),
})
}
repo.flushCreateBatch(integrationDB, batch)
insertedCount := 0
var firstID int64
for idx, req := range batch {
res := <-req.resultCh
require.NoError(t, res.err)
if res.inserted {
insertedCount++
}
require.NotZero(t, logs[idx].ID)
if idx == 0 {
firstID = logs[idx].ID
} else {
require.Equal(t, firstID, logs[idx].ID)
}
}
require.Equal(t, 1, insertedCount)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
require.Equal(t, 1, count)
}
func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()})
requestID := uuid.NewString()
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
require.NoError(t, repo.CreateBestEffort(ctx, log1))
require.NoError(t, repo.CreateBestEffort(ctx, log2))
require.Eventually(t, func() bool {
var count int
err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)
return err == nil && count == 1
}, 3*time.Second, 20*time.Millisecond)
}
func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1)
repo.bestEffortBatchCh <- usageLogBestEffortRequest{}
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()})
err := repo.CreateBestEffort(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.Error(t, err)
require.True(t, service.IsUsageLogCreateDropped(err))
}
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()})
ctx, cancel := context.WithCancel(context.Background())
cancel()
inserted, err := repo.Create(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.False(t, inserted)
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
}
func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
repo.createBatchCh <- usageLogCreateRequest{}
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()})
inserted, err := repo.Create(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.False(t, inserted)
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
}
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()})
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error, 1)
go func() {
_, err := repo.createBatched(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
errCh <- err
}()
req := <-repo.createBatchCh
require.NotNil(t, req.shared)
cancel()
err := <-errCh
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)})
}
func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()})
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
req := usageLogCreateRequest{
log: log,
prepared: prepareUsageLogInsert(log),
shared: &usageLogCreateShared{},
resultCh: make(chan usageLogCreateResult, 1),
}
req.shared.state.Store(usageLogCreateStateCanceled)
repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req})
res := <-req.resultCh
require.False(t, res.inserted)
require.Error(t, res.err)
require.True(t, service.IsUsageLogCreateNotPersisted(res.err))
}
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})

View File

@@ -73,6 +73,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // media_type
sqlmock.AnyArg(), // service_tier
sqlmock.AnyArg(), // reasoning_effort
sqlmock.AnyArg(), // inbound_endpoint
sqlmock.AnyArg(), // upstream_endpoint
log.CacheTTLOverridden,
createdAt,
).
@@ -141,6 +143,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(),
serviceTier,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.CacheTTLOverridden,
createdAt,
).
@@ -248,6 +252,35 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}).
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0).
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0).
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0)
mock.ExpectQuery("WITH user_spend AS \\(").
WithArgs(start, end, 12).
WillReturnRows(rows)
got, err := repo.GetUserSpendingRanking(context.Background(), start, end, 12)
require.NoError(t, err)
require.Equal(t, &usagestats.UserSpendingRankingResponse{
Ranking: []usagestats.UserSpendingRankingItem{
{UserID: 2, Email: "beta@example.com", ActualCost: 12.5, Requests: 9, Tokens: 900},
{UserID: 1, Email: "alpha@example.com", ActualCost: 12.5, Requests: 8, Tokens: 800},
{UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300},
},
TotalActualCost: 40.0,
}, got)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
tests := []struct {
name string
@@ -347,6 +380,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
now,
}})
@@ -386,6 +421,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{Valid: true, String: "flex"},
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
now,
}})
@@ -425,6 +462,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
now,
}})

View File

@@ -3,8 +3,11 @@
package repository
import (
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
@@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) {
})
}
}
func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) {
log := &service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-batch-no-update",
Model: "gpt-5",
InputTokens: 10,
OutputTokens: 5,
TotalCost: 1.2,
ActualCost: 1.2,
CreatedAt: time.Now().UTC(),
}
prepared := prepareUsageLogInsert(log)
query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{
usageLogBatchKey(log.RequestID, log.APIKeyID): prepared,
})
require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING")
require.NotContains(t, strings.ToUpper(query), "DO UPDATE")
}

View File

@@ -95,6 +95,35 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
return result, nil
}
// GetByGroupID 获取指定分组下所有用户的专属倍率
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
query := `
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
FROM user_group_rate_multipliers ugr
JOIN users u ON u.id = ugr.user_id
WHERE ugr.group_id = $1
ORDER BY ugr.user_id
`
rows, err := r.sql.QueryContext(ctx, query, groupID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var result []service.UserGroupRateEntry
for rows.Next() {
var entry service.UserGroupRateEntry
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil {
return nil, err
}
result = append(result, entry)
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// GetByUserAndGroup 获取用户在特定分组的专属倍率
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
@@ -164,6 +193,31 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
return nil
}
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil {
return err
}
if len(entries) == 0 {
return nil
}
userIDs := make([]int64, len(entries))
rates := make([]float64, len(entries))
for i, e := range entries {
userIDs[i] = e.UserID
rates[i] = e.RateMultiplier
}
now := time.Now()
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
SELECT data.user_id, $1::bigint, data.rate_multiplier, $2::timestamptz, $2::timestamptz
FROM unnest($3::bigint[], $4::double precision[]) AS data(user_id, rate_multiplier)
ON CONFLICT (user_id, group_id)
DO UPDATE SET rate_multiplier = EXCLUDED.rate_multiplier, updated_at = EXCLUDED.updated_at
`, groupID, now, pq.Array(userIDs), pq.Array(rates))
return err
}
// DeleteByGroupID 删除指定分组的所有用户专属倍率
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)

View File

@@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet(
NewAnnouncementRepository,
NewAnnouncementReadRepository,
NewUsageLogRepository,
NewUsageBillingRepository,
NewIdempotencyRepository,
NewUsageCleanupRepository,
NewDashboardAggregationRepository,
@@ -99,6 +100,10 @@ var ProviderSet = wire.NewSet(
// Encryptors
NewAESEncryptor,
// Backup infrastructure
NewPgDumper,
NewS3BackupStoreFactory,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,
ProvidePricingRemoteClient,

View File

@@ -493,6 +493,7 @@ func TestAPIContracts(t *testing.T) {
"registration_email_suffix_whitelist": [],
"promo_code_enabled": true,
"password_reset_enabled": false,
"frontend_url": "",
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com",
@@ -537,6 +538,7 @@ func TestAPIContracts(t *testing.T) {
"purchase_subscription_url": "",
"min_claude_code_version": "",
"allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"custom_menu_items": []
}
}`,
@@ -645,7 +647,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
@@ -1623,6 +1625,14 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, errors.New("not implemented")
}
@@ -1635,6 +1645,10 @@ func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, end
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
logs := r.userLogs[userID]
if len(logs) == 0 {

View File

@@ -0,0 +1,51 @@
package middleware
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// BackendModeUserGuard blocks non-admin users from accessing user routes when backend mode is enabled.
// Must be placed AFTER JWT auth middleware so that the user role is available in context.
func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFunc {
return func(c *gin.Context) {
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
c.Next()
return
}
role, _ := GetUserRoleFromContext(c)
if role == "admin" {
c.Next()
return
}
response.Forbidden(c, "Backend mode is active. User self-service is disabled.")
c.Abort()
}
}
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
// Allows: login, login/2fa, logout, refresh (admin needs these).
// Blocks: register, forgot-password, reset-password, OAuth, etc.
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
return func(c *gin.Context) {
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
c.Next()
return
}
path := c.Request.URL.Path
// Allow login, 2FA, logout, refresh, public settings
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
for _, suffix := range allowedSuffixes {
if strings.HasSuffix(path, suffix) {
c.Next()
return
}
}
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
c.Abort()
}
}

View File

@@ -0,0 +1,239 @@
//go:build unit
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type bmSettingRepo struct {
values map[string]string
}
func (r *bmSettingRepo) Get(_ context.Context, _ string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (r *bmSettingRepo) GetValue(_ context.Context, key string) (string, error) {
v, ok := r.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return v, nil
}
func (r *bmSettingRepo) Set(_ context.Context, _, _ string) error {
panic("unexpected Set call")
}
func (r *bmSettingRepo) GetMultiple(_ context.Context, _ []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (r *bmSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
if r.values == nil {
r.values = make(map[string]string, len(settings))
}
for key, value := range settings {
r.values[key] = value
}
return nil
}
func (r *bmSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (r *bmSettingRepo) Delete(_ context.Context, _ string) error {
panic("unexpected Delete call")
}
func newBackendModeSettingService(t *testing.T, enabled string) *service.SettingService {
t.Helper()
repo := &bmSettingRepo{
values: map[string]string{
service.SettingKeyBackendModeEnabled: enabled,
},
}
svc := service.NewSettingService(repo, &config.Config{})
require.NoError(t, svc.UpdateSettings(context.Background(), &service.SystemSettings{
BackendModeEnabled: enabled == "true",
}))
return svc
}
func stringPtr(v string) *string {
return &v
}
func TestBackendModeUserGuard(t *testing.T) {
tests := []struct {
name string
nilService bool
enabled string
role *string
wantStatus int
}{
{
name: "disabled_allows_all",
enabled: "false",
role: stringPtr("user"),
wantStatus: http.StatusOK,
},
{
name: "nil_service_allows_all",
nilService: true,
role: stringPtr("user"),
wantStatus: http.StatusOK,
},
{
name: "enabled_admin_allowed",
enabled: "true",
role: stringPtr("admin"),
wantStatus: http.StatusOK,
},
{
name: "enabled_user_blocked",
enabled: "true",
role: stringPtr("user"),
wantStatus: http.StatusForbidden,
},
{
name: "enabled_no_role_blocked",
enabled: "true",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_empty_role_blocked",
enabled: "true",
role: stringPtr(""),
wantStatus: http.StatusForbidden,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
if tc.role != nil {
role := *tc.role
r.Use(func(c *gin.Context) {
c.Set(string(ContextKeyUserRole), role)
c.Next()
})
}
var svc *service.SettingService
if !tc.nilService {
svc = newBackendModeSettingService(t, tc.enabled)
}
r.Use(BackendModeUserGuard(svc))
r.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
r.ServeHTTP(w, req)
require.Equal(t, tc.wantStatus, w.Code)
})
}
}
func TestBackendModeAuthGuard(t *testing.T) {
tests := []struct {
name string
nilService bool
enabled string
path string
wantStatus int
}{
{
name: "disabled_allows_all",
enabled: "false",
path: "/api/v1/auth/register",
wantStatus: http.StatusOK,
},
{
name: "nil_service_allows_all",
nilService: true,
path: "/api/v1/auth/register",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_login",
enabled: "true",
path: "/api/v1/auth/login",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_login_2fa",
enabled: "true",
path: "/api/v1/auth/login/2fa",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_logout",
enabled: "true",
path: "/api/v1/auth/logout",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_refresh",
enabled: "true",
path: "/api/v1/auth/refresh",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_register",
enabled: "true",
path: "/api/v1/auth/register",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_blocks_forgot_password",
enabled: "true",
path: "/api/v1/auth/forgot-password",
wantStatus: http.StatusForbidden,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
var svc *service.SettingService
if !tc.nilService {
svc = newBackendModeSettingService(t, tc.enabled)
}
r.Use(BackendModeAuthGuard(svc))
r.Any("/*path", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
r.ServeHTTP(w, req)
require.Equal(t, tc.wantStatus, w.Code)
})
}
}

View File

@@ -107,9 +107,9 @@ func registerRoutes(
v1 := r.Group("/api/v1")
// 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
routes.RegisterUserRoutes(v1, h, jwtAuth)
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
}

View File

@@ -58,6 +58,9 @@ func RegisterAdminRoutes(
// 数据管理
registerDataManagementRoutes(admin, h)
// 数据库备份恢复
registerBackupRoutes(admin, h)
// 运维监控Ops
registerOpsRoutes(admin, h)
@@ -192,6 +195,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
@@ -228,6 +232,9 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
groups.PUT("/:id", h.Admin.Group.Update)
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers)
groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
}
@@ -436,6 +443,30 @@ func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func registerBackupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
backup := admin.Group("/backups")
{
// S3 存储配置
backup.GET("/s3-config", h.Admin.Backup.GetS3Config)
backup.PUT("/s3-config", h.Admin.Backup.UpdateS3Config)
backup.POST("/s3-config/test", h.Admin.Backup.TestS3Connection)
// 定时备份配置
backup.GET("/schedule", h.Admin.Backup.GetSchedule)
backup.PUT("/schedule", h.Admin.Backup.UpdateSchedule)
// 备份操作
backup.POST("", h.Admin.Backup.CreateBackup)
backup.GET("", h.Admin.Backup.ListBackups)
backup.GET("/:id", h.Admin.Backup.GetBackup)
backup.DELETE("/:id", h.Admin.Backup.DeleteBackup)
backup.GET("/:id/download-url", h.Admin.Backup.GetDownloadURL)
// 恢复操作
backup.POST("/:id/restore", h.Admin.Backup.RestoreBackup)
}
}
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
system := admin.Group("/system")
{

View File

@@ -6,6 +6,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/middleware"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
@@ -17,12 +18,14 @@ func RegisterAuthRoutes(
h *handler.Handlers,
jwtAuth servermiddleware.JWTAuthMiddleware,
redisClient *redis.Client,
settingService *service.SettingService,
) {
// 创建速率限制器
rateLimiter := middleware.NewRateLimiter(redisClient)
// 公开接口
auth := v1.Group("/auth")
auth.Use(servermiddleware.BackendModeAuthGuard(settingService))
{
// 注册/登录/2FA/验证码发送均属于高风险入口增加服务端兜底限流Redis 故障时 fail-close
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
@@ -78,6 +81,7 @@ func RegisterAuthRoutes(
// 需要认证的当前用户信息
authenticated := v1.Group("")
authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(servermiddleware.BackendModeUserGuard(settingService))
{
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 撤销所有会话(需要认证)

View File

@@ -29,6 +29,7 @@ func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
c.Next()
}),
redisClient,
nil,
)
return router

View File

@@ -3,6 +3,7 @@ package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -12,6 +13,7 @@ func RegisterSoraClientRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
settingService *service.SettingService,
) {
if h.SoraClient == nil {
return
@@ -19,6 +21,7 @@ func RegisterSoraClientRoutes(
authenticated := v1.Group("/sora")
authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(middleware.BackendModeUserGuard(settingService))
{
authenticated.POST("/generate", h.SoraClient.Generate)
authenticated.GET("/generations", h.SoraClient.ListGenerations)

View File

@@ -3,6 +3,7 @@ package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -12,9 +13,11 @@ func RegisterUserRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
settingService *service.SettingService,
) {
authenticated := v1.Group("")
authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(middleware.BackendModeUserGuard(settingService))
{
// 用户接口
user := authenticated.Group("/user")

View File

@@ -3,6 +3,7 @@ package service
import (
"encoding/json"
"errors"
"hash/fnv"
"reflect"
"sort"
@@ -412,6 +413,7 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
}
// Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整)
return nil
}
if len(rawMapping) == 0 {
@@ -521,16 +523,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
// 如果未配置 mapping返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string {
mappedModel, _ := a.ResolveMappedModel(requestedModel)
return mappedModel
}
// ResolveMappedModel 获取映射后的模型名,并返回是否命中了账号级映射。
// matched=true 表示命中了精确映射或通配符映射,即使映射结果与原模型名相同。
func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, matched bool) {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return requestedModel
return requestedModel, false
}
// 精确匹配优先
if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel
return mappedModel, true
}
// 通配符匹配(最长优先)
return matchWildcardMapping(mapping, requestedModel)
return matchWildcardMappingResult(mapping, requestedModel)
}
func (a *Account) GetBaseURL() string {
@@ -604,9 +613,7 @@ func matchWildcard(pattern, str string) bool {
return matchAntigravityWildcard(pattern, str)
}
// matchWildcardMapping 通配符映射匹配(最长优先)
// 如果没有匹配,返回原始字符串
func matchWildcardMapping(mapping map[string]string, requestedModel string) string {
func matchWildcardMappingResult(mapping map[string]string, requestedModel string) (string, bool) {
// 收集所有匹配的 pattern按长度降序排序最长优先
type patternMatch struct {
pattern string
@@ -621,7 +628,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
}
if len(matches) == 0 {
return requestedModel // 无匹配,返回原始模型名
return requestedModel, false // 无匹配,返回原始模型名
}
// 按 pattern 长度降序排序
@@ -632,7 +639,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
return matches[i].pattern < matches[j].pattern
})
return matches[0].target
return matches[0].target, true
}
func (a *Account) IsCustomErrorCodesEnabled() bool {
@@ -650,7 +657,7 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
// IsPoolMode 检查 API Key 账号是否启用池模式。
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
func (a *Account) IsPoolMode() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
if !a.IsAPIKeyOrBedrock() || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["pool_mode"]; ok {
@@ -764,6 +771,19 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
return false
}
func (a *Account) IsBedrock() bool {
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
}
func (a *Account) IsBedrockAPIKey() bool {
return a.IsBedrock() && a.GetCredential("auth_mode") == "apikey"
}
// IsAPIKeyOrBedrock 返回账号类型是否支持配额和池模式等特性
func (a *Account) IsAPIKeyOrBedrock() bool {
return a.Type == AccountTypeAPIKey || a.Type == AccountTypeBedrock
}
func (a *Account) IsOpenAI() bool {
return a.Platform == PlatformOpenAI
}
@@ -1260,6 +1280,240 @@ func (a *Account) getExtraTime(key string) time.Time {
return time.Time{}
}
// getExtraString 从 Extra 中读取指定 key 的字符串值
func (a *Account) getExtraString(key string) string {
if a.Extra == nil {
return ""
}
if v, ok := a.Extra[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// getExtraInt 从 Extra 中读取指定 key 的 int 值
func (a *Account) getExtraInt(key string) int {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra[key]; ok {
return int(parseExtraFloat64(v))
}
return 0
}
// GetQuotaDailyResetMode 获取日额度重置模式:"rolling"(默认)或 "fixed"
func (a *Account) GetQuotaDailyResetMode() string {
if m := a.getExtraString("quota_daily_reset_mode"); m == "fixed" {
return "fixed"
}
return "rolling"
}
// GetQuotaDailyResetHour 获取固定重置的小时0-23默认 0
func (a *Account) GetQuotaDailyResetHour() int {
return a.getExtraInt("quota_daily_reset_hour")
}
// GetQuotaWeeklyResetMode 获取周额度重置模式:"rolling"(默认)或 "fixed"
func (a *Account) GetQuotaWeeklyResetMode() string {
if m := a.getExtraString("quota_weekly_reset_mode"); m == "fixed" {
return "fixed"
}
return "rolling"
}
// GetQuotaWeeklyResetDay 获取固定重置的星期几0=周日, 1=周一, ..., 6=周六),默认 1周一
func (a *Account) GetQuotaWeeklyResetDay() int {
if a.Extra == nil {
return 1
}
if _, ok := a.Extra["quota_weekly_reset_day"]; !ok {
return 1
}
return a.getExtraInt("quota_weekly_reset_day")
}
// GetQuotaWeeklyResetHour 获取周配额固定重置的小时0-23默认 0
func (a *Account) GetQuotaWeeklyResetHour() int {
return a.getExtraInt("quota_weekly_reset_hour")
}
// GetQuotaResetTimezone 获取固定重置的时区名IANA默认 "UTC"
func (a *Account) GetQuotaResetTimezone() string {
if tz := a.getExtraString("quota_reset_timezone"); tz != "" {
return tz
}
return "UTC"
}
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
t := after.In(tz)
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
if !after.Before(today) {
return today.AddDate(0, 0, 1)
}
return today
}
// lastFixedDailyReset 计算 now 之前最近一次的每日固定重置时间点
func lastFixedDailyReset(hour int, tz *time.Location, now time.Time) time.Time {
t := now.In(tz)
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
if now.Before(today) {
return today.AddDate(0, 0, -1)
}
return today
}
// nextFixedWeeklyReset 计算在 after 之后的下一个每周固定重置时间点
// day: 0=Sunday, 1=Monday, ..., 6=Saturday
func nextFixedWeeklyReset(day, hour int, tz *time.Location, after time.Time) time.Time {
t := after.In(tz)
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
currentDay := int(todayReset.Weekday())
daysForward := (day - currentDay + 7) % 7
if daysForward == 0 && !after.Before(todayReset) {
daysForward = 7
}
return todayReset.AddDate(0, 0, daysForward)
}
// lastFixedWeeklyReset 计算 now 之前最近一次的每周固定重置时间点
func lastFixedWeeklyReset(day, hour int, tz *time.Location, now time.Time) time.Time {
t := now.In(tz)
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
currentDay := int(todayReset.Weekday())
daysBack := (currentDay - day + 7) % 7
if daysBack == 0 && now.Before(todayReset) {
daysBack = 7
}
return todayReset.AddDate(0, 0, -daysBack)
}
// isFixedDailyPeriodExpired 检查日配额是否在固定时间模式下已过期
func (a *Account) isFixedDailyPeriodExpired(periodStart time.Time) bool {
if periodStart.IsZero() {
return true
}
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
if err != nil {
tz = time.UTC
}
lastReset := lastFixedDailyReset(a.GetQuotaDailyResetHour(), tz, time.Now())
return periodStart.Before(lastReset)
}
// isFixedWeeklyPeriodExpired 检查周配额是否在固定时间模式下已过期
func (a *Account) isFixedWeeklyPeriodExpired(periodStart time.Time) bool {
if periodStart.IsZero() {
return true
}
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
if err != nil {
tz = time.UTC
}
lastReset := lastFixedWeeklyReset(a.GetQuotaWeeklyResetDay(), a.GetQuotaWeeklyResetHour(), tz, time.Now())
return periodStart.Before(lastReset)
}
// ComputeQuotaResetAt 根据当前配置计算并填充 extra 中的 quota_daily_reset_at / quota_weekly_reset_at
// 在保存账号配置时调用
func ComputeQuotaResetAt(extra map[string]any) {
now := time.Now()
tzName, _ := extra["quota_reset_timezone"].(string)
if tzName == "" {
tzName = "UTC"
}
tz, err := time.LoadLocation(tzName)
if err != nil {
tz = time.UTC
}
// 日配额固定重置时间
if mode, _ := extra["quota_daily_reset_mode"].(string); mode == "fixed" {
hour := int(parseExtraFloat64(extra["quota_daily_reset_hour"]))
if hour < 0 || hour > 23 {
hour = 0
}
resetAt := nextFixedDailyReset(hour, tz, now)
extra["quota_daily_reset_at"] = resetAt.UTC().Format(time.RFC3339)
} else {
delete(extra, "quota_daily_reset_at")
}
// 周配额固定重置时间
if mode, _ := extra["quota_weekly_reset_mode"].(string); mode == "fixed" {
day := 1 // 默认周一
if d, ok := extra["quota_weekly_reset_day"]; ok {
day = int(parseExtraFloat64(d))
}
if day < 0 || day > 6 {
day = 1
}
hour := int(parseExtraFloat64(extra["quota_weekly_reset_hour"]))
if hour < 0 || hour > 23 {
hour = 0
}
resetAt := nextFixedWeeklyReset(day, hour, tz, now)
extra["quota_weekly_reset_at"] = resetAt.UTC().Format(time.RFC3339)
} else {
delete(extra, "quota_weekly_reset_at")
}
}
// ValidateQuotaResetConfig 校验配额固定重置时间配置的合法性
func ValidateQuotaResetConfig(extra map[string]any) error {
if extra == nil {
return nil
}
// 校验时区
if tz, ok := extra["quota_reset_timezone"].(string); ok && tz != "" {
if _, err := time.LoadLocation(tz); err != nil {
return errors.New("invalid quota_reset_timezone: must be a valid IANA timezone name")
}
}
// 日配额重置模式
if mode, ok := extra["quota_daily_reset_mode"].(string); ok {
if mode != "rolling" && mode != "fixed" {
return errors.New("quota_daily_reset_mode must be 'rolling' or 'fixed'")
}
}
// 日配额重置小时
if v, ok := extra["quota_daily_reset_hour"]; ok {
hour := int(parseExtraFloat64(v))
if hour < 0 || hour > 23 {
return errors.New("quota_daily_reset_hour must be between 0 and 23")
}
}
// 周配额重置模式
if mode, ok := extra["quota_weekly_reset_mode"].(string); ok {
if mode != "rolling" && mode != "fixed" {
return errors.New("quota_weekly_reset_mode must be 'rolling' or 'fixed'")
}
}
// 周配额重置星期几
if v, ok := extra["quota_weekly_reset_day"]; ok {
day := int(parseExtraFloat64(v))
if day < 0 || day > 6 {
return errors.New("quota_weekly_reset_day must be between 0 (Sunday) and 6 (Saturday)")
}
}
// 周配额重置小时
if v, ok := extra["quota_weekly_reset_hour"]; ok {
hour := int(parseExtraFloat64(v))
if hour < 0 || hour > 23 {
return errors.New("quota_weekly_reset_hour must be between 0 and 23")
}
}
return nil
}
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
func (a *Account) HasAnyQuotaLimit() bool {
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
@@ -1282,14 +1536,26 @@ func (a *Account) IsQuotaExceeded() bool {
// 日额度(周期过期视为未超限,下次 increment 会重置)
if limit := a.GetQuotaDailyLimit(); limit > 0 {
start := a.getExtraTime("quota_daily_start")
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit {
var expired bool
if a.GetQuotaDailyResetMode() == "fixed" {
expired = a.isFixedDailyPeriodExpired(start)
} else {
expired = isPeriodExpired(start, 24*time.Hour)
}
if !expired && a.GetQuotaDailyUsed() >= limit {
return true
}
}
// 周额度
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
start := a.getExtraTime("quota_weekly_start")
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit {
var expired bool
if a.GetQuotaWeeklyResetMode() == "fixed" {
expired = a.isFixedWeeklyPeriodExpired(start)
} else {
expired = isPeriodExpired(start, 7*24*time.Hour)
}
if !expired && a.GetQuotaWeeklyUsed() >= limit {
return true
}
}

View File

@@ -0,0 +1,516 @@
//go:build unit
package service
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// nextFixedDailyReset
// ---------------------------------------------------------------------------
func TestNextFixedDailyReset_BeforeResetHour(t *testing.T) {
tz := time.UTC
// 2026-03-14 06:00 UTC, reset hour = 9
after := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
got := nextFixedDailyReset(9, tz, after)
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_AtResetHour(t *testing.T) {
tz := time.UTC
// Exactly at reset hour → should return tomorrow
after := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
got := nextFixedDailyReset(9, tz, after)
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_AfterResetHour(t *testing.T) {
tz := time.UTC
// After reset hour → should return tomorrow
after := time.Date(2026, 3, 14, 15, 30, 0, 0, tz)
got := nextFixedDailyReset(9, tz, after)
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_MidnightReset(t *testing.T) {
tz := time.UTC
// Reset at hour 0 (midnight), currently 23:59
after := time.Date(2026, 3, 14, 23, 59, 0, 0, tz)
got := nextFixedDailyReset(0, tz, after)
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_NonUTCTimezone(t *testing.T) {
tz, err := time.LoadLocation("Asia/Shanghai")
require.NoError(t, err)
// 2026-03-14 07:00 UTC = 2026-03-14 15:00 CST, reset hour = 9 (CST)
after := time.Date(2026, 3, 14, 7, 0, 0, 0, time.UTC)
got := nextFixedDailyReset(9, tz, after)
// Already past 9:00 CST today → tomorrow 9:00 CST = 2026-03-15 01:00 UTC
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// lastFixedDailyReset
// ---------------------------------------------------------------------------
func TestLastFixedDailyReset_BeforeResetHour(t *testing.T) {
tz := time.UTC
now := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
got := lastFixedDailyReset(9, tz, now)
// Before today's 9:00 → yesterday 9:00
want := time.Date(2026, 3, 13, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedDailyReset_AtResetHour(t *testing.T) {
tz := time.UTC
now := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
got := lastFixedDailyReset(9, tz, now)
// At exactly 9:00 → today 9:00
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedDailyReset_AfterResetHour(t *testing.T) {
tz := time.UTC
now := time.Date(2026, 3, 14, 15, 0, 0, 0, tz)
got := lastFixedDailyReset(9, tz, now)
// After 9:00 → today 9:00
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// nextFixedWeeklyReset
// ---------------------------------------------------------------------------
func TestNextFixedWeeklyReset_TargetDayAhead(t *testing.T) {
tz := time.UTC
// 2026-03-14 is Saturday (day=6), target = Monday (day=1), hour = 9
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday = 2026-03-16
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayToday_BeforeHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, before 9:00
after := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Today at 9:00
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayToday_AtHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday, target = Monday, hour = 9, exactly at 9:00
after := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday at 9:00
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayToday_AfterHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday, target = Monday, hour = 9, after 9:00
after := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday at 9:00
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayPast(t *testing.T) {
tz := time.UTC
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
after := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday = 2026-03-23
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_Sunday(t *testing.T) {
tz := time.UTC
// 2026-03-14 is Saturday (day=6), target = Sunday (day=0)
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
got := nextFixedWeeklyReset(0, 0, tz, after)
// Next Sunday = 2026-03-15
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// lastFixedWeeklyReset
// ---------------------------------------------------------------------------
func TestLastFixedWeeklyReset_SameDay_AfterHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, now = 15:00
now := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
got := lastFixedWeeklyReset(1, 9, tz, now)
// Today at 9:00
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedWeeklyReset_SameDay_BeforeHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday, target = Monday, hour = 9, now = 06:00
now := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
got := lastFixedWeeklyReset(1, 9, tz, now)
// Last Monday at 9:00 = 2026-03-09
want := time.Date(2026, 3, 9, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedWeeklyReset_DifferentDay(t *testing.T) {
tz := time.UTC
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
now := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
got := lastFixedWeeklyReset(1, 9, tz, now)
// Last Monday = 2026-03-16
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// isFixedDailyPeriodExpired
// ---------------------------------------------------------------------------
func TestIsFixedDailyPeriodExpired_ZeroPeriodStart(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
assert.True(t, a.isFixedDailyPeriodExpired(time.Time{}))
}
func TestIsFixedDailyPeriodExpired_NotExpired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started after the most recent reset → not expired
// (This test uses a time very close to "now", which is after the last reset)
periodStart := time.Now().Add(-1 * time.Minute)
assert.False(t, a.isFixedDailyPeriodExpired(periodStart))
}
func TestIsFixedDailyPeriodExpired_Expired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started 3 days ago → definitely expired
periodStart := time.Now().Add(-72 * time.Hour)
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
}
func TestIsFixedDailyPeriodExpired_InvalidTimezone(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "Invalid/Timezone",
}}
// Invalid timezone falls back to UTC
periodStart := time.Now().Add(-72 * time.Hour)
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
}
// ---------------------------------------------------------------------------
// isFixedWeeklyPeriodExpired
// ---------------------------------------------------------------------------
func TestIsFixedWeeklyPeriodExpired_ZeroPeriodStart(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
assert.True(t, a.isFixedWeeklyPeriodExpired(time.Time{}))
}
func TestIsFixedWeeklyPeriodExpired_NotExpired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started 1 minute ago → not expired
periodStart := time.Now().Add(-1 * time.Minute)
assert.False(t, a.isFixedWeeklyPeriodExpired(periodStart))
}
func TestIsFixedWeeklyPeriodExpired_Expired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started 10 days ago → definitely expired
periodStart := time.Now().Add(-240 * time.Hour)
assert.True(t, a.isFixedWeeklyPeriodExpired(periodStart))
}
// ---------------------------------------------------------------------------
// ValidateQuotaResetConfig
// ---------------------------------------------------------------------------
func TestValidateQuotaResetConfig_NilExtra(t *testing.T) {
assert.NoError(t, ValidateQuotaResetConfig(nil))
}
func TestValidateQuotaResetConfig_EmptyExtra(t *testing.T) {
assert.NoError(t, ValidateQuotaResetConfig(map[string]any{}))
}
func TestValidateQuotaResetConfig_ValidFixed(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(0),
"quota_reset_timezone": "Asia/Shanghai",
}
assert.NoError(t, ValidateQuotaResetConfig(extra))
}
func TestValidateQuotaResetConfig_ValidRolling(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "rolling",
"quota_weekly_reset_mode": "rolling",
}
assert.NoError(t, ValidateQuotaResetConfig(extra))
}
func TestValidateQuotaResetConfig_InvalidTimezone(t *testing.T) {
extra := map[string]any{
"quota_reset_timezone": "Not/A/Timezone",
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_reset_timezone")
}
func TestValidateQuotaResetConfig_InvalidDailyMode(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "invalid",
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_daily_reset_mode")
}
func TestValidateQuotaResetConfig_InvalidDailyHour_TooHigh(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_hour": float64(24),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
}
func TestValidateQuotaResetConfig_InvalidDailyHour_Negative(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_hour": float64(-1),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
}
func TestValidateQuotaResetConfig_InvalidWeeklyMode(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_mode": "unknown",
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_mode")
}
func TestValidateQuotaResetConfig_InvalidWeeklyDay_TooHigh(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_day": float64(7),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
}
func TestValidateQuotaResetConfig_InvalidWeeklyDay_Negative(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_day": float64(-1),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
}
func TestValidateQuotaResetConfig_InvalidWeeklyHour(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_hour": float64(25),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_hour")
}
func TestValidateQuotaResetConfig_BoundaryValues(t *testing.T) {
// All boundary values should be valid
extra := map[string]any{
"quota_daily_reset_hour": float64(23),
"quota_weekly_reset_day": float64(0), // Sunday
"quota_weekly_reset_hour": float64(0),
"quota_reset_timezone": "UTC",
}
assert.NoError(t, ValidateQuotaResetConfig(extra))
extra2 := map[string]any{
"quota_daily_reset_hour": float64(0),
"quota_weekly_reset_day": float64(6), // Saturday
"quota_weekly_reset_hour": float64(23),
}
assert.NoError(t, ValidateQuotaResetConfig(extra2))
}
// ---------------------------------------------------------------------------
// ComputeQuotaResetAt
// ---------------------------------------------------------------------------
func TestComputeQuotaResetAt_RollingMode_NoResetAt(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "rolling",
"quota_weekly_reset_mode": "rolling",
}
ComputeQuotaResetAt(extra)
_, hasDailyResetAt := extra["quota_daily_reset_at"]
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
assert.False(t, hasDailyResetAt, "rolling mode should not set quota_daily_reset_at")
assert.False(t, hasWeeklyResetAt, "rolling mode should not set quota_weekly_reset_at")
}
func TestComputeQuotaResetAt_RollingMode_ClearsExistingResetAt(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "rolling",
"quota_weekly_reset_mode": "rolling",
"quota_daily_reset_at": "2026-03-14T09:00:00Z",
"quota_weekly_reset_at": "2026-03-16T09:00:00Z",
}
ComputeQuotaResetAt(extra)
_, hasDailyResetAt := extra["quota_daily_reset_at"]
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
assert.False(t, hasDailyResetAt, "rolling mode should remove quota_daily_reset_at")
assert.False(t, hasWeeklyResetAt, "rolling mode should remove quota_weekly_reset_at")
}
func TestComputeQuotaResetAt_FixedDaily_SetsResetAt(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok, "quota_daily_reset_at should be set")
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Reset time should be in the future
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
// Reset hour should be 9 UTC
assert.Equal(t, 9, resetAt.UTC().Hour())
}
func TestComputeQuotaResetAt_FixedWeekly_SetsResetAt(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1), // Monday
"quota_weekly_reset_hour": float64(0),
"quota_reset_timezone": "UTC",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_weekly_reset_at"].(string)
require.True(t, ok, "quota_weekly_reset_at should be set")
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Reset time should be in the future
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
// Reset day should be Monday
assert.Equal(t, time.Monday, resetAt.UTC().Weekday())
}
func TestComputeQuotaResetAt_FixedDaily_WithTimezone(t *testing.T) {
tz, err := time.LoadLocation("Asia/Shanghai")
require.NoError(t, err)
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "Asia/Shanghai",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok)
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// In Shanghai timezone, the hour should be 9
assert.Equal(t, 9, resetAt.In(tz).Hour())
}
func TestComputeQuotaResetAt_DefaultTimezone(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(12),
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok)
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Default timezone is UTC
assert.Equal(t, 12, resetAt.UTC().Hour())
}
func TestComputeQuotaResetAt_InvalidHour_ClampedToZero(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(99),
"quota_reset_timezone": "UTC",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok)
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Invalid hour → clamped to 0
assert.Equal(t, 0, resetAt.UTC().Hour())
}

View File

@@ -207,14 +207,14 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
testModelID = claude.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
// API Key 账号测试连接时也需要应用通配符模型映射。
if account.Type == "apikey" {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
}
testModelID = account.GetMappedModel(testModelID)
}
// Bedrock accounts use a separate test path
if account.IsBedrock() {
return s.testBedrockAccountConnection(c, ctx, account, testModelID)
}
// Determine authentication method and API URL
@@ -312,6 +312,109 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.processClaudeStream(c, resp.Body)
}
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
region := bedrockRuntimeRegion(account)
resolvedModelID, ok := ResolveBedrockModelID(account, testModelID)
if !ok {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Bedrock model: %s", testModelID))
}
testModelID = resolvedModelID
// Set SSE headers (test UI expects SSE)
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
// Create a minimal Bedrock-compatible payload (no stream, no cache_control)
bedrockPayload := map[string]any{
"anthropic_version": "bedrock-2023-05-31",
"messages": []map[string]any{
{
"role": "user",
"content": []map[string]any{
{
"type": "text",
"text": "hi",
},
},
},
},
"max_tokens": 256,
"temperature": 1,
}
bedrockBody, _ := json.Marshal(bedrockPayload)
// Use non-streaming endpoint (response is standard Claude JSON)
apiURL := BuildBedrockURL(region, testModelID, false)
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bedrockBody))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
req.Header.Set("Content-Type", "application/json")
// Sign or set auth based on account type
if account.IsBedrockAPIKey() {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
req.Header.Set("Authorization", "Bearer "+apiKey)
} else {
signer, err := NewBedrockSignerFromAccount(account)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Bedrock signer: %s", err.Error()))
}
if err := signer.SignRequest(ctx, req, bedrockBody); err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to sign request: %s", err.Error()))
}
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, false)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}
// Bedrock non-streaming response is standard Claude JSON, extract the text
var result struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
}
if err := json.Unmarshal(body, &result); err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error()))
}
text := ""
if len(result.Content) > 0 {
text = result.Content[0].Text
}
if text == "" {
text = "(empty response)"
}
s.sendEvent(c, TestEvent{Type: "content", Text: text})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
// testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context()

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"log"
"log/slog"
"math/rand/v2"
"net/http"
"strings"
@@ -44,9 +45,12 @@ type UsageLogRepository interface {
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
@@ -99,6 +103,7 @@ type antigravityUsageCache struct {
const (
apiCacheTTL = 3 * time.Minute
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL429 等错误缓存 1 分钟
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL可恢复错误
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
windowStatsCacheTTL = 1 * time.Minute
openAIProbeCacheTTL = 10 * time.Minute
@@ -107,11 +112,12 @@ const (
// UsageCache 封装账户使用量相关的缓存
type UsageCache struct {
apiCache sync.Map // accountID -> *apiUsageCache
windowStatsCache sync.Map // accountID -> *windowStatsCache
antigravityCache sync.Map // accountID -> *antigravityUsageCache
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存
openAIProbeCache sync.Map // accountID -> time.Time
apiCache sync.Map // accountID -> *apiUsageCache
windowStatsCache sync.Map // accountID -> *windowStatsCache
antigravityCache sync.Map // accountID -> *antigravityUsageCache
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存Anthropic
antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
openAIProbeCache sync.Map // accountID -> time.Time
}
// NewUsageCache 创建 UsageCache 实例
@@ -148,6 +154,18 @@ type AntigravityModelQuota struct {
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
}
// AntigravityModelDetail Antigravity 单个模型的详细能力信息
type AntigravityModelDetail struct {
DisplayName string `json:"display_name,omitempty"`
SupportsImages *bool `json:"supports_images,omitempty"`
SupportsThinking *bool `json:"supports_thinking,omitempty"`
ThinkingBudget *int `json:"thinking_budget,omitempty"`
Recommended *bool `json:"recommended,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"`
}
// UsageInfo 账号使用量信息
type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
@@ -163,6 +181,33 @@ type UsageInfo struct {
// Antigravity 多模型配额
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
// Antigravity 账号级信息
SubscriptionTier string `json:"subscription_tier,omitempty"` // 归一化订阅等级: FREE/PRO/ULTRA/UNKNOWN
SubscriptionTierRaw string `json:"subscription_tier_raw,omitempty"` // 上游原始订阅等级名称
// Antigravity 模型详细能力信息(与 antigravity_quota 同 key
AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"`
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
// Antigravity 账号是否被上游禁止 (HTTP 403)
IsForbidden bool `json:"is_forbidden,omitempty"`
ForbiddenReason string `json:"forbidden_reason,omitempty"`
ForbiddenType string `json:"forbidden_type,omitempty"` // "validation" / "violation" / "forbidden"
ValidationURL string `json:"validation_url,omitempty"` // 验证/申诉链接
// 状态标记(从 ForbiddenType / HTTP 错误码推导)
NeedsVerify bool `json:"needs_verify,omitempty"` // 需要人工验证forbidden_type=validation
IsBanned bool `json:"is_banned,omitempty"` // 账号被封forbidden_type=violation
NeedsReauth bool `json:"needs_reauth,omitempty"` // token 失效需重新授权401
// 错误码机器可读forbidden / unauthenticated / rate_limited / network_error
ErrorCode string `json:"error_code,omitempty"`
// 获取 usage 时的错误信息(降级返回,而非 500
Error string `json:"error,omitempty"`
}
// ClaudeUsageResponse Anthropic API返回的usage结构
@@ -647,34 +692,157 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
return &UsageInfo{UpdatedAt: &now}, nil
}
// 1. 检查缓存10 分钟)
// 1. 检查缓存
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
// 重新计算 RemainingSeconds
usage := cache.usageInfo
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
if cache, ok := cached.(*antigravityUsageCache); ok {
ttl := antigravityCacheTTL(cache.usageInfo)
if time.Since(cache.timestamp) < ttl {
usage := cache.usageInfo
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
}
return usage, nil
}
return usage, nil
}
}
// 2. 获取代理 URL
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account)
// 2. singleflight 防止并发击穿
flightKey := fmt.Sprintf("ag-usage:%d", account.ID)
result, flightErr, _ := s.cache.antigravityFlight.Do(flightKey, func() (any, error) {
// 再次检查缓存(等待期间可能已被填充)
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
if cache, ok := cached.(*antigravityUsageCache); ok {
ttl := antigravityCacheTTL(cache.usageInfo)
if time.Since(cache.timestamp) < ttl {
usage := cache.usageInfo
// 重新计算 RemainingSeconds避免返回过时的剩余秒数
recalcAntigravityRemainingSeconds(usage)
return usage, nil
}
}
}
// 3. 调用 API 获取额度
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL)
if err != nil {
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
}
// 使用独立 context避免调用方 cancel 导致所有共享 flight 的请求失败
fetchCtx, fetchCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer fetchCancel()
// 4. 缓存结果
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
usageInfo: result.UsageInfo,
timestamp: time.Now(),
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(fetchCtx, account)
fetchResult, err := s.antigravityQuotaFetcher.FetchQuota(fetchCtx, account, proxyURL)
if err != nil {
degraded := buildAntigravityDegradedUsage(err)
enrichUsageWithAccountError(degraded, account)
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
usageInfo: degraded,
timestamp: time.Now(),
})
return degraded, nil
}
enrichUsageWithAccountError(fetchResult.UsageInfo, account)
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
usageInfo: fetchResult.UsageInfo,
timestamp: time.Now(),
})
return fetchResult.UsageInfo, nil
})
return result.UsageInfo, nil
if flightErr != nil {
return nil, flightErr
}
usage, ok := result.(*UsageInfo)
if !ok || usage == nil {
now := time.Now()
return &UsageInfo{UpdatedAt: &now}, nil
}
return usage, nil
}
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
func recalcAntigravityRemainingSeconds(info *UsageInfo) {
if info == nil {
return
}
if info.FiveHour != nil && info.FiveHour.ResetsAt != nil {
remaining := int(time.Until(*info.FiveHour.ResetsAt).Seconds())
if remaining < 0 {
remaining = 0
}
info.FiveHour.RemainingSeconds = remaining
}
}
// antigravityCacheTTL 根据 UsageInfo 内容决定缓存 TTL
// 403 forbidden 状态稳定缓存与成功相同3 分钟);
// 其他错误401/网络)可能快速恢复,缓存 1 分钟。
func antigravityCacheTTL(info *UsageInfo) time.Duration {
if info == nil {
return antigravityErrorTTL
}
if info.IsForbidden {
return apiCacheTTL // 封号/验证状态不会很快变
}
if info.ErrorCode != "" || info.Error != "" {
return antigravityErrorTTL
}
return apiCacheTTL
}
// buildAntigravityDegradedUsage 从 FetchQuota 错误构建降级 UsageInfo
func buildAntigravityDegradedUsage(err error) *UsageInfo {
now := time.Now()
errMsg := fmt.Sprintf("usage API error: %v", err)
slog.Warn("antigravity usage fetch failed, returning degraded response", "error", err)
info := &UsageInfo{
UpdatedAt: &now,
Error: errMsg,
}
// 从错误信息推断 error_code 和状态标记
// 错误格式来自 antigravity/client.go: "fetchAvailableModels 失败 (HTTP %d): ..."
errStr := err.Error()
switch {
case strings.Contains(errStr, "HTTP 401") ||
strings.Contains(errStr, "UNAUTHENTICATED") ||
strings.Contains(errStr, "invalid_grant"):
info.ErrorCode = errorCodeUnauthenticated
info.NeedsReauth = true
case strings.Contains(errStr, "HTTP 429"):
info.ErrorCode = errorCodeRateLimited
default:
info.ErrorCode = errorCodeNetworkError
}
return info
}
// enrichUsageWithAccountError 结合账号错误状态修正 UsageInfo
// 场景 1成功路径FetchAvailableModels 正常返回,但账号已因 403 被标记为 error
//
// 需要在正常 usage 数据上附加 forbidden/validation 信息。
//
// 场景 2降级路径被封号的账号 OAuth token 失效FetchAvailableModels 返回 401
//
// 降级逻辑设置了 needs_reauth但账号实际是 403 封号/需验证,需覆盖为正确状态。
func enrichUsageWithAccountError(info *UsageInfo, account *Account) {
if info == nil || account == nil || account.Status != StatusError {
return
}
msg := strings.ToLower(account.ErrorMessage)
if !strings.Contains(msg, "403") && !strings.Contains(msg, "forbidden") &&
!strings.Contains(msg, "violation") && !strings.Contains(msg, "validation") {
return
}
fbType := classifyForbiddenType(account.ErrorMessage)
info.IsForbidden = true
info.ForbiddenType = fbType
info.ForbiddenReason = account.ErrorMessage
info.NeedsVerify = fbType == forbiddenTypeValidation
info.IsBanned = fbType == forbiddenTypeViolation
info.ValidationURL = extractValidationURL(account.ErrorMessage)
info.ErrorCode = errorCodeForbidden
info.NeedsReauth = false
}
// addWindowStats 为 usage 数据添加窗口期统计

View File

@@ -43,12 +43,13 @@ func TestMatchWildcard(t *testing.T) {
}
}
func TestMatchWildcardMapping(t *testing.T) {
func TestMatchWildcardMappingResult(t *testing.T) {
tests := []struct {
name string
mapping map[string]string
requestedModel string
expected string
matched bool
}{
// 精确匹配优先于通配符
{
@@ -59,6 +60,7 @@ func TestMatchWildcardMapping(t *testing.T) {
},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5-exact",
matched: true,
},
// 最长通配符优先
@@ -71,6 +73,7 @@ func TestMatchWildcardMapping(t *testing.T) {
},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-series",
matched: true,
},
// 单个通配符
@@ -81,6 +84,7 @@ func TestMatchWildcardMapping(t *testing.T) {
},
requestedModel: "claude-opus-4-5",
expected: "claude-mapped",
matched: true,
},
// 无匹配返回原始模型
@@ -91,6 +95,7 @@ func TestMatchWildcardMapping(t *testing.T) {
},
requestedModel: "gemini-3-flash",
expected: "gemini-3-flash",
matched: false,
},
// 空映射返回原始模型
@@ -99,6 +104,7 @@ func TestMatchWildcardMapping(t *testing.T) {
mapping: map[string]string{},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5",
matched: false,
},
// Gemini 模型映射
@@ -110,14 +116,15 @@ func TestMatchWildcardMapping(t *testing.T) {
},
requestedModel: "gemini-3-flash-preview",
expected: "gemini-3-pro-high",
matched: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchWildcardMapping(tt.mapping, tt.requestedModel)
if result != tt.expected {
t.Errorf("matchWildcardMapping(%v, %q) = %q, want %q", tt.mapping, tt.requestedModel, result, tt.expected)
result, matched := matchWildcardMappingResult(tt.mapping, tt.requestedModel)
if result != tt.expected || matched != tt.matched {
t.Errorf("matchWildcardMappingResult(%v, %q) = (%q, %v), want (%q, %v)", tt.mapping, tt.requestedModel, result, matched, tt.expected, tt.matched)
}
})
}
@@ -268,6 +275,69 @@ func TestAccountGetMappedModel(t *testing.T) {
}
}
func TestAccountResolveMappedModel(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
requestedModel string
expectedModel string
expectedMatch bool
}{
{
name: "no mapping reports unmatched",
credentials: nil,
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
{
name: "exact passthrough mapping still counts as matched",
credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: true,
},
{
name: "wildcard passthrough mapping still counts as matched",
credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-*": "gpt-5.4",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: true,
},
{
name: "missing mapping reports unmatched",
credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.2": "gpt-5.2",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Credentials: tt.credentials,
}
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
if mappedModel != tt.expectedModel || matched != tt.expectedMatch {
t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch)
}
})
}
}
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,

View File

@@ -42,6 +42,9 @@ type AdminService interface {
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// API Key management (admin)
@@ -57,6 +60,8 @@ type AdminService interface {
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*Account, error)
SetAccountError(ctx context.Context, id int64, errorMsg string) error
// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号 privacy_mode未设置则尝试关闭训练数据共享并持久化。
EnsureOpenAIPrivacy(ctx context.Context, account *Account) string
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
@@ -433,6 +438,7 @@ type adminServiceImpl struct {
settingService *SettingService
defaultSubAssigner DefaultSubscriptionAssigner
userSubRepo UserSubscriptionRepository
privacyClientFactory PrivacyClientFactory
}
type userGroupRateBatchReader interface {
@@ -461,6 +467,7 @@ func NewAdminService(
settingService *SettingService,
defaultSubAssigner DefaultSubscriptionAssigner,
userSubRepo UserSubscriptionRepository,
privacyClientFactory PrivacyClientFactory,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
@@ -479,6 +486,7 @@ func NewAdminService(
settingService: settingService,
defaultSubAssigner: defaultSubAssigner,
userSubRepo: userSubRepo,
privacyClientFactory: privacyClientFactory,
}
}
@@ -824,7 +832,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
subscriptionType = SubscriptionTypeStandard
}
// 限额字段:0 和 nil 都表示"无限制"
// 限额字段:nil/负数 表示"无限制"0 表示"不允许用量",正数表示具体限额
dailyLimit := normalizeLimit(input.DailyLimitUSD)
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
@@ -936,9 +944,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return group, nil
}
// normalizeLimit 将 0 或负数转换为 nil表示无限制
// normalizeLimit 将负数转换为 nil表示无限制0 保留(表示限额为零)
func normalizeLimit(limit *float64) *float64 {
if limit == nil || *limit <= 0 {
if limit == nil || *limit < 0 {
return nil
}
return limit
@@ -1050,16 +1058,11 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.SubscriptionType != "" {
group.SubscriptionType = input.SubscriptionType
}
// 限额字段:0 和 nil 都表示"无限制",正数表示具体限额
if input.DailyLimitUSD != nil {
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
}
if input.WeeklyLimitUSD != nil {
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
}
if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
}
// 限额字段:nil/负数 表示"无限制"0 表示"不允许用量",正数表示具体限额
// 前端始终发送这三个字段,无需 nil 守卫
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
// 图片生成计费配置:负数表示清除(使用默认价格)
if input.ImagePrice1K != nil {
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
@@ -1244,6 +1247,27 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
return keys, result.Total, nil
}
func (s *adminServiceImpl) GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) {
if s.userGroupRateRepo == nil {
return nil, nil
}
return s.userGroupRateRepo.GetByGroupID(ctx, groupID)
}
func (s *adminServiceImpl) ClearGroupRateMultipliers(ctx context.Context, groupID int64) error {
if s.userGroupRateRepo == nil {
return nil
}
return s.userGroupRateRepo.DeleteByGroupID(ctx, groupID)
}
func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
if s.userGroupRateRepo == nil {
return nil
}
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
}
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return s.groupRepo.UpdateSortOrders(ctx, updates)
}
@@ -1433,6 +1457,13 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
Status: StatusActive,
Schedulable: true,
}
// 预计算固定时间重置的下次重置时间
if account.Extra != nil {
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
return nil, err
}
ComputeQuotaResetAt(account.Extra)
}
if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
expiresAt := time.Unix(*input.ExpiresAt, 0)
account.ExpiresAt = &expiresAt
@@ -1506,6 +1537,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
}
}
account.Extra = input.Extra
// 校验并预计算固定时间重置的下次重置时间
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
return nil, err
}
ComputeQuotaResetAt(account.Extra)
}
if input.ProxyID != nil {
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
@@ -2502,3 +2538,39 @@ func (e *MixedChannelError) Error() string {
func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error {
return s.accountRepo.ResetQuotaUsed(ctx, id)
}
// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号是否已设置 privacy_mode
// 未设置则调用 disableOpenAITraining 并持久化到 Extra返回设置的 mode 值。
func (s *adminServiceImpl) EnsureOpenAIPrivacy(ctx context.Context, account *Account) string {
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return ""
}
if s.privacyClientFactory == nil {
return ""
}
if account.Extra != nil {
if _, ok := account.Extra["privacy_mode"]; ok {
return ""
}
}
token, _ := account.Credentials["access_token"].(string)
if token == "" {
return ""
}
var proxyURL string
if account.ProxyID != nil {
if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil {
proxyURL = p.URL()
}
}
mode := disableOpenAITraining(ctx, s.privacyClientFactory, token, proxyURL)
if mode == "" {
return ""
}
_ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode})
return mode
}

View File

@@ -0,0 +1,176 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
// userGroupRateRepoStubForGroupRate implements UserGroupRateRepository for group rate tests.
type userGroupRateRepoStubForGroupRate struct {
getByGroupIDData map[int64][]UserGroupRateEntry
getByGroupIDErr error
deletedGroupIDs []int64
deleteByGroupErr error
syncedGroupID int64
syncedEntries []GroupRateMultiplierInput
syncGroupErr error
}
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
panic("unexpected GetByUserID call")
}
func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context, _, _ int64) (*float64, error) {
panic("unexpected GetByUserAndGroup call")
}
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
if s.getByGroupIDErr != nil {
return nil, s.getByGroupIDErr
}
return s.getByGroupIDData[groupID], nil
}
func (s *userGroupRateRepoStubForGroupRate) SyncUserGroupRates(_ context.Context, _ int64, _ map[int64]*float64) error {
panic("unexpected SyncUserGroupRates call")
}
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
s.syncedGroupID = groupID
s.syncedEntries = entries
return s.syncGroupErr
}
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
return s.deleteByGroupErr
}
func (s *userGroupRateRepoStubForGroupRate) DeleteByUserID(_ context.Context, _ int64) error {
panic("unexpected DeleteByUserID call")
}
func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
t.Run("returns entries for group", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDData: map[int64][]UserGroupRateEntry{
10: {
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
},
},
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
require.NoError(t, err)
require.Len(t, entries, 2)
require.Equal(t, int64(1), entries[0].UserID)
require.Equal(t, "alice", entries[0].UserName)
require.Equal(t, 1.5, entries[0].RateMultiplier)
require.Equal(t, int64(2), entries[1].UserID)
require.Equal(t, 0.8, entries[1].RateMultiplier)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
svc := &adminServiceImpl{userGroupRateRepo: nil}
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
require.NoError(t, err)
require.Nil(t, entries)
})
t.Run("returns empty slice for group with no entries", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDData: map[int64][]UserGroupRateEntry{},
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
entries, err := svc.GetGroupRateMultipliers(context.Background(), 99)
require.NoError(t, err)
require.Nil(t, entries)
})
t.Run("propagates repo error", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDErr: errors.New("db error"),
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
_, err := svc.GetGroupRateMultipliers(context.Background(), 10)
require.Error(t, err)
require.Contains(t, err.Error(), "db error")
})
}
func TestAdminService_ClearGroupRateMultipliers(t *testing.T) {
t.Run("deletes by group ID", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
require.NoError(t, err)
require.Equal(t, []int64{42}, repo.deletedGroupIDs)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
svc := &adminServiceImpl{userGroupRateRepo: nil}
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
require.NoError(t, err)
})
t.Run("propagates repo error", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
deleteByGroupErr: errors.New("delete failed"),
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
require.Error(t, err)
require.Contains(t, err.Error(), "delete failed")
})
}
func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
t.Run("syncs entries to repo", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
entries := []GroupRateMultiplierInput{
{UserID: 1, RateMultiplier: 1.5},
{UserID: 2, RateMultiplier: 0.8},
}
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, entries)
require.NoError(t, err)
require.Equal(t, int64(10), repo.syncedGroupID)
require.Equal(t, entries, repo.syncedEntries)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
svc := &adminServiceImpl{userGroupRateRepo: nil}
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, nil)
require.NoError(t, err)
})
t.Run("propagates repo error", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
syncGroupErr: errors.New("sync failed"),
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, []GroupRateMultiplierInput{
{UserID: 1, RateMultiplier: 1.0},
})
require.Error(t, err)
require.Contains(t, err.Error(), "sync failed")
})
}

View File

@@ -68,7 +68,15 @@ func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context
panic("unexpected SyncUserGroupRates call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, _ int64) ([]UserGroupRateEntry, error) {
panic("unexpected GetByGroupID call")
}
func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.Context, _ int64, _ []GroupRateMultiplierInput) error {
panic("unexpected SyncGroupRateMultipliers call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
panic("unexpected DeleteByGroupID call")
}

View File

@@ -2,12 +2,29 @@ package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"regexp"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
const (
forbiddenTypeValidation = "validation"
forbiddenTypeViolation = "violation"
forbiddenTypeForbidden = "forbidden"
// 机器可读的错误码
errorCodeForbidden = "forbidden"
errorCodeUnauthenticated = "unauthenticated"
errorCodeRateLimited = "rate_limited"
errorCodeNetworkError = "network_error"
)
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
type AntigravityQuotaFetcher struct {
proxyRepo ProxyRepository
@@ -40,11 +57,32 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
// 调用 API 获取配额
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
if err != nil {
// 403 Forbidden: 不报错,返回 is_forbidden 标记
var forbiddenErr *antigravity.ForbiddenError
if errors.As(err, &forbiddenErr) {
now := time.Now()
fbType := classifyForbiddenType(forbiddenErr.Body)
return &QuotaResult{
UsageInfo: &UsageInfo{
UpdatedAt: &now,
IsForbidden: true,
ForbiddenReason: forbiddenErr.Body,
ForbiddenType: fbType,
ValidationURL: extractValidationURL(forbiddenErr.Body),
NeedsVerify: fbType == forbiddenTypeValidation,
IsBanned: fbType == forbiddenTypeViolation,
ErrorCode: errorCodeForbidden,
},
}, nil
}
return nil, err
}
// 调用 LoadCodeAssist 获取订阅等级(非关键路径,失败不影响主流程)
tierRaw, tierNormalized := f.fetchSubscriptionTier(ctx, client, accessToken)
// 转换为 UsageInfo
usageInfo := f.buildUsageInfo(modelsResp)
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized)
return &QuotaResult{
UsageInfo: usageInfo,
@@ -52,15 +90,52 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
}, nil
}
// buildUsageInfo 将 API 响应转换为 UsageInfo
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo {
now := time.Now()
info := &UsageInfo{
UpdatedAt: &now,
AntigravityQuota: make(map[string]*AntigravityModelQuota),
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string) {
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
if err != nil {
slog.Warn("failed to fetch subscription tier", "error", err)
return "", ""
}
if loadResp == nil {
return "", ""
}
// 遍历所有模型,填充 AntigravityQuota
raw = loadResp.GetTier() // 已有方法paidTier > currentTier
normalized = normalizeTier(raw)
return raw, normalized
}
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
func normalizeTier(raw string) string {
if raw == "" {
return ""
}
lower := strings.ToLower(raw)
switch {
case strings.Contains(lower, "ultra"):
return "ULTRA"
case strings.Contains(lower, "pro"):
return "PRO"
case strings.Contains(lower, "free"):
return "FREE"
default:
return "UNKNOWN"
}
}
// buildUsageInfo 将 API 响应转换为 UsageInfo
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string) *UsageInfo {
now := time.Now()
info := &UsageInfo{
UpdatedAt: &now,
AntigravityQuota: make(map[string]*AntigravityModelQuota),
AntigravityQuotaDetails: make(map[string]*AntigravityModelDetail),
SubscriptionTier: tierNormalized,
SubscriptionTierRaw: tierRaw,
}
// 遍历所有模型,填充 AntigravityQuota 和 AntigravityQuotaDetails
for modelName, modelInfo := range modelsResp.Models {
if modelInfo.QuotaInfo == nil {
continue
@@ -73,6 +148,27 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv
Utilization: utilization,
ResetTime: modelInfo.QuotaInfo.ResetTime,
}
// 填充模型详细能力信息
detail := &AntigravityModelDetail{
DisplayName: modelInfo.DisplayName,
SupportsImages: modelInfo.SupportsImages,
SupportsThinking: modelInfo.SupportsThinking,
ThinkingBudget: modelInfo.ThinkingBudget,
Recommended: modelInfo.Recommended,
MaxTokens: modelInfo.MaxTokens,
MaxOutputTokens: modelInfo.MaxOutputTokens,
SupportedMimeTypes: modelInfo.SupportedMimeTypes,
}
info.AntigravityQuotaDetails[modelName] = detail
}
// 废弃模型转发规则
if len(modelsResp.DeprecatedModelIDs) > 0 {
info.ModelForwardingRules = make(map[string]string, len(modelsResp.DeprecatedModelIDs))
for oldID, deprecated := range modelsResp.DeprecatedModelIDs {
info.ModelForwardingRules[oldID] = deprecated.NewModelID
}
}
// 同时设置 FiveHour 用于兼容展示(取主要模型)
@@ -108,3 +204,58 @@ func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Acco
}
return proxy.URL()
}
// classifyForbiddenType 根据 403 响应体判断禁止类型
func classifyForbiddenType(body string) string {
lower := strings.ToLower(body)
switch {
case strings.Contains(lower, "validation_required") ||
strings.Contains(lower, "verify your account") ||
strings.Contains(lower, "validation_url"):
return forbiddenTypeValidation
case strings.Contains(lower, "terms of service") ||
strings.Contains(lower, "violation"):
return forbiddenTypeViolation
default:
return forbiddenTypeForbidden
}
}
// urlPattern 用于从 403 响应体中提取 URL降级方案
var urlPattern = regexp.MustCompile(`https://[^\s"'\\]+`)
// extractValidationURL 从 403 响应 JSON 中提取验证/申诉链接
func extractValidationURL(body string) string {
// 1. 尝试结构化 JSON 提取: /error/details[*]/metadata/validation_url 或 appeal_url
var parsed struct {
Error struct {
Details []struct {
Metadata map[string]string `json:"metadata"`
} `json:"details"`
} `json:"error"`
}
if json.Unmarshal([]byte(body), &parsed) == nil {
for _, detail := range parsed.Error.Details {
if u := detail.Metadata["validation_url"]; u != "" {
return u
}
if u := detail.Metadata["appeal_url"]; u != "" {
return u
}
}
}
// 2. 降级:正则匹配 URL
lower := strings.ToLower(body)
if !strings.Contains(lower, "validation") &&
!strings.Contains(lower, "verify") &&
!strings.Contains(lower, "appeal") {
return ""
}
// 先解码常见转义再匹配
normalized := strings.ReplaceAll(body, `\u0026`, "&")
if m := urlPattern.FindString(normalized); m != "" {
return m
}
return ""
}

View File

@@ -0,0 +1,497 @@
//go:build unit
package service
import (
"errors"
"testing"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// ---------------------------------------------------------------------------
// normalizeTier
// ---------------------------------------------------------------------------
func TestNormalizeTier(t *testing.T) {
tests := []struct {
name string
raw string
expected string
}{
{name: "empty string", raw: "", expected: ""},
{name: "free-tier", raw: "free-tier", expected: "FREE"},
{name: "g1-pro-tier", raw: "g1-pro-tier", expected: "PRO"},
{name: "g1-ultra-tier", raw: "g1-ultra-tier", expected: "ULTRA"},
{name: "unknown-something", raw: "unknown-something", expected: "UNKNOWN"},
{name: "Google AI Pro contains pro keyword", raw: "Google AI Pro", expected: "PRO"},
{name: "case insensitive FREE", raw: "FREE-TIER", expected: "FREE"},
{name: "case insensitive Ultra", raw: "Ultra Plan", expected: "ULTRA"},
{name: "arbitrary unrecognized string", raw: "enterprise-custom", expected: "UNKNOWN"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeTier(tt.raw)
require.Equal(t, tt.expected, got, "normalizeTier(%q)", tt.raw)
})
}
}
// ---------------------------------------------------------------------------
// buildUsageInfo
// ---------------------------------------------------------------------------
func aqfBoolPtr(v bool) *bool { return &v }
func aqfIntPtr(v int) *int { return &v }
func TestBuildUsageInfo_BasicModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.75,
ResetTime: "2026-03-08T12:00:00Z",
},
DisplayName: "Claude Sonnet 4",
SupportsImages: aqfBoolPtr(true),
SupportsThinking: aqfBoolPtr(false),
ThinkingBudget: aqfIntPtr(0),
Recommended: aqfBoolPtr(true),
MaxTokens: aqfIntPtr(200000),
MaxOutputTokens: aqfIntPtr(16384),
SupportedMimeTypes: map[string]bool{
"image/png": true,
"image/jpeg": true,
},
},
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.50,
ResetTime: "2026-03-08T15:00:00Z",
},
DisplayName: "Gemini 2.5 Pro",
MaxTokens: aqfIntPtr(1000000),
MaxOutputTokens: aqfIntPtr(65536),
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO")
// 基本字段
require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set")
require.Equal(t, "PRO", info.SubscriptionTier)
require.Equal(t, "g1-pro-tier", info.SubscriptionTierRaw)
// AntigravityQuota
require.Len(t, info.AntigravityQuota, 2)
sonnetQuota := info.AntigravityQuota["claude-sonnet-4-20250514"]
require.NotNil(t, sonnetQuota)
require.Equal(t, 25, sonnetQuota.Utilization) // (1 - 0.75) * 100 = 25
require.Equal(t, "2026-03-08T12:00:00Z", sonnetQuota.ResetTime)
geminiQuota := info.AntigravityQuota["gemini-2.5-pro"]
require.NotNil(t, geminiQuota)
require.Equal(t, 50, geminiQuota.Utilization) // (1 - 0.50) * 100 = 50
require.Equal(t, "2026-03-08T15:00:00Z", geminiQuota.ResetTime)
// AntigravityQuotaDetails
require.Len(t, info.AntigravityQuotaDetails, 2)
sonnetDetail := info.AntigravityQuotaDetails["claude-sonnet-4-20250514"]
require.NotNil(t, sonnetDetail)
require.Equal(t, "Claude Sonnet 4", sonnetDetail.DisplayName)
require.Equal(t, aqfBoolPtr(true), sonnetDetail.SupportsImages)
require.Equal(t, aqfBoolPtr(false), sonnetDetail.SupportsThinking)
require.Equal(t, aqfIntPtr(0), sonnetDetail.ThinkingBudget)
require.Equal(t, aqfBoolPtr(true), sonnetDetail.Recommended)
require.Equal(t, aqfIntPtr(200000), sonnetDetail.MaxTokens)
require.Equal(t, aqfIntPtr(16384), sonnetDetail.MaxOutputTokens)
require.Equal(t, map[string]bool{"image/png": true, "image/jpeg": true}, sonnetDetail.SupportedMimeTypes)
geminiDetail := info.AntigravityQuotaDetails["gemini-2.5-pro"]
require.NotNil(t, geminiDetail)
require.Equal(t, "Gemini 2.5 Pro", geminiDetail.DisplayName)
require.Nil(t, geminiDetail.SupportsImages)
require.Nil(t, geminiDetail.SupportsThinking)
require.Equal(t, aqfIntPtr(1000000), geminiDetail.MaxTokens)
require.Equal(t, aqfIntPtr(65536), geminiDetail.MaxOutputTokens)
}
func TestBuildUsageInfo_DeprecatedModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 1.0,
},
},
},
DeprecatedModelIDs: map[string]antigravity.DeprecatedModelInfo{
"claude-3-sonnet-20240229": {NewModelID: "claude-sonnet-4-20250514"},
"claude-3-haiku-20240307": {NewModelID: "claude-haiku-3.5-latest"},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.Len(t, info.ModelForwardingRules, 2)
require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"])
require.Equal(t, "claude-haiku-3.5-latest", info.ModelForwardingRules["claude-3-haiku-20240307"])
}
func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"some-model": {
QuotaInfo: &antigravity.ModelQuotaInfo{RemainingFraction: 0.9},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models")
}
func TestBuildUsageInfo_EmptyModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.NotNil(t, info)
require.NotNil(t, info.AntigravityQuota)
require.Empty(t, info.AntigravityQuota)
require.NotNil(t, info.AntigravityQuotaDetails)
require.Empty(t, info.AntigravityQuotaDetails)
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
}
func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"model-without-quota": {
DisplayName: "No Quota Model",
// QuotaInfo is nil
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.NotNil(t, info)
require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped")
require.Empty(t, info.AntigravityQuotaDetails, "models with nil QuotaInfo should be skipped from details too")
}
func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// priorityModels = ["claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"]
// When the first priority model exists, it should be used for FiveHour
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.40,
ResetTime: "2026-03-08T18:00:00Z",
},
},
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.80,
ResetTime: "2026-03-08T12:00:00Z",
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists")
// claude-sonnet-4-20250514 is first in priority list, so it should be used
expectedUtilization := (1.0 - 0.80) * 100 // 20
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
require.NotNil(t, info.FiveHour.ResetsAt, "ResetsAt should be parsed from ResetTime")
}
func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// Only claude-sonnet-4 exists (second in priority list), not claude-sonnet-4-20250514
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.60,
ResetTime: "2026-03-08T14:00:00Z",
},
},
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.30,
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.NotNil(t, info.FiveHour)
expectedUtilization := (1.0 - 0.60) * 100 // 40
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
}
func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// Only gemini-2.5-pro exists (third in priority list)
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.30,
},
},
"other-model": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.90,
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.NotNil(t, info.FiveHour)
expectedUtilization := (1.0 - 0.30) * 100 // 70
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
}
func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// None of the priority models exist
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"some-other-model": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.50,
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
}
func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.50,
ResetTime: "", // empty reset time
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.NotNil(t, info.FiveHour)
require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty")
require.Equal(t, 0, info.FiveHour.RemainingSeconds)
}
func TestBuildUsageInfo_FullUtilization(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.0, // fully used
ResetTime: "2026-03-08T12:00:00Z",
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
require.NotNil(t, quota)
require.Equal(t, 100, quota.Utilization)
}
func TestBuildUsageInfo_ZeroUtilization(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 1.0, // fully available
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
require.NotNil(t, quota)
require.Equal(t, 0, quota.Utilization)
}
func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) {
// 模拟 FetchQuota 遇到 403 时的行为:
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true
forbiddenErr := &antigravity.ForbiddenError{
StatusCode: 403,
Body: "Access denied",
}
// 验证 ForbiddenError 满足 errors.As
var target *antigravity.ForbiddenError
require.True(t, errors.As(forbiddenErr, &target))
require.Equal(t, 403, target.StatusCode)
require.Equal(t, "Access denied", target.Body)
require.Contains(t, forbiddenErr.Error(), "403")
}
// ---------------------------------------------------------------------------
// classifyForbiddenType
// ---------------------------------------------------------------------------
func TestClassifyForbiddenType(t *testing.T) {
tests := []struct {
name string
body string
expected string
}{
{
name: "VALIDATION_REQUIRED keyword",
body: `{"error":{"message":"VALIDATION_REQUIRED"}}`,
expected: "validation",
},
{
name: "verify your account",
body: `Please verify your account to continue`,
expected: "validation",
},
{
name: "contains validation_url field",
body: `{"error":{"details":[{"metadata":{"validation_url":"https://..."}}]}}`,
expected: "validation",
},
{
name: "terms of service violation",
body: `Your account has been suspended for Terms of Service violation`,
expected: "violation",
},
{
name: "violation keyword",
body: `Account suspended due to policy violation`,
expected: "violation",
},
{
name: "generic 403",
body: `Access denied`,
expected: "forbidden",
},
{
name: "empty body",
body: "",
expected: "forbidden",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := classifyForbiddenType(tt.body)
require.Equal(t, tt.expected, got)
})
}
}
// ---------------------------------------------------------------------------
// extractValidationURL
// ---------------------------------------------------------------------------
func TestExtractValidationURL(t *testing.T) {
tests := []struct {
name string
body string
expected string
}{
{
name: "structured validation_url",
body: `{"error":{"details":[{"metadata":{"validation_url":"https://accounts.google.com/verify?token=abc"}}]}}`,
expected: "https://accounts.google.com/verify?token=abc",
},
{
name: "structured appeal_url",
body: `{"error":{"details":[{"metadata":{"appeal_url":"https://support.google.com/appeal/123"}}]}}`,
expected: "https://support.google.com/appeal/123",
},
{
name: "validation_url takes priority over appeal_url",
body: `{"error":{"details":[{"metadata":{"validation_url":"https://v.com","appeal_url":"https://a.com"}}]}}`,
expected: "https://v.com",
},
{
name: "fallback regex with verify keyword",
body: `Please verify your account at https://accounts.google.com/verify`,
expected: "https://accounts.google.com/verify",
},
{
name: "no URL in generic forbidden",
body: `Access denied`,
expected: "",
},
{
name: "empty body",
body: "",
expected: "",
},
{
name: "URL present but no validation keywords",
body: `Error at https://example.com/something`,
expected: "",
},
{
name: "unicode escaped ampersand",
body: `validation required: https://accounts.google.com/verify?a=1\u0026b=2`,
expected: "https://accounts.google.com/verify?a=1&b=2",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractValidationURL(tt.body)
require.Equal(t, tt.expected, got)
})
}
}

View File

@@ -22,8 +22,9 @@ const (
)
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
// A nil windowStart is treated as expired — no initialized window means any accumulated usage is stale.
func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool {
return windowStart != nil && time.Since(*windowStart) >= duration
return windowStart == nil || time.Since(*windowStart) >= duration
}
type APIKey struct {

View File

@@ -15,10 +15,10 @@ func TestIsWindowExpired(t *testing.T) {
want bool
}{
{
name: "nil window start",
name: "nil window start (treated as expired)",
start: nil,
duration: RateLimitWindow5h,
want: false,
want: true,
},
{
name: "active window (started 1h ago, 5h window)",
@@ -113,7 +113,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
want7d: 0,
},
{
name: "nil window starts return raw usage",
name: "nil window starts return 0 (stale usage reset)",
key: APIKey{
Usage5h: 5.0,
Usage1d: 10.0,
@@ -122,9 +122,9 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
Window1dStart: nil,
Window7dStart: nil,
},
want5h: 5.0,
want1d: 10.0,
want7d: 50.0,
want5h: 0,
want1d: 0,
want7d: 0,
},
{
name: "mixed: 5h expired, 1d active, 7d nil",
@@ -138,7 +138,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
},
want5h: 0,
want1d: 10.0,
want7d: 50.0,
want7d: 0,
},
{
name: "zero usage with active windows",
@@ -210,7 +210,7 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
want7d: 0,
},
{
name: "nil window starts return raw usage",
name: "nil window starts return 0 (stale usage reset)",
data: APIKeyRateLimitData{
Usage5h: 3.0,
Usage1d: 8.0,
@@ -219,9 +219,9 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
Window1dStart: nil,
Window7dStart: nil,
},
want5h: 3.0,
want1d: 8.0,
want7d: 40.0,
want5h: 0,
want1d: 0,
want7d: 0,
},
}

View File

@@ -1087,6 +1087,12 @@ type TokenPair struct {
ExpiresIn int `json:"expires_in"` // Access Token有效期
}
// TokenPairWithUser extends TokenPair with user role for backend mode checks
type TokenPairWithUser struct {
TokenPair
UserRole string
}
// GenerateTokenPair 生成Access Token和Refresh Token对
// familyID: 可选的Token家族ID用于Token轮转时保持家族关系
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
@@ -1168,7 +1174,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
// RefreshTokenPair 使用Refresh Token刷新Token对
// 实现Token轮转每次刷新都会生成新的Refresh Token旧Token立即失效
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) {
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPairWithUser, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, ErrRefreshTokenInvalid
@@ -1233,7 +1239,14 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
}
// 生成新的Token对保持同一个家族ID
return s.GenerateTokenPair(ctx, user, data.FamilyID)
pair, err := s.GenerateTokenPair(ctx, user, data.FamilyID)
if err != nil {
return nil, err
}
return &TokenPairWithUser{
TokenPair: *pair,
UserRole: user.Role,
}, nil
}
// RevokeRefreshToken 撤销单个Refresh Token

View File

@@ -0,0 +1,770 @@
package service
import (
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io"
"sort"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/robfig/cron/v3"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
const (
settingKeyBackupS3Config = "backup_s3_config"
settingKeyBackupSchedule = "backup_schedule"
settingKeyBackupRecords = "backup_records"
maxBackupRecords = 100
)
var (
ErrBackupS3NotConfigured = infraerrors.BadRequest("BACKUP_S3_NOT_CONFIGURED", "backup S3 storage is not configured")
ErrBackupNotFound = infraerrors.NotFound("BACKUP_NOT_FOUND", "backup record not found")
ErrBackupInProgress = infraerrors.Conflict("BACKUP_IN_PROGRESS", "a backup is already in progress")
ErrRestoreInProgress = infraerrors.Conflict("RESTORE_IN_PROGRESS", "a restore is already in progress")
ErrBackupRecordsCorrupt = infraerrors.InternalServer("BACKUP_RECORDS_CORRUPT", "backup records data is corrupted")
ErrBackupS3ConfigCorrupt = infraerrors.InternalServer("BACKUP_S3_CONFIG_CORRUPT", "backup S3 config data is corrupted")
)
// ─── 接口定义 ───
// DBDumper abstracts database dump/restore operations
type DBDumper interface {
Dump(ctx context.Context) (io.ReadCloser, error)
Restore(ctx context.Context, data io.Reader) error
}
// BackupObjectStore abstracts object storage for backup files
type BackupObjectStore interface {
Upload(ctx context.Context, key string, body io.Reader, contentType string) (sizeBytes int64, err error)
Download(ctx context.Context, key string) (io.ReadCloser, error)
Delete(ctx context.Context, key string) error
PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error)
HeadBucket(ctx context.Context) error
}
// BackupObjectStoreFactory creates an object store from S3 config
type BackupObjectStoreFactory func(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error)
// ─── 数据模型 ───
// BackupS3Config S3 兼容存储配置(支持 Cloudflare R2
type BackupS3Config struct {
Endpoint string `json:"endpoint"` // e.g. https://<account_id>.r2.cloudflarestorage.com
Region string `json:"region"` // R2 用 "auto"
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key,omitempty"` //nolint:revive // field name follows AWS convention
Prefix string `json:"prefix"` // S3 key 前缀,如 "backups/"
ForcePathStyle bool `json:"force_path_style"`
}
// IsConfigured 检查必要字段是否已配置
func (c *BackupS3Config) IsConfigured() bool {
return c.Bucket != "" && c.AccessKeyID != "" && c.SecretAccessKey != ""
}
// BackupScheduleConfig 定时备份配置
type BackupScheduleConfig struct {
Enabled bool `json:"enabled"`
CronExpr string `json:"cron_expr"` // cron 表达式,如 "0 2 * * *" 每天凌晨2点
RetainDays int `json:"retain_days"` // 备份文件过期天数默认140=不自动清理
RetainCount int `json:"retain_count"` // 最多保留份数0=不限制
}
// BackupRecord 备份记录
type BackupRecord struct {
ID string `json:"id"`
Status string `json:"status"` // pending, running, completed, failed
BackupType string `json:"backup_type"` // postgres
FileName string `json:"file_name"`
S3Key string `json:"s3_key"`
SizeBytes int64 `json:"size_bytes"`
TriggeredBy string `json:"triggered_by"` // manual, scheduled
ErrorMsg string `json:"error_message,omitempty"`
StartedAt string `json:"started_at"`
FinishedAt string `json:"finished_at,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
}
// BackupService 数据库备份恢复服务
type BackupService struct {
settingRepo SettingRepository
dbCfg *config.DatabaseConfig
encryptor SecretEncryptor
storeFactory BackupObjectStoreFactory
dumper DBDumper
mu sync.Mutex
store BackupObjectStore
s3Cfg *BackupS3Config
backingUp bool
restoring bool
recordsMu sync.Mutex // 保护 records 的 load/save 操作
cronMu sync.Mutex
cronSched *cron.Cron
cronEntryID cron.EntryID
}
func NewBackupService(
settingRepo SettingRepository,
cfg *config.Config,
encryptor SecretEncryptor,
storeFactory BackupObjectStoreFactory,
dumper DBDumper,
) *BackupService {
return &BackupService{
settingRepo: settingRepo,
dbCfg: &cfg.Database,
encryptor: encryptor,
storeFactory: storeFactory,
dumper: dumper,
}
}
// Start 启动定时备份调度器
func (s *BackupService) Start() {
s.cronSched = cron.New()
s.cronSched.Start()
// 加载已有的定时配置
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
schedule, err := s.GetSchedule(ctx)
if err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 加载定时备份配置失败: %v", err)
return
}
if schedule.Enabled && schedule.CronExpr != "" {
if err := s.applyCronSchedule(schedule); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 应用定时备份配置失败: %v", err)
}
}
}
// Stop 停止定时备份
func (s *BackupService) Stop() {
s.cronMu.Lock()
defer s.cronMu.Unlock()
if s.cronSched != nil {
s.cronSched.Stop()
}
}
// ─── S3 配置管理 ───
func (s *BackupService) GetS3Config(ctx context.Context) (*BackupS3Config, error) {
cfg, err := s.loadS3Config(ctx)
if err != nil {
return nil, err
}
if cfg == nil {
return &BackupS3Config{}, nil
}
// 脱敏返回
cfg.SecretAccessKey = ""
return cfg, nil
}
func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) (*BackupS3Config, error) {
// 如果没提供 secret保留原有值
if cfg.SecretAccessKey == "" {
old, _ := s.loadS3Config(ctx)
if old != nil {
cfg.SecretAccessKey = old.SecretAccessKey
}
} else {
// 加密 SecretAccessKey
encrypted, err := s.encryptor.Encrypt(cfg.SecretAccessKey)
if err != nil {
return nil, fmt.Errorf("encrypt secret: %w", err)
}
cfg.SecretAccessKey = encrypted
}
data, err := json.Marshal(cfg)
if err != nil {
return nil, fmt.Errorf("marshal s3 config: %w", err)
}
if err := s.settingRepo.Set(ctx, settingKeyBackupS3Config, string(data)); err != nil {
return nil, fmt.Errorf("save s3 config: %w", err)
}
// 清除缓存的 S3 客户端
s.mu.Lock()
s.store = nil
s.s3Cfg = nil
s.mu.Unlock()
cfg.SecretAccessKey = ""
return &cfg, nil
}
func (s *BackupService) TestS3Connection(ctx context.Context, cfg BackupS3Config) error {
// 如果没提供 secret用已保存的
if cfg.SecretAccessKey == "" {
old, _ := s.loadS3Config(ctx)
if old != nil {
cfg.SecretAccessKey = old.SecretAccessKey
}
}
if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
return fmt.Errorf("incomplete S3 config: bucket, access_key_id, secret_access_key are required")
}
store, err := s.storeFactory(ctx, &cfg)
if err != nil {
return err
}
return store.HeadBucket(ctx)
}
// ─── 定时备份管理 ───
func (s *BackupService) GetSchedule(ctx context.Context) (*BackupScheduleConfig, error) {
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupSchedule)
if err != nil || raw == "" {
return &BackupScheduleConfig{}, nil
}
var cfg BackupScheduleConfig
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
return &BackupScheduleConfig{}, nil
}
return &cfg, nil
}
func (s *BackupService) UpdateSchedule(ctx context.Context, cfg BackupScheduleConfig) (*BackupScheduleConfig, error) {
if cfg.Enabled && cfg.CronExpr == "" {
return nil, infraerrors.BadRequest("INVALID_CRON", "cron expression is required when schedule is enabled")
}
// 验证 cron 表达式
if cfg.CronExpr != "" {
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
if _, err := parser.Parse(cfg.CronExpr); err != nil {
return nil, infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("invalid cron expression: %v", err))
}
}
data, err := json.Marshal(cfg)
if err != nil {
return nil, fmt.Errorf("marshal schedule config: %w", err)
}
if err := s.settingRepo.Set(ctx, settingKeyBackupSchedule, string(data)); err != nil {
return nil, fmt.Errorf("save schedule config: %w", err)
}
// 应用或停止定时任务
if cfg.Enabled {
if err := s.applyCronSchedule(&cfg); err != nil {
return nil, err
}
} else {
s.removeCronSchedule()
}
return &cfg, nil
}
func (s *BackupService) applyCronSchedule(cfg *BackupScheduleConfig) error {
s.cronMu.Lock()
defer s.cronMu.Unlock()
if s.cronSched == nil {
return fmt.Errorf("cron scheduler not initialized")
}
// 移除旧任务
if s.cronEntryID != 0 {
s.cronSched.Remove(s.cronEntryID)
s.cronEntryID = 0
}
entryID, err := s.cronSched.AddFunc(cfg.CronExpr, func() {
s.runScheduledBackup()
})
if err != nil {
return infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("failed to schedule: %v", err))
}
s.cronEntryID = entryID
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已启用: %s", cfg.CronExpr)
return nil
}
func (s *BackupService) removeCronSchedule() {
s.cronMu.Lock()
defer s.cronMu.Unlock()
if s.cronSched != nil && s.cronEntryID != 0 {
s.cronSched.Remove(s.cronEntryID)
s.cronEntryID = 0
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已停用")
}
}
func (s *BackupService) runScheduledBackup() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
// 读取定时备份配置中的过期天数
schedule, _ := s.GetSchedule(ctx)
expireDays := 14 // 默认14天过期
if schedule != nil && schedule.RetainDays > 0 {
expireDays = schedule.RetainDays
}
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
if err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
return
}
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
// 清理过期备份(复用已加载的 schedule
if schedule == nil {
return
}
if err := s.cleanupOldBackups(ctx, schedule); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 清理过期备份失败: %v", err)
}
}
// ─── 备份/恢复核心 ───
// CreateBackup 创建全量数据库备份并上传到 S3流式处理
// expireDays: 备份过期天数0=永不过期默认14天
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
s.mu.Lock()
if s.backingUp {
s.mu.Unlock()
return nil, ErrBackupInProgress
}
s.backingUp = true
s.mu.Unlock()
defer func() {
s.mu.Lock()
s.backingUp = false
s.mu.Unlock()
}()
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return nil, err
}
if s3Cfg == nil || !s3Cfg.IsConfigured() {
return nil, ErrBackupS3NotConfigured
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return nil, fmt.Errorf("init object store: %w", err)
}
now := time.Now()
backupID := uuid.New().String()[:8]
fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405"))
s3Key := s.buildS3Key(s3Cfg, fileName)
var expiresAt string
if expireDays > 0 {
expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339)
}
record := &BackupRecord{
ID: backupID,
Status: "running",
BackupType: "postgres",
FileName: fileName,
S3Key: s3Key,
TriggeredBy: triggeredBy,
StartedAt: now.Format(time.RFC3339),
ExpiresAt: expiresAt,
}
// 流式执行: pg_dump -> gzip -> S3 upload
dumpReader, err := s.dumper.Dump(ctx)
if err != nil {
record.Status = "failed"
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(ctx, record)
return record, fmt.Errorf("pg_dump: %w", err)
}
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
pr, pw := io.Pipe()
var gzipErr error
go func() {
gzWriter := gzip.NewWriter(pw)
_, gzipErr = io.Copy(gzWriter, dumpReader)
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
gzipErr = closeErr
}
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
gzipErr = closeErr
}
if gzipErr != nil {
_ = pw.CloseWithError(gzipErr)
} else {
_ = pw.Close()
}
}()
contentType := "application/gzip"
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
if err != nil {
record.Status = "failed"
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
if gzipErr != nil {
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
}
record.ErrorMsg = errMsg
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(ctx, record)
return record, fmt.Errorf("backup upload: %w", err)
}
record.SizeBytes = sizeBytes
record.Status = "completed"
record.FinishedAt = time.Now().Format(time.RFC3339)
if err := s.saveRecord(ctx, record); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err)
}
return record, nil
}
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
s.mu.Lock()
if s.restoring {
s.mu.Unlock()
return ErrRestoreInProgress
}
s.restoring = true
s.mu.Unlock()
defer func() {
s.mu.Lock()
s.restoring = false
s.mu.Unlock()
}()
record, err := s.GetBackupRecord(ctx, backupID)
if err != nil {
return err
}
if record.Status != "completed" {
return infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup")
}
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return err
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return fmt.Errorf("init object store: %w", err)
}
// 从 S3 流式下载
body, err := objectStore.Download(ctx, record.S3Key)
if err != nil {
return fmt.Errorf("S3 download failed: %w", err)
}
defer func() { _ = body.Close() }()
// 流式解压 gzip -> psql不将全部数据加载到内存
gzReader, err := gzip.NewReader(body)
if err != nil {
return fmt.Errorf("gzip reader: %w", err)
}
defer func() { _ = gzReader.Close() }()
// 流式恢复
if err := s.dumper.Restore(ctx, gzReader); err != nil {
return fmt.Errorf("pg restore: %w", err)
}
return nil
}
// ─── 备份记录管理 ───
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
records, err := s.loadRecords(ctx)
if err != nil {
return nil, err
}
// 倒序返回(最新在前)
sort.Slice(records, func(i, j int) bool {
return records[i].StartedAt > records[j].StartedAt
})
return records, nil
}
func (s *BackupService) GetBackupRecord(ctx context.Context, backupID string) (*BackupRecord, error) {
records, err := s.loadRecords(ctx)
if err != nil {
return nil, err
}
for i := range records {
if records[i].ID == backupID {
return &records[i], nil
}
}
return nil, ErrBackupNotFound
}
func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error {
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
records, err := s.loadRecordsLocked(ctx)
if err != nil {
return err
}
var found *BackupRecord
var remaining []BackupRecord
for i := range records {
if records[i].ID == backupID {
found = &records[i]
} else {
remaining = append(remaining, records[i])
}
}
if found == nil {
return ErrBackupNotFound
}
// 从 S3 删除
if found.S3Key != "" && found.Status == "completed" {
s3Cfg, err := s.loadS3Config(ctx)
if err == nil && s3Cfg != nil && s3Cfg.IsConfigured() {
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err == nil {
_ = objectStore.Delete(ctx, found.S3Key)
}
}
}
return s.saveRecordsLocked(ctx, remaining)
}
// GetBackupDownloadURL 获取备份文件预签名下载 URL
func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID string) (string, error) {
record, err := s.GetBackupRecord(ctx, backupID)
if err != nil {
return "", err
}
if record.Status != "completed" {
return "", infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "backup is not completed")
}
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return "", err
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return "", err
}
url, err := objectStore.PresignURL(ctx, record.S3Key, 1*time.Hour)
if err != nil {
return "", fmt.Errorf("presign url: %w", err)
}
return url, nil
}
// ─── 内部方法 ───
func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, error) {
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config)
if err != nil || raw == "" {
return nil, nil //nolint:nilnil // no config is a valid state
}
var cfg BackupS3Config
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
return nil, ErrBackupS3ConfigCorrupt
}
// 解密 SecretAccessKey
if cfg.SecretAccessKey != "" {
decrypted, err := s.encryptor.Decrypt(cfg.SecretAccessKey)
if err != nil {
// 兼容未加密的旧数据:如果解密失败,保持原值
logger.LegacyPrintf("service.backup", "[Backup] S3 SecretAccessKey 解密失败(可能是旧的未加密数据): %v", err)
} else {
cfg.SecretAccessKey = decrypted
}
}
return &cfg, nil
}
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.store != nil && s.s3Cfg != nil {
return s.store, nil
}
if cfg == nil {
return nil, ErrBackupS3NotConfigured
}
store, err := s.storeFactory(ctx, cfg)
if err != nil {
return nil, err
}
s.store = store
s.s3Cfg = cfg
return store, nil
}
func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string {
prefix := strings.TrimRight(cfg.Prefix, "/")
if prefix == "" {
prefix = "backups"
}
return fmt.Sprintf("%s/%s/%s", prefix, time.Now().Format("2006/01/02"), fileName)
}
// loadRecords 加载备份记录,区分"无数据"和"数据损坏"
func (s *BackupService) loadRecords(ctx context.Context) ([]BackupRecord, error) {
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
return s.loadRecordsLocked(ctx)
}
// loadRecordsLocked 在已持有 recordsMu 锁的情况下加载记录
func (s *BackupService) loadRecordsLocked(ctx context.Context) ([]BackupRecord, error) {
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupRecords)
if err != nil || raw == "" {
return nil, nil //nolint:nilnil // no records is a valid state
}
var records []BackupRecord
if err := json.Unmarshal([]byte(raw), &records); err != nil {
return nil, ErrBackupRecordsCorrupt
}
return records, nil
}
// saveRecordsLocked 在已持有 recordsMu 锁的情况下保存记录
func (s *BackupService) saveRecordsLocked(ctx context.Context, records []BackupRecord) error {
data, err := json.Marshal(records)
if err != nil {
return err
}
return s.settingRepo.Set(ctx, settingKeyBackupRecords, string(data))
}
// saveRecord 保存单条记录(带互斥锁保护)
func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) error {
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
records, _ := s.loadRecordsLocked(ctx)
// 更新已有记录或追加
found := false
for i := range records {
if records[i].ID == record.ID {
records[i] = *record
found = true
break
}
}
if !found {
records = append(records, *record)
}
// 限制记录数量
if len(records) > maxBackupRecords {
records = records[len(records)-maxBackupRecords:]
}
return s.saveRecordsLocked(ctx, records)
}
func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupScheduleConfig) error {
if schedule == nil {
return nil
}
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
records, err := s.loadRecordsLocked(ctx)
if err != nil {
return err
}
// 按时间倒序
sort.Slice(records, func(i, j int) bool {
return records[i].StartedAt > records[j].StartedAt
})
var toDelete []BackupRecord
var toKeep []BackupRecord
for i, r := range records {
shouldDelete := false
// 按保留份数清理
if schedule.RetainCount > 0 && i >= schedule.RetainCount {
shouldDelete = true
}
// 按保留天数清理
if schedule.RetainDays > 0 && r.StartedAt != "" {
startedAt, err := time.Parse(time.RFC3339, r.StartedAt)
if err == nil && time.Since(startedAt) > time.Duration(schedule.RetainDays)*24*time.Hour {
shouldDelete = true
}
}
if shouldDelete && r.Status == "completed" {
toDelete = append(toDelete, r)
} else {
toKeep = append(toKeep, r)
}
}
// 删除 S3 上的文件
for _, r := range toDelete {
if r.S3Key != "" {
_ = s.deleteS3Object(ctx, r.S3Key)
}
}
if len(toDelete) > 0 {
logger.LegacyPrintf("service.backup", "[Backup] 自动清理了 %d 个过期备份", len(toDelete))
return s.saveRecordsLocked(ctx, toKeep)
}
return nil
}
func (s *BackupService) deleteS3Object(ctx context.Context, key string) error {
s3Cfg, err := s.loadS3Config(ctx)
if err != nil || s3Cfg == nil {
return nil
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return err
}
return objectStore.Delete(ctx, key)
}

View File

@@ -0,0 +1,528 @@
//go:build unit
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// ─── Mocks ───
type mockSettingRepo struct {
mu sync.Mutex
data map[string]string
}
func newMockSettingRepo() *mockSettingRepo {
return &mockSettingRepo{data: make(map[string]string)}
}
func (m *mockSettingRepo) Get(_ context.Context, key string) (*Setting, error) {
m.mu.Lock()
defer m.mu.Unlock()
v, ok := m.data[key]
if !ok {
return nil, ErrSettingNotFound
}
return &Setting{Key: key, Value: v}, nil
}
func (m *mockSettingRepo) GetValue(_ context.Context, key string) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
v, ok := m.data[key]
if !ok {
return "", nil
}
return v, nil
}
func (m *mockSettingRepo) Set(_ context.Context, key, value string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.data[key] = value
return nil
}
func (m *mockSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
result := make(map[string]string)
for _, k := range keys {
if v, ok := m.data[k]; ok {
result[k] = v
}
}
return result, nil
}
func (m *mockSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
m.mu.Lock()
defer m.mu.Unlock()
for k, v := range settings {
m.data[k] = v
}
return nil
}
func (m *mockSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
result := make(map[string]string, len(m.data))
for k, v := range m.data {
result[k] = v
}
return result, nil
}
func (m *mockSettingRepo) Delete(_ context.Context, key string) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.data, key)
return nil
}
// plainEncryptor 仅做 base64-like 包装,用于测试
type plainEncryptor struct{}
func (e *plainEncryptor) Encrypt(plaintext string) (string, error) {
return "ENC:" + plaintext, nil
}
func (e *plainEncryptor) Decrypt(ciphertext string) (string, error) {
if strings.HasPrefix(ciphertext, "ENC:") {
return strings.TrimPrefix(ciphertext, "ENC:"), nil
}
return ciphertext, fmt.Errorf("not encrypted")
}
type mockDumper struct {
dumpData []byte
dumpErr error
restored []byte
restErr error
}
func (m *mockDumper) Dump(_ context.Context) (io.ReadCloser, error) {
if m.dumpErr != nil {
return nil, m.dumpErr
}
return io.NopCloser(bytes.NewReader(m.dumpData)), nil
}
func (m *mockDumper) Restore(_ context.Context, data io.Reader) error {
if m.restErr != nil {
return m.restErr
}
d, err := io.ReadAll(data)
if err != nil {
return err
}
m.restored = d
return nil
}
type mockObjectStore struct {
objects map[string][]byte
mu sync.Mutex
}
func newMockObjectStore() *mockObjectStore {
return &mockObjectStore{objects: make(map[string][]byte)}
}
func (m *mockObjectStore) Upload(_ context.Context, key string, body io.Reader, _ string) (int64, error) {
data, err := io.ReadAll(body)
if err != nil {
return 0, err
}
m.mu.Lock()
m.objects[key] = data
m.mu.Unlock()
return int64(len(data)), nil
}
func (m *mockObjectStore) Download(_ context.Context, key string) (io.ReadCloser, error) {
m.mu.Lock()
data, ok := m.objects[key]
m.mu.Unlock()
if !ok {
return nil, fmt.Errorf("not found: %s", key)
}
return io.NopCloser(bytes.NewReader(data)), nil
}
func (m *mockObjectStore) Delete(_ context.Context, key string) error {
m.mu.Lock()
delete(m.objects, key)
m.mu.Unlock()
return nil
}
func (m *mockObjectStore) PresignURL(_ context.Context, key string, _ time.Duration) (string, error) {
return "https://presigned.example.com/" + key, nil
}
func (m *mockObjectStore) HeadBucket(_ context.Context) error {
return nil
}
func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService {
cfg := &config.Config{
Database: config.DatabaseConfig{
Host: "localhost",
Port: 5432,
User: "test",
DBName: "testdb",
},
}
factory := func(_ context.Context, _ *BackupS3Config) (BackupObjectStore, error) {
return store, nil
}
return NewBackupService(repo, cfg, &plainEncryptor{}, factory, dumper)
}
func seedS3Config(t *testing.T, repo *mockSettingRepo) {
t.Helper()
cfg := BackupS3Config{
Bucket: "test-bucket",
AccessKeyID: "AKID",
SecretAccessKey: "ENC:secret123",
Prefix: "backups",
}
data, _ := json.Marshal(cfg)
require.NoError(t, repo.Set(context.Background(), settingKeyBackupS3Config, string(data)))
}
// ─── Tests ───
func TestBackupService_S3ConfigEncryption(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
// 保存配置 -> SecretAccessKey 应被加密
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
Bucket: "my-bucket",
AccessKeyID: "AKID",
SecretAccessKey: "my-secret",
Prefix: "backups",
})
require.NoError(t, err)
// 直接读取数据库中存储的值,应该是加密后的
raw, _ := repo.GetValue(context.Background(), settingKeyBackupS3Config)
var stored BackupS3Config
require.NoError(t, json.Unmarshal([]byte(raw), &stored))
require.Equal(t, "ENC:my-secret", stored.SecretAccessKey)
// 通过 GetS3Config 获取应该脱敏
cfg, err := svc.GetS3Config(context.Background())
require.NoError(t, err)
require.Empty(t, cfg.SecretAccessKey)
require.Equal(t, "my-bucket", cfg.Bucket)
// loadS3Config 内部应解密
internal, err := svc.loadS3Config(context.Background())
require.NoError(t, err)
require.Equal(t, "my-secret", internal.SecretAccessKey)
}
func TestBackupService_S3ConfigKeepExistingSecret(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
// 先保存一个有 secret 的配置
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
Bucket: "my-bucket",
AccessKeyID: "AKID",
SecretAccessKey: "original-secret",
})
require.NoError(t, err)
// 再更新时不提供 secret应保留原值
_, err = svc.UpdateS3Config(context.Background(), BackupS3Config{
Bucket: "my-bucket",
AccessKeyID: "AKID-NEW",
})
require.NoError(t, err)
internal, err := svc.loadS3Config(context.Background())
require.NoError(t, err)
require.Equal(t, "original-secret", internal.SecretAccessKey)
require.Equal(t, "AKID-NEW", internal.AccessKeyID)
}
func TestBackupService_SaveRecordConcurrency(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
var wg sync.WaitGroup
n := 20
wg.Add(n)
for i := 0; i < n; i++ {
go func(idx int) {
defer wg.Done()
record := &BackupRecord{
ID: fmt.Sprintf("rec-%d", idx),
Status: "completed",
StartedAt: time.Now().Format(time.RFC3339),
}
_ = svc.saveRecord(context.Background(), record)
}(i)
}
wg.Wait()
records, err := svc.loadRecords(context.Background())
require.NoError(t, err)
require.Len(t, records, n)
}
func TestBackupService_LoadRecords_Empty(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
records, err := svc.loadRecords(context.Background())
require.NoError(t, err)
require.Nil(t, records) // 无数据时返回 nil
}
func TestBackupService_LoadRecords_Corrupted(t *testing.T) {
repo := newMockSettingRepo()
_ = repo.Set(context.Background(), settingKeyBackupRecords, "not valid json{{{")
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
records, err := svc.loadRecords(context.Background())
require.Error(t, err) // 损坏数据应返回错误
require.Nil(t, records)
}
func TestBackupService_CreateBackup_Streaming(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
dumper := &mockDumper{dumpData: []byte(dumpContent)}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
require.Equal(t, "completed", record.Status)
require.Greater(t, record.SizeBytes, int64(0))
require.NotEmpty(t, record.S3Key)
// 验证 S3 上确实有文件
store.mu.Lock()
require.Len(t, store.objects, 1)
store.mu.Unlock()
}
func TestBackupService_CreateBackup_DumpFailure(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &mockDumper{dumpErr: fmt.Errorf("pg_dump failed")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.Error(t, err)
require.Equal(t, "failed", record.Status)
require.Contains(t, record.ErrorMsg, "pg_dump")
}
func TestBackupService_CreateBackup_NoS3Config(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
_, err := svc.CreateBackup(context.Background(), "manual", 14)
require.ErrorIs(t, err, ErrBackupS3NotConfigured)
}
func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
// 使用一个慢速 dumper 来模拟正在进行的备份
dumper := &mockDumper{dumpData: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
// 手动设置 backingUp 标志
svc.mu.Lock()
svc.backingUp = true
svc.mu.Unlock()
_, err := svc.CreateBackup(context.Background(), "manual", 14)
require.ErrorIs(t, err, ErrBackupInProgress)
}
func TestBackupService_RestoreBackup_Streaming(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
dumper := &mockDumper{dumpData: []byte(dumpContent)}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
// 先创建一个备份
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// 恢复
err = svc.RestoreBackup(context.Background(), record.ID)
require.NoError(t, err)
// 验证 psql 收到的数据是否与原始 dump 内容一致
require.Equal(t, dumpContent, string(dumper.restored))
}
func TestBackupService_RestoreBackup_NotCompleted(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
// 手动插入一条 failed 记录
_ = svc.saveRecord(context.Background(), &BackupRecord{
ID: "fail-1",
Status: "failed",
})
err := svc.RestoreBackup(context.Background(), "fail-1")
require.Error(t, err)
}
func TestBackupService_DeleteBackup(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumpContent := "data"
dumper := &mockDumper{dumpData: []byte(dumpContent)}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// S3 中应有文件
store.mu.Lock()
require.Len(t, store.objects, 1)
store.mu.Unlock()
// 删除
err = svc.DeleteBackup(context.Background(), record.ID)
require.NoError(t, err)
// S3 中文件应被删除
store.mu.Lock()
require.Len(t, store.objects, 0)
store.mu.Unlock()
// 记录应不存在
_, err = svc.GetBackupRecord(context.Background(), record.ID)
require.ErrorIs(t, err, ErrBackupNotFound)
}
func TestBackupService_GetDownloadURL(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &mockDumper{dumpData: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
url, err := svc.GetBackupDownloadURL(context.Background(), record.ID)
require.NoError(t, err)
require.Contains(t, url, "https://presigned.example.com/")
}
func TestBackupService_ListBackups_Sorted(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
now := time.Now()
for i := 0; i < 3; i++ {
_ = svc.saveRecord(context.Background(), &BackupRecord{
ID: fmt.Sprintf("rec-%d", i),
Status: "completed",
StartedAt: now.Add(time.Duration(i) * time.Hour).Format(time.RFC3339),
})
}
records, err := svc.ListBackups(context.Background())
require.NoError(t, err)
require.Len(t, records, 3)
// 最新在前
require.Equal(t, "rec-2", records[0].ID)
require.Equal(t, "rec-0", records[2].ID)
}
func TestBackupService_TestS3Connection(t *testing.T) {
repo := newMockSettingRepo()
store := newMockObjectStore()
svc := newTestBackupService(repo, &mockDumper{}, store)
err := svc.TestS3Connection(context.Background(), BackupS3Config{
Bucket: "test",
AccessKeyID: "ak",
SecretAccessKey: "sk",
})
require.NoError(t, err)
}
func TestBackupService_TestS3Connection_Incomplete(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
err := svc.TestS3Connection(context.Background(), BackupS3Config{
Bucket: "test",
})
require.Error(t, err)
require.Contains(t, err.Error(), "incomplete")
}
func TestBackupService_Schedule_CronValidation(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
svc.cronSched = nil // 未初始化 cron
// 启用但 cron 为空
_, err := svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
Enabled: true,
CronExpr: "",
})
require.Error(t, err)
// 无效的 cron 表达式
_, err = svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
Enabled: true,
CronExpr: "invalid",
})
require.Error(t, err)
}
func TestBackupService_LoadS3Config_Corrupted(t *testing.T) {
repo := newMockSettingRepo()
_ = repo.Set(context.Background(), settingKeyBackupS3Config, "not json!!!!")
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
cfg, err := svc.loadS3Config(context.Background())
require.Error(t, err)
require.Nil(t, cfg)
}

View File

@@ -0,0 +1,607 @@
package service
import (
"encoding/json"
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const defaultBedrockRegion = "us-east-1"
var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."}
// BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀
// 参考: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
func BedrockCrossRegionPrefix(region string) string {
switch {
case strings.HasPrefix(region, "us-gov"):
return "us-gov" // GovCloud 使用独立的 us-gov 前缀
case strings.HasPrefix(region, "us-"):
return "us"
case strings.HasPrefix(region, "eu-"):
return "eu"
case region == "ap-northeast-1":
return "jp" // 日本区域使用独立的 jp 前缀AWS 官方定义)
case region == "ap-southeast-2":
return "au" // 澳大利亚区域使用独立的 au 前缀AWS 官方定义)
case strings.HasPrefix(region, "ap-"):
return "apac" // 其余亚太区域使用通用 apac 前缀
case strings.HasPrefix(region, "ca-"):
return "us" // 加拿大区域使用 us 前缀的跨区域推理
case strings.HasPrefix(region, "sa-"):
return "us" // 南美区域使用 us 前缀的跨区域推理
default:
return "us"
}
}
// AdjustBedrockModelRegionPrefix 将模型 ID 的区域前缀替换为与当前 AWS Region 匹配的前缀
// 例如 region=eu-west-1 时,"us.anthropic.claude-opus-4-6-v1" → "eu.anthropic.claude-opus-4-6-v1"
// 特殊值 region="global" 强制使用 global. 前缀
func AdjustBedrockModelRegionPrefix(modelID, region string) string {
var targetPrefix string
if region == "global" {
targetPrefix = "global"
} else {
targetPrefix = BedrockCrossRegionPrefix(region)
}
for _, p := range bedrockCrossRegionPrefixes {
if strings.HasPrefix(modelID, p) {
if p == targetPrefix+"." {
return modelID // 前缀已匹配,无需替换
}
return targetPrefix + "." + modelID[len(p):]
}
}
// 模型 ID 没有已知区域前缀(如 "anthropic.claude-..."),不做修改
return modelID
}
func bedrockRuntimeRegion(account *Account) string {
if account == nil {
return defaultBedrockRegion
}
if region := account.GetCredential("aws_region"); region != "" {
return region
}
return defaultBedrockRegion
}
func shouldForceBedrockGlobal(account *Account) bool {
return account != nil && account.GetCredential("aws_force_global") == "true"
}
func isRegionalBedrockModelID(modelID string) bool {
for _, prefix := range bedrockCrossRegionPrefixes {
if strings.HasPrefix(modelID, prefix) {
return true
}
}
return false
}
func isLikelyBedrockModelID(modelID string) bool {
lower := strings.ToLower(strings.TrimSpace(modelID))
if lower == "" {
return false
}
if strings.HasPrefix(lower, "arn:") {
return true
}
for _, prefix := range []string{
"anthropic.",
"amazon.",
"meta.",
"mistral.",
"cohere.",
"ai21.",
"deepseek.",
"stability.",
"writer.",
"nova.",
} {
if strings.HasPrefix(lower, prefix) {
return true
}
}
return isRegionalBedrockModelID(lower)
}
func normalizeBedrockModelID(modelID string) (normalized string, shouldAdjustRegion bool, ok bool) {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
return "", false, false
}
if mapped, exists := domain.DefaultBedrockModelMapping[modelID]; exists {
return mapped, true, true
}
if isRegionalBedrockModelID(modelID) {
return modelID, true, true
}
if isLikelyBedrockModelID(modelID) {
return modelID, false, true
}
return "", false, false
}
// ResolveBedrockModelID resolves a requested Claude model into a Bedrock model ID.
// It applies account model_mapping first, then default Bedrock aliases, and finally
// adjusts Anthropic cross-region prefixes to match the account region.
func ResolveBedrockModelID(account *Account, requestedModel string) (string, bool) {
if account == nil {
return "", false
}
mappedModel := account.GetMappedModel(requestedModel)
modelID, shouldAdjustRegion, ok := normalizeBedrockModelID(mappedModel)
if !ok {
return "", false
}
if shouldAdjustRegion {
targetRegion := bedrockRuntimeRegion(account)
if shouldForceBedrockGlobal(account) {
targetRegion = "global"
}
modelID = AdjustBedrockModelRegionPrefix(modelID, targetRegion)
}
return modelID, true
}
// BuildBedrockURL 构建 Bedrock InvokeModel 的 URL
// stream=true 时使用 invoke-with-response-stream 端点
// modelID 中的特殊字符会被 URL 编码(与 litellm 的 urllib.parse.quote(safe="") 对齐)
func BuildBedrockURL(region, modelID string, stream bool) string {
if region == "" {
region = defaultBedrockRegion
}
encodedModelID := url.PathEscape(modelID)
// url.PathEscape 不编码冒号RFC 允许 path 中出现 ":"
// 但 AWS Bedrock 期望模型 ID 中的冒号被编码为 %3A
encodedModelID = strings.ReplaceAll(encodedModelID, ":", "%3A")
if stream {
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", region, encodedModelID)
}
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke", region, encodedModelID)
}
// PrepareBedrockRequestBody 处理请求体以适配 Bedrock API
// 1. 注入 anthropic_version
// 2. 注入 anthropic_beta从客户端 anthropic-beta 头解析)
// 3. 移除 Bedrock 不支持的字段model, stream, output_format, output_config
// 4. 移除工具定义中的 custom 字段Claude Code 会发送 custom: {defer_loading: true}
// 5. 清理 cache_control 中 Bedrock 不支持的字段scope, ttl
func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) {
betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID)
return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens)
}
// PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens.
func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) {
var err error
// 注入 anthropic_versionBedrock 要求)
body, err = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31")
if err != nil {
return nil, fmt.Errorf("inject anthropic_version: %w", err)
}
// 注入 anthropic_betaBedrock Invoke 通过请求体传递 beta 头,而非 HTTP 头)
// 1. 从客户端 anthropic-beta header 解析
// 2. 根据请求体内容自动补齐必要的 beta token
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() + _get_tool_search_beta_header_for_bedrock()
if len(betaTokens) > 0 {
body, err = sjson.SetBytes(body, "anthropic_beta", betaTokens)
if err != nil {
return nil, fmt.Errorf("inject anthropic_beta: %w", err)
}
}
// 移除 model 字段Bedrock 通过 URL 指定模型)
body, err = sjson.DeleteBytes(body, "model")
if err != nil {
return nil, fmt.Errorf("remove model field: %w", err)
}
// 移除 stream 字段Bedrock 通过不同端点控制流式,不接受请求体中的 stream 字段)
body, err = sjson.DeleteBytes(body, "stream")
if err != nil {
return nil, fmt.Errorf("remove stream field: %w", err)
}
// 转换 output_formatBedrock Invoke 不支持此字段,但可将 schema 内联到最后一条 user message
// 参考 litellm: _convert_output_format_to_inline_schema()
body = convertOutputFormatToInlineSchema(body)
// 移除 output_config 字段Bedrock Invoke 不支持)
body, err = sjson.DeleteBytes(body, "output_config")
if err != nil {
return nil, fmt.Errorf("remove output_config field: %w", err)
}
// 移除工具定义中的 custom 字段
// Claude Code (v2.1.69+) 在 tool 定义中发送 custom: {defer_loading: true}
// Anthropic API 接受但 Bedrock 会拒绝并报 "Extra inputs are not permitted"
body = removeCustomFieldFromTools(body)
// 清理 cache_control 中 Bedrock 不支持的字段
body = sanitizeBedrockCacheControl(body, modelID)
return body, nil
}
// ResolveBedrockBetaTokens computes the final Bedrock beta token list before policy filtering.
func ResolveBedrockBetaTokens(betaHeader string, body []byte, modelID string) []string {
betaTokens := parseAnthropicBetaHeader(betaHeader)
betaTokens = autoInjectBedrockBetaTokens(betaTokens, body, modelID)
return filterBedrockBetaTokens(betaTokens)
}
// convertOutputFormatToInlineSchema 将 output_format 中的 JSON schema 内联到最后一条 user message
// Bedrock Invoke 不支持 output_format 参数litellm 的做法是将 schema 追加到用户消息中
// 参考: litellm AmazonAnthropicClaudeMessagesConfig._convert_output_format_to_inline_schema()
func convertOutputFormatToInlineSchema(body []byte) []byte {
outputFormat := gjson.GetBytes(body, "output_format")
if !outputFormat.Exists() || !outputFormat.IsObject() {
return body
}
// 先从请求体中移除 output_format
body, _ = sjson.DeleteBytes(body, "output_format")
schema := outputFormat.Get("schema")
if !schema.Exists() {
return body
}
// 找到最后一条 user message
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
msgArr := messages.Array()
lastUserIdx := -1
for i := len(msgArr) - 1; i >= 0; i-- {
if msgArr[i].Get("role").String() == "user" {
lastUserIdx = i
break
}
}
if lastUserIdx < 0 {
return body
}
// 将 schema 序列化为 JSON 文本追加到该 message 的 content 数组
schemaJSON, err := json.Marshal(json.RawMessage(schema.Raw))
if err != nil {
return body
}
content := msgArr[lastUserIdx].Get("content")
basePath := fmt.Sprintf("messages.%d.content", lastUserIdx)
if content.IsArray() {
// 追加一个 text block 到 content 数组末尾
idx := len(content.Array())
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.type", basePath, idx), "text")
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.text", basePath, idx), string(schemaJSON))
} else if content.Type == gjson.String {
// content 是纯字符串,转换为数组格式
originalText := content.String()
body, _ = sjson.SetBytes(body, basePath, []map[string]string{
{"type": "text", "text": originalText},
{"type": "text", "text": string(schemaJSON)},
})
}
return body
}
// removeCustomFieldFromTools 移除 tools 数组中每个工具定义的 custom 字段
func removeCustomFieldFromTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if !tools.Exists() || !tools.IsArray() {
return body
}
var err error
for i := range tools.Array() {
body, err = sjson.DeleteBytes(body, fmt.Sprintf("tools.%d.custom", i))
if err != nil {
// 删除失败不影响整体流程,跳过
continue
}
}
return body
}
// claudeVersionRe 匹配 Claude 模型 ID 中的版本号部分
// 支持 claude-{tier}-{major}-{minor} 和 claude-{tier}-{major}.{minor} 格式
var claudeVersionRe = regexp.MustCompile(`claude-(?:haiku|sonnet|opus)-(\d+)[-.](\d+)`)
// isBedrockClaude45OrNewer 判断 Bedrock 模型 ID 是否为 Claude 4.5 或更新版本
// Claude 4.5+ 支持 cache_control 中的 ttl 字段("5m" 和 "1h"
func isBedrockClaude45OrNewer(modelID string) bool {
lower := strings.ToLower(modelID)
matches := claudeVersionRe.FindStringSubmatch(lower)
if matches == nil {
return false
}
major, _ := strconv.Atoi(matches[1])
minor, _ := strconv.Atoi(matches[2])
return major > 4 || (major == 4 && minor >= 5)
}
// sanitizeBedrockCacheControl 清理 system 和 messages 中 cache_control 里
// Bedrock 不支持的字段:
// - scopeBedrock 不支持(如 "global" 跨请求缓存)
// - ttl仅 Claude 4.5+ 支持 "5m" 和 "1h",旧模型需要移除
func sanitizeBedrockCacheControl(body []byte, modelID string) []byte {
isClaude45 := isBedrockClaude45OrNewer(modelID)
// 清理 system 数组中的 cache_control
systemArr := gjson.GetBytes(body, "system")
if systemArr.Exists() && systemArr.IsArray() {
for i, item := range systemArr.Array() {
if !item.IsObject() {
continue
}
cc := item.Get("cache_control")
if !cc.Exists() || !cc.IsObject() {
continue
}
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("system.%d.cache_control", i), cc, isClaude45)
}
}
// 清理 messages 中的 cache_control
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
for mi, msg := range messages.Array() {
if !msg.IsObject() {
continue
}
content := msg.Get("content")
if !content.Exists() || !content.IsArray() {
continue
}
for ci, block := range content.Array() {
if !block.IsObject() {
continue
}
cc := block.Get("cache_control")
if !cc.Exists() || !cc.IsObject() {
continue
}
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("messages.%d.content.%d.cache_control", mi, ci), cc, isClaude45)
}
}
return body
}
// deleteCacheControlUnsupportedFields 删除给定 cache_control 路径下 Bedrock 不支持的字段
func deleteCacheControlUnsupportedFields(body []byte, basePath string, cc gjson.Result, isClaude45 bool) []byte {
// Bedrock 不支持 scope如 "global"
if cc.Get("scope").Exists() {
body, _ = sjson.DeleteBytes(body, basePath+".scope")
}
// ttl仅 Claude 4.5+ 支持 "5m" 和 "1h",其余情况移除
ttl := cc.Get("ttl")
if ttl.Exists() {
shouldRemove := true
if isClaude45 {
v := ttl.String()
if v == "5m" || v == "1h" {
shouldRemove = false
}
}
if shouldRemove {
body, _ = sjson.DeleteBytes(body, basePath+".ttl")
}
}
return body
}
// parseAnthropicBetaHeader 解析 anthropic-beta 头的逗号分隔字符串为 token 列表
func parseAnthropicBetaHeader(header string) []string {
header = strings.TrimSpace(header)
if header == "" {
return nil
}
if strings.HasPrefix(header, "[") && strings.HasSuffix(header, "]") {
var parsed []any
if err := json.Unmarshal([]byte(header), &parsed); err == nil {
tokens := make([]string, 0, len(parsed))
for _, item := range parsed {
token := strings.TrimSpace(fmt.Sprint(item))
if token != "" {
tokens = append(tokens, token)
}
}
return tokens
}
}
var tokens []string
for _, part := range strings.Split(header, ",") {
t := strings.TrimSpace(part)
if t != "" {
tokens = append(tokens, t)
}
}
return tokens
}
// bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单
// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json)
// 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单
var bedrockSupportedBetaTokens = map[string]bool{
"computer-use-2025-01-24": true,
"computer-use-2025-11-24": true,
"context-1m-2025-08-07": true,
"context-management-2025-06-27": true,
"compact-2026-01-12": true,
"interleaved-thinking-2025-05-14": true,
"tool-search-tool-2025-10-19": true,
"tool-examples-2025-10-29": true,
}
// bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则
// Anthropic 直接 API 使用通用头Bedrock Invoke 需要特定的替代头
var bedrockBetaTokenTransforms = map[string]string{
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
}
// autoInjectBedrockBetaTokens 根据请求体内容自动补齐必要的 beta token
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() 和
// AmazonAnthropicClaudeMessagesConfig._get_tool_search_beta_header_for_bedrock()
//
// 客户端(特别是非 Claude Code 客户端)可能只在 body 中启用了功能而不在 header 中带对应 beta token
// 这里通过检测请求体特征自动补齐,确保 Bedrock Invoke 不会因缺少必要 beta 头而 400。
func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) []string {
seen := make(map[string]bool, len(tokens))
for _, t := range tokens {
seen[t] = true
}
inject := func(token string) {
if !seen[token] {
tokens = append(tokens, token)
seen[token] = true
}
}
// 检测 thinking / interleaved thinking
// 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta
if gjson.GetBytes(body, "thinking").Exists() {
inject("interleaved-thinking-2025-05-14")
}
// 检测 computer_use 工具
// tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta
tools := gjson.GetBytes(body, "tools")
if tools.Exists() && tools.IsArray() {
toolSearchUsed := false
programmaticToolCallingUsed := false
inputExamplesUsed := false
for _, tool := range tools.Array() {
toolType := tool.Get("type").String()
if strings.HasPrefix(toolType, "computer_20") {
inject("computer-use-2025-11-24")
}
if isBedrockToolSearchType(toolType) {
toolSearchUsed = true
}
if hasCodeExecutionAllowedCallers(tool) {
programmaticToolCallingUsed = true
}
if hasInputExamples(tool) {
inputExamplesUsed = true
}
}
if programmaticToolCallingUsed || inputExamplesUsed {
// programmatic tool calling 和 input examples 需要 advanced-tool-use
// 后续 filterBedrockBetaTokens 会将其转换为 Bedrock 特定的 tool-search-tool
inject("advanced-tool-use-2025-11-20")
}
if toolSearchUsed && bedrockModelSupportsToolSearch(modelID) {
// 纯 tool search无 programmatic/inputExamples时直接注入 Bedrock 特定头,
// 跳过 advanced-tool-use → tool-search-tool 的转换步骤(与 litellm 对齐)
if !programmaticToolCallingUsed && !inputExamplesUsed {
inject("tool-search-tool-2025-10-19")
} else {
inject("advanced-tool-use-2025-11-20")
}
}
}
return tokens
}
func isBedrockToolSearchType(toolType string) bool {
return toolType == "tool_search_tool_regex_20251119" || toolType == "tool_search_tool_bm25_20251119"
}
func hasCodeExecutionAllowedCallers(tool gjson.Result) bool {
allowedCallers := tool.Get("allowed_callers")
if containsStringInJSONArray(allowedCallers, "code_execution_20250825") {
return true
}
return containsStringInJSONArray(tool.Get("function.allowed_callers"), "code_execution_20250825")
}
func hasInputExamples(tool gjson.Result) bool {
if arr := tool.Get("input_examples"); arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 {
return true
}
arr := tool.Get("function.input_examples")
return arr.Exists() && arr.IsArray() && len(arr.Array()) > 0
}
func containsStringInJSONArray(result gjson.Result, target string) bool {
if !result.Exists() || !result.IsArray() {
return false
}
for _, item := range result.Array() {
if item.String() == target {
return true
}
}
return false
}
// bedrockModelSupportsToolSearch 判断 Bedrock 模型是否支持 tool search
// 目前仅 Claude Opus/Sonnet 4.5+ 支持Haiku 不支持
func bedrockModelSupportsToolSearch(modelID string) bool {
lower := strings.ToLower(modelID)
matches := claudeVersionRe.FindStringSubmatch(lower)
if matches == nil {
return false
}
// Haiku 不支持 tool search
if strings.Contains(lower, "haiku") {
return false
}
major, _ := strconv.Atoi(matches[1])
minor, _ := strconv.Atoi(matches[2])
return major > 4 || (major == 4 && minor >= 5)
}
// filterBedrockBetaTokens 过滤并转换 beta token 列表,仅保留 Bedrock Invoke 支持的 token
// 1. 应用转换规则(如 advanced-tool-use → tool-search-tool
// 2. 过滤掉 Bedrock 不支持的 token如 output-128k, files-api, structured-outputs 等)
// 3. 自动关联 tool-examples当 tool-search-tool 存在时)
func filterBedrockBetaTokens(tokens []string) []string {
seen := make(map[string]bool, len(tokens))
var result []string
for _, t := range tokens {
// 应用转换规则
if replacement, ok := bedrockBetaTokenTransforms[t]; ok {
t = replacement
}
// 只保留白名单中的 token且去重
if bedrockSupportedBetaTokens[t] && !seen[t] {
result = append(result, t)
seen[t] = true
}
}
// 自动关联: tool-search-tool 存在时,确保 tool-examples 也存在
if seen["tool-search-tool-2025-10-19"] && !seen["tool-examples-2025-10-29"] {
result = append(result, "tool-examples-2025-10-29")
}
return result
}

View File

@@ -0,0 +1,659 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestPrepareBedrockRequestBody_BasicFields(t *testing.T) {
input := `{"model":"claude-opus-4-6","stream":true,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
// anthropic_version 应被注入
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
// model 和 stream 应被移除
assert.False(t, gjson.GetBytes(result, "model").Exists())
assert.False(t, gjson.GetBytes(result, "stream").Exists())
// max_tokens 应保留
assert.Equal(t, int64(1024), gjson.GetBytes(result, "max_tokens").Int())
}
func TestPrepareBedrockRequestBody_OutputFormatInlineSchema(t *testing.T) {
t.Run("schema inlined into last user message array content", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}},"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
// schema 应内联到最后一条 user message 的 content 数组末尾
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
require.Len(t, contentArr, 2)
assert.Equal(t, "text", contentArr[1].Get("type").String())
assert.Contains(t, contentArr[1].Get("text").String(), `"name":"string"`)
})
t.Run("schema inlined into string content", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"result":"number"}},"messages":[{"role":"user","content":"compute this"}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
require.Len(t, contentArr, 2)
assert.Equal(t, "compute this", contentArr[0].Get("text").String())
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"number"`)
})
t.Run("no schema field just removes output_format", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json"},"messages":[{"role":"user","content":"hi"}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
})
t.Run("no messages just removes output_format", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
})
}
func TestPrepareBedrockRequestBody_RemoveOutputConfig(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_config":{"max_tokens":100},"messages":[]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
}
func TestRemoveCustomFieldFromTools(t *testing.T) {
input := `{
"tools": [
{"name":"tool1","custom":{"defer_loading":true},"description":"desc1"},
{"name":"tool2","description":"desc2"},
{"name":"tool3","custom":{"defer_loading":true,"other":123},"description":"desc3"}
]
}`
result := removeCustomFieldFromTools([]byte(input))
tools := gjson.GetBytes(result, "tools").Array()
require.Len(t, tools, 3)
// custom 应被移除
assert.False(t, tools[0].Get("custom").Exists())
// name/description 应保留
assert.Equal(t, "tool1", tools[0].Get("name").String())
assert.Equal(t, "desc1", tools[0].Get("description").String())
// 没有 custom 的工具不受影响
assert.Equal(t, "tool2", tools[1].Get("name").String())
// 第三个工具的 custom 也应被移除
assert.False(t, tools[2].Get("custom").Exists())
assert.Equal(t, "tool3", tools[2].Get("name").String())
}
func TestRemoveCustomFieldFromTools_NoTools(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}]}`
result := removeCustomFieldFromTools([]byte(input))
// 无 tools 时不改变原始数据
assert.JSONEq(t, input, string(result))
}
func TestSanitizeBedrockCacheControl_RemoveScope(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","scope":"global"}}],
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"global"}}]}]
}`
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
// scope 应被移除
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
assert.False(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.scope").Exists())
// type 应保留
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "messages.0.content.0.cache_control.type").String())
}
func TestSanitizeBedrockCacheControl_TTL_OldModel(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
}`
// 旧模型Claude 3.5)不支持 ttl
result := sanitizeBedrockCacheControl([]byte(input), "anthropic.claude-3-5-sonnet-20241022-v2:0")
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
}
func TestSanitizeBedrockCacheControl_TTL_Claude45_Supported(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
}`
// Claude 4.5+ 支持 "5m" 和 "1h"
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
assert.True(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
}
func TestSanitizeBedrockCacheControl_TTL_Claude45_UnsupportedValue(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"10m"}}]
}`
// Claude 4.5 不支持 "10m"
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
}
func TestSanitizeBedrockCacheControl_TTL_Claude46(t *testing.T) {
input := `{
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
}`
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
assert.True(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").Exists())
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
}
func TestSanitizeBedrockCacheControl_NoCacheControl(t *testing.T) {
input := `{"system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
// 无 cache_control 时不改变原始数据
assert.JSONEq(t, input, string(result))
}
func TestIsBedrockClaude45OrNewer(t *testing.T) {
tests := []struct {
modelID string
expect bool
}{
{"us.anthropic.claude-opus-4-6-v1", true},
{"us.anthropic.claude-sonnet-4-6", true},
{"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true},
{"us.anthropic.claude-opus-4-5-20251101-v1:0", true},
{"us.anthropic.claude-haiku-4-5-20251001-v1:0", true},
{"anthropic.claude-3-5-sonnet-20241022-v2:0", false},
{"anthropic.claude-3-opus-20240229-v1:0", false},
{"anthropic.claude-3-haiku-20240307-v1:0", false},
// 未来版本应自动支持
{"us.anthropic.claude-sonnet-5-0-v1", true},
{"us.anthropic.claude-opus-4-7-v1", true},
// 旧版本
{"anthropic.claude-opus-4-1-v1", false},
{"anthropic.claude-sonnet-4-0-v1", false},
// 非 Claude 模型
{"amazon.nova-pro-v1", false},
{"meta.llama3-70b", false},
}
for _, tt := range tests {
t.Run(tt.modelID, func(t *testing.T) {
assert.Equal(t, tt.expect, isBedrockClaude45OrNewer(tt.modelID))
})
}
}
func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) {
// 模拟一个完整的 Claude Code 请求
input := `{
"model": "claude-opus-4-6",
"stream": true,
"max_tokens": 16384,
"output_format": {"type": "json", "schema": {"result": "string"}},
"output_config": {"max_tokens": 100},
"system": [{"type": "text", "text": "You are helpful", "cache_control": {"type": "ephemeral", "scope": "global", "ttl": "5m"}}],
"messages": [
{"role": "user", "content": [{"type": "text", "text": "hello", "cache_control": {"type": "ephemeral", "ttl": "1h"}}]}
],
"tools": [
{"name": "bash", "description": "Run bash", "custom": {"defer_loading": true}, "input_schema": {"type": "object"}},
{"name": "read", "description": "Read file", "input_schema": {"type": "object"}}
]
}`
betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12"
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader)
require.NoError(t, err)
// 基本字段
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
assert.False(t, gjson.GetBytes(result, "model").Exists())
assert.False(t, gjson.GetBytes(result, "stream").Exists())
assert.Equal(t, int64(16384), gjson.GetBytes(result, "max_tokens").Int())
// anthropic_beta 应包含所有 beta tokens
betaArr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, betaArr, 3)
assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String())
assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String())
assert.Equal(t, "compact-2026-01-12", betaArr[2].String())
// output_format 应被移除schema 内联到最后一条 user message
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
// content 数组:原始 text block + 内联 schema block
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
require.Len(t, contentArr, 2)
assert.Equal(t, "hello", contentArr[0].Get("text").String())
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"string"`)
// tools 中的 custom 应被移除
assert.False(t, gjson.GetBytes(result, "tools.0.custom").Exists())
assert.Equal(t, "bash", gjson.GetBytes(result, "tools.0.name").String())
assert.Equal(t, "read", gjson.GetBytes(result, "tools.1.name").String())
// cache_control: scope 应被移除ttl 在 Claude 4.6 上保留合法值
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
}
func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
t.Run("empty beta header", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
})
t.Run("single beta token", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 1)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
})
t.Run("multiple beta tokens with spaces", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
})
t.Run("json array beta header", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`)
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
})
}
func TestParseAnthropicBetaHeader(t *testing.T) {
assert.Nil(t, parseAnthropicBetaHeader(""))
assert.Equal(t, []string{"a"}, parseAnthropicBetaHeader("a"))
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a,b"))
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a , b "))
assert.Equal(t, []string{"a", "b", "c"}, parseAnthropicBetaHeader("a,b,c"))
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader(`["a","b"]`))
}
func TestFilterBedrockBetaTokens(t *testing.T) {
t.Run("supported tokens pass through", func(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, tokens, result)
})
t.Run("unsupported tokens are filtered out", func(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
})
t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) {
tokens := []string{"advanced-tool-use-2025-11-20"}
result := filterBedrockBetaTokens(tokens)
assert.Contains(t, result, "tool-search-tool-2025-10-19")
// tool-examples 自动关联
assert.Contains(t, result, "tool-examples-2025-10-29")
})
t.Run("tool-search-tool auto-associates tool-examples", func(t *testing.T) {
tokens := []string{"tool-search-tool-2025-10-19"}
result := filterBedrockBetaTokens(tokens)
assert.Contains(t, result, "tool-search-tool-2025-10-19")
assert.Contains(t, result, "tool-examples-2025-10-29")
})
t.Run("no duplication when tool-examples already present", func(t *testing.T) {
tokens := []string{"tool-search-tool-2025-10-19", "tool-examples-2025-10-29"}
result := filterBedrockBetaTokens(tokens)
count := 0
for _, t := range result {
if t == "tool-examples-2025-10-29" {
count++
}
}
assert.Equal(t, 1, count)
})
t.Run("empty input returns nil", func(t *testing.T) {
result := filterBedrockBetaTokens(nil)
assert.Nil(t, result)
})
t.Run("all unsupported returns nil", func(t *testing.T) {
result := filterBedrockBetaTokens([]string{"output-128k-2025-02-19", "effort-2025-11-24"})
assert.Nil(t, result)
})
t.Run("duplicate tokens are deduplicated", func(t *testing.T) {
tokens := []string{"context-1m-2025-08-07", "context-1m-2025-08-07"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
})
}
func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
t.Run("unsupported beta tokens are filtered", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
"interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 1)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
})
t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
"advanced-tool-use-2025-11-20")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "tool-search-tool-2025-10-19", arr[0].String())
assert.Equal(t, "tool-examples-2025-10-29", arr[1].String())
})
}
func TestBedrockCrossRegionPrefix(t *testing.T) {
tests := []struct {
region string
expect string
}{
// US regions
{"us-east-1", "us"},
{"us-east-2", "us"},
{"us-west-1", "us"},
{"us-west-2", "us"},
// GovCloud
{"us-gov-east-1", "us-gov"},
{"us-gov-west-1", "us-gov"},
// EU regions
{"eu-west-1", "eu"},
{"eu-west-2", "eu"},
{"eu-west-3", "eu"},
{"eu-central-1", "eu"},
{"eu-central-2", "eu"},
{"eu-north-1", "eu"},
{"eu-south-1", "eu"},
// APAC regions
{"ap-northeast-1", "jp"},
{"ap-northeast-2", "apac"},
{"ap-southeast-1", "apac"},
{"ap-southeast-2", "au"},
{"ap-south-1", "apac"},
// Canada / South America fallback to us
{"ca-central-1", "us"},
{"sa-east-1", "us"},
// Unknown defaults to us
{"me-south-1", "us"},
}
for _, tt := range tests {
t.Run(tt.region, func(t *testing.T) {
assert.Equal(t, tt.expect, BedrockCrossRegionPrefix(tt.region))
})
}
}
func TestResolveBedrockModelID(t *testing.T) {
t.Run("default alias resolves and adjusts region", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "eu-west-1",
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-5")
require.True(t, ok)
assert.Equal(t, "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", modelID)
})
t.Run("custom alias mapping reuses default bedrock mapping", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "ap-southeast-2",
"model_mapping": map[string]any{
"claude-*": "claude-opus-4-6",
},
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-opus-4-6-thinking")
require.True(t, ok)
assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID)
})
t.Run("force global rewrites anthropic regional model id", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
"aws_force_global": "true",
"model_mapping": map[string]any{
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
},
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-6")
require.True(t, ok)
assert.Equal(t, "global.anthropic.claude-sonnet-4-6", modelID)
})
t.Run("direct bedrock model id passes through", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
},
}
modelID, ok := ResolveBedrockModelID(account, "anthropic.claude-haiku-4-5-20251001-v1:0")
require.True(t, ok)
assert.Equal(t, "anthropic.claude-haiku-4-5-20251001-v1:0", modelID)
})
t.Run("unsupported alias returns false", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
},
}
_, ok := ResolveBedrockModelID(account, "claude-3-5-sonnet-20241022")
assert.False(t, ok)
})
}
func TestAutoInjectBedrockBetaTokens(t *testing.T) {
t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
})
t.Run("no duplicate when already present", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1")
count := 0
for _, t := range result {
if t == "interleaved-thinking-2025-05-14" {
count++
}
}
assert.Equal(t, 1, count)
})
t.Run("inject computer-use when computer tool present", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"computer_20250124","name":"computer","display_width_px":1024}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "computer-use-2025-11-24")
})
t.Run("inject advanced-tool-use for programmatic tool calling", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("inject advanced-tool-use for input examples", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","input_examples":[{"cmd":"ls"}]}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("inject tool-search-tool directly for pure tool search (no programmatic/inputExamples)", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
// 纯 tool search 场景直接注入 Bedrock 特定头,不走 advanced-tool-use 转换
assert.Contains(t, result, "tool-search-tool-2025-10-19")
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("inject advanced-tool-use when tool search combined with programmatic calling", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"},{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
// 混合场景使用 advanced-tool-use后续由 filter 转换为 tool-search-tool
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("do not inject tool-search beta for unsupported models", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "anthropic.claude-3-5-sonnet-20241022-v2:0")
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
assert.NotContains(t, result, "tool-search-tool-2025-10-19")
})
t.Run("no injection for regular tools", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","description":"run bash","input_schema":{"type":"object"}}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Empty(t, result)
})
t.Run("no injection when no features detected", func(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Empty(t, result)
})
t.Run("preserves existing tokens", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled"},"messages":[{"role":"user","content":"hi"}]}`)
existing := []string{"context-1m-2025-08-07", "compact-2026-01-12"}
result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "context-1m-2025-08-07")
assert.Contains(t, result, "compact-2026-01-12")
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
})
}
func TestResolveBedrockBetaTokens(t *testing.T) {
t.Run("body-only tool features resolve to final bedrock tokens", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
result := ResolveBedrockBetaTokens("", body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "tool-search-tool-2025-10-19")
assert.Contains(t, result, "tool-examples-2025-10-29")
})
t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1")
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
})
}
func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) {
t.Run("thinking in body auto-injects beta without header", func(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
found := false
for _, v := range arr {
if v.String() == "interleaved-thinking-2025-05-14" {
found = true
}
}
assert.True(t, found, "interleaved-thinking should be auto-injected")
})
t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
names := make([]string, len(arr))
for i, v := range arr {
names[i] = v.String()
}
assert.Contains(t, names, "context-1m-2025-08-07")
assert.Contains(t, names, "interleaved-thinking-2025-05-14")
})
}
func TestAdjustBedrockModelRegionPrefix(t *testing.T) {
tests := []struct {
name string
modelID string
region string
expect string
}{
// US region — no change needed
{"us region keeps us prefix", "us.anthropic.claude-opus-4-6-v1", "us-east-1", "us.anthropic.claude-opus-4-6-v1"},
// EU region — replace us → eu
{"eu region replaces prefix", "us.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
{"eu region sonnet", "us.anthropic.claude-sonnet-4-6", "eu-central-1", "eu.anthropic.claude-sonnet-4-6"},
// APAC region — jp and au have dedicated prefixes per AWS docs
{"jp region (ap-northeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-northeast-1", "jp.anthropic.claude-sonnet-4-5-20250929-v1:0"},
{"au region (ap-southeast-2)", "us.anthropic.claude-haiku-4-5-20251001-v1:0", "ap-southeast-2", "au.anthropic.claude-haiku-4-5-20251001-v1:0"},
{"apac region (ap-southeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-southeast-1", "apac.anthropic.claude-sonnet-4-5-20250929-v1:0"},
// eu → us (user manually set eu prefix, moved to us region)
{"eu to us", "eu.anthropic.claude-opus-4-6-v1", "us-west-2", "us.anthropic.claude-opus-4-6-v1"},
// global prefix — replace to match region
{"global to eu", "global.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
// No known prefix — leave unchanged
{"no prefix unchanged", "anthropic.claude-3-5-sonnet-20241022-v2:0", "eu-west-1", "anthropic.claude-3-5-sonnet-20241022-v2:0"},
// GovCloud — uses independent us-gov prefix
{"govcloud from us", "us.anthropic.claude-opus-4-6-v1", "us-gov-east-1", "us-gov.anthropic.claude-opus-4-6-v1"},
{"govcloud already correct", "us-gov.anthropic.claude-opus-4-6-v1", "us-gov-west-1", "us-gov.anthropic.claude-opus-4-6-v1"},
// Force global (special region value)
{"force global from us", "us.anthropic.claude-opus-4-6-v1", "global", "global.anthropic.claude-opus-4-6-v1"},
{"force global from eu", "eu.anthropic.claude-sonnet-4-6", "global", "global.anthropic.claude-sonnet-4-6"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expect, AdjustBedrockModelRegionPrefix(tt.modelID, tt.region))
})
}
}

View File

@@ -0,0 +1,67 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
)
// BedrockSigner 使用 AWS SigV4 对 Bedrock 请求签名
type BedrockSigner struct {
credentials aws.Credentials
region string
signer *v4.Signer
}
// NewBedrockSigner 创建 BedrockSigner
func NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region string) *BedrockSigner {
return &BedrockSigner{
credentials: aws.Credentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
},
region: region,
signer: v4.NewSigner(),
}
}
// NewBedrockSignerFromAccount 从 Account 凭证创建 BedrockSigner
func NewBedrockSignerFromAccount(account *Account) (*BedrockSigner, error) {
accessKeyID := account.GetCredential("aws_access_key_id")
if accessKeyID == "" {
return nil, fmt.Errorf("aws_access_key_id not found in credentials")
}
secretAccessKey := account.GetCredential("aws_secret_access_key")
if secretAccessKey == "" {
return nil, fmt.Errorf("aws_secret_access_key not found in credentials")
}
region := account.GetCredential("aws_region")
if region == "" {
region = defaultBedrockRegion
}
sessionToken := account.GetCredential("aws_session_token") // 可选
return NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region), nil
}
// SignRequest 对 HTTP 请求进行 SigV4 签名
// 重要约束调用此方法前req 应只包含 AWS 相关的 header如 Content-Type、Accept
// 非 AWS header如 anthropic-beta会参与签名计算如果 Bedrock 服务端不识别这些 header
// 签名验证可能失败。litellm 通过 _filter_headers_for_aws_signature 实现头过滤,
// 当前实现中 buildUpstreamRequestBedrock 仅设置了 Content-Type 和 Accept因此是安全的。
func (s *BedrockSigner) SignRequest(ctx context.Context, req *http.Request, body []byte) error {
payloadHash := sha256Hash(body)
return s.signer.SignHTTP(ctx, s.credentials, req, payloadHash, "bedrock", s.region, time.Now())
}
func sha256Hash(data []byte) string {
h := sha256.Sum256(data)
return hex.EncodeToString(h[:])
}

View File

@@ -0,0 +1,35 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewBedrockSignerFromAccount_DefaultRegion(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_access_key_id": "test-akid",
"aws_secret_access_key": "test-secret",
},
}
signer, err := NewBedrockSignerFromAccount(account)
require.NoError(t, err)
require.NotNil(t, signer)
assert.Equal(t, defaultBedrockRegion, signer.region)
}
func TestFilterBetaTokens(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "tool-search-tool-2025-10-19"}
filterSet := map[string]struct{}{
"tool-search-tool-2025-10-19": {},
}
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, filterBetaTokens(tokens, filterSet))
assert.Equal(t, tokens, filterBetaTokens(tokens, nil))
assert.Nil(t, filterBetaTokens(nil, filterSet))
}

View File

@@ -0,0 +1,414 @@
package service
import (
"bufio"
"context"
"encoding/base64"
"errors"
"fmt"
"hash/crc32"
"io"
"net/http"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// handleBedrockStreamingResponse 处理 Bedrock InvokeModelWithResponseStream 的 EventStream 响应
// Bedrock 返回 AWS EventStream 二进制格式,每个事件的 payload 中 chunk.bytes 是 base64 编码的
// Claude SSE 事件 JSON。本方法解码后转换为标准 SSE 格式写入客户端。
func (s *GatewayService) handleBedrockStreamingResponse(
ctx context.Context,
resp *http.Response,
c *gin.Context,
account *Account,
startTime time.Time,
model string,
) (*streamingResult, error) {
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if v := resp.Header.Get("x-amzn-requestid"); v != "" {
c.Header("x-request-id", v)
}
usage := &ClaudeUsage{}
var firstTokenMs *int
clientDisconnected := false
// Bedrock EventStream 使用 application/vnd.amazon.eventstream 二进制格式。
// 每个帧结构total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4)
// 但更实用的方式是使用行扫描找 JSON chunks因为 Bedrock 的响应在二进制帧中。
// 我们使用 EventStream decoder 来正确解析。
decoder := newBedrockEventStreamDecoder(resp.Body)
type decodeEvent struct {
payload []byte
err error
}
events := make(chan decodeEvent, 16)
done := make(chan struct{})
sendEvent := func(ev decodeEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt atomic.Int64
lastReadAt.Store(time.Now().UnixNano())
go func() {
defer close(events)
for {
payload, err := decoder.Decode()
if err != nil {
if err == io.EOF {
return
}
_ = sendEvent(decodeEvent{err: err})
return
}
lastReadAt.Store(time.Now().UnixNano())
if !sendEvent(decodeEvent{payload: payload}) {
return
}
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
for {
select {
case ev, ok := <-events:
if !ok {
if !clientDisconnected {
flusher.Flush()
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
}
if ev.err != nil {
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("bedrock stream read error: %w", ev.err)
}
// payload 是 JSON提取 chunk.bytesbase64 编码的 Claude SSE 事件数据)
sseData := extractBedrockChunkData(ev.payload)
if sseData == nil {
continue
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
// 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式
// 同时移除该字段避免透传给客户端
sseData = transformBedrockInvocationMetrics(sseData)
// 解析 SSE 事件数据提取 usage
s.parseSSEUsagePassthrough(string(sseData), usage)
// 确定 SSE event type
eventType := gjson.GetBytes(sseData, "type").String()
// 写入标准 SSE 格式
if !clientDisconnected {
var writeErr error
if eventType != "" {
_, writeErr = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, sseData)
} else {
_, writeErr = fmt.Fprintf(w, "data: %s\n\n", sseData)
}
if writeErr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "[Bedrock] Client disconnected during streaming, continue draining for usage: account=%d", account.ID)
} else {
flusher.Flush()
}
}
case <-intervalCh:
lastRead := time.Unix(0, lastReadAt.Load())
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
logger.LegacyPrintf("service.gateway", "[Bedrock] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
}
// extractBedrockChunkData 从 Bedrock EventStream payload 中提取 Claude SSE 事件数据
// Bedrock payload 格式:{"bytes":"<base64-encoded-json>"}
func extractBedrockChunkData(payload []byte) []byte {
b64 := gjson.GetBytes(payload, "bytes").String()
if b64 == "" {
return nil
}
decoded, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return nil
}
return decoded
}
// transformBedrockInvocationMetrics 将 Bedrock 特有的 amazon-bedrock-invocationMetrics
// 转换为标准 Anthropic usage 格式,并从 SSE 数据中移除该字段。
//
// Bedrock Invoke 返回的 message_delta 事件可能包含:
//
// {"type":"message_delta","delta":{...},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}
//
// 转换为:
//
// {"type":"message_delta","delta":{...},"usage":{"input_tokens":150,"output_tokens":42}}
func transformBedrockInvocationMetrics(data []byte) []byte {
metrics := gjson.GetBytes(data, "amazon-bedrock-invocationMetrics")
if !metrics.Exists() || !metrics.IsObject() {
return data
}
// 移除 Bedrock 特有字段
data, _ = sjson.DeleteBytes(data, "amazon-bedrock-invocationMetrics")
// 如果已有标准 usage 字段,不覆盖
if gjson.GetBytes(data, "usage").Exists() {
return data
}
// 转换 camelCase → snake_case 写入 usage
inputTokens := metrics.Get("inputTokenCount")
outputTokens := metrics.Get("outputTokenCount")
if inputTokens.Exists() {
data, _ = sjson.SetBytes(data, "usage.input_tokens", inputTokens.Int())
}
if outputTokens.Exists() {
data, _ = sjson.SetBytes(data, "usage.output_tokens", outputTokens.Int())
}
return data
}
// bedrockEventStreamDecoder 解码 AWS EventStream 二进制帧
// EventStream 帧格式:
//
// [total_byte_length: 4 bytes]
// [headers_byte_length: 4 bytes]
// [prelude_crc: 4 bytes]
// [headers: variable]
// [payload: variable]
// [message_crc: 4 bytes]
type bedrockEventStreamDecoder struct {
reader *bufio.Reader
}
func newBedrockEventStreamDecoder(r io.Reader) *bedrockEventStreamDecoder {
return &bedrockEventStreamDecoder{
reader: bufio.NewReaderSize(r, 64*1024),
}
}
// Decode 读取下一个 EventStream 帧并返回 chunk 类型事件的 payload
func (d *bedrockEventStreamDecoder) Decode() ([]byte, error) {
for {
// 读取 prelude: total_length(4) + headers_length(4) + prelude_crc(4) = 12 bytes
prelude := make([]byte, 12)
if _, err := io.ReadFull(d.reader, prelude); err != nil {
return nil, err
}
// 验证 prelude CRCAWS EventStream 使用标准 CRC32 / IEEE
preludeCRC := bedrockReadUint32(prelude[8:12])
if crc32.Checksum(prelude[0:8], crc32IEEETable) != preludeCRC {
return nil, fmt.Errorf("eventstream prelude CRC mismatch")
}
totalLength := bedrockReadUint32(prelude[0:4])
headersLength := bedrockReadUint32(prelude[4:8])
if totalLength < 16 { // minimum: 12 prelude + 4 message_crc
return nil, fmt.Errorf("invalid eventstream frame: total_length=%d", totalLength)
}
// 读取 headers + payload + message_crc
remaining := int(totalLength) - 12
if remaining <= 0 {
continue
}
data := make([]byte, remaining)
if _, err := io.ReadFull(d.reader, data); err != nil {
return nil, err
}
// 验证 message CRC覆盖 prelude + headers + payload
messageCRC := bedrockReadUint32(data[len(data)-4:])
h := crc32.New(crc32IEEETable)
_, _ = h.Write(prelude)
_, _ = h.Write(data[:len(data)-4])
if h.Sum32() != messageCRC {
return nil, fmt.Errorf("eventstream message CRC mismatch")
}
// 解析 headers
headers := data[:headersLength]
payload := data[headersLength : len(data)-4] // 去掉 message_crc
// 从 headers 中提取 :event-type
eventType := extractEventStreamHeaderValue(headers, ":event-type")
// 只处理 chunk 事件
if eventType == "chunk" {
// payload 是完整的 JSON包含 bytes 字段
return payload, nil
}
// 检查异常事件
exceptionType := extractEventStreamHeaderValue(headers, ":exception-type")
if exceptionType != "" {
return nil, fmt.Errorf("bedrock exception: %s: %s", exceptionType, string(payload))
}
messageType := extractEventStreamHeaderValue(headers, ":message-type")
if messageType == "exception" || messageType == "error" {
return nil, fmt.Errorf("bedrock error: %s", string(payload))
}
// 跳过其他事件类型(如 initial-response
}
}
// extractEventStreamHeaderValue 从 EventStream headers 二进制数据中提取指定 header 的字符串值
// EventStream header 格式:
//
// [name_length: 1 byte][name: variable][value_type: 1 byte][value: variable]
//
// value_type = 7 表示 string 类型,前 2 bytes 为长度
func extractEventStreamHeaderValue(headers []byte, targetName string) string {
pos := 0
for pos < len(headers) {
if pos >= len(headers) {
break
}
nameLen := int(headers[pos])
pos++
if pos+nameLen > len(headers) {
break
}
name := string(headers[pos : pos+nameLen])
pos += nameLen
if pos >= len(headers) {
break
}
valueType := headers[pos]
pos++
switch valueType {
case 7: // string
if pos+2 > len(headers) {
return ""
}
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
pos += 2
if pos+valueLen > len(headers) {
return ""
}
value := string(headers[pos : pos+valueLen])
pos += valueLen
if name == targetName {
return value
}
case 0: // bool true
if name == targetName {
return "true"
}
case 1: // bool false
if name == targetName {
return "false"
}
case 2: // byte
pos++
if name == targetName {
return ""
}
case 3: // short
pos += 2
if name == targetName {
return ""
}
case 4: // int
pos += 4
if name == targetName {
return ""
}
case 5: // long
pos += 8
if name == targetName {
return ""
}
case 6: // bytes
if pos+2 > len(headers) {
return ""
}
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
pos += 2 + valueLen
case 8: // timestamp
pos += 8
case 9: // uuid
pos += 16
default:
return "" // 未知类型,无法继续解析
}
}
return ""
}
// crc32IEEETable is the CRC32 / IEEE table used by AWS EventStream.
var crc32IEEETable = crc32.MakeTable(crc32.IEEE)
func bedrockReadUint32(b []byte) uint32 {
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
}
func bedrockReadUint16(b []byte) uint16 {
return uint16(b[0])<<8 | uint16(b[1])
}

View File

@@ -0,0 +1,261 @@
package service
import (
"bytes"
"encoding/base64"
"encoding/binary"
"hash/crc32"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestExtractBedrockChunkData(t *testing.T) {
t.Run("valid base64 payload", func(t *testing.T) {
original := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`
b64 := base64.StdEncoding.EncodeToString([]byte(original))
payload := []byte(`{"bytes":"` + b64 + `"}`)
result := extractBedrockChunkData(payload)
require.NotNil(t, result)
assert.JSONEq(t, original, string(result))
})
t.Run("empty bytes field", func(t *testing.T) {
result := extractBedrockChunkData([]byte(`{"bytes":""}`))
assert.Nil(t, result)
})
t.Run("no bytes field", func(t *testing.T) {
result := extractBedrockChunkData([]byte(`{"other":"value"}`))
assert.Nil(t, result)
})
t.Run("invalid base64", func(t *testing.T) {
result := extractBedrockChunkData([]byte(`{"bytes":"not-valid-base64!!!"}`))
assert.Nil(t, result)
})
}
func TestTransformBedrockInvocationMetrics(t *testing.T) {
t.Run("converts metrics to usage", func(t *testing.T) {
input := `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
result := transformBedrockInvocationMetrics([]byte(input))
// amazon-bedrock-invocationMetrics should be removed
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
// usage should be set
assert.Equal(t, int64(150), gjson.GetBytes(result, "usage.input_tokens").Int())
assert.Equal(t, int64(42), gjson.GetBytes(result, "usage.output_tokens").Int())
// original fields preserved
assert.Equal(t, "message_delta", gjson.GetBytes(result, "type").String())
assert.Equal(t, "end_turn", gjson.GetBytes(result, "delta.stop_reason").String())
})
t.Run("no metrics present", func(t *testing.T) {
input := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}`
result := transformBedrockInvocationMetrics([]byte(input))
assert.JSONEq(t, input, string(result))
})
t.Run("does not overwrite existing usage", func(t *testing.T) {
input := `{"type":"message_delta","usage":{"output_tokens":100},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
result := transformBedrockInvocationMetrics([]byte(input))
// metrics removed but existing usage preserved
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
assert.Equal(t, int64(100), gjson.GetBytes(result, "usage.output_tokens").Int())
})
}
func TestExtractEventStreamHeaderValue(t *testing.T) {
// Build a header with :event-type = "chunk" (string type = 7)
buildStringHeader := func(name, value string) []byte {
var buf bytes.Buffer
// name length (1 byte)
_ = buf.WriteByte(byte(len(name)))
// name
_, _ = buf.WriteString(name)
// value type (7 = string)
_ = buf.WriteByte(7)
// value length (2 bytes, big-endian)
_ = binary.Write(&buf, binary.BigEndian, uint16(len(value)))
// value
_, _ = buf.WriteString(value)
return buf.Bytes()
}
t.Run("find string header", func(t *testing.T) {
headers := buildStringHeader(":event-type", "chunk")
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
})
t.Run("header not found", func(t *testing.T) {
headers := buildStringHeader(":event-type", "chunk")
assert.Equal(t, "", extractEventStreamHeaderValue(headers, ":message-type"))
})
t.Run("multiple headers", func(t *testing.T) {
var buf bytes.Buffer
_, _ = buf.Write(buildStringHeader(":content-type", "application/json"))
_, _ = buf.Write(buildStringHeader(":event-type", "chunk"))
_, _ = buf.Write(buildStringHeader(":message-type", "event"))
headers := buf.Bytes()
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
assert.Equal(t, "application/json", extractEventStreamHeaderValue(headers, ":content-type"))
assert.Equal(t, "event", extractEventStreamHeaderValue(headers, ":message-type"))
})
t.Run("empty headers", func(t *testing.T) {
assert.Equal(t, "", extractEventStreamHeaderValue([]byte{}, ":event-type"))
})
}
func TestBedrockEventStreamDecoder(t *testing.T) {
crc32IeeeTab := crc32.MakeTable(crc32.IEEE)
// Build a valid EventStream frame with correct CRC32/IEEE checksums.
buildFrame := func(eventType string, payload []byte) []byte {
// Build headers
var headersBuf bytes.Buffer
// :event-type header
_ = headersBuf.WriteByte(byte(len(":event-type")))
_, _ = headersBuf.WriteString(":event-type")
_ = headersBuf.WriteByte(7) // string type
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len(eventType)))
_, _ = headersBuf.WriteString(eventType)
// :message-type header
_ = headersBuf.WriteByte(byte(len(":message-type")))
_, _ = headersBuf.WriteString(":message-type")
_ = headersBuf.WriteByte(7)
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("event")))
_, _ = headersBuf.WriteString("event")
headers := headersBuf.Bytes()
headersLen := uint32(len(headers))
// total = 12 (prelude) + headers + payload + 4 (message_crc)
totalLen := uint32(12 + len(headers) + len(payload) + 4)
// Prelude: total_length(4) + headers_length(4)
var preludeBuf bytes.Buffer
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
preludeBytes := preludeBuf.Bytes()
preludeCRC := crc32.Checksum(preludeBytes, crc32IeeeTab)
// Build frame: prelude + prelude_crc + headers + payload
var frame bytes.Buffer
_, _ = frame.Write(preludeBytes)
_ = binary.Write(&frame, binary.BigEndian, preludeCRC)
_, _ = frame.Write(headers)
_, _ = frame.Write(payload)
// Message CRC covers everything before itself
messageCRC := crc32.Checksum(frame.Bytes(), crc32IeeeTab)
_ = binary.Write(&frame, binary.BigEndian, messageCRC)
return frame.Bytes()
}
t.Run("decode chunk event", func(t *testing.T) {
payload := []byte(`{"bytes":"dGVzdA=="}`) // base64("test")
frame := buildFrame("chunk", payload)
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
result, err := decoder.Decode()
require.NoError(t, err)
assert.Equal(t, payload, result)
})
t.Run("skip non-chunk events", func(t *testing.T) {
// Write initial-response followed by chunk
var buf bytes.Buffer
_, _ = buf.Write(buildFrame("initial-response", []byte(`{}`)))
chunkPayload := []byte(`{"bytes":"aGVsbG8="}`)
_, _ = buf.Write(buildFrame("chunk", chunkPayload))
decoder := newBedrockEventStreamDecoder(&buf)
result, err := decoder.Decode()
require.NoError(t, err)
assert.Equal(t, chunkPayload, result)
})
t.Run("EOF on empty input", func(t *testing.T) {
decoder := newBedrockEventStreamDecoder(bytes.NewReader(nil))
_, err := decoder.Decode()
assert.Equal(t, io.EOF, err)
})
t.Run("corrupted prelude CRC", func(t *testing.T) {
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
// Corrupt the prelude CRC (bytes 8-11)
frame[8] ^= 0xFF
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
_, err := decoder.Decode()
require.Error(t, err)
assert.Contains(t, err.Error(), "prelude CRC mismatch")
})
t.Run("corrupted message CRC", func(t *testing.T) {
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
// Corrupt the message CRC (last 4 bytes)
frame[len(frame)-1] ^= 0xFF
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
_, err := decoder.Decode()
require.Error(t, err)
assert.Contains(t, err.Error(), "message CRC mismatch")
})
t.Run("castagnoli encoded frame is rejected", func(t *testing.T) {
castagnoliTab := crc32.MakeTable(crc32.Castagnoli)
payload := []byte(`{"bytes":"dGVzdA=="}`)
var headersBuf bytes.Buffer
_ = headersBuf.WriteByte(byte(len(":event-type")))
_, _ = headersBuf.WriteString(":event-type")
_ = headersBuf.WriteByte(7)
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("chunk")))
_, _ = headersBuf.WriteString("chunk")
headers := headersBuf.Bytes()
headersLen := uint32(len(headers))
totalLen := uint32(12 + len(headers) + len(payload) + 4)
var preludeBuf bytes.Buffer
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
preludeBytes := preludeBuf.Bytes()
var frame bytes.Buffer
_, _ = frame.Write(preludeBytes)
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(preludeBytes, castagnoliTab))
_, _ = frame.Write(headers)
_, _ = frame.Write(payload)
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(frame.Bytes(), castagnoliTab))
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame.Bytes()))
_, err := decoder.Decode()
require.Error(t, err)
assert.Contains(t, err.Error(), "prelude CRC mismatch")
})
}
func TestBuildBedrockURL(t *testing.T) {
t.Run("stream URL with colon in model ID", func(t *testing.T) {
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-opus-4-5-20251101-v1:0", true)
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/invoke-with-response-stream", url)
})
t.Run("non-stream URL with colon in model ID", func(t *testing.T) {
url := BuildBedrockURL("eu-west-1", "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", false)
assert.Equal(t, "https://bedrock-runtime.eu-west-1.amazonaws.com/model/eu.anthropic.claude-sonnet-4-5-20250929-v1%3A0/invoke", url)
})
t.Run("model ID without colon", func(t *testing.T) {
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-sonnet-4-6", true)
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-sonnet-4-6/invoke-with-response-stream", url)
})
}

View File

@@ -35,6 +35,7 @@ type DashboardAggregationRepository interface {
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
}
@@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays)
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
if aggErr != nil {
@@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
if usageErr != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
}
if aggErr == nil && usageErr == nil {
dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff)
if dedupErr != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr)
}
if aggErr == nil && usageErr == nil && dedupErr == nil {
s.lastRetentionCleanup.Store(now)
}
}

View File

@@ -12,12 +12,18 @@ import (
type dashboardAggregationRepoTestStub struct {
aggregateCalls int
recomputeCalls int
cleanupUsageCalls int
cleanupDedupCalls int
ensurePartitionCalls int
lastStart time.Time
lastEnd time.Time
watermark time.Time
aggregateErr error
cleanupAggregatesErr error
cleanupUsageErr error
cleanupDedupErr error
ensurePartitionErr error
}
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
@@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
}
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
s.recomputeCalls++
return s.AggregateRange(ctx, start, end)
}
@@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context
}
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
s.cleanupUsageCalls++
return s.cleanupUsageErr
}
func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
s.cleanupDedupCalls++
return s.cleanupDedupErr
}
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
s.ensurePartitionCalls++
return s.ensurePartitionErr
}
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
@@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
require.Nil(t, svc.lastRetentionCleanup.Load())
require.Equal(t, 1, repo.cleanupUsageCalls)
require.Equal(t, 1, repo.cleanupDedupCalls)
}
func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
require.Nil(t, svc.lastRetentionCleanup.Load())
require.Equal(t, 1, repo.cleanupDedupCalls)
}
func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Enabled: true,
IntervalSeconds: 60,
LookbackSeconds: 120,
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
UsageBillingDedupDays: 2,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.runScheduledAggregation()
require.Equal(t, 1, repo.ensurePartitionCalls)
require.Equal(t, 1, repo.aggregateCalls)
}
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {

Some files were not shown because too many files have changed in this diff Show More