Compare commits

..

75 Commits

Author SHA1 Message Date
Wesley Liddick
474165d7aa Merge pull request #1043 from touwaeriol/pr/antigravity-credits-overages
feat: Antigravity AI Credits overages handling & balance display
2026-03-16 09:22:19 +08:00
Wesley Liddick
94e067a2e2 Merge pull request #1040 from 0xObjc/codex/fix-user-spending-ranking-others
fix(admin): polish spending ranking and usage defaults
2026-03-16 09:19:46 +08:00
Wesley Liddick
4293c89166 Merge pull request #1036 from Ethan0x0000/feat/usage-endpoint-distribution
fix: record endpoint info for all API surfaces & unify normalization via middleware
2026-03-16 09:17:32 +08:00
Wesley Liddick
ec82c37da5 Merge pull request #1042 from touwaeriol/feat/unified-oauth-refresh-api
feat: unified OAuth token refresh API with distributed locking
2026-03-16 09:00:42 +08:00
erio
552a4b998a fix: resolve golangci-lint issues (gofmt, errcheck)
- Fix gofmt alignment in admin_service.go and trailing newline in
  antigravity_credits_overages.go
- Suppress errcheck for fmt.Sscanf in client.go GetMinimumAmount
2026-03-16 05:15:27 +08:00
erio
0d2061b268 fix: remove ClaudeMax references not yet in upstream/main
Remove SimulateClaudeMaxEnabled field and related logic from
admin_service.go, and remove applyClaudeMaxCacheBillingPolicyToUsage,
applyClaudeMaxNonStreamingRewrite, setupClaudeMaxStreamingHook calls
from antigravity_gateway_service.go. These symbols are not yet
available in upstream/main.
2026-03-16 05:01:42 +08:00
erio
8a260defc2 refactor: replace sync.Map credits state with AICredits rate limit key
Replace process-memory sync.Map + per-model runtime state with a single
"AICredits" key in model_rate_limits, making credits exhaustion fully
isomorphic with model-level rate limiting.

Scheduler: rate-limited accounts with overages enabled + credits available
are now scheduled instead of excluded.

Forwarding: when model is rate-limited + credits available, inject credits
proactively without waiting for a 429 round trip.

Storage: credits exhaustion stored as model_rate_limits["AICredits"] with
5h duration, reusing SetModelRateLimit/isRateLimitActiveForKey.

Frontend: show credits_active (yellow ) when model rate-limited but
credits available, credits_exhausted (red) when AICredits key active.

Tests: add unit tests for shouldMarkCreditsExhausted, injectEnabledCreditTypes,
clearCreditsExhausted, and update existing overages tests.
2026-03-16 04:58:58 +08:00
SilentFlower
e14c87597a feat: simplify AI Credits display logic and enhance UI presentation 2026-03-16 04:58:46 +08:00
SilentFlower
f3f19d35aa feat: enhance Antigravity account overages handling and improve UI credit display 2026-03-16 04:58:35 +08:00
SilentFlower
ced90e1d84 feat: add AI Credits balance handling and update model status indicators 2026-03-16 04:58:23 +08:00
SilentFlower
17e4033340 feat: implement resolveCreditsOveragesModelKey function to stabilize model key resolution for credit overages 2026-03-16 04:58:12 +08:00
erio
044d3a013d fix: suppress SA4006 unused value warning in Path A branch 2026-03-16 01:38:06 +08:00
erio
1fc9dd7b68 feat: unified OAuth token refresh API with distributed locking
Introduce OAuthRefreshAPI as the single entry point for all OAuth token
refresh operations, eliminating the race condition where background
refresh and inline refresh could simultaneously use the same
refresh_token (fixes #1035).

Key changes:
- Add OAuthRefreshExecutor interface extending TokenRefresher with CacheKey
- Add OAuthRefreshAPI.RefreshIfNeeded with lock → DB re-read → double-check flow
- Add ProviderRefreshPolicy / BackgroundRefreshPolicy strategy types
- Simplify all 4 TokenProviders to delegate to OAuthRefreshAPI
- Rewrite TokenRefreshService.refreshWithRetry to use unified API path
- Add MergeCredentials and BuildClaudeAccountCredentials helpers
- Add 40 unit tests covering all new and modified code paths
2026-03-16 01:31:54 +08:00
Peter
8147866c09 fix(admin): polish spending ranking and usage defaults 2026-03-16 00:17:47 +08:00
Ethan0x0000
7bd1972f94 refactor: migrate all handlers to shared endpoint normalization middleware
- Apply InboundEndpointMiddleware to all gateway route groups
- Replace normalizedOpenAIInboundEndpoint/normalizedOpenAIUpstreamEndpoint and normalizedGatewayInboundEndpoint/normalizedGatewayUpstreamEndpoint with GetInboundEndpoint/GetUpstreamEndpoint
- Remove 4 old constants and 4 old normalization functions (-70 lines)
- Migrate existing endpoint normalization test to new API

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-03-15 22:13:42 +08:00
Ethan0x0000
2c9dcfe27b refactor: add unified endpoint normalization infrastructure
Introduce endpoint.go with shared constants, NormalizeInboundEndpoint, DeriveUpstreamEndpoint, InboundEndpointMiddleware, and context helpers. This replaces the two separate normalization implementations (OpenAI and Gateway) with a single source of truth. Includes comprehensive test coverage.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-03-15 22:13:31 +08:00
Ethan0x0000
1b79b0f3ff feat: add InboundEndpoint/UpstreamEndpoint fields to non-OpenAI usage records
Extend RecordUsageInput and RecordUsageLongContextInput structs with InboundEndpoint and UpstreamEndpoint so that Claude, Gemini, and Sora handlers can record endpoint info alongside OpenAI handlers.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-03-15 22:13:22 +08:00
Ethan0x0000
c637e6cf31 fix: use half-open date ranges for DST-safe usage queries
Replace t.Add(24*time.Hour - time.Nanosecond) with t.AddDate(0, 0, 1) and use SQL < instead of <= for end-of-day boundaries. This avoids edge-case misses around DST transitions.

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-03-15 22:13:12 +08:00
Wesley Liddick
d3a9f5bb88 Merge pull request #1027 from touwaeriol/feat/ignore-insufficient-balance-errors
feat(ops): add ignore insufficient balance errors toggle and extract error constants
2026-03-15 19:10:18 +08:00
Wesley Liddick
7eb0415a8a Merge pull request #1028 from IanShaw027/fix/open-issues-cleanup
fix: 修复多个issues - Gemini schema 兼容性、批量编辑白名单、Docker 工具支持和限额字段处理Fix/open issues cleanup
2026-03-15 19:09:49 +08:00
erio
bdbc8fa08f fix(ops): align constant declarations for gofmt compliance 2026-03-15 18:55:14 +08:00
erio
63f3af0f94 fix(ops): match "insufficient account balance" in error filter
The upstream Gemini API returns "Insufficient account balance" which
doesn't contain the substring "insufficient balance". Add explicit
match for the full phrase to ensure the filter works correctly.
2026-03-15 18:45:48 +08:00
IanShaw027
686f890fbf style: 修复 gofmt 格式问题 2026-03-15 18:42:32 +08:00
shaw
220fbe6544 fix: 恢复 UsageProgressBar 中被意外移除的窗口统计数据展示
commit 0debe0a8 在修复 OpenAI WS 用量窗口刷新问题时,意外删除了
UsageProgressBar 中的 window stats 渲染逻辑和格式化函数。

恢复进度条上方的统计行(requests, tokens, account cost, user cost)
及对应的 4 个格式化 computed 属性。
2026-03-15 18:29:23 +08:00
shaw
ae44a94325 fix: 重置密码功能新增UI配置发送邮件域名 2026-03-15 17:52:29 +08:00
IanShaw
3718d6dcd4 Merge branch 'Wei-Shaw:main' into fix/open-issues-cleanup 2026-03-15 17:49:20 +08:00
IanShaw027
90b3838173 fix: 移除 Gemini 不支持的 patternProperties 字段 #795 2026-03-15 17:46:58 +08:00
IanShaw027
19d3ecc76f fix: 修复批量编辑账号时模型白名单显示与实际不一致的问题 #982
修复批量编辑账号时,UI 显示的是 plain 模型名(如 GPT-5),但实际落库的是 dated 模型名的问题。

核心改动:
1. 批量编辑白名单不再使用 BulkEditAccountModal.vue 中手写的过期模型列表
   - 移除了 allModels 和 presetMappings 的硬编码列表(共 200+ 行)
   - 直接复用 ModelWhitelistSelector.vue 组件

2. ModelWhitelistSelector 组件支持多平台联合过滤
   - 新增 platforms 属性支持传入多个平台
   - 添加 normalizedPlatforms 计算属性统一处理单平台和多平台场景
   - availableOptions 根据选中的多个平台动态联合过滤模型列表
   - fillRelated 功能支持一次性填充多个平台的相关模型

3. 模型映射预设改为动态生成
   - filteredPresets 改用 getPresetMappingsByPlatform 从统一模型源按平台动态生成
   - 不再依赖弹窗中的手写预设列表

现在的行为:
- UI 显示什么模型,勾选什么模型,传给后端的就是什么模型
- 彻底解决了批量编辑链路上"显示与实际不一致"的问题
- 模型列表和映射预设始终与系统定义保持同步
2026-03-15 17:46:58 +08:00
IanShaw027
6fba4ebb13 fix: 在 Dockerfile.goreleaser 中添加 pg_dump 和 psql 工具 #1002
为了支持容器内的数据库备份和恢复功能,在运行时镜像中添加 PostgreSQL 客户端工具。

变更内容:
- 使用多阶段构建从 postgres:18-alpine 镜像复制 pg_dump 和 psql 二进制文件
- 添加必要的依赖库(libpq, zstd-libs, lz4-libs, krb5-libs, libldap, libedit)
- 升级基础镜像到 alpine:3.21
- 复制 libpq.so.5 共享库以确保工具正常运行

这样可以在运行时容器中直接执行数据库备份和恢复操作,无需访问 Docker socket。
2026-03-15 17:46:58 +08:00
IanShaw027
c31974c913 fix: 兼容部分限额字段为空的情况 #1021
修复在填写限额时,如果不填写完整的三个限额额度(日限额、周限额、月限额)就会报错的问题。

变更内容:
- 后端:添加 optionalLimitField 类型处理空值和空字符串,兼容部分限额字段为空的情况
- 前端:添加 normalizeOptionalLimit 函数规范化限额输入,将空值、空字符串和无效数字统一处理为 null
2026-03-15 17:46:58 +08:00
erio
6177fa5dd8 fix(i18n): correct insufficient balance error hint text
Remove misleading "upstream" wording - the error is about client API key
user balance, not upstream account balance.
2026-03-15 17:41:51 +08:00
erio
cfe72159d0 feat(ops): add ignore insufficient balance errors toggle and extract error constants
- Add 5th error filter switch IgnoreInsufficientBalanceErrors to suppress
  upstream insufficient balance / insufficient_quota errors from ops log
- Extract hardcoded error strings into package-level constants for
  shouldSkipOpsErrorLog, normalizeOpsErrorType, classifyOpsPhase, and
  classifyOpsIsBusinessLimited
- Define ErrNoAvailableAccounts sentinel error and replace all
  errors.New("no available accounts") call sites
- Update tests to use require.ErrorIs with the sentinel error
2026-03-15 17:26:18 +08:00
Wesley Liddick
8321e4a647 Merge pull request #1023 from YanzheL/fix/claude-output-effort-logging
fix: extract and log Claude output_config.effort in usage records
2026-03-15 16:45:37 +08:00
Wesley Liddick
3084330d0c Merge pull request #1019 from Ethan0x0000/feat/usage-endpoint-distribution
feat: add endpoint metadata and usage endpoint distribution insights
2026-03-15 16:42:03 +08:00
Wesley Liddick
b566649e79 Merge pull request #1025 from touwaeriol/fix/rate-limit-nil-window-reset
fix(billing): treat nil rate limit window as expired to prevent usage accumulation
2026-03-15 16:33:14 +08:00
Wesley Liddick
10a6180e4a Merge pull request #1026 from touwaeriol/fix/group-quota-clear
fix(billing): allow clearing group quota limits and treat 0 as zero-limit
2026-03-15 16:33:00 +08:00
Wesley Liddick
cbe9e78977 Merge pull request #1007 from StarryKira/fix/streaming-failover-corruption
fix(gateway): 防止流式 failover 拼接腐化导致客户端收到双 message_start fix issue #991
2026-03-15 16:29:31 +08:00
Wesley Liddick
74145b1f39 Merge pull request #1017 from SsageParuders/fix/bedrock-account-quota
fix: Bedrock 账户配额限制不生效
2026-03-15 16:28:42 +08:00
Elysia
359e56751b 增加测试 2026-03-15 16:21:49 +08:00
erio
5899784aa4 fix(billing): allow clearing group quota limits and treat 0 as zero-limit
Previously, v-model.number produced "" when input was cleared, causing
JSON decode errors on the backend. Also, normalizeLimit treated 0 as
"unlimited" which prevented setting a zero quota. Now "" is converted
to null (unlimited) in frontend, and 0 is preserved as a valid limit.

Closes Wei-Shaw/sub2api#1021
2026-03-15 16:15:15 +08:00
erio
9e8959c56d fix(billing): treat nil rate limit window as expired to prevent usage accumulation
When Redis cache is populated from DB with a NULL window_1d_start, the
Lua increment script only updates usage counters without setting window
timestamps. IsWindowExpired(nil) previously returned false, so the
accumulated usage was never reset across time windows, effectively
turning usage_1d into a lifetime counter. Once this exceeded
rate_limit_1d the key was incorrectly blocked with "日限额已用完".

Fixes Wei-Shaw/sub2api#1022
2026-03-15 14:04:13 +08:00
YanzheL
1bff2292a6 fix: extract and log Claude output_config.effort in usage records
Claude's output_config.effort parameter (low/medium/high/max) was not
being extracted from requests or logged in the reasoning_effort column
of usage logs. Only the OpenAI path populated this field.

Changes:
- Extract output_config.effort in ParseGatewayRequest
- Add ReasoningEffort field to ForwardResult
- Populate reasoning_effort in both RecordUsage and RecordUsageWithLongContext
- Guard against overwriting service-set effort values in handler
- Update stale comments that described reasoning_effort as OpenAI-only
- Add unit tests for extraction, normalization, and persistence
2026-03-15 12:55:37 +08:00
Ethan0x0000
cf9247754e test: fix usage repo stubs for unit builds 2026-03-15 12:51:34 +08:00
Ethan0x0000
eefab15958 feat: 完善使用记录端点可观测性与分布统计
将入站、上游与路径三类端点分布统一到使用记录页的一致化卡片交互中,并补齐端点元数据与统计链路,提升排障与流量分析效率。
2026-03-15 11:26:42 +08:00
Elysia
0e23732631 fix(gateway): 防止流式 failover 拼接腐化导致客户端收到双 message_start
当上游在 SSE 流中途返回 event:error 时,handleStreamingResponse 已将
部分 SSE 事件写入客户端,但原先的 failover 逻辑仍会切换到下一个账号
并写入完整流,导致客户端收到两个 message_start 进而产生 400 错误。

修复方案:在每次 Forward 调用前记录 c.Writer.Size(),若 Forward 返回
UpstreamFailoverError 后 writer 字节数增加,说明 SSE 内容已不可撤销地
发送给客户端,此时直接调用 handleFailoverExhausted 发送 SSE error 事件
终止流,而非继续 failover。

Ping-only 场景不受影响:slot 等待期的 ping 字节在 Forward 前后相等,
正常 failover 流程照常进行。

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-14 22:49:23 +08:00
SsageParuders
37c044fb4b fix: Bedrock 账户配额限制不生效,配额计数器始终为 $0.00
applyUsageBillingEffects() 中配额更新条件仅检查了 AccountTypeAPIKey,
遗漏了 AccountTypeBedrock,导致 Bedrock 账户的配额计数器永远不递增。
扩展条件以同时支持 APIKey 和 Bedrock 类型。

同时在前端账户筛选下拉框中添加 AWS Bedrock 选项。
2026-03-14 22:47:44 +08:00
shaw
6da5fa01b9 fix(frontend): 修复运维设置对话框保存按钮始终禁用的问题
后端默认 alert.enabled=true 但 recipients 为空,前端验证将其视为
错误并阻断保存按钮。移除该阻断性验证,改为保存时自动禁用无收件人
的邮件通知配置。
2026-03-14 20:39:29 +08:00
shaw
616930f9d3 refactor(frontend): 将备份和数据管理页面合并为设置页的标签页
将独立的 /admin/backup 和 /admin/data-management 页面整合到设置页,
作为「备份」和「Sora 存储」标签页,减少侧边栏条目,集中管理配置。

- 移除 BackupView 和 DataManagementView 的 AppLayout 包装
- 在 SettingsView 中以子组件形式嵌入,使用 v-show 切换标签
- 删除独立路由和侧边栏菜单入口
- 备份/数据标签页下隐藏主保存按钮(各自有独立保存)
- 优化标签栏样式适配7个标签,PC端支持细滚动条
- 清理未使用的图标组件和 i18n 键
2026-03-14 20:22:39 +08:00
Wesley Liddick
b9c31fa7c4 Merge pull request #999 from InCerryGit/fix/enc_coot
fix: handle invalid encrypted content error and retry logic.
2026-03-14 19:29:07 +08:00
Wesley Liddick
17b339972c Merge pull request #1000 from touwaeriol/fix/ops-agg-tuning
fix(ops): tune aggregation constants to prevent PG overload
2026-03-14 19:03:03 +08:00
shaw
39f8bd91b9 fix: remove unused saveRecords method to pass lint 2026-03-14 19:01:27 +08:00
Wesley Liddick
aa4e37d085 Merge pull request #966 from GuangYiDing/feat/db-backup-restore
feat: 数据库定时备份与恢复(S3 兼容存储,支持 Cloudflare R2)
2026-03-14 18:58:56 +08:00
erio
f59b66b7d4 fix(ops): tune aggregation constants to prevent PG overload
Increase MAX(bucket_start) query timeout from 3s to 5s to reduce
timeout-induced fallbacks. Shrink backfill window from 30 days to
1 hour so that fallback recomputation stays lightweight instead of
scanning the entire retention range.
2026-03-14 18:47:37 +08:00
InCerry
8f0ea7a02d Merge branch 'main' into fix/enc_coot 2026-03-14 18:46:33 +08:00
Wesley Liddick
a1dc00890e Merge pull request #944 from miraserver/feat/backend-mode
feat: add Backend Mode toggle to disable user self-service
2026-03-14 17:53:54 +08:00
Wesley Liddick
dfbcc363d1 Merge pull request #969 from wucm667/feat/quota-fixed-reset-mode
feat: 账号配额支持固定时间重置模式
2026-03-14 17:52:56 +08:00
Rose Ding
1047f973d5 fix: 按 review 意见重构数据库备份服务(安全性 + 架构 + 健壮性)
1. S3 凭证加密存储:使用 SecretEncryptor (AES-256-GCM) 加密 SecretAccessKey,
   防止备份文件中泄露 S3 凭证,兼容旧的未加密数据
2. 修复 saveRecord 竞态条件:添加 recordsMu 互斥锁保护 records 的 load/save
3. 恢复操作增加服务端验证:handler 层要求重新输入管理员密码,通过 bcrypt
   校验,前端弹出密码输入框
4. pg_dump/psql/S3 操作抽象为接口:定义 DBDumper 和 BackupObjectStore 接口,
   实现放入 repository 层,遵循项目依赖注入架构规范
5. 改为流式处理避免大数据库 OOM:备份时 pg_dump stdout -> gzip -> io.Pipe ->
   S3 upload;恢复时 S3 download -> gzip reader -> psql stdin,不再全量加载
6. loadRecords 区分"无数据"和"数据损坏"场景:JSON 解析失败返回明确错误
7. 添加 18 个核心逻辑单元测试:覆盖加密、并发、流式备份/恢复、错误处理等

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-14 17:48:21 +08:00
Wesley Liddick
e32977dd73 Merge pull request #997 from SsageParuders/refactor/bedrock-channel-merge
refactor: merge bedrock-apikey into unified bedrock channel with auth_mode
2026-03-14 17:48:01 +08:00
wucm667
b5f78ec1e8 feat: 实现固定时间重置模式的 SQL 表达式,并添加相关单元测试 2026-03-14 17:37:34 +08:00
SsageParuders
e0f290fdc8 style: fix gofmt formatting for account type constants 2026-03-14 17:32:53 +08:00
Wesley Liddick
fc00a4e3b2 Merge pull request #959 from touwaeriol/feat/antigravity-403-detection
feat(antigravity): add 403 forbidden status detection and display
2026-03-14 17:23:22 +08:00
Wesley Liddick
db1f6ded88 Merge pull request #961 from 0xObjc/codex/ops-openai-token-visibility
feat(ops): make OpenAI token stats optional
2026-03-14 17:23:01 +08:00
SsageParuders
4644af2ccc refactor: merge bedrock-apikey into bedrock with auth_mode credential
Consolidate two separate channel types (bedrock + bedrock-apikey) into
a single "AWS Bedrock" channel. Authentication mode is now distinguished
by credentials.auth_mode ("sigv4" | "apikey") instead of separate types.

Backend:
- Remove AccountTypeBedrockAPIKey constant
- IsBedrock() simplified; IsBedrockAPIKey() checks auth_mode
- Add IsAPIKeyOrBedrock() helper to eliminate repeated type checks
- Extend pool mode, quota scheduling, and billing to bedrock
- Add RetryableOnSameAccount to handleBedrockUpstreamErrors
- Add "bedrock" scope to Beta Policy for independent control

Frontend:
- Merge two buttons into one "AWS Bedrock" with auth mode radio
- Badge displays "Anthropic | AWS"
- Pool mode and quota limit UI available for bedrock
- Quota display in account list (usage bars, capacity badges, reset)
- Remove all bedrock-apikey type references
2026-03-14 17:13:30 +08:00
InCerry
e4a4dfd038 Merge remote-tracking branch 'origin/main' into fix/enc_coot
# Conflicts:
#	backend/internal/service/openai_gateway_service.go
2026-03-14 13:04:24 +08:00
InCerry
2666422b99 fix: handle invalid encrypted content error and retry logic. 2026-03-14 11:42:42 +08:00
erio
45456fa24c fix: restore OAuth 401 temp-unschedulable for Gemini, update Antigravity tests
The 403 detection PR changed the 401 handler condition from
`account.Type == AccountTypeOAuth` to
`account.Type == AccountTypeOAuth && account.Platform == PlatformOpenAI`,
which accidentally excluded Gemini OAuth from the temp-unschedulable path.

Fix: use `!= PlatformAntigravity` instead, preserving Gemini behavior
while correctly excluding Antigravity (whose 401 is handled by
applyErrorPolicy's temp_unschedulable_rules).

Update tests to reflect Antigravity's new 401 semantics:
- HandleUpstreamError: Antigravity OAuth 401 now uses SetError
- CheckErrorPolicy: Antigravity 401 second hit stays TempUnscheduled
- DB fallback: split into Gemini (escalates) and Antigravity (stays temp)
2026-03-14 02:21:22 +08:00
erio
6344fa2a86 feat(antigravity): add 403 forbidden status detection, classification and display
Backend:
- Detect and classify 403 responses into three types:
  validation (account needs Google verification),
  violation (terms of service / banned),
  forbidden (generic 403)
- Extract verification/appeal URLs from 403 response body
  (structured JSON parsing with regex fallback)
- Add needs_verify, is_banned, needs_reauth, error_code fields
  to UsageInfo (omitempty for zero impact on other platforms)
- Handle 403 in request path: classify and permanently set account error
- Save validation_url in error_message for degraded path recovery
- Enrich usage with account error on both success and degraded paths
- Add singleflight dedup for usage requests with independent context
- Differentiate cache TTL: success/403 → 3min, errors → 1min
- Return degraded UsageInfo instead of HTTP 500 on quota fetch errors

Frontend:
- Display forbidden status badges with color coding (red for banned,
  amber for needs verification, gray for generic)
- Show clickable verification/appeal URL links
- Display needs_reauth and degraded error states in usage cell
- Add Antigravity tier label badge next to platform type

Tests:
- Comprehensive unit tests for classifyForbiddenType (7 cases)
- Unit tests for extractValidationURL (8 cases including unicode escapes)
- Integration test for FetchQuota forbidden path
2026-03-13 18:22:45 +08:00
Peter
29b0e4a8a5 feat(ops): allow hiding alert events 2026-03-13 17:18:04 +08:00
Rose Ding
f7177be3b6 fix: golangci-lint 修复(gofmt 格式化 + errcheck 返回值检查)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-13 12:01:05 +08:00
Rose Ding
875b417fde fix: 补充 wire_gen_test.go 中 provideCleanup 缺少的 backupSvc 参数
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-13 11:53:46 +08:00
wucm667
2573107b32 refactor: 将 ComputeQuotaResetAt 和 ValidateQuotaResetConfig 函数中的 map 类型从 map[string]interface{} 修改为 map[string]any 2026-03-13 11:44:49 +08:00
wucm667
5b85005945 feat: 账号配额支持固定时间重置模式
- 后端新增 rolling/fixed 两种配额重置模式,支持日配额和周配额
- fixed 模式下可配置重置时刻(小时)、重置星期几(周配额)及时区(IANA)
- 在 account_repo.go 中使用 SQL 表达式适配两种模式的过期判断与重置时间推进
- 新增 ComputeQuotaResetAt / ValidateQuotaResetConfig 等辅助函数
- DTO 层新增相关字段并在 mappers 中完整映射
- 前端 QuotaLimitCard 新增 rolling/fixed 切换 UI、时区选择器
- CreateAccountModal / EditAccountModal 透传新配置字段
- i18n(zh/en)同步新增相关翻译词条
2026-03-13 11:12:37 +08:00
Rose Ding
53ad1645cf feat: 数据库定时备份与恢复(S3 兼容存储,支持 Cloudflare R2)
新增管理员专属的数据库备份与恢复功能:
- 全量 PostgreSQL 备份(pg_dump),gzip 压缩后上传到 S3 兼容存储
- 支持手动备份和 cron 定时备份
- 支持从备份恢复(psql --single-transaction)
- 备份文件自动过期清理(默认 14 天)
- 前端完整管理页面(S3 配置、定时配置、备份列表、恢复/下载/删除)
- 内置 Cloudflare R2 配置教程弹窗
- Dockerfile 从 postgres 镜像多阶段复制 pg_dump/psql,确保版本一致

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-13 10:38:19 +08:00
Peter
af9c4a7dd0 feat(ops): make openai token stats optional 2026-03-13 04:11:58 +08:00
John Doe
6826149a8f feat: add Backend Mode toggle to disable user self-service
Add a system-wide "Backend Mode" that disables user self-registration
and self-service while keeping admin panel and API gateway fully
functional. When enabled, only admin can log in; all user-facing
routes return 403.

Backend:
- New setting key `backend_mode_enabled` with atomic cached reads (60s TTL)
- BackendModeUserGuard middleware blocks non-admin authenticated routes
- BackendModeAuthGuard middleware blocks registration/password-reset auth routes
- Login/Login2FA/RefreshToken handlers reject non-admin when enabled
- TokenPairWithUser struct for role-aware token refresh
- 20 unit tests (middleware + service layer)

Frontend:
- Router guards redirect unauthenticated users to /login
- Admin toggle in Settings page
- Login page hides register link and footer in backend mode
- 9 unit tests for router guard logic
- i18n support (en/zh)

27 files changed, 833 insertions(+), 17 deletions(-)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-12 02:42:57 +03:00
158 changed files with 12784 additions and 1821 deletions

View File

@@ -9,6 +9,7 @@
ARG NODE_IMAGE=node:24-alpine
ARG GOLANG_IMAGE=golang:1.26.1-alpine
ARG ALPINE_IMAGE=alpine:3.21
ARG POSTGRES_IMAGE=postgres:18-alpine
ARG GOPROXY=https://goproxy.cn,direct
ARG GOSUMDB=sum.golang.google.cn
@@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \
./cmd/server
# -----------------------------------------------------------------------------
# Stage 3: Final Runtime Image
# Stage 3: PostgreSQL Client (version-matched with docker-compose)
# -----------------------------------------------------------------------------
FROM ${POSTGRES_IMAGE} AS pg-client
# -----------------------------------------------------------------------------
# Stage 4: Final Runtime Image
# -----------------------------------------------------------------------------
FROM ${ALPINE_IMAGE}
@@ -86,8 +92,20 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
RUN apk add --no-cache \
ca-certificates \
tzdata \
libpq \
zstd-libs \
lz4-libs \
krb5-libs \
libldap \
libedit \
&& rm -rf /var/cache/apk/*
# Copy pg_dump and psql from the same postgres image used in docker-compose
# This ensures version consistency between backup tools and the database server
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
# Create non-root user
RUN addgroup -g 1000 sub2api && \
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api

View File

@@ -5,7 +5,12 @@
# It only packages the pre-built binary, no compilation needed.
# =============================================================================
FROM alpine:3.19
ARG ALPINE_IMAGE=alpine:3.21
ARG POSTGRES_IMAGE=postgres:18-alpine
FROM ${POSTGRES_IMAGE} AS pg-client
FROM ${ALPINE_IMAGE}
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
LABEL description="Sub2API - AI API Gateway Platform"
@@ -16,8 +21,20 @@ RUN apk add --no-cache \
ca-certificates \
tzdata \
curl \
libpq \
zstd-libs \
lz4-libs \
krb5-libs \
libldap \
libedit \
&& rm -rf /var/cache/apk/*
# Copy pg_dump and psql from a version-matched PostgreSQL image so backup and
# restore work in the runtime container without requiring Docker socket access.
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
# Create non-root user
RUN addgroup -g 1000 sub2api && \
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api

View File

@@ -94,6 +94,7 @@ func provideCleanup(
antigravityOAuth *service.AntigravityOAuthService,
openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -230,6 +231,12 @@ func provideCleanup(
}
return nil
}},
{"BackupService", func() error {
if backupSvc != nil {
backupSvc.Stop()
}
return nil
}},
}
infraSteps := []cleanupStep{

View File

@@ -124,6 +124,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
httpUpstream := repository.NewHTTPUpstream(configConfig)
@@ -132,11 +133,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
usageCache := service.NewUsageCache()
identityCache := repository.NewIdentityCache(redisClient)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
@@ -146,6 +147,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
dataManagementService := service.NewDataManagementService()
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
backupObjectStoreFactory := repository.NewS3BackupStoreFactory()
dbDumper := repository.NewPgDumper(configConfig)
backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper)
backupHandler := admin.NewBackupHandler(backupService, userService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
@@ -162,10 +167,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
@@ -201,7 +206,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@@ -228,11 +233,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -285,6 +290,7 @@ func provideCleanup(
antigravityOAuth *service.AntigravityOAuthService,
openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -420,6 +426,12 @@ func provideCleanup(
}
return nil
}},
{"BackupService", func() error {
if backupSvc != nil {
backupSvc.Stop()
}
return nil
}},
}
infraSteps := []cleanupStep{

View File

@@ -75,6 +75,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
antigravityOAuthSvc,
nil, // openAIGateway
nil, // scheduledTestRunner
nil, // backupSvc
)
require.NotPanics(t, func() {

View File

@@ -27,12 +27,11 @@ const (
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock
AccountTypeBedrockAPIKey = "bedrock-apikey" // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock由 credentials.auth_mode 区分
)
// Redeem type constants

View File

@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
type UpdateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`

View File

@@ -0,0 +1,204 @@
package admin
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type BackupHandler struct {
backupService *service.BackupService
userService *service.UserService
}
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
return &BackupHandler{
backupService: backupService,
userService: userService,
}
}
// ─── S3 配置 ───
func (h *BackupHandler) GetS3Config(c *gin.Context) {
cfg, err := h.backupService.GetS3Config(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
var req service.BackupS3Config
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
var req service.BackupS3Config
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.backupService.TestS3Connection(c.Request.Context(), req)
if err != nil {
response.Success(c, gin.H{"ok": false, "message": err.Error()})
return
}
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
}
// ─── 定时备份 ───
func (h *BackupHandler) GetSchedule(c *gin.Context) {
cfg, err := h.backupService.GetSchedule(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
var req service.BackupScheduleConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
// ─── 备份操作 ───
type CreateBackupRequest struct {
ExpireDays *int `json:"expire_days"` // nil=使用默认值140=永不过期
}
func (h *BackupHandler) CreateBackup(c *gin.Context) {
var req CreateBackupRequest
_ = c.ShouldBindJSON(&req) // 允许空 body
expireDays := 14 // 默认14天过期
if req.ExpireDays != nil {
expireDays = *req.ExpireDays
}
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, record)
}
func (h *BackupHandler) ListBackups(c *gin.Context) {
records, err := h.backupService.ListBackups(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
if records == nil {
records = []service.BackupRecord{}
}
response.Success(c, gin.H{"items": records})
}
func (h *BackupHandler) GetBackup(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, record)
}
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"url": url})
}
// ─── 恢复操作(需要重新输入管理员密码) ───
type RestoreBackupRequest struct {
Password string `json:"password" binding:"required"`
}
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
var req RestoreBackupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "password is required for restore operation")
return
}
// 从上下文获取当前管理员用户 ID
sub, ok := middleware.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "unauthorized")
return
}
// 获取管理员用户并验证密码
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if !user.CheckPassword(req.Password) {
response.BadRequest(c, "incorrect admin password")
return
}
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"restored": true})
}

View File

@@ -512,6 +512,8 @@ func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
payload := gin.H{
"ranking": ranking.Ranking,
"total_actual_cost": ranking.TotalActualCost,
"total_requests": ranking.TotalRequests,
"total_tokens": ranking.TotalTokens,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
}

View File

@@ -61,6 +61,8 @@ func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
return &usagestats.UserSpendingRankingResponse{
Ranking: s.ranking,
TotalActualCost: s.rankingTotal,
TotalRequests: 44,
TotalTokens: 1234,
}, nil
}
@@ -164,6 +166,8 @@ func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 50, repo.rankingLimit)
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
require.Contains(t, rec.Body.String(), "\"total_requests\":44")
require.Contains(t, rec.Body.String(), "\"total_tokens\":1234")
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)

View File

@@ -1,6 +1,9 @@
package admin
import (
"bytes"
"encoding/json"
"fmt"
"strconv"
"strings"
@@ -16,6 +19,55 @@ type GroupHandler struct {
adminService service.AdminService
}
type optionalLimitField struct {
set bool
value *float64
}
func (f *optionalLimitField) UnmarshalJSON(data []byte) error {
f.set = true
trimmed := bytes.TrimSpace(data)
if bytes.Equal(trimmed, []byte("null")) {
f.value = nil
return nil
}
var number float64
if err := json.Unmarshal(trimmed, &number); err == nil {
f.value = &number
return nil
}
var text string
if err := json.Unmarshal(trimmed, &text); err == nil {
text = strings.TrimSpace(text)
if text == "" {
f.value = nil
return nil
}
number, err = strconv.ParseFloat(text, 64)
if err != nil {
return fmt.Errorf("invalid numeric limit value %q: %w", text, err)
}
f.value = &number
return nil
}
return fmt.Errorf("invalid limit value: %s", string(trimmed))
}
func (f optionalLimitField) ToServiceInput() *float64 {
if !f.set {
return nil
}
if f.value != nil {
return f.value
}
zero := 0.0
return &zero
}
// NewGroupHandler creates a new admin group handler
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
return &GroupHandler{
@@ -25,15 +77,15 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
// CreateGroupRequest represents create group request
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
@@ -62,16 +114,16 @@ type CreateGroupRequest struct {
// UpdateGroupRequest represents update group request
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
@@ -191,9 +243,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
@@ -244,9 +296,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
IsExclusive: req.IsExclusive,
Status: req.Status,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,

View File

@@ -80,6 +80,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
FrontendURL: settings.FrontendURL,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
@@ -125,6 +126,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
BackendModeEnabled: settings.BackendModeEnabled,
})
}
@@ -136,6 +138,7 @@ type UpdateSettingsRequest struct {
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
FrontendURL string `json:"frontend_url"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
@@ -199,6 +202,9 @@ type UpdateSettingsRequest struct {
// 分组隔离
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
// Backend Mode
BackendModeEnabled bool `json:"backend_mode_enabled"`
}
// UpdateSettings 更新系统设置
@@ -322,6 +328,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// Frontend URL 验证
req.FrontendURL = strings.TrimSpace(req.FrontendURL)
if req.FrontendURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.FrontendURL); err != nil {
response.BadRequest(c, "Frontend URL must be an absolute http(s) URL")
return
}
}
// 自定义菜单项验证
const (
maxCustomMenuItems = 20
@@ -433,6 +448,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
FrontendURL: req.FrontendURL,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost,
@@ -473,6 +489,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
IdentityPatchPrompt: req.IdentityPatchPrompt,
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
BackendModeEnabled: req.BackendModeEnabled,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@@ -526,6 +543,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
FrontendURL: updatedSettings.FrontendURL,
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
TotpEnabled: updatedSettings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
@@ -571,6 +589,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
BackendModeEnabled: updatedSettings.BackendModeEnabled,
})
}
@@ -608,6 +627,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled")
}
if before.FrontendURL != after.FrontendURL {
changed = append(changed, "frontend_url")
}
if before.TotpEnabled != after.TotpEnabled {
changed = append(changed, "totp_enabled")
}
@@ -725,6 +747,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
changed = append(changed, "allow_ungrouped_key_scheduling")
}
if before.BackendModeEnabled != after.BackendModeEnabled {
changed = append(changed, "backend_mode_enabled")
}
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
changed = append(changed, "purchase_subscription_enabled")
}

View File

@@ -159,8 +159,8 @@ func (h *UsageHandler) List(c *gin.Context) {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
// Set end time to end of day
t = t.Add(24*time.Hour - time.Nanosecond)
// Use half-open range [start, end), move to next calendar day start (DST-safe).
t = t.AddDate(0, 0, 1)
endTime = &t
}
@@ -285,7 +285,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界DST-safe
endTime = endTime.AddDate(0, 0, 1)
} else {
period := c.DefaultQuery("period", "today")
switch period {

View File

@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
// Backend mode: only admin can login
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return
}
h.respondWithTokenPair(c, user)
}
@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
return
}
// Delete the login session
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
// Get the user
// Get the user (before session deletion so we can check backend mode)
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Backend mode: only admin can login (check BEFORE deleting session)
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return
}
// Delete the login session (only after all checks pass)
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
h.respondWithTokenPair(c, user)
}
@@ -447,9 +459,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
return
}
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
frontendBaseURL := strings.TrimSpace(h.settingSvc.GetFrontendURL(c.Request.Context()))
if frontendBaseURL == "" {
slog.Error("server.frontend_url not configured; cannot build password reset link")
slog.Error("frontend_url not configured in settings or config; cannot build password reset link")
response.InternalError(c, "Password reset is not configured")
return
}
@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
return
}
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Backend mode: block non-admin token refresh
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return
}
response.Success(c, RefreshTokenResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
ExpiresIn: tokenPair.ExpiresIn,
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
ExpiresIn: result.ExpiresIn,
TokenType: "Bearer",
})
}

View File

@@ -264,8 +264,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
}
}
// 提取 API Key 账号配额限制(apikey 类型有效)
if a.Type == service.AccountTypeAPIKey {
// 提取账号配额限制apikey / bedrock 类型有效)
if a.IsAPIKeyOrBedrock() {
if limit := a.GetQuotaLimit(); limit > 0 {
out.QuotaLimit = &limit
used := a.GetQuotaUsed()
@@ -281,6 +281,31 @@ func AccountFromServiceShallow(a *service.Account) *Account {
used := a.GetQuotaWeeklyUsed()
out.QuotaWeeklyUsed = &used
}
// 固定时间重置配置
if mode := a.GetQuotaDailyResetMode(); mode == "fixed" {
out.QuotaDailyResetMode = &mode
hour := a.GetQuotaDailyResetHour()
out.QuotaDailyResetHour = &hour
}
if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" {
out.QuotaWeeklyResetMode = &mode
day := a.GetQuotaWeeklyResetDay()
out.QuotaWeeklyResetDay = &day
hour := a.GetQuotaWeeklyResetHour()
out.QuotaWeeklyResetHour = &hour
}
if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" {
tz := a.GetQuotaResetTimezone()
out.QuotaResetTimezone = &tz
}
if a.Extra != nil {
if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" {
out.QuotaDailyResetAt = &v
}
if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" {
out.QuotaWeeklyResetAt = &v
}
}
}
return out
@@ -498,6 +523,8 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
Model: l.Model,
ServiceTier: l.ServiceTier,
ReasoningEffort: l.ReasoningEffort,
InboundEndpoint: l.InboundEndpoint,
UpstreamEndpoint: l.UpstreamEndpoint,
GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID,
InputTokens: l.InputTokens,

View File

@@ -76,10 +76,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
t.Parallel()
serviceTier := "priority"
inboundEndpoint := "/v1/chat/completions"
upstreamEndpoint := "/v1/responses"
log := &service.UsageLog{
RequestID: "req_3",
Model: "gpt-5.4",
ServiceTier: &serviceTier,
InboundEndpoint: &inboundEndpoint,
UpstreamEndpoint: &upstreamEndpoint,
AccountRateMultiplier: f64Ptr(1.5),
}
@@ -88,8 +92,16 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
require.NotNil(t, userDTO.ServiceTier)
require.Equal(t, serviceTier, *userDTO.ServiceTier)
require.NotNil(t, userDTO.InboundEndpoint)
require.Equal(t, inboundEndpoint, *userDTO.InboundEndpoint)
require.NotNil(t, userDTO.UpstreamEndpoint)
require.Equal(t, upstreamEndpoint, *userDTO.UpstreamEndpoint)
require.NotNil(t, adminDTO.ServiceTier)
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
require.NotNil(t, adminDTO.InboundEndpoint)
require.Equal(t, inboundEndpoint, *adminDTO.InboundEndpoint)
require.NotNil(t, adminDTO.UpstreamEndpoint)
require.Equal(t, upstreamEndpoint, *adminDTO.UpstreamEndpoint)
require.NotNil(t, adminDTO.AccountRateMultiplier)
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
}

View File

@@ -22,6 +22,7 @@ type SystemSettings struct {
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
FrontendURL string `json:"frontend_url"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
@@ -81,6 +82,9 @@ type SystemSettings struct {
// 分组隔离
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
// Backend Mode
BackendModeEnabled bool `json:"backend_mode_enabled"`
}
type DefaultSubscriptionSetting struct {
@@ -111,6 +115,7 @@ type PublicSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
Version string `json:"version"`
}

View File

@@ -203,6 +203,16 @@ type Account struct {
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
// 配额固定时间重置配置
QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"`
QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"`
QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"`
QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"`
QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"`
QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"`
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
@@ -324,9 +334,13 @@ type UsageLog struct {
Model string `json:"model"`
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
ServiceTier *string `json:"service_tier,omitempty"`
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
// nil means not provided / not applicable.
// ReasoningEffort is the request's reasoning effort level.
// OpenAI: "low"/"medium"/"high"/"xhigh"; Claude: "low"/"medium"/"high"/"max".
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
InboundEndpoint *string `json:"inbound_endpoint,omitempty"`
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
UpstreamEndpoint *string `json:"upstream_endpoint,omitempty"`
GroupID *int64 `json:"group_id"`
SubscriptionID *int64 `json:"subscription_id"`

View File

@@ -0,0 +1,174 @@
package handler
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ──────────────────────────────────────────────────────────
// Canonical inbound / upstream endpoint paths.
// All normalization and derivation reference this single set
// of constants — add new paths HERE when a new API surface
// is introduced.
// ──────────────────────────────────────────────────────────
const (
EndpointMessages = "/v1/messages"
EndpointChatCompletions = "/v1/chat/completions"
EndpointResponses = "/v1/responses"
EndpointGeminiModels = "/v1beta/models"
)
// gin.Context keys used by the middleware and helpers below.
const (
ctxKeyInboundEndpoint = "_gateway_inbound_endpoint"
)
// ──────────────────────────────────────────────────────────
// Normalization functions
// ──────────────────────────────────────────────────────────
// NormalizeInboundEndpoint maps a raw request path (which may carry
// prefixes like /antigravity, /openai, /sora) to its canonical form.
//
// "/antigravity/v1/messages" → "/v1/messages"
// "/v1/chat/completions" → "/v1/chat/completions"
// "/openai/v1/responses/foo" → "/v1/responses"
// "/v1beta/models/gemini:gen" → "/v1beta/models"
func NormalizeInboundEndpoint(path string) string {
path = strings.TrimSpace(path)
switch {
case strings.Contains(path, EndpointChatCompletions):
return EndpointChatCompletions
case strings.Contains(path, EndpointMessages):
return EndpointMessages
case strings.Contains(path, EndpointResponses):
return EndpointResponses
case strings.Contains(path, EndpointGeminiModels):
return EndpointGeminiModels
default:
return path
}
}
// DeriveUpstreamEndpoint determines the upstream endpoint from the
// account platform and the normalized inbound endpoint.
//
// Platform-specific rules:
// - OpenAI always forwards to /v1/responses (with optional subpath
// such as /v1/responses/compact preserved from the raw URL).
// - Anthropic → /v1/messages
// - Gemini → /v1beta/models
// - Sora → /v1/chat/completions
// - Antigravity routes may target either Claude or Gemini, so the
// inbound endpoint is used to distinguish.
func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
inbound = strings.TrimSpace(inbound)
switch platform {
case service.PlatformOpenAI:
// OpenAI forwards everything to the Responses API.
// Preserve subresource suffix (e.g. /v1/responses/compact).
if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" {
return EndpointResponses + suffix
}
return EndpointResponses
case service.PlatformAnthropic:
return EndpointMessages
case service.PlatformGemini:
return EndpointGeminiModels
case service.PlatformSora:
return EndpointChatCompletions
case service.PlatformAntigravity:
// Antigravity accounts serve both Claude and Gemini.
if inbound == EndpointGeminiModels {
return EndpointGeminiModels
}
return EndpointMessages
}
// Unknown platform — fall back to inbound.
return inbound
}
// responsesSubpathSuffix extracts the part after "/responses" in a raw
// request path, e.g. "/openai/v1/responses/compact" → "/compact".
// Returns "" when there is no meaningful suffix.
func responsesSubpathSuffix(rawPath string) string {
trimmed := strings.TrimRight(strings.TrimSpace(rawPath), "/")
idx := strings.LastIndex(trimmed, "/responses")
if idx < 0 {
return ""
}
suffix := trimmed[idx+len("/responses"):]
if suffix == "" || suffix == "/" {
return ""
}
if !strings.HasPrefix(suffix, "/") {
return ""
}
return suffix
}
// ──────────────────────────────────────────────────────────
// Middleware
// ──────────────────────────────────────────────────────────
// InboundEndpointMiddleware normalizes the request path and stores the
// canonical inbound endpoint in gin.Context so that every handler in
// the chain can read it via GetInboundEndpoint.
//
// Apply this middleware to all gateway route groups.
func InboundEndpointMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
path := c.FullPath()
if path == "" && c.Request != nil && c.Request.URL != nil {
path = c.Request.URL.Path
}
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(path))
c.Next()
}
}
// ──────────────────────────────────────────────────────────
// Context helpers — used by handlers before building
// RecordUsageInput / RecordUsageLongContextInput.
// ──────────────────────────────────────────────────────────
// GetInboundEndpoint returns the canonical inbound endpoint stored by
// InboundEndpointMiddleware. If the middleware did not run (e.g. in
// tests), it falls back to normalizing c.FullPath() on the fly.
func GetInboundEndpoint(c *gin.Context) string {
if v, ok := c.Get(ctxKeyInboundEndpoint); ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
// Fallback: normalize on the fly.
path := ""
if c != nil {
path = c.FullPath()
if path == "" && c.Request != nil && c.Request.URL != nil {
path = c.Request.URL.Path
}
}
return NormalizeInboundEndpoint(path)
}
// GetUpstreamEndpoint derives the upstream endpoint from the context
// and the account platform. Handlers call this after scheduling an
// account, passing account.Platform.
func GetUpstreamEndpoint(c *gin.Context, platform string) string {
inbound := GetInboundEndpoint(c)
rawPath := ""
if c != nil && c.Request != nil && c.Request.URL != nil {
rawPath = c.Request.URL.Path
}
return DeriveUpstreamEndpoint(inbound, rawPath, platform)
}

View File

@@ -0,0 +1,159 @@
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func init() { gin.SetMode(gin.TestMode) }
// ──────────────────────────────────────────────────────────
// NormalizeInboundEndpoint
// ──────────────────────────────────────────────────────────
func TestNormalizeInboundEndpoint(t *testing.T) {
tests := []struct {
path string
want string
}{
// Direct canonical paths.
{"/v1/messages", EndpointMessages},
{"/v1/chat/completions", EndpointChatCompletions},
{"/v1/responses", EndpointResponses},
{"/v1beta/models", EndpointGeminiModels},
// Prefixed paths (antigravity, openai, sora).
{"/antigravity/v1/messages", EndpointMessages},
{"/openai/v1/responses", EndpointResponses},
{"/openai/v1/responses/compact", EndpointResponses},
{"/sora/v1/chat/completions", EndpointChatCompletions},
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
// Gin route patterns with wildcards.
{"/v1beta/models/*modelAction", EndpointGeminiModels},
{"/v1/responses/*subpath", EndpointResponses},
// Unknown path is returned as-is.
{"/v1/embeddings", "/v1/embeddings"},
{"", ""},
{" /v1/messages ", EndpointMessages},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
require.Equal(t, tt.want, NormalizeInboundEndpoint(tt.path))
})
}
}
// ──────────────────────────────────────────────────────────
// DeriveUpstreamEndpoint
// ──────────────────────────────────────────────────────────
func TestDeriveUpstreamEndpoint(t *testing.T) {
tests := []struct {
name string
inbound string
rawPath string
platform string
want string
}{
// Anthropic.
{"anthropic messages", EndpointMessages, "/v1/messages", service.PlatformAnthropic, EndpointMessages},
// Gemini.
{"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels},
// Sora.
{"sora completions", EndpointChatCompletions, "/sora/v1/chat/completions", service.PlatformSora, EndpointChatCompletions},
// OpenAI — always /v1/responses.
{"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses},
{"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"},
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
// Antigravity — uses inbound to pick Claude vs Gemini upstream.
{"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages},
{"antigravity gemini", EndpointGeminiModels, "/antigravity/v1beta/models", service.PlatformAntigravity, EndpointGeminiModels},
// Unknown platform — passthrough.
{"unknown platform", "/v1/embeddings", "/v1/embeddings", "unknown", "/v1/embeddings"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, DeriveUpstreamEndpoint(tt.inbound, tt.rawPath, tt.platform))
})
}
}
// ──────────────────────────────────────────────────────────
// responsesSubpathSuffix
// ──────────────────────────────────────────────────────────
func TestResponsesSubpathSuffix(t *testing.T) {
tests := []struct {
raw string
want string
}{
{"/v1/responses", ""},
{"/v1/responses/", ""},
{"/v1/responses/compact", "/compact"},
{"/openai/v1/responses/compact/detail", "/compact/detail"},
{"/v1/messages", ""},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.raw, func(t *testing.T) {
require.Equal(t, tt.want, responsesSubpathSuffix(tt.raw))
})
}
}
// ──────────────────────────────────────────────────────────
// InboundEndpointMiddleware + context helpers
// ──────────────────────────────────────────────────────────
func TestInboundEndpointMiddleware(t *testing.T) {
router := gin.New()
router.Use(InboundEndpointMiddleware())
var captured string
router.POST("/v1/messages", func(c *gin.Context) {
captured = GetInboundEndpoint(c)
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, EndpointMessages, captured)
}
func TestGetInboundEndpoint_FallbackWithoutMiddleware(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/antigravity/v1/messages", nil)
// Middleware did not run — fallback to normalizing c.Request.URL.Path.
got := GetInboundEndpoint(c)
require.Equal(t, EndpointMessages, got)
}
func TestGetUpstreamEndpoint_FullFlow(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses/compact", nil)
// Simulate middleware.
c.Set(ctxKeyInboundEndpoint, NormalizeInboundEndpoint(c.Request.URL.Path))
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
require.Equal(t, "/v1/responses/compact", got)
}

View File

@@ -391,6 +391,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if fs.SwitchCount > 0 {
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
}
// 记录 Forward 前已写入字节数Forward 后若增加则说明 SSE 内容已发,禁止 failover
writerSizeBeforeForward := c.Writer.Size()
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
} else {
@@ -402,6 +404,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
if c.Writer.Size() != writerSizeBeforeForward {
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
@@ -435,6 +442,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
if result.ReasoningEffort == nil {
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
@@ -444,6 +457,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
@@ -637,6 +652,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if fs.SwitchCount > 0 {
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
}
// 记录 Forward 前已写入字节数Forward 后若增加则说明 SSE 内容已发,禁止 failover
writerSizeBeforeForward := c.Writer.Size()
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
} else {
@@ -706,6 +723,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
if c.Writer.Size() != writerSizeBeforeForward {
h.handleFailoverExhausted(c, failoverErr, account.Platform, true)
return
}
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
switch action {
case FailoverContinue:
@@ -739,6 +761,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
if result.ReasoningEffort == nil {
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
@@ -748,6 +776,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
User: currentAPIKey.User,
Account: account,
Subscription: currentSubscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
@@ -913,7 +943,7 @@ func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Ti
}
if s := c.Query("end_date"); s != "" {
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
endTime = t.Add(24*time.Hour - time.Second) // end of day
endTime = t.AddDate(0, 0, 1) // half-open range upper bound
}
}
return startTime, endTime

View File

@@ -0,0 +1,122 @@
package handler
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// partialMessageStartSSE 模拟 handleStreamingResponse 已写入的首批 SSE 事件。
const partialMessageStartSSE = "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_01\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-5\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":1}}}\n\n" +
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"
// TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten 验证:
// 当 Forward 在返回 UpstreamFailoverError 前已向客户端写入 SSE 内容时,
// 故障转移保护逻辑必须终止循环并发送 SSE 错误事件,而不是进行下一次 Forward。
// 具体验证:
// 1. c.Writer.Size() 检测条件正确触发(字节数已增加)
// 2. handleFailoverExhausted 以 streamStarted=true 调用后,响应体以 SSE 错误事件结尾
// 3. 响应体中只出现一个 message_start不存在第二个防止流拼接腐化
func TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// 步骤 1记录 Forward 前的 writer size模拟 writerSizeBeforeForward := c.Writer.Size()
sizeBeforeForward := c.Writer.Size()
require.Equal(t, -1, sizeBeforeForward, "gin writer 初始 Size 应为 -1未写入任何字节")
// 步骤 2模拟 Forward 已向客户端写入部分 SSE 内容message_start + content_block_start
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
require.NoError(t, err)
// 步骤 3验证守卫条件成立c.Writer.Size() != sizeBeforeForward
require.NotEqual(t, sizeBeforeForward, c.Writer.Size(),
"写入 SSE 内容后 writer size 必须增加,守卫条件应为 true")
// 步骤 4模拟 UpstreamFailoverError上游在流中途返回 403
failoverErr := &service.UpstreamFailoverError{
StatusCode: http.StatusForbidden,
ResponseBody: []byte(`{"error":{"type":"permission_error","message":"forbidden"}}`),
}
// 步骤 5守卫触发 → 调用 handleFailoverExhaustedstreamStarted=true
h := &GatewayHandler{}
h.handleFailoverExhausted(c, failoverErr, service.PlatformAnthropic, true)
body := w.Body.String()
// 断言 A响应体中包含最初写入的 message_start SSE 事件行
require.Contains(t, body, "event: message_start", "响应体应包含已写入的 message_start SSE 事件")
// 断言 B响应体以 SSE 错误事件结尾data: {"type":"error",...}\n\n
require.True(t, strings.HasSuffix(strings.TrimRight(body, "\n"), "}"),
"响应体应以 JSON 对象结尾SSE error event 的 data 字段)")
require.Contains(t, body, `"type":"error"`, "响应体末尾必须包含 SSE 错误事件")
// 断言 CSSE event 行 "event: message_start" 只出现一次(防止双 message_start 拼接腐化)
firstIdx := strings.Index(body, "event: message_start")
lastIdx := strings.LastIndex(body, "event: message_start")
assert.Equal(t, firstIdx, lastIdx,
"响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次")
}
// TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten 与上述测试相同,
// 验证 Gemini 路径使用 service.PlatformGemini而非 account.Platform时行为一致。
func TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.0-flash:streamGenerateContent", nil)
sizeBeforeForward := c.Writer.Size()
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
require.NoError(t, err)
require.NotEqual(t, sizeBeforeForward, c.Writer.Size())
failoverErr := &service.UpstreamFailoverError{
StatusCode: http.StatusForbidden,
}
h := &GatewayHandler{}
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
body := w.Body.String()
require.Contains(t, body, "event: message_start")
require.Contains(t, body, `"type":"error"`)
firstIdx := strings.Index(body, "event: message_start")
lastIdx := strings.LastIndex(body, "event: message_start")
assert.Equal(t, firstIdx, lastIdx, "Gemini 路径不得出现双 message_start")
}
// TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered 验证反向场景:
// 当 Forward 返回 UpstreamFailoverError 时若未向客户端写入任何 SSE 内容,
// 守卫条件c.Writer.Size() != sizeBeforeForward为 false不应中止 failover。
func TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// 模拟 writerSizeBeforeForward初始为 -1
sizeBeforeForward := c.Writer.Size()
// Forward 未写入任何字节直接返回错误(例如 401 发生在连接建立前)
// c.Writer.Size() 仍为 -1
// 守卫条件sizeBeforeForward == c.Writer.Size() → 不触发
guardTriggered := c.Writer.Size() != sizeBeforeForward
require.False(t, guardTriggered,
"未写入任何字节时,守卫条件必须为 false应允许正常 failover 继续")
}

View File

@@ -504,6 +504,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
@@ -511,6 +513,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,

View File

@@ -12,6 +12,7 @@ type AdminHandlers struct {
Account *admin.AccountHandler
Announcement *admin.AnnouncementHandler
DataManagement *admin.DataManagementHandler
Backup *admin.BackupHandler
OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler

View File

@@ -256,14 +256,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.chat_completions"),

View File

@@ -0,0 +1,56 @@
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint verifies that the
// unified GetUpstreamEndpoint helper produces the same results as the
// former normalizedOpenAIUpstreamEndpoint for OpenAI platform requests.
func TestOpenAIUpstreamEndpoint_ViaGetUpstreamEndpoint(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
path string
want string
}{
{
name: "responses root maps to responses upstream",
path: "/v1/responses",
want: EndpointResponses,
},
{
name: "responses compact keeps compact suffix",
path: "/openai/v1/responses/compact",
want: "/v1/responses/compact",
},
{
name: "responses nested suffix preserved",
path: "/openai/v1/responses/compact/detail",
want: "/v1/responses/compact/detail",
},
{
name: "non responses path uses platform fallback",
path: "/v1/messages",
want: EndpointResponses,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
got := GetUpstreamEndpoint(c, service.PlatformOpenAI)
require.Equal(t, tt.want, got)
})
}
}

View File

@@ -362,6 +362,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
@@ -738,6 +740,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
@@ -1235,6 +1239,8 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: GetInboundEndpoint(c),
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),

View File

@@ -26,6 +26,22 @@ const (
opsStreamKey = "ops_stream"
opsRequestBodyKey = "ops_request_body"
opsAccountIDKey = "ops_account_id"
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
opsErrContextCanceled = "context canceled"
opsErrNoAvailableAccounts = "no available accounts"
opsErrInvalidAPIKey = "invalid_api_key"
opsErrAPIKeyRequired = "api_key_required"
opsErrInsufficientBalance = "insufficient balance"
opsErrInsufficientAccountBalance = "insufficient account balance"
opsErrInsufficientQuota = "insufficient_quota"
// 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited)
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
opsCodeUserInactive = "USER_INACTIVE"
)
const (
@@ -1024,9 +1040,9 @@ func normalizeOpsErrorType(errType string, code string) string {
return errType
}
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE":
case opsCodeInsufficientBalance:
return "billing_error"
case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
case opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
return "subscription_error"
default:
return "api_error"
@@ -1038,7 +1054,7 @@ func classifyOpsPhase(errType, message, code string) string {
// Standardized phases: request|auth|routing|upstream|network|internal
// Map billing/concurrency/response => request; scheduling => routing.
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
return "request"
}
@@ -1057,7 +1073,7 @@ func classifyOpsPhase(errType, message, code string) string {
case "upstream_error", "overloaded_error":
return "upstream"
case "api_error":
if strings.Contains(msg, "no available accounts") {
if strings.Contains(msg, opsErrNoAvailableAccounts) {
return "routing"
}
return "internal"
@@ -1103,7 +1119,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
return true
}
if phase == "billing" || phase == "concurrency" {
@@ -1197,21 +1213,30 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
// Check if context canceled errors should be ignored (client disconnects)
if settings.IgnoreContextCanceled {
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
if strings.Contains(msgLower, opsErrContextCanceled) || strings.Contains(bodyLower, opsErrContextCanceled) {
return true
}
}
// Check if "no available accounts" errors should be ignored
if settings.IgnoreNoAvailableAccounts {
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
if strings.Contains(msgLower, opsErrNoAvailableAccounts) || strings.Contains(bodyLower, opsErrNoAvailableAccounts) {
return true
}
}
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
if settings.IgnoreInvalidApiKeyErrors {
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
if strings.Contains(bodyLower, opsErrInvalidAPIKey) || strings.Contains(bodyLower, opsErrAPIKeyRequired) {
return true
}
}
// Check if insufficient balance errors should be ignored
if settings.IgnoreInsufficientBalanceErrors {
if strings.Contains(bodyLower, opsErrInsufficientBalance) || strings.Contains(bodyLower, opsErrInsufficientAccountBalance) ||
strings.Contains(bodyLower, opsErrInsufficientQuota) ||
strings.Contains(msgLower, opsErrInsufficientBalance) || strings.Contains(msgLower, opsErrInsufficientAccountBalance) {
return true
}
}

View File

@@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
SoraClientEnabled: settings.SoraClientEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
Version: h.version,
})
}

View File

@@ -400,6 +400,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
@@ -409,6 +411,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
User: apiKey.User,
Account: account,
Subscription: subscription,
InboundEndpoint: inboundEndpoint,
UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,

View File

@@ -334,6 +334,14 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return []usagestats.EndpointStat{}, nil
}
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return []usagestats.EndpointStat{}, nil
}
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, nil
}

View File

@@ -114,8 +114,8 @@ func (h *UsageHandler) List(c *gin.Context) {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
// Set end time to end of day
t = t.Add(24*time.Hour - time.Nanosecond)
// Use half-open range [start, end), move to next calendar day start (DST-safe).
t = t.AddDate(0, 0, 1)
endTime = &t
}
@@ -227,8 +227,8 @@ func (h *UsageHandler) Stats(c *gin.Context) {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
// 设置结束时间为当天结束
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界DST-safe
endTime = endTime.AddDate(0, 0, 1)
} else {
// 使用 period 参数
period := c.DefaultQuery("period", "today")

View File

@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
accountHandler *admin.AccountHandler,
announcementHandler *admin.AnnouncementHandler,
dataManagementHandler *admin.DataManagementHandler,
backupHandler *admin.BackupHandler,
oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler,
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
Account: accountHandler,
Announcement: announcementHandler,
DataManagement: dataManagementHandler,
Backup: backupHandler,
OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler,
@@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet(
admin.NewAccountHandler,
admin.NewAnnouncementHandler,
admin.NewDataManagementHandler,
admin.NewBackupHandler,
admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler,

View File

@@ -19,6 +19,16 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
)
// ForbiddenError 表示上游返回 403 Forbidden
type ForbiddenError struct {
StatusCode int
Body string
}
func (e *ForbiddenError) Error() string {
return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body)
}
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求v1internal 端点)
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
// 构建 URL流式请求添加 ?alt=sse 参数
@@ -114,10 +124,68 @@ type IneligibleTier struct {
type LoadCodeAssistResponse struct {
CloudAICompanionProject string `json:"cloudaicompanionProject"`
CurrentTier *TierInfo `json:"currentTier,omitempty"`
PaidTier *TierInfo `json:"paidTier,omitempty"`
PaidTier *PaidTierInfo `json:"paidTier,omitempty"`
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
}
// PaidTierInfo 付费等级信息,包含 AI Credits 余额。
type PaidTierInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
AvailableCredits []AvailableCredit `json:"availableCredits,omitempty"`
}
// UnmarshalJSON 兼容 paidTier 既可能是字符串也可能是对象的情况。
func (p *PaidTierInfo) UnmarshalJSON(data []byte) error {
data = bytes.TrimSpace(data)
if len(data) == 0 || string(data) == "null" {
return nil
}
if data[0] == '"' {
var id string
if err := json.Unmarshal(data, &id); err != nil {
return err
}
p.ID = id
return nil
}
type alias PaidTierInfo
var raw alias
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
*p = PaidTierInfo(raw)
return nil
}
// AvailableCredit 表示一条 AI Credits 余额记录。
type AvailableCredit struct {
CreditType string `json:"creditType,omitempty"`
CreditAmount string `json:"creditAmount,omitempty"`
MinimumCreditAmountForUsage string `json:"minimumCreditAmountForUsage,omitempty"`
}
// GetAmount 将 creditAmount 解析为浮点数。
func (c *AvailableCredit) GetAmount() float64 {
if c.CreditAmount == "" {
return 0
}
var value float64
_, _ = fmt.Sscanf(c.CreditAmount, "%f", &value)
return value
}
// GetMinimumAmount 将 minimumCreditAmountForUsage 解析为浮点数。
func (c *AvailableCredit) GetMinimumAmount() float64 {
if c.MinimumCreditAmountForUsage == "" {
return 0
}
var value float64
_, _ = fmt.Sscanf(c.MinimumCreditAmountForUsage, "%f", &value)
return value
}
// OnboardUserRequest onboardUser 请求
type OnboardUserRequest struct {
TierID string `json:"tierId"`
@@ -147,6 +215,14 @@ func (r *LoadCodeAssistResponse) GetTier() string {
return ""
}
// GetAvailableCredits 返回 paid tier 中的 AI Credits 余额列表。
func (r *LoadCodeAssistResponse) GetAvailableCredits() []AvailableCredit {
if r.PaidTier == nil {
return nil
}
return r.PaidTier.AvailableCredits
}
// Client Antigravity API 客户端
type Client struct {
httpClient *http.Client
@@ -514,7 +590,20 @@ type ModelQuotaInfo struct {
// ModelInfo 模型信息
type ModelInfo struct {
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
DisplayName string `json:"displayName,omitempty"`
SupportsImages *bool `json:"supportsImages,omitempty"`
SupportsThinking *bool `json:"supportsThinking,omitempty"`
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
Recommended *bool `json:"recommended,omitempty"`
MaxTokens *int `json:"maxTokens,omitempty"`
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"`
}
// DeprecatedModelInfo 废弃模型转发信息
type DeprecatedModelInfo struct {
NewModelID string `json:"newModelId"`
}
// FetchAvailableModelsRequest fetchAvailableModels 请求
@@ -524,7 +613,8 @@ type FetchAvailableModelsRequest struct {
// FetchAvailableModelsResponse fetchAvailableModels 响应
type FetchAvailableModelsResponse struct {
Models map[string]ModelInfo `json:"models"`
Models map[string]ModelInfo `json:"models"`
DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"`
}
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
@@ -573,6 +663,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
continue
}
if resp.StatusCode == http.StatusForbidden {
return nil, nil, &ForbiddenError{
StatusCode: resp.StatusCode,
Body: string(respBodyBytes),
}
}
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
}

View File

@@ -190,7 +190,7 @@ func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) {
func TestGetTier_PaidTier优先(t *testing.T) {
resp := &LoadCodeAssistResponse{
CurrentTier: &TierInfo{ID: "free-tier"},
PaidTier: &TierInfo{ID: "g1-pro-tier"},
PaidTier: &PaidTierInfo{ID: "g1-pro-tier"},
}
if got := resp.GetTier(); got != "g1-pro-tier" {
t.Errorf("应返回 paidTier: got %s", got)
@@ -209,7 +209,7 @@ func TestGetTier_回退到CurrentTier(t *testing.T) {
func TestGetTier_PaidTier为空ID(t *testing.T) {
resp := &LoadCodeAssistResponse{
CurrentTier: &TierInfo{ID: "free-tier"},
PaidTier: &TierInfo{ID: ""},
PaidTier: &PaidTierInfo{ID: ""},
}
// paidTier.ID 为空时应回退到 currentTier
if got := resp.GetTier(); got != "free-tier" {
@@ -217,6 +217,32 @@ func TestGetTier_PaidTier为空ID(t *testing.T) {
}
}
func TestGetAvailableCredits(t *testing.T) {
resp := &LoadCodeAssistResponse{
PaidTier: &PaidTierInfo{
ID: "g1-pro-tier",
AvailableCredits: []AvailableCredit{
{
CreditType: "GOOGLE_ONE_AI",
CreditAmount: "25",
MinimumCreditAmountForUsage: "5",
},
},
},
}
credits := resp.GetAvailableCredits()
if len(credits) != 1 {
t.Fatalf("AI Credits 数量不匹配: got %d", len(credits))
}
if credits[0].GetAmount() != 25 {
t.Errorf("CreditAmount 解析不正确: got %v", credits[0].GetAmount())
}
if credits[0].GetMinimumAmount() != 5 {
t.Errorf("MinimumCreditAmountForUsage 解析不正确: got %v", credits[0].GetMinimumAmount())
}
}
func TestGetTier_两者都为nil(t *testing.T) {
resp := &LoadCodeAssistResponse{}
if got := resp.GetTier(); got != "" {

View File

@@ -81,6 +81,15 @@ type ModelStat struct {
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// EndpointStat represents usage statistics for a single request endpoint.
type EndpointStat struct {
Endpoint string `json:"endpoint"`
Requests int64 `json:"requests"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// GroupStat represents usage statistics for a single group
type GroupStat struct {
GroupID int64 `json:"group_id"`
@@ -116,6 +125,8 @@ type UserSpendingRankingItem struct {
type UserSpendingRankingResponse struct {
Ranking []UserSpendingRankingItem `json:"ranking"`
TotalActualCost float64 `json:"total_actual_cost"`
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
}
// APIKeyUsageTrendPoint represents API key usage trend data point
@@ -179,15 +190,18 @@ type UsageLogFilters struct {
// UsageStats represents usage statistics
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
AverageDurationMs float64 `json:"average_duration_ms"`
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
AverageDurationMs float64 `json:"average_duration_ms"`
Endpoints []EndpointStat `json:"endpoints,omitempty"`
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints,omitempty"`
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
}
// BatchUserUsageStats represents usage stats for a single user
@@ -254,7 +268,9 @@ type AccountUsageSummary struct {
// AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse struct {
History []AccountUsageHistory `json:"history"`
Summary AccountUsageSummary `json:"summary"`
Models []ModelStat `json:"models"`
History []AccountUsageHistory `json:"history"`
Summary AccountUsageSummary `json:"summary"`
Models []ModelStat `json:"models"`
Endpoints []EndpointStat `json:"endpoints"`
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints"`
}

View File

@@ -1727,8 +1727,96 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired.
// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes.
const dailyExpiredExpr = `(
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz)
ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
END
)`
// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired.
const weeklyExpiredExpr = `(
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz)
ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
END
)`
// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs.
// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour.
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
const nextDailyResetAtExpr = `(
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
THEN to_char((
-- Compute today's reset point in the configured timezone, then pick next future one
CASE WHEN NOW() >= (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
-- NOW() is at or past today's reset point → next reset is tomorrow
THEN (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
+ '1 day'::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
-- NOW() is before today's reset point → next reset is today
ELSE (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
END
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
ELSE NULL END
)`
// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs.
// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour.
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
const nextWeeklyResetAtExpr = `(
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
THEN to_char((
-- Compute this week's reset point in the configured timezone
-- Step 1: get today's date at reset hour in configured tz
-- Step 2: compute days forward to target weekday
-- Step 3: if same day but past reset hour, advance 7 days
CASE
WHEN (
-- days_forward = (target_day - current_day + 7) % 7
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
+ 7) % 7
) = 0 AND NOW() >= (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
-- Same weekday and past reset hour → next week
THEN (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
+ '7 days'::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
ELSE (
-- Advance to target weekday this week (or next if days_forward > 0)
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
+ ((
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
+ 7) % 7
) || ' days')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
END
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
ELSE NULL END
)`
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
// 日/周额度在周期过期时自动重置为 0 再递增。
// 支持滚动窗口rolling和固定时间fixed两种重置模式。
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
rows, err := r.sql.QueryContext(ctx,
`UPDATE accounts SET extra = (
@@ -1739,31 +1827,35 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_daily_used',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
CASE WHEN `+dailyExpiredExpr+`
THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
CASE WHEN `+dailyExpiredExpr+`
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
)
-- 固定模式重置时更新下次重置时间
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
ELSE '{}'::jsonb END
ELSE '{}'::jsonb END
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_weekly_used',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
CASE WHEN `+weeklyExpiredExpr+`
THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
CASE WHEN `+weeklyExpiredExpr+`
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
)
-- 固定模式重置时更新下次重置时间
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
ELSE '{}'::jsonb END
ELSE '{}'::jsonb END
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
@@ -1796,12 +1888,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
}
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
// 保留固定重置模式的配置字段quota_daily_reset_mode 等),仅清零用量和窗口起始时间
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx,
`UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb)
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW()
WHERE id = $1 AND deleted_at IS NULL`,
id)
if err != nil {

View File

@@ -0,0 +1,98 @@
package repository
import (
"context"
"fmt"
"io"
"os/exec"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// PgDumper implements service.DBDumper using pg_dump/psql
type PgDumper struct {
cfg *config.DatabaseConfig
}
// NewPgDumper creates a new PgDumper
func NewPgDumper(cfg *config.Config) service.DBDumper {
return &PgDumper{cfg: &cfg.Database}
}
// Dump executes pg_dump and returns a streaming reader of the output
func (d *PgDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
args := []string{
"-h", d.cfg.Host,
"-p", fmt.Sprintf("%d", d.cfg.Port),
"-U", d.cfg.User,
"-d", d.cfg.DBName,
"--no-owner",
"--no-acl",
"--clean",
"--if-exists",
}
cmd := exec.CommandContext(ctx, "pg_dump", args...)
if d.cfg.Password != "" {
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
}
if d.cfg.SSLMode != "" {
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("create stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("start pg_dump: %w", err)
}
// 返回一个 ReadCloser读 stdout关闭时等待进程退出
return &cmdReadCloser{ReadCloser: stdout, cmd: cmd}, nil
}
// Restore executes psql to restore from a streaming reader
func (d *PgDumper) Restore(ctx context.Context, data io.Reader) error {
args := []string{
"-h", d.cfg.Host,
"-p", fmt.Sprintf("%d", d.cfg.Port),
"-U", d.cfg.User,
"-d", d.cfg.DBName,
"--single-transaction",
}
cmd := exec.CommandContext(ctx, "psql", args...)
if d.cfg.Password != "" {
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
}
if d.cfg.SSLMode != "" {
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
}
cmd.Stdin = data
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("%v: %s", err, string(output))
}
return nil
}
// cmdReadCloser wraps a command stdout pipe and waits for the process on Close
type cmdReadCloser struct {
io.ReadCloser
cmd *exec.Cmd
}
func (c *cmdReadCloser) Close() error {
// Close the pipe first
_ = c.ReadCloser.Close()
// Wait for the process to exit
if err := c.cmd.Wait(); err != nil {
return fmt.Errorf("pg_dump exited with error: %w", err)
}
return nil
}

View File

@@ -0,0 +1,116 @@
package repository
import (
"bytes"
"context"
"fmt"
"io"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage
type S3BackupStore struct {
client *s3.Client
bucket string
}
// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores
func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) {
region := cfg.Region
if region == "" {
region = "auto" // Cloudflare R2 默认 region
}
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
awsconfig.WithRegion(region),
awsconfig.WithCredentialsProvider(
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
),
)
if err != nil {
return nil, fmt.Errorf("load aws config: %w", err)
}
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
if cfg.Endpoint != "" {
o.BaseEndpoint = &cfg.Endpoint
}
if cfg.ForcePathStyle {
o.UsePathStyle = true
}
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
})
return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil
}
}
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
// 读取全部内容以获取大小S3 PutObject 需要知道内容长度)
data, err := io.ReadAll(body)
if err != nil {
return 0, fmt.Errorf("read body: %w", err)
}
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
Bucket: &s.bucket,
Key: &key,
Body: bytes.NewReader(data),
ContentType: &contentType,
})
if err != nil {
return 0, fmt.Errorf("S3 PutObject: %w", err)
}
return int64(len(data)), nil
}
func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) {
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
Bucket: &s.bucket,
Key: &key,
})
if err != nil {
return nil, fmt.Errorf("S3 GetObject: %w", err)
}
return result.Body, nil
}
func (s *S3BackupStore) Delete(ctx context.Context, key string) error {
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: &s.bucket,
Key: &key,
})
return err
}
func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) {
presignClient := s3.NewPresignClient(s.client)
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
Bucket: &s.bucket,
Key: &key,
}, s3.WithPresignExpires(expiry))
if err != nil {
return "", fmt.Errorf("presign url: %w", err)
}
return result.URL, nil
}
func (s *S3BackupStore) HeadBucket(ctx context.Context) error {
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
Bucket: &s.bucket,
})
if err != nil {
return fmt.Errorf("S3 HeadBucket failed: %w", err)
}
return nil
}

View File

@@ -132,7 +132,7 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
}
}
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) {
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
return err
}

View File

@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
var usageLogInsertArgTypes = [...]string{
"bigint",
@@ -65,6 +65,8 @@ var usageLogInsertArgTypes = [...]string{
"text",
"text",
"text",
"text",
"text",
"boolean",
"timestamptz",
}
@@ -304,6 +306,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
) VALUES (
@@ -312,7 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -732,11 +736,13 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
) AS (VALUES `)
args := make([]any, 0, len(keys)*37)
args := make([]any, 0, len(keys)*38)
argPos := 1
for idx, key := range keys {
if idx > 0 {
@@ -799,6 +805,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
)
@@ -837,6 +845,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
FROM input
@@ -915,11 +925,13 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
) AS (VALUES `)
args := make([]any, 0, len(preparedList)*36)
args := make([]any, 0, len(preparedList)*38)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
@@ -979,6 +991,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
)
@@ -1017,6 +1031,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
FROM input
@@ -1063,6 +1079,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
created_at
) VALUES (
@@ -1071,7 +1089,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
@@ -1101,6 +1119,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
mediaType := nullString(log.MediaType)
serviceTier := nullString(log.ServiceTier)
reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint)
var requestIDArg any
if requestID != "" {
@@ -1147,6 +1167,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
mediaType,
serviceTier,
reasoningEffort,
inboundEndpoint,
upstreamEndpoint,
log.CacheTTLOverridden,
createdAt,
},
@@ -2139,7 +2161,9 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
actual_cost,
requests,
tokens,
COALESCE(SUM(actual_cost) OVER (), 0) as total_actual_cost
COALESCE(SUM(actual_cost) OVER (), 0) as total_actual_cost,
COALESCE(SUM(requests) OVER (), 0) as total_requests,
COALESCE(SUM(tokens) OVER (), 0) as total_tokens
FROM user_spend
ORDER BY actual_cost DESC, tokens DESC, user_id ASC
LIMIT $3
@@ -2150,7 +2174,9 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
actual_cost,
requests,
tokens,
total_actual_cost
total_actual_cost,
total_requests,
total_tokens
FROM ranked
ORDER BY actual_cost DESC, tokens DESC, user_id ASC
`
@@ -2168,9 +2194,11 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
ranking := make([]UserSpendingRankingItem, 0)
totalActualCost := 0.0
totalRequests := int64(0)
totalTokens := int64(0)
for rows.Next() {
var row UserSpendingRankingItem
if err = rows.Scan(&row.UserID, &row.Email, &row.ActualCost, &row.Requests, &row.Tokens, &totalActualCost); err != nil {
if err = rows.Scan(&row.UserID, &row.Email, &row.ActualCost, &row.Requests, &row.Tokens, &totalActualCost, &totalRequests, &totalTokens); err != nil {
return nil, err
}
ranking = append(ranking, row)
@@ -2182,6 +2210,8 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
return &UserSpendingRankingResponse{
Ranking: ranking,
TotalActualCost: totalActualCost,
TotalRequests: totalRequests,
TotalTokens: totalTokens,
}, nil
}
@@ -2505,7 +2535,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
args = append(args, *filters.StartTime)
}
if filters.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1))
args = append(args, *filters.EndTime)
}
@@ -2982,7 +3012,7 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
WHERE created_at >= $1 AND created_at <= $2
WHERE created_at >= $1 AND created_at < $2
`
stats := &UsageStats{}
@@ -3040,7 +3070,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
args = append(args, *filters.StartTime)
}
if filters.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1))
args = append(args, *filters.EndTime)
}
@@ -3080,6 +3110,35 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
stats.TotalAccountCost = &totalAccountCost
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
start := time.Unix(0, 0).UTC()
if filters.StartTime != nil {
start = *filters.StartTime
}
end := time.Now().UTC()
if filters.EndTime != nil {
end = *filters.EndTime
}
endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
if endpointErr != nil {
logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetStatsWithFilters: %v", endpointErr)
endpoints = []EndpointStat{}
}
upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
if upstreamEndpointErr != nil {
logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetStatsWithFilters: %v", upstreamEndpointErr)
upstreamEndpoints = []EndpointStat{}
}
endpointPaths, endpointPathErr := r.getEndpointPathStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
if endpointPathErr != nil {
logger.LegacyPrintf("repository.usage_log", "getEndpointPathStatsWithFilters failed in GetStatsWithFilters: %v", endpointPathErr)
endpointPaths = []EndpointStat{}
}
stats.Endpoints = endpoints
stats.UpstreamEndpoints = upstreamEndpoints
stats.EndpointPaths = endpointPaths
return stats, nil
}
@@ -3092,6 +3151,163 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
// AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
// EndpointStat represents endpoint usage statistics row.
type EndpointStat = usagestats.EndpointStat
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
SELECT
COALESCE(NULLIF(TRIM(%s), ''), 'unknown') AS endpoint,
COUNT(*) AS requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`, endpointColumn, actualCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
args = append(args, userID)
}
if apiKeyID > 0 {
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
args = append(args, apiKeyID)
}
if accountID > 0 {
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID)
}
if groupID > 0 {
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if model != "" {
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY endpoint ORDER BY requests DESC"
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results = make([]EndpointStat, 0)
for rows.Next() {
var row EndpointStat
if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
SELECT
CONCAT(
COALESCE(NULLIF(TRIM(inbound_endpoint), ''), 'unknown'),
' -> ',
COALESCE(NULLIF(TRIM(upstream_endpoint), ''), 'unknown')
) AS endpoint,
COUNT(*) AS requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`, actualCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
args = append(args, userID)
}
if apiKeyID > 0 {
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
args = append(args, apiKeyID)
}
if accountID > 0 {
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID)
}
if groupID > 0 {
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if model != "" {
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY endpoint ORDER BY requests DESC"
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results = make([]EndpointStat, 0)
for rows.Next() {
var row EndpointStat
if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
// GetEndpointStatsWithFilters returns inbound endpoint statistics with optional filters.
func (r *usageLogRepository) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) {
return r.getEndpointStatsByColumnWithFilters(ctx, "inbound_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
}
// GetUpstreamEndpointStatsWithFilters returns upstream endpoint statistics with optional filters.
func (r *usageLogRepository) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) {
return r.getEndpointStatsByColumnWithFilters(ctx, "upstream_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
}
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) {
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
@@ -3254,11 +3470,23 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
if err != nil {
models = []ModelStat{}
}
endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil)
if endpointErr != nil {
logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetAccountUsageStats: %v", endpointErr)
endpoints = []EndpointStat{}
}
upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil)
if upstreamEndpointErr != nil {
logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetAccountUsageStats: %v", upstreamEndpointErr)
upstreamEndpoints = []EndpointStat{}
}
resp = &AccountUsageStatsResponse{
History: history,
Summary: summary,
Models: models,
History: history,
Summary: summary,
Models: models,
Endpoints: endpoints,
UpstreamEndpoints: upstreamEndpoints,
}
return resp, nil
}
@@ -3541,6 +3769,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
mediaType sql.NullString
serviceTier sql.NullString
reasoningEffort sql.NullString
inboundEndpoint sql.NullString
upstreamEndpoint sql.NullString
cacheTTLOverridden bool
createdAt time.Time
)
@@ -3581,6 +3811,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&mediaType,
&serviceTier,
&reasoningEffort,
&inboundEndpoint,
&upstreamEndpoint,
&cacheTTLOverridden,
&createdAt,
); err != nil {
@@ -3656,6 +3888,12 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if reasoningEffort.Valid {
log.ReasoningEffort = &reasoningEffort.String
}
if inboundEndpoint.Valid {
log.InboundEndpoint = &inboundEndpoint.String
}
if upstreamEndpoint.Valid {
log.UpstreamEndpoint = &upstreamEndpoint.String
}
return log, nil
}

View File

@@ -73,6 +73,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // media_type
sqlmock.AnyArg(), // service_tier
sqlmock.AnyArg(), // reasoning_effort
sqlmock.AnyArg(), // inbound_endpoint
sqlmock.AnyArg(), // upstream_endpoint
log.CacheTTLOverridden,
createdAt,
).
@@ -141,6 +143,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(),
serviceTier,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.CacheTTLOverridden,
createdAt,
).
@@ -255,10 +259,10 @@ func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}).
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0).
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0).
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0)
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost", "total_requests", "total_tokens"}).
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0, int64(30), int64(2600)).
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0, int64(30), int64(2600)).
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0, int64(30), int64(2600))
mock.ExpectQuery("WITH user_spend AS \\(").
WithArgs(start, end, 12).
@@ -273,6 +277,8 @@ func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
{UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300},
},
TotalActualCost: 40.0,
TotalRequests: 30,
TotalTokens: 2600,
}, got)
require.NoError(t, mock.ExpectationsWereMet())
}
@@ -376,6 +382,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
now,
}})
@@ -415,6 +423,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{Valid: true, String: "flex"},
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
now,
}})
@@ -454,6 +464,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
now,
}})

View File

@@ -100,6 +100,10 @@ var ProviderSet = wire.NewSet(
// Encryptors
NewAESEncryptor,
// Backup infrastructure
NewPgDumper,
NewS3BackupStoreFactory,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,
ProvidePricingRemoteClient,

View File

@@ -493,6 +493,7 @@ func TestAPIContracts(t *testing.T) {
"registration_email_suffix_whitelist": [],
"promo_code_enabled": true,
"password_reset_enabled": false,
"frontend_url": "",
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com",
@@ -537,6 +538,7 @@ func TestAPIContracts(t *testing.T) {
"purchase_subscription_url": "",
"min_claude_code_version": "",
"allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"custom_menu_items": []
}
}`,
@@ -1623,6 +1625,14 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, errors.New("not implemented")
}

View File

@@ -0,0 +1,51 @@
package middleware
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// BackendModeUserGuard blocks non-admin users from accessing user routes when backend mode is enabled.
// Must be placed AFTER JWT auth middleware so that the user role is available in context.
func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFunc {
return func(c *gin.Context) {
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
c.Next()
return
}
role, _ := GetUserRoleFromContext(c)
if role == "admin" {
c.Next()
return
}
response.Forbidden(c, "Backend mode is active. User self-service is disabled.")
c.Abort()
}
}
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
// Allows: login, login/2fa, logout, refresh (admin needs these).
// Blocks: register, forgot-password, reset-password, OAuth, etc.
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
return func(c *gin.Context) {
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
c.Next()
return
}
path := c.Request.URL.Path
// Allow login, 2FA, logout, refresh, public settings
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
for _, suffix := range allowedSuffixes {
if strings.HasSuffix(path, suffix) {
c.Next()
return
}
}
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
c.Abort()
}
}

View File

@@ -0,0 +1,239 @@
//go:build unit
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type bmSettingRepo struct {
values map[string]string
}
func (r *bmSettingRepo) Get(_ context.Context, _ string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (r *bmSettingRepo) GetValue(_ context.Context, key string) (string, error) {
v, ok := r.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return v, nil
}
func (r *bmSettingRepo) Set(_ context.Context, _, _ string) error {
panic("unexpected Set call")
}
func (r *bmSettingRepo) GetMultiple(_ context.Context, _ []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (r *bmSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
if r.values == nil {
r.values = make(map[string]string, len(settings))
}
for key, value := range settings {
r.values[key] = value
}
return nil
}
func (r *bmSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (r *bmSettingRepo) Delete(_ context.Context, _ string) error {
panic("unexpected Delete call")
}
func newBackendModeSettingService(t *testing.T, enabled string) *service.SettingService {
t.Helper()
repo := &bmSettingRepo{
values: map[string]string{
service.SettingKeyBackendModeEnabled: enabled,
},
}
svc := service.NewSettingService(repo, &config.Config{})
require.NoError(t, svc.UpdateSettings(context.Background(), &service.SystemSettings{
BackendModeEnabled: enabled == "true",
}))
return svc
}
func stringPtr(v string) *string {
return &v
}
func TestBackendModeUserGuard(t *testing.T) {
tests := []struct {
name string
nilService bool
enabled string
role *string
wantStatus int
}{
{
name: "disabled_allows_all",
enabled: "false",
role: stringPtr("user"),
wantStatus: http.StatusOK,
},
{
name: "nil_service_allows_all",
nilService: true,
role: stringPtr("user"),
wantStatus: http.StatusOK,
},
{
name: "enabled_admin_allowed",
enabled: "true",
role: stringPtr("admin"),
wantStatus: http.StatusOK,
},
{
name: "enabled_user_blocked",
enabled: "true",
role: stringPtr("user"),
wantStatus: http.StatusForbidden,
},
{
name: "enabled_no_role_blocked",
enabled: "true",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_empty_role_blocked",
enabled: "true",
role: stringPtr(""),
wantStatus: http.StatusForbidden,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
if tc.role != nil {
role := *tc.role
r.Use(func(c *gin.Context) {
c.Set(string(ContextKeyUserRole), role)
c.Next()
})
}
var svc *service.SettingService
if !tc.nilService {
svc = newBackendModeSettingService(t, tc.enabled)
}
r.Use(BackendModeUserGuard(svc))
r.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
r.ServeHTTP(w, req)
require.Equal(t, tc.wantStatus, w.Code)
})
}
}
func TestBackendModeAuthGuard(t *testing.T) {
tests := []struct {
name string
nilService bool
enabled string
path string
wantStatus int
}{
{
name: "disabled_allows_all",
enabled: "false",
path: "/api/v1/auth/register",
wantStatus: http.StatusOK,
},
{
name: "nil_service_allows_all",
nilService: true,
path: "/api/v1/auth/register",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_login",
enabled: "true",
path: "/api/v1/auth/login",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_login_2fa",
enabled: "true",
path: "/api/v1/auth/login/2fa",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_logout",
enabled: "true",
path: "/api/v1/auth/logout",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_refresh",
enabled: "true",
path: "/api/v1/auth/refresh",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_register",
enabled: "true",
path: "/api/v1/auth/register",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_blocks_forgot_password",
enabled: "true",
path: "/api/v1/auth/forgot-password",
wantStatus: http.StatusForbidden,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
var svc *service.SettingService
if !tc.nilService {
svc = newBackendModeSettingService(t, tc.enabled)
}
r.Use(BackendModeAuthGuard(svc))
r.Any("/*path", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
r.ServeHTTP(w, req)
require.Equal(t, tc.wantStatus, w.Code)
})
}
}

View File

@@ -107,9 +107,9 @@ func registerRoutes(
v1 := r.Group("/api/v1")
// 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
routes.RegisterUserRoutes(v1, h, jwtAuth)
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
}

View File

@@ -58,6 +58,9 @@ func RegisterAdminRoutes(
// 数据管理
registerDataManagementRoutes(admin, h)
// 数据库备份恢复
registerBackupRoutes(admin, h)
// 运维监控Ops
registerOpsRoutes(admin, h)
@@ -440,6 +443,30 @@ func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func registerBackupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
backup := admin.Group("/backups")
{
// S3 存储配置
backup.GET("/s3-config", h.Admin.Backup.GetS3Config)
backup.PUT("/s3-config", h.Admin.Backup.UpdateS3Config)
backup.POST("/s3-config/test", h.Admin.Backup.TestS3Connection)
// 定时备份配置
backup.GET("/schedule", h.Admin.Backup.GetSchedule)
backup.PUT("/schedule", h.Admin.Backup.UpdateSchedule)
// 备份操作
backup.POST("", h.Admin.Backup.CreateBackup)
backup.GET("", h.Admin.Backup.ListBackups)
backup.GET("/:id", h.Admin.Backup.GetBackup)
backup.DELETE("/:id", h.Admin.Backup.DeleteBackup)
backup.GET("/:id/download-url", h.Admin.Backup.GetDownloadURL)
// 恢复操作
backup.POST("/:id/restore", h.Admin.Backup.RestoreBackup)
}
}
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
system := admin.Group("/system")
{

View File

@@ -6,6 +6,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/middleware"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
@@ -17,12 +18,14 @@ func RegisterAuthRoutes(
h *handler.Handlers,
jwtAuth servermiddleware.JWTAuthMiddleware,
redisClient *redis.Client,
settingService *service.SettingService,
) {
// 创建速率限制器
rateLimiter := middleware.NewRateLimiter(redisClient)
// 公开接口
auth := v1.Group("/auth")
auth.Use(servermiddleware.BackendModeAuthGuard(settingService))
{
// 注册/登录/2FA/验证码发送均属于高风险入口增加服务端兜底限流Redis 故障时 fail-close
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
@@ -78,6 +81,7 @@ func RegisterAuthRoutes(
// 需要认证的当前用户信息
authenticated := v1.Group("")
authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(servermiddleware.BackendModeUserGuard(settingService))
{
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 撤销所有会话(需要认证)

View File

@@ -29,6 +29,7 @@ func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
c.Next()
}),
redisClient,
nil,
)
return router

View File

@@ -30,6 +30,7 @@ func RegisterGatewayRoutes(
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
endpointNorm := handler.InboundEndpointMiddleware()
// 未分组 Key 拦截中间件(按协议格式区分错误响应)
requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter)
@@ -40,6 +41,7 @@ func RegisterGatewayRoutes(
gateway.Use(bodyLimit)
gateway.Use(clientRequestID)
gateway.Use(opsErrorLogger)
gateway.Use(endpointNorm)
gateway.Use(gin.HandlerFunc(apiKeyAuth))
gateway.Use(requireGroupAnthropic)
{
@@ -80,6 +82,7 @@ func RegisterGatewayRoutes(
gemini.Use(bodyLimit)
gemini.Use(clientRequestID)
gemini.Use(opsErrorLogger)
gemini.Use(endpointNorm)
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
gemini.Use(requireGroupGoogle)
{
@@ -90,11 +93,11 @@ func RegisterGatewayRoutes(
}
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
// OpenAI Chat Completions API不带v1前缀的别名
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
@@ -104,6 +107,7 @@ func RegisterGatewayRoutes(
antigravityV1.Use(bodyLimit)
antigravityV1.Use(clientRequestID)
antigravityV1.Use(opsErrorLogger)
antigravityV1.Use(endpointNorm)
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
antigravityV1.Use(requireGroupAnthropic)
@@ -118,6 +122,7 @@ func RegisterGatewayRoutes(
antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(clientRequestID)
antigravityV1Beta.Use(opsErrorLogger)
antigravityV1Beta.Use(endpointNorm)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
antigravityV1Beta.Use(requireGroupGoogle)
@@ -132,6 +137,7 @@ func RegisterGatewayRoutes(
soraV1.Use(soraBodyLimit)
soraV1.Use(clientRequestID)
soraV1.Use(opsErrorLogger)
soraV1.Use(endpointNorm)
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
soraV1.Use(requireGroupAnthropic)

View File

@@ -3,6 +3,7 @@ package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -12,6 +13,7 @@ func RegisterSoraClientRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
settingService *service.SettingService,
) {
if h.SoraClient == nil {
return
@@ -19,6 +21,7 @@ func RegisterSoraClientRoutes(
authenticated := v1.Group("/sora")
authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(middleware.BackendModeUserGuard(settingService))
{
authenticated.POST("/generate", h.SoraClient.Generate)
authenticated.GET("/generations", h.SoraClient.ListGenerations)

View File

@@ -3,6 +3,7 @@ package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -12,9 +13,11 @@ func RegisterUserRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
settingService *service.SettingService,
) {
authenticated := v1.Group("")
authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(middleware.BackendModeUserGuard(settingService))
{
// 用户接口
user := authenticated.Group("/user")

View File

@@ -3,6 +3,7 @@ package service
import (
"encoding/json"
"errors"
"hash/fnv"
"reflect"
"sort"
@@ -656,7 +657,7 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
// IsPoolMode 检查 API Key 账号是否启用池模式。
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
func (a *Account) IsPoolMode() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
if !a.IsAPIKeyOrBedrock() || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["pool_mode"]; ok {
@@ -771,11 +772,16 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
}
func (a *Account) IsBedrock() bool {
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeBedrock || a.Type == AccountTypeBedrockAPIKey)
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
}
func (a *Account) IsBedrockAPIKey() bool {
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrockAPIKey
return a.IsBedrock() && a.GetCredential("auth_mode") == "apikey"
}
// IsAPIKeyOrBedrock 返回账号类型是否支持配额和池模式等特性
func (a *Account) IsAPIKeyOrBedrock() bool {
return a.Type == AccountTypeAPIKey || a.Type == AccountTypeBedrock
}
func (a *Account) IsOpenAI() bool {
@@ -895,6 +901,22 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
return false
}
// IsOveragesEnabled 检查 Antigravity 账号是否启用 AI Credits 超量请求。
func (a *Account) IsOveragesEnabled() bool {
if a.Platform != PlatformAntigravity {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["allow_overages"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
//
// 新字段accounts.extra.openai_passthrough。
@@ -1274,6 +1296,240 @@ func (a *Account) getExtraTime(key string) time.Time {
return time.Time{}
}
// getExtraString 从 Extra 中读取指定 key 的字符串值
func (a *Account) getExtraString(key string) string {
if a.Extra == nil {
return ""
}
if v, ok := a.Extra[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// getExtraInt 从 Extra 中读取指定 key 的 int 值
func (a *Account) getExtraInt(key string) int {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra[key]; ok {
return int(parseExtraFloat64(v))
}
return 0
}
// GetQuotaDailyResetMode 获取日额度重置模式:"rolling"(默认)或 "fixed"
func (a *Account) GetQuotaDailyResetMode() string {
if m := a.getExtraString("quota_daily_reset_mode"); m == "fixed" {
return "fixed"
}
return "rolling"
}
// GetQuotaDailyResetHour 获取固定重置的小时0-23默认 0
func (a *Account) GetQuotaDailyResetHour() int {
return a.getExtraInt("quota_daily_reset_hour")
}
// GetQuotaWeeklyResetMode 获取周额度重置模式:"rolling"(默认)或 "fixed"
func (a *Account) GetQuotaWeeklyResetMode() string {
if m := a.getExtraString("quota_weekly_reset_mode"); m == "fixed" {
return "fixed"
}
return "rolling"
}
// GetQuotaWeeklyResetDay 获取固定重置的星期几0=周日, 1=周一, ..., 6=周六),默认 1周一
func (a *Account) GetQuotaWeeklyResetDay() int {
if a.Extra == nil {
return 1
}
if _, ok := a.Extra["quota_weekly_reset_day"]; !ok {
return 1
}
return a.getExtraInt("quota_weekly_reset_day")
}
// GetQuotaWeeklyResetHour 获取周配额固定重置的小时0-23默认 0
func (a *Account) GetQuotaWeeklyResetHour() int {
return a.getExtraInt("quota_weekly_reset_hour")
}
// GetQuotaResetTimezone 获取固定重置的时区名IANA默认 "UTC"
func (a *Account) GetQuotaResetTimezone() string {
if tz := a.getExtraString("quota_reset_timezone"); tz != "" {
return tz
}
return "UTC"
}
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
t := after.In(tz)
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
if !after.Before(today) {
return today.AddDate(0, 0, 1)
}
return today
}
// lastFixedDailyReset 计算 now 之前最近一次的每日固定重置时间点
func lastFixedDailyReset(hour int, tz *time.Location, now time.Time) time.Time {
t := now.In(tz)
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
if now.Before(today) {
return today.AddDate(0, 0, -1)
}
return today
}
// nextFixedWeeklyReset 计算在 after 之后的下一个每周固定重置时间点
// day: 0=Sunday, 1=Monday, ..., 6=Saturday
func nextFixedWeeklyReset(day, hour int, tz *time.Location, after time.Time) time.Time {
t := after.In(tz)
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
currentDay := int(todayReset.Weekday())
daysForward := (day - currentDay + 7) % 7
if daysForward == 0 && !after.Before(todayReset) {
daysForward = 7
}
return todayReset.AddDate(0, 0, daysForward)
}
// lastFixedWeeklyReset 计算 now 之前最近一次的每周固定重置时间点
func lastFixedWeeklyReset(day, hour int, tz *time.Location, now time.Time) time.Time {
t := now.In(tz)
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
currentDay := int(todayReset.Weekday())
daysBack := (currentDay - day + 7) % 7
if daysBack == 0 && now.Before(todayReset) {
daysBack = 7
}
return todayReset.AddDate(0, 0, -daysBack)
}
// isFixedDailyPeriodExpired 检查日配额是否在固定时间模式下已过期
func (a *Account) isFixedDailyPeriodExpired(periodStart time.Time) bool {
if periodStart.IsZero() {
return true
}
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
if err != nil {
tz = time.UTC
}
lastReset := lastFixedDailyReset(a.GetQuotaDailyResetHour(), tz, time.Now())
return periodStart.Before(lastReset)
}
// isFixedWeeklyPeriodExpired 检查周配额是否在固定时间模式下已过期
func (a *Account) isFixedWeeklyPeriodExpired(periodStart time.Time) bool {
if periodStart.IsZero() {
return true
}
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
if err != nil {
tz = time.UTC
}
lastReset := lastFixedWeeklyReset(a.GetQuotaWeeklyResetDay(), a.GetQuotaWeeklyResetHour(), tz, time.Now())
return periodStart.Before(lastReset)
}
// ComputeQuotaResetAt 根据当前配置计算并填充 extra 中的 quota_daily_reset_at / quota_weekly_reset_at
// 在保存账号配置时调用
func ComputeQuotaResetAt(extra map[string]any) {
now := time.Now()
tzName, _ := extra["quota_reset_timezone"].(string)
if tzName == "" {
tzName = "UTC"
}
tz, err := time.LoadLocation(tzName)
if err != nil {
tz = time.UTC
}
// 日配额固定重置时间
if mode, _ := extra["quota_daily_reset_mode"].(string); mode == "fixed" {
hour := int(parseExtraFloat64(extra["quota_daily_reset_hour"]))
if hour < 0 || hour > 23 {
hour = 0
}
resetAt := nextFixedDailyReset(hour, tz, now)
extra["quota_daily_reset_at"] = resetAt.UTC().Format(time.RFC3339)
} else {
delete(extra, "quota_daily_reset_at")
}
// 周配额固定重置时间
if mode, _ := extra["quota_weekly_reset_mode"].(string); mode == "fixed" {
day := 1 // 默认周一
if d, ok := extra["quota_weekly_reset_day"]; ok {
day = int(parseExtraFloat64(d))
}
if day < 0 || day > 6 {
day = 1
}
hour := int(parseExtraFloat64(extra["quota_weekly_reset_hour"]))
if hour < 0 || hour > 23 {
hour = 0
}
resetAt := nextFixedWeeklyReset(day, hour, tz, now)
extra["quota_weekly_reset_at"] = resetAt.UTC().Format(time.RFC3339)
} else {
delete(extra, "quota_weekly_reset_at")
}
}
// ValidateQuotaResetConfig 校验配额固定重置时间配置的合法性
func ValidateQuotaResetConfig(extra map[string]any) error {
if extra == nil {
return nil
}
// 校验时区
if tz, ok := extra["quota_reset_timezone"].(string); ok && tz != "" {
if _, err := time.LoadLocation(tz); err != nil {
return errors.New("invalid quota_reset_timezone: must be a valid IANA timezone name")
}
}
// 日配额重置模式
if mode, ok := extra["quota_daily_reset_mode"].(string); ok {
if mode != "rolling" && mode != "fixed" {
return errors.New("quota_daily_reset_mode must be 'rolling' or 'fixed'")
}
}
// 日配额重置小时
if v, ok := extra["quota_daily_reset_hour"]; ok {
hour := int(parseExtraFloat64(v))
if hour < 0 || hour > 23 {
return errors.New("quota_daily_reset_hour must be between 0 and 23")
}
}
// 周配额重置模式
if mode, ok := extra["quota_weekly_reset_mode"].(string); ok {
if mode != "rolling" && mode != "fixed" {
return errors.New("quota_weekly_reset_mode must be 'rolling' or 'fixed'")
}
}
// 周配额重置星期几
if v, ok := extra["quota_weekly_reset_day"]; ok {
day := int(parseExtraFloat64(v))
if day < 0 || day > 6 {
return errors.New("quota_weekly_reset_day must be between 0 (Sunday) and 6 (Saturday)")
}
}
// 周配额重置小时
if v, ok := extra["quota_weekly_reset_hour"]; ok {
hour := int(parseExtraFloat64(v))
if hour < 0 || hour > 23 {
return errors.New("quota_weekly_reset_hour must be between 0 and 23")
}
}
return nil
}
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
func (a *Account) HasAnyQuotaLimit() bool {
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
@@ -1296,14 +1552,26 @@ func (a *Account) IsQuotaExceeded() bool {
// 日额度(周期过期视为未超限,下次 increment 会重置)
if limit := a.GetQuotaDailyLimit(); limit > 0 {
start := a.getExtraTime("quota_daily_start")
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit {
var expired bool
if a.GetQuotaDailyResetMode() == "fixed" {
expired = a.isFixedDailyPeriodExpired(start)
} else {
expired = isPeriodExpired(start, 24*time.Hour)
}
if !expired && a.GetQuotaDailyUsed() >= limit {
return true
}
}
// 周额度
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
start := a.getExtraTime("quota_weekly_start")
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit {
var expired bool
if a.GetQuotaWeeklyResetMode() == "fixed" {
expired = a.isFixedWeeklyPeriodExpired(start)
} else {
expired = isPeriodExpired(start, 7*24*time.Hour)
}
if !expired && a.GetQuotaWeeklyUsed() >= limit {
return true
}
}

View File

@@ -0,0 +1,516 @@
//go:build unit
package service
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// nextFixedDailyReset
// ---------------------------------------------------------------------------
func TestNextFixedDailyReset_BeforeResetHour(t *testing.T) {
tz := time.UTC
// 2026-03-14 06:00 UTC, reset hour = 9
after := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
got := nextFixedDailyReset(9, tz, after)
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_AtResetHour(t *testing.T) {
tz := time.UTC
// Exactly at reset hour → should return tomorrow
after := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
got := nextFixedDailyReset(9, tz, after)
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_AfterResetHour(t *testing.T) {
tz := time.UTC
// After reset hour → should return tomorrow
after := time.Date(2026, 3, 14, 15, 30, 0, 0, tz)
got := nextFixedDailyReset(9, tz, after)
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_MidnightReset(t *testing.T) {
tz := time.UTC
// Reset at hour 0 (midnight), currently 23:59
after := time.Date(2026, 3, 14, 23, 59, 0, 0, tz)
got := nextFixedDailyReset(0, tz, after)
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_NonUTCTimezone(t *testing.T) {
tz, err := time.LoadLocation("Asia/Shanghai")
require.NoError(t, err)
// 2026-03-14 07:00 UTC = 2026-03-14 15:00 CST, reset hour = 9 (CST)
after := time.Date(2026, 3, 14, 7, 0, 0, 0, time.UTC)
got := nextFixedDailyReset(9, tz, after)
// Already past 9:00 CST today → tomorrow 9:00 CST = 2026-03-15 01:00 UTC
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// lastFixedDailyReset
// ---------------------------------------------------------------------------
func TestLastFixedDailyReset_BeforeResetHour(t *testing.T) {
tz := time.UTC
now := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
got := lastFixedDailyReset(9, tz, now)
// Before today's 9:00 → yesterday 9:00
want := time.Date(2026, 3, 13, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedDailyReset_AtResetHour(t *testing.T) {
tz := time.UTC
now := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
got := lastFixedDailyReset(9, tz, now)
// At exactly 9:00 → today 9:00
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedDailyReset_AfterResetHour(t *testing.T) {
tz := time.UTC
now := time.Date(2026, 3, 14, 15, 0, 0, 0, tz)
got := lastFixedDailyReset(9, tz, now)
// After 9:00 → today 9:00
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// nextFixedWeeklyReset
// ---------------------------------------------------------------------------
func TestNextFixedWeeklyReset_TargetDayAhead(t *testing.T) {
tz := time.UTC
// 2026-03-14 is Saturday (day=6), target = Monday (day=1), hour = 9
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday = 2026-03-16
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayToday_BeforeHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, before 9:00
after := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Today at 9:00
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayToday_AtHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday, target = Monday, hour = 9, exactly at 9:00
after := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday at 9:00
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayToday_AfterHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday, target = Monday, hour = 9, after 9:00
after := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday at 9:00
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayPast(t *testing.T) {
tz := time.UTC
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
after := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday = 2026-03-23
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_Sunday(t *testing.T) {
tz := time.UTC
// 2026-03-14 is Saturday (day=6), target = Sunday (day=0)
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
got := nextFixedWeeklyReset(0, 0, tz, after)
// Next Sunday = 2026-03-15
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// lastFixedWeeklyReset
// ---------------------------------------------------------------------------
func TestLastFixedWeeklyReset_SameDay_AfterHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, now = 15:00
now := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
got := lastFixedWeeklyReset(1, 9, tz, now)
// Today at 9:00
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedWeeklyReset_SameDay_BeforeHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday, target = Monday, hour = 9, now = 06:00
now := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
got := lastFixedWeeklyReset(1, 9, tz, now)
// Last Monday at 9:00 = 2026-03-09
want := time.Date(2026, 3, 9, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedWeeklyReset_DifferentDay(t *testing.T) {
tz := time.UTC
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
now := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
got := lastFixedWeeklyReset(1, 9, tz, now)
// Last Monday = 2026-03-16
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// isFixedDailyPeriodExpired
// ---------------------------------------------------------------------------
func TestIsFixedDailyPeriodExpired_ZeroPeriodStart(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
assert.True(t, a.isFixedDailyPeriodExpired(time.Time{}))
}
func TestIsFixedDailyPeriodExpired_NotExpired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started after the most recent reset → not expired
// (This test uses a time very close to "now", which is after the last reset)
periodStart := time.Now().Add(-1 * time.Minute)
assert.False(t, a.isFixedDailyPeriodExpired(periodStart))
}
func TestIsFixedDailyPeriodExpired_Expired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started 3 days ago → definitely expired
periodStart := time.Now().Add(-72 * time.Hour)
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
}
func TestIsFixedDailyPeriodExpired_InvalidTimezone(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "Invalid/Timezone",
}}
// Invalid timezone falls back to UTC
periodStart := time.Now().Add(-72 * time.Hour)
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
}
// ---------------------------------------------------------------------------
// isFixedWeeklyPeriodExpired
// ---------------------------------------------------------------------------
func TestIsFixedWeeklyPeriodExpired_ZeroPeriodStart(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
assert.True(t, a.isFixedWeeklyPeriodExpired(time.Time{}))
}
func TestIsFixedWeeklyPeriodExpired_NotExpired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started 1 minute ago → not expired
periodStart := time.Now().Add(-1 * time.Minute)
assert.False(t, a.isFixedWeeklyPeriodExpired(periodStart))
}
func TestIsFixedWeeklyPeriodExpired_Expired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started 10 days ago → definitely expired
periodStart := time.Now().Add(-240 * time.Hour)
assert.True(t, a.isFixedWeeklyPeriodExpired(periodStart))
}
// ---------------------------------------------------------------------------
// ValidateQuotaResetConfig
// ---------------------------------------------------------------------------
func TestValidateQuotaResetConfig_NilExtra(t *testing.T) {
assert.NoError(t, ValidateQuotaResetConfig(nil))
}
func TestValidateQuotaResetConfig_EmptyExtra(t *testing.T) {
assert.NoError(t, ValidateQuotaResetConfig(map[string]any{}))
}
func TestValidateQuotaResetConfig_ValidFixed(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(0),
"quota_reset_timezone": "Asia/Shanghai",
}
assert.NoError(t, ValidateQuotaResetConfig(extra))
}
func TestValidateQuotaResetConfig_ValidRolling(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "rolling",
"quota_weekly_reset_mode": "rolling",
}
assert.NoError(t, ValidateQuotaResetConfig(extra))
}
func TestValidateQuotaResetConfig_InvalidTimezone(t *testing.T) {
extra := map[string]any{
"quota_reset_timezone": "Not/A/Timezone",
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_reset_timezone")
}
func TestValidateQuotaResetConfig_InvalidDailyMode(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "invalid",
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_daily_reset_mode")
}
func TestValidateQuotaResetConfig_InvalidDailyHour_TooHigh(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_hour": float64(24),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
}
func TestValidateQuotaResetConfig_InvalidDailyHour_Negative(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_hour": float64(-1),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
}
func TestValidateQuotaResetConfig_InvalidWeeklyMode(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_mode": "unknown",
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_mode")
}
func TestValidateQuotaResetConfig_InvalidWeeklyDay_TooHigh(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_day": float64(7),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
}
func TestValidateQuotaResetConfig_InvalidWeeklyDay_Negative(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_day": float64(-1),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
}
func TestValidateQuotaResetConfig_InvalidWeeklyHour(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_hour": float64(25),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_hour")
}
func TestValidateQuotaResetConfig_BoundaryValues(t *testing.T) {
// All boundary values should be valid
extra := map[string]any{
"quota_daily_reset_hour": float64(23),
"quota_weekly_reset_day": float64(0), // Sunday
"quota_weekly_reset_hour": float64(0),
"quota_reset_timezone": "UTC",
}
assert.NoError(t, ValidateQuotaResetConfig(extra))
extra2 := map[string]any{
"quota_daily_reset_hour": float64(0),
"quota_weekly_reset_day": float64(6), // Saturday
"quota_weekly_reset_hour": float64(23),
}
assert.NoError(t, ValidateQuotaResetConfig(extra2))
}
// ---------------------------------------------------------------------------
// ComputeQuotaResetAt
// ---------------------------------------------------------------------------
func TestComputeQuotaResetAt_RollingMode_NoResetAt(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "rolling",
"quota_weekly_reset_mode": "rolling",
}
ComputeQuotaResetAt(extra)
_, hasDailyResetAt := extra["quota_daily_reset_at"]
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
assert.False(t, hasDailyResetAt, "rolling mode should not set quota_daily_reset_at")
assert.False(t, hasWeeklyResetAt, "rolling mode should not set quota_weekly_reset_at")
}
func TestComputeQuotaResetAt_RollingMode_ClearsExistingResetAt(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "rolling",
"quota_weekly_reset_mode": "rolling",
"quota_daily_reset_at": "2026-03-14T09:00:00Z",
"quota_weekly_reset_at": "2026-03-16T09:00:00Z",
}
ComputeQuotaResetAt(extra)
_, hasDailyResetAt := extra["quota_daily_reset_at"]
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
assert.False(t, hasDailyResetAt, "rolling mode should remove quota_daily_reset_at")
assert.False(t, hasWeeklyResetAt, "rolling mode should remove quota_weekly_reset_at")
}
func TestComputeQuotaResetAt_FixedDaily_SetsResetAt(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok, "quota_daily_reset_at should be set")
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Reset time should be in the future
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
// Reset hour should be 9 UTC
assert.Equal(t, 9, resetAt.UTC().Hour())
}
func TestComputeQuotaResetAt_FixedWeekly_SetsResetAt(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1), // Monday
"quota_weekly_reset_hour": float64(0),
"quota_reset_timezone": "UTC",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_weekly_reset_at"].(string)
require.True(t, ok, "quota_weekly_reset_at should be set")
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Reset time should be in the future
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
// Reset day should be Monday
assert.Equal(t, time.Monday, resetAt.UTC().Weekday())
}
func TestComputeQuotaResetAt_FixedDaily_WithTimezone(t *testing.T) {
tz, err := time.LoadLocation("Asia/Shanghai")
require.NoError(t, err)
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "Asia/Shanghai",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok)
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// In Shanghai timezone, the hour should be 9
assert.Equal(t, 9, resetAt.In(tz).Hour())
}
func TestComputeQuotaResetAt_DefaultTimezone(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(12),
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok)
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Default timezone is UTC
assert.Equal(t, 12, resetAt.UTC().Hour())
}
func TestComputeQuotaResetAt_InvalidHour_ClampedToZero(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(99),
"quota_reset_timezone": "UTC",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok)
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Invalid hour → clamped to 0
assert.Equal(t, 0, resetAt.UTC().Hour())
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"log"
"log/slog"
"math/rand/v2"
"net/http"
"strings"
@@ -44,6 +45,8 @@ type UsageLogRepository interface {
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
@@ -100,6 +103,7 @@ type antigravityUsageCache struct {
const (
apiCacheTTL = 3 * time.Minute
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL429 等错误缓存 1 分钟
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL可恢复错误
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
windowStatsCacheTTL = 1 * time.Minute
openAIProbeCacheTTL = 10 * time.Minute
@@ -108,11 +112,12 @@ const (
// UsageCache 封装账户使用量相关的缓存
type UsageCache struct {
apiCache sync.Map // accountID -> *apiUsageCache
windowStatsCache sync.Map // accountID -> *windowStatsCache
antigravityCache sync.Map // accountID -> *antigravityUsageCache
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存
openAIProbeCache sync.Map // accountID -> time.Time
apiCache sync.Map // accountID -> *apiUsageCache
windowStatsCache sync.Map // accountID -> *windowStatsCache
antigravityCache sync.Map // accountID -> *antigravityUsageCache
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存Anthropic
antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
openAIProbeCache sync.Map // accountID -> time.Time
}
// NewUsageCache 创建 UsageCache 实例
@@ -149,6 +154,25 @@ type AntigravityModelQuota struct {
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
}
// AntigravityModelDetail Antigravity 单个模型的详细能力信息
type AntigravityModelDetail struct {
DisplayName string `json:"display_name,omitempty"`
SupportsImages *bool `json:"supports_images,omitempty"`
SupportsThinking *bool `json:"supports_thinking,omitempty"`
ThinkingBudget *int `json:"thinking_budget,omitempty"`
Recommended *bool `json:"recommended,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"`
}
// AICredit 表示 Antigravity 账号的 AI Credits 余额信息。
type AICredit struct {
CreditType string `json:"credit_type,omitempty"`
Amount float64 `json:"amount,omitempty"`
MinimumBalance float64 `json:"minimum_balance,omitempty"`
}
// UsageInfo 账号使用量信息
type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
@@ -164,6 +188,36 @@ type UsageInfo struct {
// Antigravity 多模型配额
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
// Antigravity 账号级信息
SubscriptionTier string `json:"subscription_tier,omitempty"` // 归一化订阅等级: FREE/PRO/ULTRA/UNKNOWN
SubscriptionTierRaw string `json:"subscription_tier_raw,omitempty"` // 上游原始订阅等级名称
// Antigravity 模型详细能力信息(与 antigravity_quota 同 key
AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"`
// Antigravity AI Credits 余额
AICredits []AICredit `json:"ai_credits,omitempty"`
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
// Antigravity 账号是否被上游禁止 (HTTP 403)
IsForbidden bool `json:"is_forbidden,omitempty"`
ForbiddenReason string `json:"forbidden_reason,omitempty"`
ForbiddenType string `json:"forbidden_type,omitempty"` // "validation" / "violation" / "forbidden"
ValidationURL string `json:"validation_url,omitempty"` // 验证/申诉链接
// 状态标记(从 ForbiddenType / HTTP 错误码推导)
NeedsVerify bool `json:"needs_verify,omitempty"` // 需要人工验证forbidden_type=validation
IsBanned bool `json:"is_banned,omitempty"` // 账号被封forbidden_type=violation
NeedsReauth bool `json:"needs_reauth,omitempty"` // token 失效需重新授权401
// 错误码机器可读forbidden / unauthenticated / rate_limited / network_error
ErrorCode string `json:"error_code,omitempty"`
// 获取 usage 时的错误信息(降级返回,而非 500
Error string `json:"error,omitempty"`
}
// ClaudeUsageResponse Anthropic API返回的usage结构
@@ -648,34 +702,157 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
return &UsageInfo{UpdatedAt: &now}, nil
}
// 1. 检查缓存10 分钟)
// 1. 检查缓存
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
// 重新计算 RemainingSeconds
usage := cache.usageInfo
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
if cache, ok := cached.(*antigravityUsageCache); ok {
ttl := antigravityCacheTTL(cache.usageInfo)
if time.Since(cache.timestamp) < ttl {
usage := cache.usageInfo
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
}
return usage, nil
}
return usage, nil
}
}
// 2. 获取代理 URL
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account)
// 2. singleflight 防止并发击穿
flightKey := fmt.Sprintf("ag-usage:%d", account.ID)
result, flightErr, _ := s.cache.antigravityFlight.Do(flightKey, func() (any, error) {
// 再次检查缓存(等待期间可能已被填充)
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
if cache, ok := cached.(*antigravityUsageCache); ok {
ttl := antigravityCacheTTL(cache.usageInfo)
if time.Since(cache.timestamp) < ttl {
usage := cache.usageInfo
// 重新计算 RemainingSeconds避免返回过时的剩余秒数
recalcAntigravityRemainingSeconds(usage)
return usage, nil
}
}
}
// 3. 调用 API 获取额度
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL)
if err != nil {
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
}
// 使用独立 context避免调用方 cancel 导致所有共享 flight 的请求失败
fetchCtx, fetchCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer fetchCancel()
// 4. 缓存结果
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
usageInfo: result.UsageInfo,
timestamp: time.Now(),
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(fetchCtx, account)
fetchResult, err := s.antigravityQuotaFetcher.FetchQuota(fetchCtx, account, proxyURL)
if err != nil {
degraded := buildAntigravityDegradedUsage(err)
enrichUsageWithAccountError(degraded, account)
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
usageInfo: degraded,
timestamp: time.Now(),
})
return degraded, nil
}
enrichUsageWithAccountError(fetchResult.UsageInfo, account)
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
usageInfo: fetchResult.UsageInfo,
timestamp: time.Now(),
})
return fetchResult.UsageInfo, nil
})
return result.UsageInfo, nil
if flightErr != nil {
return nil, flightErr
}
usage, ok := result.(*UsageInfo)
if !ok || usage == nil {
now := time.Now()
return &UsageInfo{UpdatedAt: &now}, nil
}
return usage, nil
}
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
func recalcAntigravityRemainingSeconds(info *UsageInfo) {
if info == nil {
return
}
if info.FiveHour != nil && info.FiveHour.ResetsAt != nil {
remaining := int(time.Until(*info.FiveHour.ResetsAt).Seconds())
if remaining < 0 {
remaining = 0
}
info.FiveHour.RemainingSeconds = remaining
}
}
// antigravityCacheTTL 根据 UsageInfo 内容决定缓存 TTL
// 403 forbidden 状态稳定缓存与成功相同3 分钟);
// 其他错误401/网络)可能快速恢复,缓存 1 分钟。
func antigravityCacheTTL(info *UsageInfo) time.Duration {
if info == nil {
return antigravityErrorTTL
}
if info.IsForbidden {
return apiCacheTTL // 封号/验证状态不会很快变
}
if info.ErrorCode != "" || info.Error != "" {
return antigravityErrorTTL
}
return apiCacheTTL
}
// buildAntigravityDegradedUsage 从 FetchQuota 错误构建降级 UsageInfo
func buildAntigravityDegradedUsage(err error) *UsageInfo {
now := time.Now()
errMsg := fmt.Sprintf("usage API error: %v", err)
slog.Warn("antigravity usage fetch failed, returning degraded response", "error", err)
info := &UsageInfo{
UpdatedAt: &now,
Error: errMsg,
}
// 从错误信息推断 error_code 和状态标记
// 错误格式来自 antigravity/client.go: "fetchAvailableModels 失败 (HTTP %d): ..."
errStr := err.Error()
switch {
case strings.Contains(errStr, "HTTP 401") ||
strings.Contains(errStr, "UNAUTHENTICATED") ||
strings.Contains(errStr, "invalid_grant"):
info.ErrorCode = errorCodeUnauthenticated
info.NeedsReauth = true
case strings.Contains(errStr, "HTTP 429"):
info.ErrorCode = errorCodeRateLimited
default:
info.ErrorCode = errorCodeNetworkError
}
return info
}
// enrichUsageWithAccountError 结合账号错误状态修正 UsageInfo
// 场景 1成功路径FetchAvailableModels 正常返回,但账号已因 403 被标记为 error
//
// 需要在正常 usage 数据上附加 forbidden/validation 信息。
//
// 场景 2降级路径被封号的账号 OAuth token 失效FetchAvailableModels 返回 401
//
// 降级逻辑设置了 needs_reauth但账号实际是 403 封号/需验证,需覆盖为正确状态。
func enrichUsageWithAccountError(info *UsageInfo, account *Account) {
if info == nil || account == nil || account.Status != StatusError {
return
}
msg := strings.ToLower(account.ErrorMessage)
if !strings.Contains(msg, "403") && !strings.Contains(msg, "forbidden") &&
!strings.Contains(msg, "violation") && !strings.Contains(msg, "validation") {
return
}
fbType := classifyForbiddenType(account.ErrorMessage)
info.IsForbidden = true
info.ForbiddenType = fbType
info.ForbiddenReason = account.ErrorMessage
info.NeedsVerify = fbType == forbiddenTypeValidation
info.IsBanned = fbType == forbiddenTypeViolation
info.ValidationURL = extractValidationURL(account.ErrorMessage)
info.ErrorCode = errorCodeForbidden
info.NeedsReauth = false
}
// addWindowStats 为 usage 数据添加窗口期统计

View File

@@ -368,6 +368,10 @@ type ProxyExitInfoProber interface {
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
}
type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
type proxyQualityTarget struct {
Target string
URL string
@@ -445,10 +449,6 @@ type userGroupRateBatchReader interface {
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
}
type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAdminService creates a new AdminService
func NewAdminService(
userRepo UserRepository,
@@ -832,7 +832,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
subscriptionType = SubscriptionTypeStandard
}
// 限额字段:0 和 nil 都表示"无限制"
// 限额字段:nil/负数 表示"无限制"0 表示"不允许用量",正数表示具体限额
dailyLimit := normalizeLimit(input.DailyLimitUSD)
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
@@ -944,9 +944,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return group, nil
}
// normalizeLimit 将 0 或负数转换为 nil表示无限制
// normalizeLimit 将负数转换为 nil表示无限制0 保留(表示限额为零)
func normalizeLimit(limit *float64) *float64 {
if limit == nil || *limit <= 0 {
if limit == nil || *limit < 0 {
return nil
}
return limit
@@ -1058,16 +1058,11 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.SubscriptionType != "" {
group.SubscriptionType = input.SubscriptionType
}
// 限额字段:0 和 nil 都表示"无限制",正数表示具体限额
if input.DailyLimitUSD != nil {
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
}
if input.WeeklyLimitUSD != nil {
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
}
if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
}
// 限额字段:nil/负数 表示"无限制"0 表示"不允许用量",正数表示具体限额
// 前端始终发送这三个字段,无需 nil 守卫
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
// 图片生成计费配置:负数表示清除(使用默认价格)
if input.ImagePrice1K != nil {
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
@@ -1462,6 +1457,13 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
Status: StatusActive,
Schedulable: true,
}
// 预计算固定时间重置的下次重置时间
if account.Extra != nil {
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
return nil, err
}
ComputeQuotaResetAt(account.Extra)
}
if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
expiresAt := time.Unix(*input.ExpiresAt, 0)
account.ExpiresAt = &expiresAt
@@ -1514,6 +1516,7 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if err != nil {
return nil, err
}
wasOveragesEnabled := account.IsOveragesEnabled()
if input.Name != "" {
account.Name = input.Name
@@ -1535,6 +1538,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
}
}
account.Extra = input.Extra
if account.Platform == PlatformAntigravity && wasOveragesEnabled && !account.IsOveragesEnabled() {
delete(account.Extra, "antigravity_credits_overages") // 清理旧版 overages 运行态
// 清除 AICredits 限流 key
if rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any); ok {
delete(rawLimits, creditsExhaustedKey)
}
}
if account.Platform == PlatformAntigravity && !wasOveragesEnabled && account.IsOveragesEnabled() {
delete(account.Extra, modelRateLimitsKey)
delete(account.Extra, "antigravity_credits_overages") // 清理旧版 overages 运行态
}
// 校验并预计算固定时间重置的下次重置时间
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
return nil, err
}
ComputeQuotaResetAt(account.Extra)
}
if input.ProxyID != nil {
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)

View File

@@ -0,0 +1,123 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type updateAccountOveragesRepoStub struct {
mockAccountRepoForGemini
account *Account
updateCalls int
}
func (r *updateAccountOveragesRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
return r.account, nil
}
func (r *updateAccountOveragesRepoStub) Update(ctx context.Context, account *Account) error {
r.updateCalls++
r.account = account
return nil
}
func TestUpdateAccount_DisableOveragesClearsAICreditsKey(t *testing.T) {
accountID := int64(101)
repo := &updateAccountOveragesRepoStub{
account: &Account{
ID: accountID,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Extra: map[string]any{
"allow_overages": true,
"mixed_scheduling": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
},
creditsExhaustedKey: map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
},
}
svc := &adminServiceImpl{accountRepo: repo}
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
Extra: map[string]any{
"mixed_scheduling": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
},
creditsExhaustedKey: map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
})
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, 1, repo.updateCalls)
require.False(t, updated.IsOveragesEnabled())
// 关闭 overages 后AICredits key 应被清除
rawLimits, ok := repo.account.Extra[modelRateLimitsKey].(map[string]any)
if ok {
_, exists := rawLimits[creditsExhaustedKey]
require.False(t, exists, "关闭 overages 时应清除 AICredits 限流 key")
}
// 普通模型限流应保留
require.True(t, ok)
_, exists := rawLimits["claude-sonnet-4-5"]
require.True(t, exists, "普通模型限流应保留")
}
func TestUpdateAccount_EnableOveragesClearsModelRateLimitsBeforePersist(t *testing.T) {
accountID := int64(102)
repo := &updateAccountOveragesRepoStub{
account: &Account{
ID: accountID,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Extra: map[string]any{
"mixed_scheduling": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
},
},
},
},
}
svc := &adminServiceImpl{accountRepo: repo}
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
Extra: map[string]any{
"mixed_scheduling": true,
"allow_overages": true,
},
})
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, 1, repo.updateCalls)
require.True(t, updated.IsOveragesEnabled())
_, exists := repo.account.Extra[modelRateLimitsKey]
require.False(t, exists, "开启 overages 时应在持久化前清掉旧模型限流")
}

View File

@@ -0,0 +1,234 @@
package service
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
const (
// creditsExhaustedKey 是 model_rate_limits 中标记积分耗尽的特殊 key。
// 与普通模型限流完全同构:通过 SetModelRateLimit / isRateLimitActiveForKey 读写。
creditsExhaustedKey = "AICredits"
creditsExhaustedDuration = 5 * time.Hour
)
type antigravity429Category string
const (
antigravity429Unknown antigravity429Category = "unknown"
antigravity429RateLimited antigravity429Category = "rate_limited"
antigravity429QuotaExhausted antigravity429Category = "quota_exhausted"
)
var (
antigravityQuotaExhaustedKeywords = []string{
"quota_exhausted",
"quota exhausted",
}
creditsExhaustedKeywords = []string{
"google_one_ai",
"insufficient credit",
"insufficient credits",
"not enough credit",
"not enough credits",
"credit exhausted",
"credits exhausted",
"credit balance",
"minimumcreditamountforusage",
"minimum credit amount for usage",
"minimum credit",
}
)
// isCreditsExhausted 检查账号的 AICredits 限流 key 是否生效(积分是否耗尽)。
func (a *Account) isCreditsExhausted() bool {
if a == nil {
return false
}
return a.isRateLimitActiveForKey(creditsExhaustedKey)
}
// setCreditsExhausted 标记账号积分耗尽:写入 model_rate_limits["AICredits"] + 更新缓存。
func (s *AntigravityGatewayService) setCreditsExhausted(ctx context.Context, account *Account) {
if account == nil || account.ID == 0 {
return
}
resetAt := time.Now().Add(creditsExhaustedDuration)
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, creditsExhaustedKey, resetAt); err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "set credits exhausted failed: account=%d err=%v", account.ID, err)
return
}
s.updateAccountModelRateLimitInCache(ctx, account, creditsExhaustedKey, resetAt)
logger.LegacyPrintf("service.antigravity_gateway", "credits_exhausted_marked account=%d reset_at=%s",
account.ID, resetAt.UTC().Format(time.RFC3339))
}
// clearCreditsExhausted 清除账号的 AICredits 限流 key。
func (s *AntigravityGatewayService) clearCreditsExhausted(ctx context.Context, account *Account) {
if account == nil || account.ID == 0 || account.Extra == nil {
return
}
rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any)
if !ok {
return
}
if _, exists := rawLimits[creditsExhaustedKey]; !exists {
return
}
delete(rawLimits, creditsExhaustedKey)
account.Extra[modelRateLimitsKey] = rawLimits
if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{
modelRateLimitsKey: rawLimits,
}); err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "clear credits exhausted failed: account=%d err=%v", account.ID, err)
}
}
// classifyAntigravity429 将 Antigravity 的 429 响应归类为配额耗尽、限流或未知。
func classifyAntigravity429(body []byte) antigravity429Category {
if len(body) == 0 {
return antigravity429Unknown
}
lowerBody := strings.ToLower(string(body))
for _, keyword := range antigravityQuotaExhaustedKeywords {
if strings.Contains(lowerBody, keyword) {
return antigravity429QuotaExhausted
}
}
if info := parseAntigravitySmartRetryInfo(body); info != nil && !info.IsModelCapacityExhausted {
return antigravity429RateLimited
}
return antigravity429Unknown
}
// injectEnabledCreditTypes 在已序列化的 v1internal JSON body 中注入 AI Credits 类型。
func injectEnabledCreditTypes(body []byte) []byte {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil
}
payload["enabledCreditTypes"] = []string{"GOOGLE_ONE_AI"}
result, err := json.Marshal(payload)
if err != nil {
return nil
}
return result
}
// resolveCreditsOveragesModelKey 解析当前请求对应的 overages 状态模型 key。
func resolveCreditsOveragesModelKey(ctx context.Context, account *Account, upstreamModelName, requestedModel string) string {
modelKey := strings.TrimSpace(upstreamModelName)
if modelKey != "" {
return modelKey
}
if account == nil {
return ""
}
modelKey = resolveFinalAntigravityModelKey(ctx, account, requestedModel)
if strings.TrimSpace(modelKey) != "" {
return modelKey
}
return resolveAntigravityModelKey(requestedModel)
}
// shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。
func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool {
if reqErr != nil || resp == nil {
return false
}
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout {
return false
}
if isURLLevelRateLimit(respBody) {
return false
}
if info := parseAntigravitySmartRetryInfo(respBody); info != nil {
return false
}
bodyLower := strings.ToLower(string(respBody))
for _, keyword := range creditsExhaustedKeywords {
if strings.Contains(bodyLower, keyword) {
return true
}
}
return false
}
type creditsOveragesRetryResult struct {
handled bool
resp *http.Response
}
// attemptCreditsOveragesRetry 在确认免费配额耗尽后,尝试注入 AI Credits 继续请求。
func (s *AntigravityGatewayService) attemptCreditsOveragesRetry(
p antigravityRetryLoopParams,
baseURL string,
modelName string,
waitDuration time.Duration,
originalStatusCode int,
respBody []byte,
) *creditsOveragesRetryResult {
creditsBody := injectEnabledCreditTypes(p.body)
if creditsBody == nil {
return &creditsOveragesRetryResult{handled: false}
}
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 credit_overages_retry model=%s account=%d (injecting enabledCreditTypes)",
p.prefix, modelKey, p.account.ID)
creditsReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, creditsBody)
if err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d build_request_err=%v",
p.prefix, modelKey, p.account.ID, err)
return &creditsOveragesRetryResult{handled: true}
}
creditsResp, err := p.httpUpstream.Do(creditsReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if err == nil && creditsResp != nil && creditsResp.StatusCode < 400 {
s.clearCreditsExhausted(p.ctx, p.account)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d credit_overages_success model=%s account=%d",
p.prefix, creditsResp.StatusCode, modelKey, p.account.ID)
return &creditsOveragesRetryResult{handled: true, resp: creditsResp}
}
s.handleCreditsRetryFailure(p.ctx, p.prefix, modelKey, p.account, creditsResp, err)
return &creditsOveragesRetryResult{handled: true}
}
func (s *AntigravityGatewayService) handleCreditsRetryFailure(
ctx context.Context,
prefix string,
modelKey string,
account *Account,
creditsResp *http.Response,
reqErr error,
) {
var creditsRespBody []byte
creditsStatusCode := 0
if creditsResp != nil {
creditsStatusCode = creditsResp.StatusCode
if creditsResp.Body != nil {
creditsRespBody, _ = io.ReadAll(io.LimitReader(creditsResp.Body, 64<<10))
_ = creditsResp.Body.Close()
}
}
if shouldMarkCreditsExhausted(creditsResp, creditsRespBody, reqErr) && account != nil {
s.setCreditsExhausted(ctx, account)
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=true status=%d body=%s",
prefix, modelKey, account.ID, creditsStatusCode, truncateForLog(creditsRespBody, 200))
return
}
if account != nil {
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=false status=%d err=%v body=%s",
prefix, modelKey, account.ID, creditsStatusCode, reqErr, truncateForLog(creditsRespBody, 200))
}
}

View File

@@ -0,0 +1,538 @@
//go:build unit
package service
import (
"bytes"
"context"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
func TestClassifyAntigravity429(t *testing.T) {
t.Run("明确配额耗尽", func(t *testing.T) {
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body))
})
t.Run("结构化限流", func(t *testing.T) {
body := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
require.Equal(t, antigravity429RateLimited, classifyAntigravity429(body))
})
t.Run("未知429", func(t *testing.T) {
body := []byte(`{"error":{"message":"too many requests"}}`)
require.Equal(t, antigravity429Unknown, classifyAntigravity429(body))
})
}
func TestIsCreditsExhausted_UsesAICreditsKey(t *testing.T) {
t.Run("无 AICredits key 则积分可用", func(t *testing.T) {
account := &Account{
ID: 1,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
},
}
require.False(t, account.isCreditsExhausted())
})
t.Run("AICredits key 生效则积分耗尽", func(t *testing.T) {
account := &Account{
ID: 2,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
modelRateLimitsKey: map[string]any{
creditsExhaustedKey: map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
}
require.True(t, account.isCreditsExhausted())
})
t.Run("AICredits key 过期则积分可用", func(t *testing.T) {
account := &Account{
ID: 3,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
modelRateLimitsKey: map[string]any{
creditsExhaustedKey: map[string]any{
"rate_limited_at": time.Now().Add(-6 * time.Hour).UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
}
require.False(t, account.isCreditsExhausted())
})
}
func TestHandleSmartRetry_QuotaExhausted_UsesCreditsAndStoresIndependentState(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 101,
Name: "acc-101",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
},
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-opus-4-6": "claude-sonnet-4-5",
},
},
}
respBody := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-opus-4-6","request":{}}`),
httpUpstream: upstream,
accountRepo: repo,
requestedModel: "claude-opus-4-6",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp)
require.Nil(t, result.switchError)
require.Len(t, upstream.requestBodies, 1)
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
require.Empty(t, repo.modelRateLimitCalls, "overages 成功后不应写入普通 model_rate_limits")
}
func TestHandleSmartRetry_RateLimited_DoesNotUseCredits(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 102,
Name: "acc-102",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
},
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp)
require.Len(t, upstream.requestBodies, 1)
require.NotContains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
require.Empty(t, repo.extraUpdateCalls)
require.Empty(t, repo.modelRateLimitCalls)
}
func TestAntigravityRetryLoop_ModelRateLimited_InjectsCredits(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
defer func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvailability
}()
antigravity.BaseURLs = []string{"https://ag-1.test"}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
},
},
errors: []error{nil},
}
// 模型已限流 + overages 启用 + 无 AICredits key → 应直接注入积分
account := &Account{
ID: 103,
Name: "acc-103",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"allow_overages": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.NoError(t, err)
require.NotNil(t, result)
require.Len(t, upstream.requestBodies, 1)
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
}
func TestAntigravityRetryLoop_CreditsExhausted_DoesNotInject(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
defer func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvailability
}()
antigravity.BaseURLs = []string{"https://ag-1.test"}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
// 模型限流 + overages 启用 + AICredits key 生效 → 不应注入积分,应切号
account := &Account{
ID: 104,
Name: "acc-104",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"allow_overages": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
},
creditsExhaustedKey: map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{}
_, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
// 模型限流 + 积分耗尽 → 应触发切号错误
require.Error(t, err)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
}
func TestAntigravityRetryLoop_CreditErrorMarksExhausted(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
defer func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvailability
}()
antigravity.BaseURLs = []string{"https://ag-1.test"}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
repo := &stubAntigravityAccountRepo{}
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusForbidden,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Insufficient GOOGLE_ONE_AI credits"}}`)),
},
},
errors: []error{nil},
}
// 模型限流 + overages 启用 + 积分可用 → 注入积分但上游返回积分不足
account := &Account{
ID: 105,
Name: "acc-105",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"allow_overages": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{accountRepo: repo}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
httpUpstream: upstream,
accountRepo: repo,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.NoError(t, err)
require.NotNil(t, result)
// 验证 AICredits key 已通过 SetModelRateLimit 写入数据库
require.Len(t, repo.modelRateLimitCalls, 1, "应通过 SetModelRateLimit 写入 AICredits key")
require.Equal(t, creditsExhaustedKey, repo.modelRateLimitCalls[0].modelKey)
}
func TestShouldMarkCreditsExhausted(t *testing.T) {
t.Run("reqErr 不为 nil 时不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusForbidden}
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), io.ErrUnexpectedEOF))
})
t.Run("resp 为 nil 时不标记", func(t *testing.T) {
require.False(t, shouldMarkCreditsExhausted(nil, []byte(`{"error":"Insufficient credits"}`), nil))
})
t.Run("5xx 响应不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusInternalServerError}
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil))
})
t.Run("408 RequestTimeout 不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusRequestTimeout}
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil))
})
t.Run("URL 级限流不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
body := []byte(`{"error":{"message":"Resource has been exhausted"}}`)
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
})
t.Run("结构化限流不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`)
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
})
t.Run("含 credits 关键词时标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusForbidden}
for _, keyword := range []string{
"Insufficient GOOGLE_ONE_AI credits",
"insufficient credit balance",
"not enough credits for this request",
"Credits exhausted",
"minimumCreditAmountForUsage requirement not met",
} {
body := []byte(`{"error":{"message":"` + keyword + `"}}`)
require.True(t, shouldMarkCreditsExhausted(resp, body, nil), "should mark for keyword: %s", keyword)
}
})
t.Run("无 credits 关键词时不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusForbidden}
body := []byte(`{"error":{"message":"permission denied"}}`)
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
})
}
func TestInjectEnabledCreditTypes(t *testing.T) {
t.Run("正常 JSON 注入成功", func(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-5","request":{}}`)
result := injectEnabledCreditTypes(body)
require.NotNil(t, result)
require.Contains(t, string(result), `"enabledCreditTypes"`)
require.Contains(t, string(result), `GOOGLE_ONE_AI`)
})
t.Run("非法 JSON 返回 nil", func(t *testing.T) {
require.Nil(t, injectEnabledCreditTypes([]byte(`not json`)))
})
t.Run("空 body 返回 nil", func(t *testing.T) {
require.Nil(t, injectEnabledCreditTypes([]byte{}))
})
t.Run("已有 enabledCreditTypes 会被覆盖", func(t *testing.T) {
body := []byte(`{"enabledCreditTypes":["OLD"],"model":"test"}`)
result := injectEnabledCreditTypes(body)
require.NotNil(t, result)
require.Contains(t, string(result), `GOOGLE_ONE_AI`)
require.NotContains(t, string(result), `OLD`)
})
}
func TestClearCreditsExhausted(t *testing.T) {
t.Run("account 为 nil 不操作", func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
svc.clearCreditsExhausted(context.Background(), nil)
require.Empty(t, repo.extraUpdateCalls)
})
t.Run("Extra 为 nil 不操作", func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
svc.clearCreditsExhausted(context.Background(), &Account{ID: 1})
require.Empty(t, repo.extraUpdateCalls)
})
t.Run("无 modelRateLimitsKey 不操作", func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
svc.clearCreditsExhausted(context.Background(), &Account{
ID: 1,
Extra: map[string]any{"some_key": "value"},
})
require.Empty(t, repo.extraUpdateCalls)
})
t.Run("无 AICredits key 不操作", func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
svc.clearCreditsExhausted(context.Background(), &Account{
ID: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
},
},
},
})
require.Empty(t, repo.extraUpdateCalls)
})
t.Run("有 AICredits key 时删除并调用 UpdateExtra", func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{
ID: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
},
creditsExhaustedKey: map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
}
svc.clearCreditsExhausted(context.Background(), account)
require.Len(t, repo.extraUpdateCalls, 1)
// AICredits key 应被删除
rawLimits := account.Extra[modelRateLimitsKey].(map[string]any)
_, exists := rawLimits[creditsExhaustedKey]
require.False(t, exists, "AICredits key 应被删除")
// 普通模型限流应保留
_, exists = rawLimits["claude-sonnet-4-5"]
require.True(t, exists, "普通模型限流应保留")
})
}

View File

@@ -188,9 +188,29 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
return &smartRetryResult{action: smartRetryActionContinueURL}
}
category := antigravity429Unknown
if resp.StatusCode == http.StatusTooManyRequests {
category = classifyAntigravity429(respBody)
}
// 判断是否触发智能重试
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(p.account, respBody)
// AI Credits 超量请求:
// 仅在上游明确返回免费配额耗尽时才允许切换到 credits。
if resp.StatusCode == http.StatusTooManyRequests &&
category == antigravity429QuotaExhausted &&
p.account.IsOveragesEnabled() &&
!p.account.isCreditsExhausted() {
result := s.attemptCreditsOveragesRetry(p, baseURL, modelName, waitDuration, resp.StatusCode, respBody)
if result.handled && result.resp != nil {
return &smartRetryResult{
action: smartRetryActionBreakWithResp,
resp: result.resp,
}
}
}
// 情况1: retryDelay >= 阈值,限流模型并切换账号
if shouldRateLimitModel {
// 单账号 503 退避重试模式:不设限流、不切换账号,改为原地等待+重试
@@ -532,14 +552,31 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
// antigravityRetryLoop 执行带 URL fallback 的重试循环
func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
// 预检查:模型限流 + overages 启用 + 积分未耗尽 → 直接注入 AI Credits
overagesInjected := false
if p.requestedModel != "" && p.account.Platform == PlatformAntigravity &&
p.account.IsOveragesEnabled() && !p.account.isCreditsExhausted() &&
p.account.isModelRateLimitedWithContext(p.ctx, p.requestedModel) {
if creditsBody := injectEnabledCreditTypes(p.body); creditsBody != nil {
p.body = creditsBody
overagesInjected = true
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: model_rate_limited_credits_inject model=%s account=%d (injecting enabledCreditTypes)",
p.prefix, p.requestedModel, p.account.ID)
}
}
// 预检查:如果账号已限流,直接返回切换信号
if p.requestedModel != "" {
if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 {
// 单账号 503 退避重试模式:跳过限流预检查,直接发请求
// 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。
// 如果上游确实还不可用handleSmartRetry → handleSingleAccountRetryInPlace
// 会在 Service 层原地等待+重试,不需要在预检查这里等。
if isSingleAccountRetry(p.ctx) {
// 已注入积分的请求不再受普通模型限流预检查阻断
if overagesInjected {
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: credits_injected_ignore_rate_limit remaining=%v model=%s account=%d",
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
} else if isSingleAccountRetry(p.ctx) {
// 单账号 503 退避重试模式:跳过限流预检查,直接发请求。
// 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。
// 如果上游确实还不可用handleSmartRetry → handleSingleAccountRetryInPlace
// 会在 Service 层原地等待+重试,不需要在预检查这里等。
logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)",
p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID)
} else {
@@ -631,6 +668,15 @@ urlFallbackLoop:
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if overagesInjected && shouldMarkCreditsExhausted(resp, respBody, nil) {
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, "", p.requestedModel)
s.handleCreditsRetryFailure(p.ctx, p.prefix, modelKey, p.account, &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}, nil)
}
// ★ 统一入口:自定义错误码 + 临时不可调度
if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
if policyErr != nil {

View File

@@ -2,12 +2,29 @@ package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"regexp"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
const (
forbiddenTypeValidation = "validation"
forbiddenTypeViolation = "violation"
forbiddenTypeForbidden = "forbidden"
// 机器可读的错误码
errorCodeForbidden = "forbidden"
errorCodeUnauthenticated = "unauthenticated"
errorCodeRateLimited = "rate_limited"
errorCodeNetworkError = "network_error"
)
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
type AntigravityQuotaFetcher struct {
proxyRepo ProxyRepository
@@ -40,11 +57,32 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
// 调用 API 获取配额
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
if err != nil {
// 403 Forbidden: 不报错,返回 is_forbidden 标记
var forbiddenErr *antigravity.ForbiddenError
if errors.As(err, &forbiddenErr) {
now := time.Now()
fbType := classifyForbiddenType(forbiddenErr.Body)
return &QuotaResult{
UsageInfo: &UsageInfo{
UpdatedAt: &now,
IsForbidden: true,
ForbiddenReason: forbiddenErr.Body,
ForbiddenType: fbType,
ValidationURL: extractValidationURL(forbiddenErr.Body),
NeedsVerify: fbType == forbiddenTypeValidation,
IsBanned: fbType == forbiddenTypeViolation,
ErrorCode: errorCodeForbidden,
},
}, nil
}
return nil, err
}
// 调用 LoadCodeAssist 获取订阅等级和 AI Credits 余额(非关键路径,失败不影响主流程)
tierRaw, tierNormalized, loadResp := f.fetchSubscriptionTier(ctx, client, accessToken)
// 转换为 UsageInfo
usageInfo := f.buildUsageInfo(modelsResp)
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized, loadResp)
return &QuotaResult{
UsageInfo: usageInfo,
@@ -52,15 +90,53 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
}, nil
}
// buildUsageInfo 将 API 响应转换为 UsageInfo
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo {
now := time.Now()
info := &UsageInfo{
UpdatedAt: &now,
AntigravityQuota: make(map[string]*AntigravityModelQuota),
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串。
// 同时返回 LoadCodeAssistResponse以便提取 AI Credits 余额。
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string, loadResp *antigravity.LoadCodeAssistResponse) {
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
if err != nil {
slog.Warn("failed to fetch subscription tier", "error", err)
return "", "", nil
}
if loadResp == nil {
return "", "", nil
}
// 遍历所有模型,填充 AntigravityQuota
raw = loadResp.GetTier() // 已有方法paidTier > currentTier
normalized = normalizeTier(raw)
return raw, normalized, loadResp
}
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
func normalizeTier(raw string) string {
if raw == "" {
return ""
}
lower := strings.ToLower(raw)
switch {
case strings.Contains(lower, "ultra"):
return "ULTRA"
case strings.Contains(lower, "pro"):
return "PRO"
case strings.Contains(lower, "free"):
return "FREE"
default:
return "UNKNOWN"
}
}
// buildUsageInfo 将 API 响应转换为 UsageInfo。
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string, loadResp *antigravity.LoadCodeAssistResponse) *UsageInfo {
now := time.Now()
info := &UsageInfo{
UpdatedAt: &now,
AntigravityQuota: make(map[string]*AntigravityModelQuota),
AntigravityQuotaDetails: make(map[string]*AntigravityModelDetail),
SubscriptionTier: tierNormalized,
SubscriptionTierRaw: tierRaw,
}
// 遍历所有模型,填充 AntigravityQuota 和 AntigravityQuotaDetails
for modelName, modelInfo := range modelsResp.Models {
if modelInfo.QuotaInfo == nil {
continue
@@ -73,6 +149,27 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv
Utilization: utilization,
ResetTime: modelInfo.QuotaInfo.ResetTime,
}
// 填充模型详细能力信息
detail := &AntigravityModelDetail{
DisplayName: modelInfo.DisplayName,
SupportsImages: modelInfo.SupportsImages,
SupportsThinking: modelInfo.SupportsThinking,
ThinkingBudget: modelInfo.ThinkingBudget,
Recommended: modelInfo.Recommended,
MaxTokens: modelInfo.MaxTokens,
MaxOutputTokens: modelInfo.MaxOutputTokens,
SupportedMimeTypes: modelInfo.SupportedMimeTypes,
}
info.AntigravityQuotaDetails[modelName] = detail
}
// 废弃模型转发规则
if len(modelsResp.DeprecatedModelIDs) > 0 {
info.ModelForwardingRules = make(map[string]string, len(modelsResp.DeprecatedModelIDs))
for oldID, deprecated := range modelsResp.DeprecatedModelIDs {
info.ModelForwardingRules[oldID] = deprecated.NewModelID
}
}
// 同时设置 FiveHour 用于兼容展示(取主要模型)
@@ -94,6 +191,16 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv
}
}
if loadResp != nil {
for _, credit := range loadResp.GetAvailableCredits() {
info.AICredits = append(info.AICredits, AICredit{
CreditType: credit.CreditType,
Amount: credit.GetAmount(),
MinimumBalance: credit.GetMinimumAmount(),
})
}
}
return info
}
@@ -108,3 +215,58 @@ func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Acco
}
return proxy.URL()
}
// classifyForbiddenType 根据 403 响应体判断禁止类型
func classifyForbiddenType(body string) string {
lower := strings.ToLower(body)
switch {
case strings.Contains(lower, "validation_required") ||
strings.Contains(lower, "verify your account") ||
strings.Contains(lower, "validation_url"):
return forbiddenTypeValidation
case strings.Contains(lower, "terms of service") ||
strings.Contains(lower, "violation"):
return forbiddenTypeViolation
default:
return forbiddenTypeForbidden
}
}
// urlPattern 用于从 403 响应体中提取 URL降级方案
var urlPattern = regexp.MustCompile(`https://[^\s"'\\]+`)
// extractValidationURL 从 403 响应 JSON 中提取验证/申诉链接
func extractValidationURL(body string) string {
// 1. 尝试结构化 JSON 提取: /error/details[*]/metadata/validation_url 或 appeal_url
var parsed struct {
Error struct {
Details []struct {
Metadata map[string]string `json:"metadata"`
} `json:"details"`
} `json:"error"`
}
if json.Unmarshal([]byte(body), &parsed) == nil {
for _, detail := range parsed.Error.Details {
if u := detail.Metadata["validation_url"]; u != "" {
return u
}
if u := detail.Metadata["appeal_url"]; u != "" {
return u
}
}
}
// 2. 降级:正则匹配 URL
lower := strings.ToLower(body)
if !strings.Contains(lower, "validation") &&
!strings.Contains(lower, "verify") &&
!strings.Contains(lower, "appeal") {
return ""
}
// 先解码常见转义再匹配
normalized := strings.ReplaceAll(body, `\u0026`, "&")
if m := urlPattern.FindString(normalized); m != "" {
return m
}
return ""
}

View File

@@ -0,0 +1,522 @@
//go:build unit
package service
import (
"errors"
"testing"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// ---------------------------------------------------------------------------
// normalizeTier
// ---------------------------------------------------------------------------
func TestNormalizeTier(t *testing.T) {
tests := []struct {
name string
raw string
expected string
}{
{name: "empty string", raw: "", expected: ""},
{name: "free-tier", raw: "free-tier", expected: "FREE"},
{name: "g1-pro-tier", raw: "g1-pro-tier", expected: "PRO"},
{name: "g1-ultra-tier", raw: "g1-ultra-tier", expected: "ULTRA"},
{name: "unknown-something", raw: "unknown-something", expected: "UNKNOWN"},
{name: "Google AI Pro contains pro keyword", raw: "Google AI Pro", expected: "PRO"},
{name: "case insensitive FREE", raw: "FREE-TIER", expected: "FREE"},
{name: "case insensitive Ultra", raw: "Ultra Plan", expected: "ULTRA"},
{name: "arbitrary unrecognized string", raw: "enterprise-custom", expected: "UNKNOWN"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeTier(tt.raw)
require.Equal(t, tt.expected, got, "normalizeTier(%q)", tt.raw)
})
}
}
// ---------------------------------------------------------------------------
// buildUsageInfo
// ---------------------------------------------------------------------------
func aqfBoolPtr(v bool) *bool { return &v }
func aqfIntPtr(v int) *int { return &v }
func TestBuildUsageInfo_BasicModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.75,
ResetTime: "2026-03-08T12:00:00Z",
},
DisplayName: "Claude Sonnet 4",
SupportsImages: aqfBoolPtr(true),
SupportsThinking: aqfBoolPtr(false),
ThinkingBudget: aqfIntPtr(0),
Recommended: aqfBoolPtr(true),
MaxTokens: aqfIntPtr(200000),
MaxOutputTokens: aqfIntPtr(16384),
SupportedMimeTypes: map[string]bool{
"image/png": true,
"image/jpeg": true,
},
},
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.50,
ResetTime: "2026-03-08T15:00:00Z",
},
DisplayName: "Gemini 2.5 Pro",
MaxTokens: aqfIntPtr(1000000),
MaxOutputTokens: aqfIntPtr(65536),
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", nil)
// 基本字段
require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set")
require.Equal(t, "PRO", info.SubscriptionTier)
require.Equal(t, "g1-pro-tier", info.SubscriptionTierRaw)
// AntigravityQuota
require.Len(t, info.AntigravityQuota, 2)
sonnetQuota := info.AntigravityQuota["claude-sonnet-4-20250514"]
require.NotNil(t, sonnetQuota)
require.Equal(t, 25, sonnetQuota.Utilization) // (1 - 0.75) * 100 = 25
require.Equal(t, "2026-03-08T12:00:00Z", sonnetQuota.ResetTime)
geminiQuota := info.AntigravityQuota["gemini-2.5-pro"]
require.NotNil(t, geminiQuota)
require.Equal(t, 50, geminiQuota.Utilization) // (1 - 0.50) * 100 = 50
require.Equal(t, "2026-03-08T15:00:00Z", geminiQuota.ResetTime)
// AntigravityQuotaDetails
require.Len(t, info.AntigravityQuotaDetails, 2)
sonnetDetail := info.AntigravityQuotaDetails["claude-sonnet-4-20250514"]
require.NotNil(t, sonnetDetail)
require.Equal(t, "Claude Sonnet 4", sonnetDetail.DisplayName)
require.Equal(t, aqfBoolPtr(true), sonnetDetail.SupportsImages)
require.Equal(t, aqfBoolPtr(false), sonnetDetail.SupportsThinking)
require.Equal(t, aqfIntPtr(0), sonnetDetail.ThinkingBudget)
require.Equal(t, aqfBoolPtr(true), sonnetDetail.Recommended)
require.Equal(t, aqfIntPtr(200000), sonnetDetail.MaxTokens)
require.Equal(t, aqfIntPtr(16384), sonnetDetail.MaxOutputTokens)
require.Equal(t, map[string]bool{"image/png": true, "image/jpeg": true}, sonnetDetail.SupportedMimeTypes)
geminiDetail := info.AntigravityQuotaDetails["gemini-2.5-pro"]
require.NotNil(t, geminiDetail)
require.Equal(t, "Gemini 2.5 Pro", geminiDetail.DisplayName)
require.Nil(t, geminiDetail.SupportsImages)
require.Nil(t, geminiDetail.SupportsThinking)
require.Equal(t, aqfIntPtr(1000000), geminiDetail.MaxTokens)
require.Equal(t, aqfIntPtr(65536), geminiDetail.MaxOutputTokens)
}
func TestBuildUsageInfo_DeprecatedModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 1.0,
},
},
},
DeprecatedModelIDs: map[string]antigravity.DeprecatedModelInfo{
"claude-3-sonnet-20240229": {NewModelID: "claude-sonnet-4-20250514"},
"claude-3-haiku-20240307": {NewModelID: "claude-haiku-3.5-latest"},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.Len(t, info.ModelForwardingRules, 2)
require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"])
require.Equal(t, "claude-haiku-3.5-latest", info.ModelForwardingRules["claude-3-haiku-20240307"])
}
func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"some-model": {
QuotaInfo: &antigravity.ModelQuotaInfo{RemainingFraction: 0.9},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models")
}
func TestBuildUsageInfo_EmptyModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info)
require.NotNil(t, info.AntigravityQuota)
require.Empty(t, info.AntigravityQuota)
require.NotNil(t, info.AntigravityQuotaDetails)
require.Empty(t, info.AntigravityQuotaDetails)
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
}
func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"model-without-quota": {
DisplayName: "No Quota Model",
// QuotaInfo is nil
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info)
require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped")
require.Empty(t, info.AntigravityQuotaDetails, "models with nil QuotaInfo should be skipped from details too")
}
func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// priorityModels = ["claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"]
// When the first priority model exists, it should be used for FiveHour
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.40,
ResetTime: "2026-03-08T18:00:00Z",
},
},
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.80,
ResetTime: "2026-03-08T12:00:00Z",
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists")
// claude-sonnet-4-20250514 is first in priority list, so it should be used
expectedUtilization := (1.0 - 0.80) * 100 // 20
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
require.NotNil(t, info.FiveHour.ResetsAt, "ResetsAt should be parsed from ResetTime")
}
func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// Only claude-sonnet-4 exists (second in priority list), not claude-sonnet-4-20250514
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.60,
ResetTime: "2026-03-08T14:00:00Z",
},
},
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.30,
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info.FiveHour)
expectedUtilization := (1.0 - 0.60) * 100 // 40
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
}
func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// Only gemini-2.5-pro exists (third in priority list)
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.30,
},
},
"other-model": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.90,
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info.FiveHour)
expectedUtilization := (1.0 - 0.30) * 100 // 70
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
}
func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// None of the priority models exist
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"some-other-model": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.50,
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
}
func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.50,
ResetTime: "", // empty reset time
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info.FiveHour)
require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty")
require.Equal(t, 0, info.FiveHour.RemainingSeconds)
}
func TestBuildUsageInfo_FullUtilization(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.0, // fully used
ResetTime: "2026-03-08T12:00:00Z",
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
require.NotNil(t, quota)
require.Equal(t, 100, quota.Utilization)
}
func TestBuildUsageInfo_ZeroUtilization(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 1.0, // fully available
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
require.NotNil(t, quota)
require.Equal(t, 0, quota.Utilization)
}
func TestBuildUsageInfo_AICredits(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{},
}
loadResp := &antigravity.LoadCodeAssistResponse{
PaidTier: &antigravity.PaidTierInfo{
ID: "g1-pro-tier",
AvailableCredits: []antigravity.AvailableCredit{
{
CreditType: "GOOGLE_ONE_AI",
CreditAmount: "25",
MinimumCreditAmountForUsage: "5",
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", loadResp)
require.Len(t, info.AICredits, 1)
require.Equal(t, "GOOGLE_ONE_AI", info.AICredits[0].CreditType)
require.Equal(t, 25.0, info.AICredits[0].Amount)
require.Equal(t, 5.0, info.AICredits[0].MinimumBalance)
}
func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) {
// 模拟 FetchQuota 遇到 403 时的行为:
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true
forbiddenErr := &antigravity.ForbiddenError{
StatusCode: 403,
Body: "Access denied",
}
// 验证 ForbiddenError 满足 errors.As
var target *antigravity.ForbiddenError
require.True(t, errors.As(forbiddenErr, &target))
require.Equal(t, 403, target.StatusCode)
require.Equal(t, "Access denied", target.Body)
require.Contains(t, forbiddenErr.Error(), "403")
}
// ---------------------------------------------------------------------------
// classifyForbiddenType
// ---------------------------------------------------------------------------
func TestClassifyForbiddenType(t *testing.T) {
tests := []struct {
name string
body string
expected string
}{
{
name: "VALIDATION_REQUIRED keyword",
body: `{"error":{"message":"VALIDATION_REQUIRED"}}`,
expected: "validation",
},
{
name: "verify your account",
body: `Please verify your account to continue`,
expected: "validation",
},
{
name: "contains validation_url field",
body: `{"error":{"details":[{"metadata":{"validation_url":"https://..."}}]}}`,
expected: "validation",
},
{
name: "terms of service violation",
body: `Your account has been suspended for Terms of Service violation`,
expected: "violation",
},
{
name: "violation keyword",
body: `Account suspended due to policy violation`,
expected: "violation",
},
{
name: "generic 403",
body: `Access denied`,
expected: "forbidden",
},
{
name: "empty body",
body: "",
expected: "forbidden",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := classifyForbiddenType(tt.body)
require.Equal(t, tt.expected, got)
})
}
}
// ---------------------------------------------------------------------------
// extractValidationURL
// ---------------------------------------------------------------------------
func TestExtractValidationURL(t *testing.T) {
tests := []struct {
name string
body string
expected string
}{
{
name: "structured validation_url",
body: `{"error":{"details":[{"metadata":{"validation_url":"https://accounts.google.com/verify?token=abc"}}]}}`,
expected: "https://accounts.google.com/verify?token=abc",
},
{
name: "structured appeal_url",
body: `{"error":{"details":[{"metadata":{"appeal_url":"https://support.google.com/appeal/123"}}]}}`,
expected: "https://support.google.com/appeal/123",
},
{
name: "validation_url takes priority over appeal_url",
body: `{"error":{"details":[{"metadata":{"validation_url":"https://v.com","appeal_url":"https://a.com"}}]}}`,
expected: "https://v.com",
},
{
name: "fallback regex with verify keyword",
body: `Please verify your account at https://accounts.google.com/verify`,
expected: "https://accounts.google.com/verify",
},
{
name: "no URL in generic forbidden",
body: `Access denied`,
expected: "",
},
{
name: "empty body",
body: "",
expected: "",
},
{
name: "URL present but no validation keywords",
body: `Error at https://example.com/something`,
expected: "",
},
{
name: "unicode escaped ampersand",
body: `validation required: https://accounts.google.com/verify?a=1\u0026b=2`,
expected: "https://accounts.google.com/verify?a=1&b=2",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractValidationURL(tt.body)
require.Equal(t, tt.expected, got)
})
}
}

View File

@@ -32,6 +32,10 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste
return false
}
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
// Antigravity + overages 启用 + 积分未耗尽 → 放行(有积分可用)
if a.Platform == PlatformAntigravity && a.IsOveragesEnabled() && !a.isCreditsExhausted() {
return true
}
return false
}
return true

View File

@@ -76,10 +76,16 @@ type modelRateLimitCall struct {
resetAt time.Time
}
type extraUpdateCall struct {
accountID int64
updates map[string]any
}
type stubAntigravityAccountRepo struct {
AccountRepository
rateCalls []rateLimitCall
modelRateLimitCalls []modelRateLimitCall
extraUpdateCalls []extraUpdateCall
}
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
@@ -92,6 +98,11 @@ func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id i
return nil
}
func (s *stubAntigravityAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
s.extraUpdateCalls = append(s.extraUpdateCalls, extraUpdateCall{accountID: id, updates: updates})
return nil
}
func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T) {
t.Setenv(antigravityForwardBaseURLEnv, "")

View File

@@ -32,15 +32,23 @@ func (c *stubSmartRetryCache) DeleteSessionAccountID(_ context.Context, groupID
// mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream
type mockSmartRetryUpstream struct {
responses []*http.Response
errors []error
callIdx int
calls []string
responses []*http.Response
errors []error
callIdx int
calls []string
requestBodies [][]byte
}
func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
idx := m.callIdx
m.calls = append(m.calls, req.URL.String())
if req != nil && req.Body != nil {
body, _ := io.ReadAll(req.Body)
m.requestBodies = append(m.requestBodies, body)
req.Body = io.NopCloser(bytes.NewReader(body))
} else {
m.requestBodies = append(m.requestBodies, nil)
}
m.callIdx++
if idx < len(m.responses) {
return m.responses[idx], m.errors[idx]

View File

@@ -3,7 +3,6 @@ package service
import (
"context"
"errors"
"log"
"log/slog"
"strconv"
"strings"
@@ -17,15 +16,18 @@ const (
antigravityBackfillCooldown = 5 * time.Minute
)
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
// AntigravityTokenCache token cache interface.
type AntigravityTokenCache = GeminiTokenCache
// AntigravityTokenProvider 管理 Antigravity 账户的 access_token
// AntigravityTokenProvider manages access_token for antigravity accounts.
type AntigravityTokenProvider struct {
accountRepo AccountRepository
tokenCache AntigravityTokenCache
antigravityOAuthService *AntigravityOAuthService
backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time
backfillCooldown sync.Map // key: accountID -> last attempt time
refreshAPI *OAuthRefreshAPI
executor OAuthRefreshExecutor
refreshPolicy ProviderRefreshPolicy
}
func NewAntigravityTokenProvider(
@@ -37,10 +39,22 @@ func NewAntigravityTokenProvider(
accountRepo: accountRepo,
tokenCache: tokenCache,
antigravityOAuthService: antigravityOAuthService,
refreshPolicy: AntigravityProviderRefreshPolicy(),
}
}
// GetAccessToken 获取有效的 access_token
// SetRefreshAPI injects unified OAuth refresh API and executor.
func (p *AntigravityTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
p.refreshAPI = api
p.executor = executor
}
// SetRefreshPolicy injects caller-side refresh policy.
func (p *AntigravityTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
p.refreshPolicy = policy
}
// GetAccessToken returns a valid access_token.
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
@@ -48,7 +62,8 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if account.Platform != PlatformAntigravity {
return "", errors.New("not an antigravity account")
}
// upstream 类型:直接从 credentials 读取 api_key不走 OAuth 刷新流程
// upstream accounts use static api_key and never refresh oauth token.
if account.Type == AccountTypeUpstream {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
@@ -62,46 +77,38 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
cacheKey := AntigravityTokenCacheKey(account)
// 1. 先尝试缓存
// 1) Try cache first.
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
// 2. 如果即将过期则刷新
// 2) Refresh if needed (pre-expiry skew).
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
if needsRefresh && p.tokenCache != nil {
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, antigravityTokenRefreshSkew)
if err != nil {
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
}
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
// default policy: continue with existing token.
} else {
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
}
} else if needsRefresh && p.tokenCache != nil {
// Backward-compatible test path when refreshAPI is not injected.
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
if p.antigravityOAuthService == nil {
return "", errors.New("antigravity oauth service not configured")
}
tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return "", err
}
p.mergeCredentials(account, tokenInfo)
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
@@ -110,32 +117,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("access_token not found in credentials")
}
// 如果账号还没有 project_id,尝试在线补齐,避免请求 daily/sandbox 时出现
// "Invalid project resource name projects/"。
// 仅调用 loadProjectIDWithRetry不刷新 OAuth token带冷却机制防止频繁重试。
// Backfill project_id online when missing, with cooldown to avoid hammering.
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
if p.shouldAttemptBackfill(account.ID) {
p.markBackfillAttempted(account.ID)
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
account.Credentials["project_id"] = projectID
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr)
slog.Warn("antigravity_project_id_backfill_persist_failed",
"account_id", account.ID,
"error", updateErr,
)
}
}
}
}
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
// 3) Populate cache with TTL.
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
@@ -156,18 +162,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return accessToken, nil
}
// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段
func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) {
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
}
// shouldAttemptBackfill 检查是否应该尝试补齐 project_id冷却期内不重复尝试
// shouldAttemptBackfill checks backfill cooldown.
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
if v, ok := p.backfillCooldown.Load(accountID); ok {
if lastAttempt, ok := v.(time.Time); ok {

View File

@@ -25,6 +25,11 @@ func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthServi
}
}
// CacheKey 返回用于分布式锁的缓存键
func (r *AntigravityTokenRefresher) CacheKey(account *Account) string {
return AntigravityTokenCacheKey(account)
}
// CanRefresh 检查是否可以刷新此账户
func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth
@@ -58,11 +63,7 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
// 合并旧的 credentials保留新 credentials 中不存在的字段
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
newCredentials = MergeCredentials(account.Credentials, newCredentials)
// 特殊处理 project_id如果新值为空但旧值非空保留旧值
// 这确保了即使 LoadCodeAssist 失败project_id 也不会丢失

View File

@@ -22,8 +22,9 @@ const (
)
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
// A nil windowStart is treated as expired — no initialized window means any accumulated usage is stale.
func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool {
return windowStart != nil && time.Since(*windowStart) >= duration
return windowStart == nil || time.Since(*windowStart) >= duration
}
type APIKey struct {

View File

@@ -15,10 +15,10 @@ func TestIsWindowExpired(t *testing.T) {
want bool
}{
{
name: "nil window start",
name: "nil window start (treated as expired)",
start: nil,
duration: RateLimitWindow5h,
want: false,
want: true,
},
{
name: "active window (started 1h ago, 5h window)",
@@ -113,7 +113,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
want7d: 0,
},
{
name: "nil window starts return raw usage",
name: "nil window starts return 0 (stale usage reset)",
key: APIKey{
Usage5h: 5.0,
Usage1d: 10.0,
@@ -122,9 +122,9 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
Window1dStart: nil,
Window7dStart: nil,
},
want5h: 5.0,
want1d: 10.0,
want7d: 50.0,
want5h: 0,
want1d: 0,
want7d: 0,
},
{
name: "mixed: 5h expired, 1d active, 7d nil",
@@ -138,7 +138,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
},
want5h: 0,
want1d: 10.0,
want7d: 50.0,
want7d: 0,
},
{
name: "zero usage with active windows",
@@ -210,7 +210,7 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
want7d: 0,
},
{
name: "nil window starts return raw usage",
name: "nil window starts return 0 (stale usage reset)",
data: APIKeyRateLimitData{
Usage5h: 3.0,
Usage1d: 8.0,
@@ -219,9 +219,9 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
Window1dStart: nil,
Window7dStart: nil,
},
want5h: 3.0,
want1d: 8.0,
want7d: 40.0,
want5h: 0,
want1d: 0,
want7d: 0,
},
}

View File

@@ -1087,6 +1087,12 @@ type TokenPair struct {
ExpiresIn int `json:"expires_in"` // Access Token有效期
}
// TokenPairWithUser extends TokenPair with user role for backend mode checks
type TokenPairWithUser struct {
TokenPair
UserRole string
}
// GenerateTokenPair 生成Access Token和Refresh Token对
// familyID: 可选的Token家族ID用于Token轮转时保持家族关系
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
@@ -1168,7 +1174,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
// RefreshTokenPair 使用Refresh Token刷新Token对
// 实现Token轮转每次刷新都会生成新的Refresh Token旧Token立即失效
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) {
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPairWithUser, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, ErrRefreshTokenInvalid
@@ -1233,7 +1239,14 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
}
// 生成新的Token对保持同一个家族ID
return s.GenerateTokenPair(ctx, user, data.FamilyID)
pair, err := s.GenerateTokenPair(ctx, user, data.FamilyID)
if err != nil {
return nil, err
}
return &TokenPairWithUser{
TokenPair: *pair,
UserRole: user.Role,
}, nil
}
// RevokeRefreshToken 撤销单个Refresh Token

View File

@@ -0,0 +1,770 @@
package service
import (
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io"
"sort"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/robfig/cron/v3"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
const (
settingKeyBackupS3Config = "backup_s3_config"
settingKeyBackupSchedule = "backup_schedule"
settingKeyBackupRecords = "backup_records"
maxBackupRecords = 100
)
var (
ErrBackupS3NotConfigured = infraerrors.BadRequest("BACKUP_S3_NOT_CONFIGURED", "backup S3 storage is not configured")
ErrBackupNotFound = infraerrors.NotFound("BACKUP_NOT_FOUND", "backup record not found")
ErrBackupInProgress = infraerrors.Conflict("BACKUP_IN_PROGRESS", "a backup is already in progress")
ErrRestoreInProgress = infraerrors.Conflict("RESTORE_IN_PROGRESS", "a restore is already in progress")
ErrBackupRecordsCorrupt = infraerrors.InternalServer("BACKUP_RECORDS_CORRUPT", "backup records data is corrupted")
ErrBackupS3ConfigCorrupt = infraerrors.InternalServer("BACKUP_S3_CONFIG_CORRUPT", "backup S3 config data is corrupted")
)
// ─── 接口定义 ───
// DBDumper abstracts database dump/restore operations
type DBDumper interface {
Dump(ctx context.Context) (io.ReadCloser, error)
Restore(ctx context.Context, data io.Reader) error
}
// BackupObjectStore abstracts object storage for backup files
type BackupObjectStore interface {
Upload(ctx context.Context, key string, body io.Reader, contentType string) (sizeBytes int64, err error)
Download(ctx context.Context, key string) (io.ReadCloser, error)
Delete(ctx context.Context, key string) error
PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error)
HeadBucket(ctx context.Context) error
}
// BackupObjectStoreFactory creates an object store from S3 config
type BackupObjectStoreFactory func(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error)
// ─── 数据模型 ───
// BackupS3Config S3 兼容存储配置(支持 Cloudflare R2
type BackupS3Config struct {
Endpoint string `json:"endpoint"` // e.g. https://<account_id>.r2.cloudflarestorage.com
Region string `json:"region"` // R2 用 "auto"
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key,omitempty"` //nolint:revive // field name follows AWS convention
Prefix string `json:"prefix"` // S3 key 前缀,如 "backups/"
ForcePathStyle bool `json:"force_path_style"`
}
// IsConfigured 检查必要字段是否已配置
func (c *BackupS3Config) IsConfigured() bool {
return c.Bucket != "" && c.AccessKeyID != "" && c.SecretAccessKey != ""
}
// BackupScheduleConfig 定时备份配置
type BackupScheduleConfig struct {
Enabled bool `json:"enabled"`
CronExpr string `json:"cron_expr"` // cron 表达式,如 "0 2 * * *" 每天凌晨2点
RetainDays int `json:"retain_days"` // 备份文件过期天数默认140=不自动清理
RetainCount int `json:"retain_count"` // 最多保留份数0=不限制
}
// BackupRecord 备份记录
type BackupRecord struct {
ID string `json:"id"`
Status string `json:"status"` // pending, running, completed, failed
BackupType string `json:"backup_type"` // postgres
FileName string `json:"file_name"`
S3Key string `json:"s3_key"`
SizeBytes int64 `json:"size_bytes"`
TriggeredBy string `json:"triggered_by"` // manual, scheduled
ErrorMsg string `json:"error_message,omitempty"`
StartedAt string `json:"started_at"`
FinishedAt string `json:"finished_at,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
}
// BackupService 数据库备份恢复服务
type BackupService struct {
settingRepo SettingRepository
dbCfg *config.DatabaseConfig
encryptor SecretEncryptor
storeFactory BackupObjectStoreFactory
dumper DBDumper
mu sync.Mutex
store BackupObjectStore
s3Cfg *BackupS3Config
backingUp bool
restoring bool
recordsMu sync.Mutex // 保护 records 的 load/save 操作
cronMu sync.Mutex
cronSched *cron.Cron
cronEntryID cron.EntryID
}
func NewBackupService(
settingRepo SettingRepository,
cfg *config.Config,
encryptor SecretEncryptor,
storeFactory BackupObjectStoreFactory,
dumper DBDumper,
) *BackupService {
return &BackupService{
settingRepo: settingRepo,
dbCfg: &cfg.Database,
encryptor: encryptor,
storeFactory: storeFactory,
dumper: dumper,
}
}
// Start 启动定时备份调度器
func (s *BackupService) Start() {
s.cronSched = cron.New()
s.cronSched.Start()
// 加载已有的定时配置
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
schedule, err := s.GetSchedule(ctx)
if err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 加载定时备份配置失败: %v", err)
return
}
if schedule.Enabled && schedule.CronExpr != "" {
if err := s.applyCronSchedule(schedule); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 应用定时备份配置失败: %v", err)
}
}
}
// Stop 停止定时备份
func (s *BackupService) Stop() {
s.cronMu.Lock()
defer s.cronMu.Unlock()
if s.cronSched != nil {
s.cronSched.Stop()
}
}
// ─── S3 配置管理 ───
func (s *BackupService) GetS3Config(ctx context.Context) (*BackupS3Config, error) {
cfg, err := s.loadS3Config(ctx)
if err != nil {
return nil, err
}
if cfg == nil {
return &BackupS3Config{}, nil
}
// 脱敏返回
cfg.SecretAccessKey = ""
return cfg, nil
}
func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) (*BackupS3Config, error) {
// 如果没提供 secret保留原有值
if cfg.SecretAccessKey == "" {
old, _ := s.loadS3Config(ctx)
if old != nil {
cfg.SecretAccessKey = old.SecretAccessKey
}
} else {
// 加密 SecretAccessKey
encrypted, err := s.encryptor.Encrypt(cfg.SecretAccessKey)
if err != nil {
return nil, fmt.Errorf("encrypt secret: %w", err)
}
cfg.SecretAccessKey = encrypted
}
data, err := json.Marshal(cfg)
if err != nil {
return nil, fmt.Errorf("marshal s3 config: %w", err)
}
if err := s.settingRepo.Set(ctx, settingKeyBackupS3Config, string(data)); err != nil {
return nil, fmt.Errorf("save s3 config: %w", err)
}
// 清除缓存的 S3 客户端
s.mu.Lock()
s.store = nil
s.s3Cfg = nil
s.mu.Unlock()
cfg.SecretAccessKey = ""
return &cfg, nil
}
func (s *BackupService) TestS3Connection(ctx context.Context, cfg BackupS3Config) error {
// 如果没提供 secret用已保存的
if cfg.SecretAccessKey == "" {
old, _ := s.loadS3Config(ctx)
if old != nil {
cfg.SecretAccessKey = old.SecretAccessKey
}
}
if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
return fmt.Errorf("incomplete S3 config: bucket, access_key_id, secret_access_key are required")
}
store, err := s.storeFactory(ctx, &cfg)
if err != nil {
return err
}
return store.HeadBucket(ctx)
}
// ─── 定时备份管理 ───
func (s *BackupService) GetSchedule(ctx context.Context) (*BackupScheduleConfig, error) {
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupSchedule)
if err != nil || raw == "" {
return &BackupScheduleConfig{}, nil
}
var cfg BackupScheduleConfig
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
return &BackupScheduleConfig{}, nil
}
return &cfg, nil
}
func (s *BackupService) UpdateSchedule(ctx context.Context, cfg BackupScheduleConfig) (*BackupScheduleConfig, error) {
if cfg.Enabled && cfg.CronExpr == "" {
return nil, infraerrors.BadRequest("INVALID_CRON", "cron expression is required when schedule is enabled")
}
// 验证 cron 表达式
if cfg.CronExpr != "" {
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
if _, err := parser.Parse(cfg.CronExpr); err != nil {
return nil, infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("invalid cron expression: %v", err))
}
}
data, err := json.Marshal(cfg)
if err != nil {
return nil, fmt.Errorf("marshal schedule config: %w", err)
}
if err := s.settingRepo.Set(ctx, settingKeyBackupSchedule, string(data)); err != nil {
return nil, fmt.Errorf("save schedule config: %w", err)
}
// 应用或停止定时任务
if cfg.Enabled {
if err := s.applyCronSchedule(&cfg); err != nil {
return nil, err
}
} else {
s.removeCronSchedule()
}
return &cfg, nil
}
func (s *BackupService) applyCronSchedule(cfg *BackupScheduleConfig) error {
s.cronMu.Lock()
defer s.cronMu.Unlock()
if s.cronSched == nil {
return fmt.Errorf("cron scheduler not initialized")
}
// 移除旧任务
if s.cronEntryID != 0 {
s.cronSched.Remove(s.cronEntryID)
s.cronEntryID = 0
}
entryID, err := s.cronSched.AddFunc(cfg.CronExpr, func() {
s.runScheduledBackup()
})
if err != nil {
return infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("failed to schedule: %v", err))
}
s.cronEntryID = entryID
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已启用: %s", cfg.CronExpr)
return nil
}
func (s *BackupService) removeCronSchedule() {
s.cronMu.Lock()
defer s.cronMu.Unlock()
if s.cronSched != nil && s.cronEntryID != 0 {
s.cronSched.Remove(s.cronEntryID)
s.cronEntryID = 0
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已停用")
}
}
func (s *BackupService) runScheduledBackup() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
// 读取定时备份配置中的过期天数
schedule, _ := s.GetSchedule(ctx)
expireDays := 14 // 默认14天过期
if schedule != nil && schedule.RetainDays > 0 {
expireDays = schedule.RetainDays
}
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
if err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
return
}
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
// 清理过期备份(复用已加载的 schedule
if schedule == nil {
return
}
if err := s.cleanupOldBackups(ctx, schedule); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 清理过期备份失败: %v", err)
}
}
// ─── 备份/恢复核心 ───
// CreateBackup 创建全量数据库备份并上传到 S3流式处理
// expireDays: 备份过期天数0=永不过期默认14天
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
s.mu.Lock()
if s.backingUp {
s.mu.Unlock()
return nil, ErrBackupInProgress
}
s.backingUp = true
s.mu.Unlock()
defer func() {
s.mu.Lock()
s.backingUp = false
s.mu.Unlock()
}()
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return nil, err
}
if s3Cfg == nil || !s3Cfg.IsConfigured() {
return nil, ErrBackupS3NotConfigured
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return nil, fmt.Errorf("init object store: %w", err)
}
now := time.Now()
backupID := uuid.New().String()[:8]
fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405"))
s3Key := s.buildS3Key(s3Cfg, fileName)
var expiresAt string
if expireDays > 0 {
expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339)
}
record := &BackupRecord{
ID: backupID,
Status: "running",
BackupType: "postgres",
FileName: fileName,
S3Key: s3Key,
TriggeredBy: triggeredBy,
StartedAt: now.Format(time.RFC3339),
ExpiresAt: expiresAt,
}
// 流式执行: pg_dump -> gzip -> S3 upload
dumpReader, err := s.dumper.Dump(ctx)
if err != nil {
record.Status = "failed"
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(ctx, record)
return record, fmt.Errorf("pg_dump: %w", err)
}
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
pr, pw := io.Pipe()
var gzipErr error
go func() {
gzWriter := gzip.NewWriter(pw)
_, gzipErr = io.Copy(gzWriter, dumpReader)
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
gzipErr = closeErr
}
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
gzipErr = closeErr
}
if gzipErr != nil {
_ = pw.CloseWithError(gzipErr)
} else {
_ = pw.Close()
}
}()
contentType := "application/gzip"
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
if err != nil {
record.Status = "failed"
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
if gzipErr != nil {
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
}
record.ErrorMsg = errMsg
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(ctx, record)
return record, fmt.Errorf("backup upload: %w", err)
}
record.SizeBytes = sizeBytes
record.Status = "completed"
record.FinishedAt = time.Now().Format(time.RFC3339)
if err := s.saveRecord(ctx, record); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err)
}
return record, nil
}
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
s.mu.Lock()
if s.restoring {
s.mu.Unlock()
return ErrRestoreInProgress
}
s.restoring = true
s.mu.Unlock()
defer func() {
s.mu.Lock()
s.restoring = false
s.mu.Unlock()
}()
record, err := s.GetBackupRecord(ctx, backupID)
if err != nil {
return err
}
if record.Status != "completed" {
return infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup")
}
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return err
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return fmt.Errorf("init object store: %w", err)
}
// 从 S3 流式下载
body, err := objectStore.Download(ctx, record.S3Key)
if err != nil {
return fmt.Errorf("S3 download failed: %w", err)
}
defer func() { _ = body.Close() }()
// 流式解压 gzip -> psql不将全部数据加载到内存
gzReader, err := gzip.NewReader(body)
if err != nil {
return fmt.Errorf("gzip reader: %w", err)
}
defer func() { _ = gzReader.Close() }()
// 流式恢复
if err := s.dumper.Restore(ctx, gzReader); err != nil {
return fmt.Errorf("pg restore: %w", err)
}
return nil
}
// ─── 备份记录管理 ───
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
records, err := s.loadRecords(ctx)
if err != nil {
return nil, err
}
// 倒序返回(最新在前)
sort.Slice(records, func(i, j int) bool {
return records[i].StartedAt > records[j].StartedAt
})
return records, nil
}
func (s *BackupService) GetBackupRecord(ctx context.Context, backupID string) (*BackupRecord, error) {
records, err := s.loadRecords(ctx)
if err != nil {
return nil, err
}
for i := range records {
if records[i].ID == backupID {
return &records[i], nil
}
}
return nil, ErrBackupNotFound
}
func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error {
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
records, err := s.loadRecordsLocked(ctx)
if err != nil {
return err
}
var found *BackupRecord
var remaining []BackupRecord
for i := range records {
if records[i].ID == backupID {
found = &records[i]
} else {
remaining = append(remaining, records[i])
}
}
if found == nil {
return ErrBackupNotFound
}
// 从 S3 删除
if found.S3Key != "" && found.Status == "completed" {
s3Cfg, err := s.loadS3Config(ctx)
if err == nil && s3Cfg != nil && s3Cfg.IsConfigured() {
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err == nil {
_ = objectStore.Delete(ctx, found.S3Key)
}
}
}
return s.saveRecordsLocked(ctx, remaining)
}
// GetBackupDownloadURL 获取备份文件预签名下载 URL
func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID string) (string, error) {
record, err := s.GetBackupRecord(ctx, backupID)
if err != nil {
return "", err
}
if record.Status != "completed" {
return "", infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "backup is not completed")
}
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return "", err
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return "", err
}
url, err := objectStore.PresignURL(ctx, record.S3Key, 1*time.Hour)
if err != nil {
return "", fmt.Errorf("presign url: %w", err)
}
return url, nil
}
// ─── 内部方法 ───
func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, error) {
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config)
if err != nil || raw == "" {
return nil, nil //nolint:nilnil // no config is a valid state
}
var cfg BackupS3Config
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
return nil, ErrBackupS3ConfigCorrupt
}
// 解密 SecretAccessKey
if cfg.SecretAccessKey != "" {
decrypted, err := s.encryptor.Decrypt(cfg.SecretAccessKey)
if err != nil {
// 兼容未加密的旧数据:如果解密失败,保持原值
logger.LegacyPrintf("service.backup", "[Backup] S3 SecretAccessKey 解密失败(可能是旧的未加密数据): %v", err)
} else {
cfg.SecretAccessKey = decrypted
}
}
return &cfg, nil
}
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.store != nil && s.s3Cfg != nil {
return s.store, nil
}
if cfg == nil {
return nil, ErrBackupS3NotConfigured
}
store, err := s.storeFactory(ctx, cfg)
if err != nil {
return nil, err
}
s.store = store
s.s3Cfg = cfg
return store, nil
}
func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string {
prefix := strings.TrimRight(cfg.Prefix, "/")
if prefix == "" {
prefix = "backups"
}
return fmt.Sprintf("%s/%s/%s", prefix, time.Now().Format("2006/01/02"), fileName)
}
// loadRecords 加载备份记录,区分"无数据"和"数据损坏"
func (s *BackupService) loadRecords(ctx context.Context) ([]BackupRecord, error) {
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
return s.loadRecordsLocked(ctx)
}
// loadRecordsLocked 在已持有 recordsMu 锁的情况下加载记录
func (s *BackupService) loadRecordsLocked(ctx context.Context) ([]BackupRecord, error) {
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupRecords)
if err != nil || raw == "" {
return nil, nil //nolint:nilnil // no records is a valid state
}
var records []BackupRecord
if err := json.Unmarshal([]byte(raw), &records); err != nil {
return nil, ErrBackupRecordsCorrupt
}
return records, nil
}
// saveRecordsLocked 在已持有 recordsMu 锁的情况下保存记录
func (s *BackupService) saveRecordsLocked(ctx context.Context, records []BackupRecord) error {
data, err := json.Marshal(records)
if err != nil {
return err
}
return s.settingRepo.Set(ctx, settingKeyBackupRecords, string(data))
}
// saveRecord 保存单条记录(带互斥锁保护)
func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) error {
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
records, _ := s.loadRecordsLocked(ctx)
// 更新已有记录或追加
found := false
for i := range records {
if records[i].ID == record.ID {
records[i] = *record
found = true
break
}
}
if !found {
records = append(records, *record)
}
// 限制记录数量
if len(records) > maxBackupRecords {
records = records[len(records)-maxBackupRecords:]
}
return s.saveRecordsLocked(ctx, records)
}
func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupScheduleConfig) error {
if schedule == nil {
return nil
}
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
records, err := s.loadRecordsLocked(ctx)
if err != nil {
return err
}
// 按时间倒序
sort.Slice(records, func(i, j int) bool {
return records[i].StartedAt > records[j].StartedAt
})
var toDelete []BackupRecord
var toKeep []BackupRecord
for i, r := range records {
shouldDelete := false
// 按保留份数清理
if schedule.RetainCount > 0 && i >= schedule.RetainCount {
shouldDelete = true
}
// 按保留天数清理
if schedule.RetainDays > 0 && r.StartedAt != "" {
startedAt, err := time.Parse(time.RFC3339, r.StartedAt)
if err == nil && time.Since(startedAt) > time.Duration(schedule.RetainDays)*24*time.Hour {
shouldDelete = true
}
}
if shouldDelete && r.Status == "completed" {
toDelete = append(toDelete, r)
} else {
toKeep = append(toKeep, r)
}
}
// 删除 S3 上的文件
for _, r := range toDelete {
if r.S3Key != "" {
_ = s.deleteS3Object(ctx, r.S3Key)
}
}
if len(toDelete) > 0 {
logger.LegacyPrintf("service.backup", "[Backup] 自动清理了 %d 个过期备份", len(toDelete))
return s.saveRecordsLocked(ctx, toKeep)
}
return nil
}
func (s *BackupService) deleteS3Object(ctx context.Context, key string) error {
s3Cfg, err := s.loadS3Config(ctx)
if err != nil || s3Cfg == nil {
return nil
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return err
}
return objectStore.Delete(ctx, key)
}

View File

@@ -0,0 +1,528 @@
//go:build unit
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// ─── Mocks ───
type mockSettingRepo struct {
mu sync.Mutex
data map[string]string
}
func newMockSettingRepo() *mockSettingRepo {
return &mockSettingRepo{data: make(map[string]string)}
}
func (m *mockSettingRepo) Get(_ context.Context, key string) (*Setting, error) {
m.mu.Lock()
defer m.mu.Unlock()
v, ok := m.data[key]
if !ok {
return nil, ErrSettingNotFound
}
return &Setting{Key: key, Value: v}, nil
}
func (m *mockSettingRepo) GetValue(_ context.Context, key string) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
v, ok := m.data[key]
if !ok {
return "", nil
}
return v, nil
}
func (m *mockSettingRepo) Set(_ context.Context, key, value string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.data[key] = value
return nil
}
func (m *mockSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
result := make(map[string]string)
for _, k := range keys {
if v, ok := m.data[k]; ok {
result[k] = v
}
}
return result, nil
}
func (m *mockSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
m.mu.Lock()
defer m.mu.Unlock()
for k, v := range settings {
m.data[k] = v
}
return nil
}
func (m *mockSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
result := make(map[string]string, len(m.data))
for k, v := range m.data {
result[k] = v
}
return result, nil
}
func (m *mockSettingRepo) Delete(_ context.Context, key string) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.data, key)
return nil
}
// plainEncryptor 仅做 base64-like 包装,用于测试
type plainEncryptor struct{}
func (e *plainEncryptor) Encrypt(plaintext string) (string, error) {
return "ENC:" + plaintext, nil
}
func (e *plainEncryptor) Decrypt(ciphertext string) (string, error) {
if strings.HasPrefix(ciphertext, "ENC:") {
return strings.TrimPrefix(ciphertext, "ENC:"), nil
}
return ciphertext, fmt.Errorf("not encrypted")
}
type mockDumper struct {
dumpData []byte
dumpErr error
restored []byte
restErr error
}
func (m *mockDumper) Dump(_ context.Context) (io.ReadCloser, error) {
if m.dumpErr != nil {
return nil, m.dumpErr
}
return io.NopCloser(bytes.NewReader(m.dumpData)), nil
}
func (m *mockDumper) Restore(_ context.Context, data io.Reader) error {
if m.restErr != nil {
return m.restErr
}
d, err := io.ReadAll(data)
if err != nil {
return err
}
m.restored = d
return nil
}
type mockObjectStore struct {
objects map[string][]byte
mu sync.Mutex
}
func newMockObjectStore() *mockObjectStore {
return &mockObjectStore{objects: make(map[string][]byte)}
}
func (m *mockObjectStore) Upload(_ context.Context, key string, body io.Reader, _ string) (int64, error) {
data, err := io.ReadAll(body)
if err != nil {
return 0, err
}
m.mu.Lock()
m.objects[key] = data
m.mu.Unlock()
return int64(len(data)), nil
}
func (m *mockObjectStore) Download(_ context.Context, key string) (io.ReadCloser, error) {
m.mu.Lock()
data, ok := m.objects[key]
m.mu.Unlock()
if !ok {
return nil, fmt.Errorf("not found: %s", key)
}
return io.NopCloser(bytes.NewReader(data)), nil
}
func (m *mockObjectStore) Delete(_ context.Context, key string) error {
m.mu.Lock()
delete(m.objects, key)
m.mu.Unlock()
return nil
}
func (m *mockObjectStore) PresignURL(_ context.Context, key string, _ time.Duration) (string, error) {
return "https://presigned.example.com/" + key, nil
}
func (m *mockObjectStore) HeadBucket(_ context.Context) error {
return nil
}
func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService {
cfg := &config.Config{
Database: config.DatabaseConfig{
Host: "localhost",
Port: 5432,
User: "test",
DBName: "testdb",
},
}
factory := func(_ context.Context, _ *BackupS3Config) (BackupObjectStore, error) {
return store, nil
}
return NewBackupService(repo, cfg, &plainEncryptor{}, factory, dumper)
}
func seedS3Config(t *testing.T, repo *mockSettingRepo) {
t.Helper()
cfg := BackupS3Config{
Bucket: "test-bucket",
AccessKeyID: "AKID",
SecretAccessKey: "ENC:secret123",
Prefix: "backups",
}
data, _ := json.Marshal(cfg)
require.NoError(t, repo.Set(context.Background(), settingKeyBackupS3Config, string(data)))
}
// ─── Tests ───
func TestBackupService_S3ConfigEncryption(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
// 保存配置 -> SecretAccessKey 应被加密
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
Bucket: "my-bucket",
AccessKeyID: "AKID",
SecretAccessKey: "my-secret",
Prefix: "backups",
})
require.NoError(t, err)
// 直接读取数据库中存储的值,应该是加密后的
raw, _ := repo.GetValue(context.Background(), settingKeyBackupS3Config)
var stored BackupS3Config
require.NoError(t, json.Unmarshal([]byte(raw), &stored))
require.Equal(t, "ENC:my-secret", stored.SecretAccessKey)
// 通过 GetS3Config 获取应该脱敏
cfg, err := svc.GetS3Config(context.Background())
require.NoError(t, err)
require.Empty(t, cfg.SecretAccessKey)
require.Equal(t, "my-bucket", cfg.Bucket)
// loadS3Config 内部应解密
internal, err := svc.loadS3Config(context.Background())
require.NoError(t, err)
require.Equal(t, "my-secret", internal.SecretAccessKey)
}
func TestBackupService_S3ConfigKeepExistingSecret(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
// 先保存一个有 secret 的配置
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
Bucket: "my-bucket",
AccessKeyID: "AKID",
SecretAccessKey: "original-secret",
})
require.NoError(t, err)
// 再更新时不提供 secret应保留原值
_, err = svc.UpdateS3Config(context.Background(), BackupS3Config{
Bucket: "my-bucket",
AccessKeyID: "AKID-NEW",
})
require.NoError(t, err)
internal, err := svc.loadS3Config(context.Background())
require.NoError(t, err)
require.Equal(t, "original-secret", internal.SecretAccessKey)
require.Equal(t, "AKID-NEW", internal.AccessKeyID)
}
func TestBackupService_SaveRecordConcurrency(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
var wg sync.WaitGroup
n := 20
wg.Add(n)
for i := 0; i < n; i++ {
go func(idx int) {
defer wg.Done()
record := &BackupRecord{
ID: fmt.Sprintf("rec-%d", idx),
Status: "completed",
StartedAt: time.Now().Format(time.RFC3339),
}
_ = svc.saveRecord(context.Background(), record)
}(i)
}
wg.Wait()
records, err := svc.loadRecords(context.Background())
require.NoError(t, err)
require.Len(t, records, n)
}
func TestBackupService_LoadRecords_Empty(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
records, err := svc.loadRecords(context.Background())
require.NoError(t, err)
require.Nil(t, records) // 无数据时返回 nil
}
func TestBackupService_LoadRecords_Corrupted(t *testing.T) {
repo := newMockSettingRepo()
_ = repo.Set(context.Background(), settingKeyBackupRecords, "not valid json{{{")
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
records, err := svc.loadRecords(context.Background())
require.Error(t, err) // 损坏数据应返回错误
require.Nil(t, records)
}
func TestBackupService_CreateBackup_Streaming(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
dumper := &mockDumper{dumpData: []byte(dumpContent)}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
require.Equal(t, "completed", record.Status)
require.Greater(t, record.SizeBytes, int64(0))
require.NotEmpty(t, record.S3Key)
// 验证 S3 上确实有文件
store.mu.Lock()
require.Len(t, store.objects, 1)
store.mu.Unlock()
}
func TestBackupService_CreateBackup_DumpFailure(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &mockDumper{dumpErr: fmt.Errorf("pg_dump failed")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.Error(t, err)
require.Equal(t, "failed", record.Status)
require.Contains(t, record.ErrorMsg, "pg_dump")
}
func TestBackupService_CreateBackup_NoS3Config(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
_, err := svc.CreateBackup(context.Background(), "manual", 14)
require.ErrorIs(t, err, ErrBackupS3NotConfigured)
}
func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
// 使用一个慢速 dumper 来模拟正在进行的备份
dumper := &mockDumper{dumpData: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
// 手动设置 backingUp 标志
svc.mu.Lock()
svc.backingUp = true
svc.mu.Unlock()
_, err := svc.CreateBackup(context.Background(), "manual", 14)
require.ErrorIs(t, err, ErrBackupInProgress)
}
func TestBackupService_RestoreBackup_Streaming(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
dumper := &mockDumper{dumpData: []byte(dumpContent)}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
// 先创建一个备份
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// 恢复
err = svc.RestoreBackup(context.Background(), record.ID)
require.NoError(t, err)
// 验证 psql 收到的数据是否与原始 dump 内容一致
require.Equal(t, dumpContent, string(dumper.restored))
}
func TestBackupService_RestoreBackup_NotCompleted(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
// 手动插入一条 failed 记录
_ = svc.saveRecord(context.Background(), &BackupRecord{
ID: "fail-1",
Status: "failed",
})
err := svc.RestoreBackup(context.Background(), "fail-1")
require.Error(t, err)
}
func TestBackupService_DeleteBackup(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumpContent := "data"
dumper := &mockDumper{dumpData: []byte(dumpContent)}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// S3 中应有文件
store.mu.Lock()
require.Len(t, store.objects, 1)
store.mu.Unlock()
// 删除
err = svc.DeleteBackup(context.Background(), record.ID)
require.NoError(t, err)
// S3 中文件应被删除
store.mu.Lock()
require.Len(t, store.objects, 0)
store.mu.Unlock()
// 记录应不存在
_, err = svc.GetBackupRecord(context.Background(), record.ID)
require.ErrorIs(t, err, ErrBackupNotFound)
}
func TestBackupService_GetDownloadURL(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &mockDumper{dumpData: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
url, err := svc.GetBackupDownloadURL(context.Background(), record.ID)
require.NoError(t, err)
require.Contains(t, url, "https://presigned.example.com/")
}
func TestBackupService_ListBackups_Sorted(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
now := time.Now()
for i := 0; i < 3; i++ {
_ = svc.saveRecord(context.Background(), &BackupRecord{
ID: fmt.Sprintf("rec-%d", i),
Status: "completed",
StartedAt: now.Add(time.Duration(i) * time.Hour).Format(time.RFC3339),
})
}
records, err := svc.ListBackups(context.Background())
require.NoError(t, err)
require.Len(t, records, 3)
// 最新在前
require.Equal(t, "rec-2", records[0].ID)
require.Equal(t, "rec-0", records[2].ID)
}
func TestBackupService_TestS3Connection(t *testing.T) {
repo := newMockSettingRepo()
store := newMockObjectStore()
svc := newTestBackupService(repo, &mockDumper{}, store)
err := svc.TestS3Connection(context.Background(), BackupS3Config{
Bucket: "test",
AccessKeyID: "ak",
SecretAccessKey: "sk",
})
require.NoError(t, err)
}
func TestBackupService_TestS3Connection_Incomplete(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
err := svc.TestS3Connection(context.Background(), BackupS3Config{
Bucket: "test",
})
require.Error(t, err)
require.Contains(t, err.Error(), "incomplete")
}
func TestBackupService_Schedule_CronValidation(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
svc.cronSched = nil // 未初始化 cron
// 启用但 cron 为空
_, err := svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
Enabled: true,
CronExpr: "",
})
require.Error(t, err)
// 无效的 cron 表达式
_, err = svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
Enabled: true,
CronExpr: "invalid",
})
require.Error(t, err)
}
func TestBackupService_LoadS3Config_Corrupted(t *testing.T) {
repo := newMockSettingRepo()
_ = repo.Set(context.Background(), settingKeyBackupS3Config, "not json!!!!")
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
cfg, err := svc.loadS3Config(context.Background())
require.Error(t, err)
require.Nil(t, cfg)
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"log/slog"
"strconv"
"strings"
"time"
)
@@ -15,14 +14,17 @@ const (
claudeLockWaitTime = 200 * time.Millisecond
)
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
// ClaudeTokenCache token cache interface.
type ClaudeTokenCache = GeminiTokenCache
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
type ClaudeTokenProvider struct {
accountRepo AccountRepository
tokenCache ClaudeTokenCache
oauthService *OAuthService
accountRepo AccountRepository
tokenCache ClaudeTokenCache
oauthService *OAuthService
refreshAPI *OAuthRefreshAPI
executor OAuthRefreshExecutor
refreshPolicy ProviderRefreshPolicy
}
func NewClaudeTokenProvider(
@@ -31,13 +33,25 @@ func NewClaudeTokenProvider(
oauthService *OAuthService,
) *ClaudeTokenProvider {
return &ClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
oauthService: oauthService,
accountRepo: accountRepo,
tokenCache: tokenCache,
oauthService: oauthService,
refreshPolicy: ClaudeProviderRefreshPolicy(),
}
}
// GetAccessToken 获取有效的 access_token
// SetRefreshAPI injects unified OAuth refresh API and executor.
func (p *ClaudeTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
p.refreshAPI = api
p.executor = executor
}
// SetRefreshPolicy injects caller-side refresh policy.
func (p *ClaudeTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
p.refreshPolicy = policy
}
// GetAccessToken returns a valid access_token.
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
@@ -48,7 +62,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
cacheKey := ClaudeTokenCacheKey(account)
// 1. 先尝试缓存
// 1) Try cache first.
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
@@ -60,114 +74,39 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
// 2. 如果即将过期则刷新
// 2) Refresh if needed (pre-expiry skew).
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, claudeTokenRefreshSkew)
if err != nil {
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
if p.oauthService == nil {
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
// 构建新 credentials保留原有字段
newCredentials := make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if lockErr != nil {
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
// 检查 ctx 是否已取消
if ctx.Err() != nil {
return "", ctx.Err()
}
// 从数据库获取最新账户信息
if p.accountRepo != nil {
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
}
expiresAt = account.GetCredentialAsTime("expires_at")
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
if p.oauthService == nil {
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
refreshFailed = true
} else {
// 构建新 credentials保留原有字段
newCredentials := make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
refreshFailed = true
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
time.Sleep(claudeLockWaitTime)
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
} else {
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
}
} else if needsRefresh && p.tokenCache != nil {
// Backward-compatible test path when refreshAPI is not injected.
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
} else if lockErr != nil {
slog.Warn("claude_token_lock_failed", "account_id", account.ID, "error", lockErr)
} else {
time.Sleep(claudeLockWaitTime)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
@@ -181,22 +120,23 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
// 3) Populate cache with TTL.
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
if p.refreshPolicy.FailureTTL > 0 {
ttl = p.refreshPolicy.FailureTTL
} else {
ttl = time.Minute
}
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)

View File

@@ -29,12 +29,11 @@ const (
// Account type constants
const (
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号inference only scope
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock
AccountTypeBedrockAPIKey = domain.AccountTypeBedrockAPIKey // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号inference only scope
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock由 credentials.auth_mode 区分
)
// Redeem type constants
@@ -81,6 +80,7 @@ const (
SettingKeyRegistrationEmailSuffixWhitelist = "registration_email_suffix_whitelist" // 注册邮箱后缀白名单JSON 数组)
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyFrontendURL = "frontend_url" // 前端基础URL用于生成邮件中的重置密码链接
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
// 邮件服务设置
@@ -221,6 +221,9 @@ const (
// SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false未分组 Key 返回 403
SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling"
// SettingKeyBackendModeEnabled Backend 模式:禁用用户注册和自助服务,仅管理员可登录
SettingKeyBackendModeEnabled = "backend_mode_enabled"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).

View File

@@ -110,7 +110,9 @@ func TestCheckErrorPolicy(t *testing.T) {
expected: ErrorPolicyTempUnscheduled,
},
{
name: "temp_unschedulable_401_second_hit_upgrades_to_none",
// Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制),
// second hit 仍然返回 TempUnscheduled。
name: "temp_unschedulable_401_second_hit_antigravity_stays_temp",
account: &Account{
ID: 15,
Type: AccountTypeOAuth,
@@ -129,7 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) {
},
statusCode: 401,
body: []byte(`unauthorized`),
expected: ErrorPolicyNone,
expected: ErrorPolicyTempUnscheduled,
},
{
name: "temp_unschedulable_body_miss_returns_none",

View File

@@ -440,7 +440,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "no available accounts")
require.ErrorIs(t, err, ErrNoAvailableAccounts)
}
// TestGatewayService_SelectAccountForModelWithPlatform_AllExcluded 测试所有账户被排除
@@ -1073,7 +1073,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_NoAccounts(t *testing.
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformAnthropic)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "no available accounts")
require.ErrorIs(t, err, ErrNoAvailableAccounts)
}
func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
@@ -1734,7 +1734,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "no available accounts")
require.ErrorIs(t, err, ErrNoAvailableAccounts)
})
t.Run("混合调度-不支持模型返回错误", func(t *testing.T) {
@@ -2290,7 +2290,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts")
require.ErrorIs(t, err, ErrNoAvailableAccounts)
})
t.Run("过滤不可调度账号-限流账号被跳过", func(t *testing.T) {

View File

@@ -369,3 +369,54 @@ func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 0, usageRepo.calls)
}
func TestGatewayServiceRecordUsage_ReasoningEffortPersisted(t *testing.T) {
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
effort := "max"
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "effort_test",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 5,
},
Model: "claude-opus-4-6",
Duration: time.Second,
ReasoningEffort: &effort,
},
APIKey: &APIKey{ID: 1},
User: &User{ID: 1},
Account: &Account{ID: 1},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.NotNil(t, usageRepo.lastLog.ReasoningEffort)
require.Equal(t, "max", *usageRepo.lastLog.ReasoningEffort)
}
func TestGatewayServiceRecordUsage_ReasoningEffortNil(t *testing.T) {
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "no_effort_test",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 5,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 1},
User: &User{ID: 1},
Account: &Account{ID: 1},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Nil(t, usageRepo.lastLog.ReasoningEffort)
}

View File

@@ -60,6 +60,7 @@ type ParsedRequest struct {
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
ThinkingEnabled bool // 是否开启 thinking部分平台会影响最终模型名
OutputEffort string // output_config.effortClaude API 的推理强度控制)
MaxTokens int // max_tokens 值(用于探测请求拦截)
SessionContext *SessionContext // 可选请求上下文区分因子nil 时行为不变)
@@ -116,6 +117,9 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
parsed.ThinkingEnabled = true
}
// output_config.effort: Claude API 的推理强度控制参数
parsed.OutputEffort = strings.TrimSpace(gjson.Get(jsonStr, "output_config.effort").String())
// max_tokens: 仅接受整数值
maxTokensResult := gjson.Get(jsonStr, "max_tokens")
if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number {
@@ -747,6 +751,21 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
return newBody
}
// NormalizeClaudeOutputEffort normalizes Claude's output_config.effort value.
// Returns nil for empty or unrecognized values.
func NormalizeClaudeOutputEffort(raw string) *string {
value := strings.ToLower(strings.TrimSpace(raw))
if value == "" {
return nil
}
switch value {
case "low", "medium", "high", "max":
return &value
default:
return nil
}
}
// =========================
// Thinking Budget Rectifier
// =========================

View File

@@ -972,6 +972,76 @@ func BenchmarkParseGatewayRequest_Old_Large(b *testing.B) {
}
}
func TestParseGatewayRequest_OutputEffort(t *testing.T) {
tests := []struct {
name string
body string
wantEffort string
}{
{
name: "output_config.effort present",
body: `{"model":"claude-opus-4-6","output_config":{"effort":"medium"},"messages":[]}`,
wantEffort: "medium",
},
{
name: "output_config.effort max",
body: `{"model":"claude-opus-4-6","output_config":{"effort":"max"},"messages":[]}`,
wantEffort: "max",
},
{
name: "output_config without effort",
body: `{"model":"claude-opus-4-6","output_config":{},"messages":[]}`,
wantEffort: "",
},
{
name: "no output_config",
body: `{"model":"claude-opus-4-6","messages":[]}`,
wantEffort: "",
},
{
name: "effort with whitespace trimmed",
body: `{"model":"claude-opus-4-6","output_config":{"effort":" high "},"messages":[]}`,
wantEffort: "high",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parsed, err := ParseGatewayRequest([]byte(tt.body), "")
require.NoError(t, err)
require.Equal(t, tt.wantEffort, parsed.OutputEffort)
})
}
}
func TestNormalizeClaudeOutputEffort(t *testing.T) {
tests := []struct {
input string
want *string
}{
{"low", strPtr("low")},
{"medium", strPtr("medium")},
{"high", strPtr("high")},
{"max", strPtr("max")},
{"LOW", strPtr("low")},
{"Max", strPtr("max")},
{" medium ", strPtr("medium")},
{"", nil},
{"unknown", nil},
{"xhigh", nil},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := NormalizeClaudeOutputEffort(tt.input)
if tt.want == nil {
require.Nil(t, got)
} else {
require.NotNil(t, got)
require.Equal(t, *tt.want, *got)
}
})
}
}
func BenchmarkParseGatewayRequest_New_Large(b *testing.B) {
data := buildLargeJSON()
b.SetBytes(int64(len(data)))

View File

@@ -346,6 +346,9 @@ var systemBlockFilterPrefixes = []string{
"x-anthropic-billing-header",
}
// ErrNoAvailableAccounts 表示没有可用的账号
var ErrNoAvailableAccounts = errors.New("no available accounts")
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
@@ -492,6 +495,7 @@ type ForwardResult struct {
Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求)
ClientDisconnect bool // 客户端是否在流式传输过程中断开
ReasoningEffort *string
// 图片生成计费字段(图片生成模型使用)
ImageCount int // 生成的图片数量
@@ -1204,7 +1208,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, err
}
if len(accounts) == 0 {
return nil, errors.New("no available accounts")
return nil, ErrNoAvailableAccounts
}
ctx = s.withWindowCostPrefetch(ctx, accounts)
ctx = s.withRPMPrefetch(ctx, accounts)
@@ -1552,7 +1556,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
if len(candidates) == 0 {
return nil, errors.New("no available accounts")
return nil, ErrNoAvailableAccounts
}
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
@@ -1641,7 +1645,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
},
}, nil
}
return nil, errors.New("no available accounts")
return nil, ErrNoAvailableAccounts
}
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
@@ -2173,10 +2177,10 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
}
// isAccountSchedulableForQuota 检查 API Key 账号是否在配额限制内
// 适用于配置了 quota_limit 的 apikey 类型账号
// isAccountSchedulableForQuota 检查账号是否在配额限制内
// 适用于配置了 quota_limit 的 apikey 和 bedrock 类型账号
func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool {
if account.Type != AccountTypeAPIKey {
if !account.IsAPIKeyOrBedrock() {
return true
}
return !account.IsQuotaExceeded()
@@ -2851,9 +2855,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if selected == nil {
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false)
if requestedModel != "" {
return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats))
return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats))
}
return nil, errors.New("no available accounts")
return nil, ErrNoAvailableAccounts
}
// 4. 建立粘性绑定
@@ -3089,9 +3093,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if selected == nil {
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, nativePlatform, accounts, excludedIDs, true)
if requestedModel != "" {
return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats))
return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats))
}
return nil, errors.New("no available accounts")
return nil, ErrNoAvailableAccounts
}
// 4. 建立粘性绑定
@@ -3532,9 +3536,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
}
return apiKey, "apikey", nil
case AccountTypeBedrock:
return "", "bedrock", nil // Bedrock 使用 SigV4 签名,不需要 token
case AccountTypeBedrockAPIKey:
return "", "bedrock-apikey", nil // Bedrock API Key 使用 Bearer Token由 forwardBedrock 处理
return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key由 forwardBedrock 处理
default:
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
}
@@ -5186,7 +5188,7 @@ func (s *GatewayService) forwardBedrock(
if account.IsBedrockAPIKey() {
bedrockAPIKey = account.GetCredential("api_key")
if bedrockAPIKey == "" {
return nil, fmt.Errorf("api_key not found in bedrock-apikey credentials")
return nil, fmt.Errorf("api_key not found in bedrock credentials")
}
} else {
signer, err = NewBedrockSignerFromAccount(account)
@@ -5375,8 +5377,9 @@ func (s *GatewayService) handleBedrockUpstreamErrors(
Message: extractUpstreamErrorMessage(respBody),
})
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
return s.handleRetryExhaustedError(ctx, resp, c, account)
@@ -5398,8 +5401,9 @@ func (s *GatewayService) handleBedrockUpstreamErrors(
Message: extractUpstreamErrorMessage(respBody),
})
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
@@ -5808,9 +5812,10 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri
return betaPolicyResult{}
}
isOAuth := account.IsOAuth()
isBedrock := account.IsBedrock()
var result betaPolicyResult
for _, rule := range settings.Rules {
if !betaPolicyScopeMatches(rule.Scope, isOAuth) {
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue
}
switch rule.Action {
@@ -5870,14 +5875,16 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont
}
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
func betaPolicyScopeMatches(scope string, isOAuth bool) bool {
func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool {
switch scope {
case BetaPolicyScopeAll:
return true
case BetaPolicyScopeOAuth:
return isOAuth
case BetaPolicyScopeAPIKey:
return !isOAuth
return !isOAuth && !isBedrock
case BetaPolicyScopeBedrock:
return isBedrock
default:
return true // unknown scope → match all (fail-open)
}
@@ -5959,12 +5966,13 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke
return nil
}
isOAuth := account.IsOAuth()
isBedrock := account.IsBedrock()
tokenSet := buildBetaTokenSet(tokens)
for _, rule := range settings.Rules {
if rule.Action != BetaPolicyActionBlock {
continue
}
if !betaPolicyScopeMatches(rule.Scope, isOAuth) {
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue
}
if _, present := tokenSet[rule.BetaToken]; present {
@@ -6125,6 +6133,29 @@ func extractUpstreamErrorMessage(body []byte) string {
return gjson.GetBytes(body, "message").String()
}
func extractUpstreamErrorCode(body []byte) string {
if code := strings.TrimSpace(gjson.GetBytes(body, "error.code").String()); code != "" {
return code
}
inner := strings.TrimSpace(gjson.GetBytes(body, "error.message").String())
if !strings.HasPrefix(inner, "{") {
return ""
}
if code := strings.TrimSpace(gjson.Get(inner, "error.code").String()); code != "" {
return code
}
if lastBrace := strings.LastIndex(inner, "}"); lastBrace >= 0 {
if code := strings.TrimSpace(gjson.Get(inner[:lastBrace+1], "error.code").String()); code != "" {
return code
}
}
return ""
}
func isCountTokensUnsupported404(statusCode int, body []byte) bool {
if statusCode != http.StatusNotFound {
return false
@@ -7099,6 +7130,8 @@ type RecordUsageInput struct {
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
InboundEndpoint string // 入站端点(客户端请求路径)
UpstreamEndpoint string // 上游端点(标准化后的上游路径)
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
@@ -7176,7 +7209,7 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
}
// 4. 账号配额用量账号口径TotalCost × 账号计费倍率)
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
@@ -7264,7 +7297,7 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
}
if p.Cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
}
@@ -7496,6 +7529,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
@@ -7576,6 +7612,8 @@ type RecordUsageLongContextInput struct {
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
InboundEndpoint string // 入站端点(客户端请求路径)
UpstreamEndpoint string // 上游端点(标准化后的上游路径)
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
@@ -7672,6 +7710,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
AccountID: account.ID,
RequestID: requestID,
Model: result.Model,
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,

View File

@@ -3235,7 +3235,7 @@ func cleanToolSchema(schema any) any {
for key, value := range v {
// 跳过不支持的字段
if key == "$schema" || key == "$id" || key == "$ref" ||
key == "additionalProperties" || key == "minLength" ||
key == "additionalProperties" || key == "patternProperties" || key == "minLength" ||
key == "maxLength" || key == "minItems" || key == "maxItems" {
continue
}

View File

@@ -15,10 +15,14 @@ const (
geminiTokenCacheSkew = 5 * time.Minute
)
// GeminiTokenProvider manages access_token for Gemini OAuth accounts.
type GeminiTokenProvider struct {
accountRepo AccountRepository
tokenCache GeminiTokenCache
geminiOAuthService *GeminiOAuthService
refreshAPI *OAuthRefreshAPI
executor OAuthRefreshExecutor
refreshPolicy ProviderRefreshPolicy
}
func NewGeminiTokenProvider(
@@ -30,9 +34,21 @@ func NewGeminiTokenProvider(
accountRepo: accountRepo,
tokenCache: tokenCache,
geminiOAuthService: geminiOAuthService,
refreshPolicy: GeminiProviderRefreshPolicy(),
}
}
// SetRefreshAPI injects unified OAuth refresh API and executor.
func (p *GeminiTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
p.refreshAPI = api
p.executor = executor
}
// SetRefreshPolicy injects caller-side refresh policy.
func (p *GeminiTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
p.refreshPolicy = policy
}
func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
@@ -53,39 +69,31 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
// 2) Refresh if needed (pre-expiry skew).
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// Re-check after lock (another worker may have refreshed).
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, geminiTokenRefreshSkew)
if err != nil {
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
}
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
slog.Debug("gemini_token_lock_held_use_old", "account_id", account.ID)
} else {
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew {
if p.geminiOAuthService == nil {
return "", errors.New("gemini oauth service not configured")
}
tokenInfo, err := p.geminiOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return "", err
}
newCredentials := p.geminiOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
_ = p.accountRepo.Update(ctx, account)
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
} else if needsRefresh && p.tokenCache != nil {
// Backward-compatible test path when refreshAPI is not injected.
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
} else if lockErr != nil {
slog.Warn("gemini_token_lock_failed", "account_id", account.ID, "error", lockErr)
}
}
@@ -95,15 +103,14 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
// project_id is optional now:
// - If present: will use Code Assist API (requires project_id)
// - If absent: will use AI Studio API with OAuth token (like regular API key mode)
// Auto-detect project_id only if explicitly enabled via a credential flag
// - If present: use Code Assist API (requires project_id)
// - If absent: use AI Studio API with OAuth token.
projectID := strings.TrimSpace(account.GetCredential("project_id"))
autoDetectProjectID := account.GetCredential("auto_detect_project_id") == "true"
if projectID == "" && autoDetectProjectID {
if p.geminiOAuthService == nil {
return accessToken, nil // Fallback to AI Studio API mode
return accessToken, nil
}
var proxyURL string
@@ -132,17 +139,15 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
}
// 3) Populate cache with TTL(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
// 3) Populate cache with TTL.
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if expiresAt != nil {

View File

@@ -13,6 +13,11 @@ func NewGeminiTokenRefresher(geminiOAuthService *GeminiOAuthService) *GeminiToke
return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService}
}
// CacheKey 返回用于分布式锁的缓存键
func (r *GeminiTokenRefresher) CacheKey(account *Account) string {
return GeminiTokenCacheKey(account)
}
func (r *GeminiTokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == PlatformGemini && account.Type == AccountTypeOAuth
}
@@ -35,11 +40,7 @@ func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *Account) (m
}
newCredentials := r.geminiOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
newCredentials = MergeCredentials(account.Credentials, newCredentials)
return newCredentials, nil
}

View File

@@ -0,0 +1,159 @@
package service
import (
"context"
"fmt"
"log/slog"
"strconv"
"time"
)
// OAuthRefreshExecutor 各平台实现的 OAuth 刷新执行器
// TokenRefresher 接口的超集:增加了 CacheKey 方法用于分布式锁
type OAuthRefreshExecutor interface {
TokenRefresher
// CacheKey 返回用于分布式锁的缓存键(与 TokenProvider 使用的一致)
CacheKey(account *Account) string
}
const refreshLockTTL = 30 * time.Second
// OAuthRefreshResult 统一刷新结果
type OAuthRefreshResult struct {
Refreshed bool // 实际执行了刷新
NewCredentials map[string]any // 刷新后的 credentialsnil 表示未刷新)
Account *Account // 从 DB 重新读取的最新 account
LockHeld bool // 锁被其他 worker 持有(未执行刷新)
}
// OAuthRefreshAPI 统一的 OAuth Token 刷新入口
// 封装分布式锁、DB 重读、已刷新检查等通用逻辑
type OAuthRefreshAPI struct {
accountRepo AccountRepository
tokenCache GeminiTokenCache // 可选nil = 无锁
}
// NewOAuthRefreshAPI 创建统一刷新 API
func NewOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI {
return &OAuthRefreshAPI{
accountRepo: accountRepo,
tokenCache: tokenCache,
}
}
// RefreshIfNeeded 在分布式锁保护下按需刷新 OAuth token
//
// 流程:
// 1. 获取分布式锁
// 2. 从 DB 重读最新 account防止使用过时的 refresh_token
// 3. 二次检查是否仍需刷新
// 4. 调用 executor.Refresh() 执行平台特定刷新逻辑
// 5. 设置 _token_version + 更新 DB
// 6. 释放锁
func (api *OAuthRefreshAPI) RefreshIfNeeded(
ctx context.Context,
account *Account,
executor OAuthRefreshExecutor,
refreshWindow time.Duration,
) (*OAuthRefreshResult, error) {
cacheKey := executor.CacheKey(account)
// 1. 获取分布式锁
lockAcquired := false
if api.tokenCache != nil {
acquired, lockErr := api.tokenCache.AcquireRefreshLock(ctx, cacheKey, refreshLockTTL)
if lockErr != nil {
// Redis 错误,降级为无锁刷新
slog.Warn("oauth_refresh_lock_failed_degraded",
"account_id", account.ID,
"cache_key", cacheKey,
"error", lockErr,
)
} else if !acquired {
// 锁被其他 worker 持有
return &OAuthRefreshResult{LockHeld: true}, nil
} else {
lockAcquired = true
defer func() { _ = api.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
}
}
// 2. 从 DB 重读最新 account锁保护下确保使用最新的 refresh_token
freshAccount, err := api.accountRepo.GetByID(ctx, account.ID)
if err != nil {
slog.Warn("oauth_refresh_db_reread_failed",
"account_id", account.ID,
"error", err,
)
// 降级使用传入的 account
freshAccount = account
} else if freshAccount == nil {
freshAccount = account
}
// 3. 二次检查是否仍需刷新(另一条路径可能已刷新)
if !executor.NeedsRefresh(freshAccount, refreshWindow) {
return &OAuthRefreshResult{
Account: freshAccount,
}, nil
}
// 4. 执行平台特定刷新逻辑
newCredentials, refreshErr := executor.Refresh(ctx, freshAccount)
if refreshErr != nil {
return nil, refreshErr
}
// 5. 设置版本号 + 更新 DB
if newCredentials != nil {
newCredentials["_token_version"] = time.Now().UnixMilli()
freshAccount.Credentials = newCredentials
if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil {
slog.Error("oauth_refresh_update_failed",
"account_id", freshAccount.ID,
"error", updateErr,
)
return nil, fmt.Errorf("oauth refresh succeeded but DB update failed: %w", updateErr)
}
}
_ = lockAcquired // suppress unused warning when tokenCache is nil
return &OAuthRefreshResult{
Refreshed: true,
NewCredentials: newCredentials,
Account: freshAccount,
}, nil
}
// MergeCredentials 将旧 credentials 中不存在于新 map 的字段保留到新 map 中
func MergeCredentials(oldCreds, newCreds map[string]any) map[string]any {
if newCreds == nil {
newCreds = make(map[string]any)
}
for k, v := range oldCreds {
if _, exists := newCreds[k]; !exists {
newCreds[k] = v
}
}
return newCreds
}
// BuildClaudeAccountCredentials 为 Claude 平台构建 OAuth credentials map
// 消除 Claude 平台没有 BuildAccountCredentials 方法的问题
func BuildClaudeAccountCredentials(tokenInfo *TokenInfo) map[string]any {
creds := map[string]any{
"access_token": tokenInfo.AccessToken,
"token_type": tokenInfo.TokenType,
"expires_in": strconv.FormatInt(tokenInfo.ExpiresIn, 10),
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
}
if tokenInfo.RefreshToken != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
creds["scope"] = tokenInfo.Scope
}
return creds
}

View File

@@ -0,0 +1,395 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ---------- mock helpers ----------
// refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests.
type refreshAPIAccountRepo struct {
mockAccountRepoForGemini
account *Account // returned by GetByID
getByIDErr error
updateErr error
updateCalls int
}
func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) {
if r.getByIDErr != nil {
return nil, r.getByIDErr
}
return r.account, nil
}
func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error {
r.updateCalls++
return r.updateErr
}
// refreshAPIExecutorStub implements OAuthRefreshExecutor for tests.
type refreshAPIExecutorStub struct {
needsRefresh bool
credentials map[string]any
err error
refreshCalls int
}
func (e *refreshAPIExecutorStub) CanRefresh(_ *Account) bool { return true }
func (e *refreshAPIExecutorStub) NeedsRefresh(_ *Account, _ time.Duration) bool {
return e.needsRefresh
}
func (e *refreshAPIExecutorStub) Refresh(_ context.Context, _ *Account) (map[string]any, error) {
e.refreshCalls++
if e.err != nil {
return nil, e.err
}
return e.credentials, nil
}
func (e *refreshAPIExecutorStub) CacheKey(account *Account) string {
return "test:api:" + account.Platform
}
// refreshAPICacheStub implements GeminiTokenCache for OAuthRefreshAPI tests.
type refreshAPICacheStub struct {
lockResult bool
lockErr error
releaseCalls int
}
func (c *refreshAPICacheStub) GetAccessToken(context.Context, string) (string, error) {
return "", nil
}
func (c *refreshAPICacheStub) SetAccessToken(context.Context, string, string, time.Duration) error {
return nil
}
func (c *refreshAPICacheStub) DeleteAccessToken(context.Context, string) error { return nil }
func (c *refreshAPICacheStub) AcquireRefreshLock(context.Context, string, time.Duration) (bool, error) {
return c.lockResult, c.lockErr
}
func (c *refreshAPICacheStub) ReleaseRefreshLock(context.Context, string) error {
c.releaseCalls++
return nil
}
// ========== RefreshIfNeeded tests ==========
func TestRefreshIfNeeded_Success(t *testing.T) {
account := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "new-token"},
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.Refreshed)
require.NotNil(t, result.NewCredentials)
require.Equal(t, "new-token", result.NewCredentials["access_token"])
require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set
require.Equal(t, 1, repo.updateCalls) // DB updated
require.Equal(t, 1, cache.releaseCalls) // lock released
require.Equal(t, 1, executor.refreshCalls)
}
func TestRefreshIfNeeded_LockHeld(t *testing.T) {
account := &Account{ID: 2, Platform: PlatformAnthropic}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockResult: false} // lock not acquired
executor := &refreshAPIExecutorStub{needsRefresh: true}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.LockHeld)
require.False(t, result.Refreshed)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, executor.refreshCalls)
}
func TestRefreshIfNeeded_LockErrorDegrades(t *testing.T) {
account := &Account{ID: 3, Platform: PlatformGemini, Type: AccountTypeOAuth}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockErr: errors.New("redis down")} // lock error
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "degraded-token"},
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.Refreshed) // still refreshed (degraded mode)
require.Equal(t, 1, repo.updateCalls) // DB updated
require.Equal(t, 0, cache.releaseCalls) // no lock to release
require.Equal(t, 1, executor.refreshCalls)
}
func TestRefreshIfNeeded_NoCacheNoLock(t *testing.T) {
account := &Account{ID: 4, Platform: PlatformGemini, Type: AccountTypeOAuth}
repo := &refreshAPIAccountRepo{account: account}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "no-cache-token"},
}
api := NewOAuthRefreshAPI(repo, nil) // no cache = no lock
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.Refreshed)
require.Equal(t, 1, repo.updateCalls)
}
func TestRefreshIfNeeded_AlreadyRefreshed(t *testing.T) {
account := &Account{ID: 5, Platform: PlatformAnthropic}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{needsRefresh: false} // already refreshed
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.False(t, result.Refreshed)
require.False(t, result.LockHeld)
require.NotNil(t, result.Account) // returns fresh account
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, executor.refreshCalls)
}
func TestRefreshIfNeeded_RefreshError(t *testing.T) {
account := &Account{ID: 6, Platform: PlatformAnthropic}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
err: errors.New("invalid_grant: token revoked"),
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "invalid_grant")
require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error
require.Equal(t, 1, cache.releaseCalls) // lock still released via defer
}
func TestRefreshIfNeeded_DBUpdateError(t *testing.T) {
account := &Account{ID: 7, Platform: PlatformGemini, Type: AccountTypeOAuth}
repo := &refreshAPIAccountRepo{
account: account,
updateErr: errors.New("db connection lost"),
}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "token"},
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "DB update failed")
require.Equal(t, 1, repo.updateCalls) // attempted
}
func TestRefreshIfNeeded_DBRereadFails(t *testing.T) {
account := &Account{ID: 8, Platform: PlatformAnthropic, Type: AccountTypeOAuth}
repo := &refreshAPIAccountRepo{
account: nil, // GetByID returns nil
getByIDErr: errors.New("db timeout"),
}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: map[string]any{"access_token": "fallback-token"},
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.Refreshed)
require.Equal(t, 1, executor.refreshCalls) // still refreshes using passed-in account
}
func TestRefreshIfNeeded_NilCredentials(t *testing.T) {
account := &Account{ID: 9, Platform: PlatformGemini, Type: AccountTypeOAuth}
repo := &refreshAPIAccountRepo{account: account}
cache := &refreshAPICacheStub{lockResult: true}
executor := &refreshAPIExecutorStub{
needsRefresh: true,
credentials: nil, // Refresh returns nil credentials
}
api := NewOAuthRefreshAPI(repo, cache)
result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute)
require.NoError(t, err)
require.True(t, result.Refreshed)
require.Nil(t, result.NewCredentials)
require.Equal(t, 0, repo.updateCalls) // no DB update when credentials are nil
}
// ========== MergeCredentials tests ==========
func TestMergeCredentials_Basic(t *testing.T) {
old := map[string]any{"a": "1", "b": "2", "c": "3"}
new := map[string]any{"a": "new", "d": "4"}
result := MergeCredentials(old, new)
require.Equal(t, "new", result["a"]) // new value preserved
require.Equal(t, "2", result["b"]) // old value kept
require.Equal(t, "3", result["c"]) // old value kept
require.Equal(t, "4", result["d"]) // new value preserved
}
func TestMergeCredentials_NilNew(t *testing.T) {
old := map[string]any{"a": "1"}
result := MergeCredentials(old, nil)
require.NotNil(t, result)
require.Equal(t, "1", result["a"])
}
func TestMergeCredentials_NilOld(t *testing.T) {
new := map[string]any{"a": "1"}
result := MergeCredentials(nil, new)
require.Equal(t, "1", result["a"])
}
func TestMergeCredentials_BothNil(t *testing.T) {
result := MergeCredentials(nil, nil)
require.NotNil(t, result)
require.Empty(t, result)
}
func TestMergeCredentials_NewOverridesOld(t *testing.T) {
old := map[string]any{"access_token": "old-token", "refresh_token": "old-refresh"}
new := map[string]any{"access_token": "new-token"}
result := MergeCredentials(old, new)
require.Equal(t, "new-token", result["access_token"]) // overridden
require.Equal(t, "old-refresh", result["refresh_token"]) // preserved
}
// ========== BuildClaudeAccountCredentials tests ==========
func TestBuildClaudeAccountCredentials_Full(t *testing.T) {
tokenInfo := &TokenInfo{
AccessToken: "at-123",
TokenType: "Bearer",
ExpiresIn: 3600,
ExpiresAt: 1700000000,
RefreshToken: "rt-456",
Scope: "openid",
}
creds := BuildClaudeAccountCredentials(tokenInfo)
require.Equal(t, "at-123", creds["access_token"])
require.Equal(t, "Bearer", creds["token_type"])
require.Equal(t, "3600", creds["expires_in"])
require.Equal(t, "1700000000", creds["expires_at"])
require.Equal(t, "rt-456", creds["refresh_token"])
require.Equal(t, "openid", creds["scope"])
}
func TestBuildClaudeAccountCredentials_Minimal(t *testing.T) {
tokenInfo := &TokenInfo{
AccessToken: "at-789",
TokenType: "Bearer",
ExpiresIn: 7200,
ExpiresAt: 1700003600,
}
creds := BuildClaudeAccountCredentials(tokenInfo)
require.Equal(t, "at-789", creds["access_token"])
require.Equal(t, "Bearer", creds["token_type"])
require.Equal(t, "7200", creds["expires_in"])
require.Equal(t, "1700003600", creds["expires_at"])
_, hasRefresh := creds["refresh_token"]
_, hasScope := creds["scope"]
require.False(t, hasRefresh, "refresh_token should not be set when empty")
require.False(t, hasScope, "scope should not be set when empty")
}
// ========== BackgroundRefreshPolicy tests ==========
func TestBackgroundRefreshPolicy_DefaultSkips(t *testing.T) {
p := DefaultBackgroundRefreshPolicy()
require.ErrorIs(t, p.handleLockHeld(), errRefreshSkipped)
require.ErrorIs(t, p.handleAlreadyRefreshed(), errRefreshSkipped)
}
func TestBackgroundRefreshPolicy_SuccessOverride(t *testing.T) {
p := BackgroundRefreshPolicy{
OnLockHeld: BackgroundSkipAsSuccess,
OnAlreadyRefresh: BackgroundSkipAsSuccess,
}
require.NoError(t, p.handleLockHeld())
require.NoError(t, p.handleAlreadyRefreshed())
}
// ========== ProviderRefreshPolicy tests ==========
func TestClaudeProviderRefreshPolicy(t *testing.T) {
p := ClaudeProviderRefreshPolicy()
require.Equal(t, ProviderRefreshErrorUseExistingToken, p.OnRefreshError)
require.Equal(t, ProviderLockHeldWaitForCache, p.OnLockHeld)
require.Equal(t, time.Minute, p.FailureTTL)
}
func TestOpenAIProviderRefreshPolicy(t *testing.T) {
p := OpenAIProviderRefreshPolicy()
require.Equal(t, ProviderRefreshErrorUseExistingToken, p.OnRefreshError)
require.Equal(t, ProviderLockHeldWaitForCache, p.OnLockHeld)
require.Equal(t, time.Minute, p.FailureTTL)
}
func TestGeminiProviderRefreshPolicy(t *testing.T) {
p := GeminiProviderRefreshPolicy()
require.Equal(t, ProviderRefreshErrorReturn, p.OnRefreshError)
require.Equal(t, ProviderLockHeldUseExistingToken, p.OnLockHeld)
require.Equal(t, time.Duration(0), p.FailureTTL)
}
func TestAntigravityProviderRefreshPolicy(t *testing.T) {
p := AntigravityProviderRefreshPolicy()
require.Equal(t, ProviderRefreshErrorReturn, p.OnRefreshError)
require.Equal(t, ProviderLockHeldUseExistingToken, p.OnLockHeld)
require.Equal(t, time.Duration(0), p.FailureTTL)
}

View File

@@ -725,7 +725,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}, len(candidates), topK, loadSkew, nil
}
return nil, len(candidates), topK, loadSkew, errors.New("no available accounts")
return nil, len(candidates), topK, loadSkew, ErrNoAvailableAccounts
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {

View File

@@ -226,6 +226,41 @@ func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T)
require.Equal(t, 1, userRepo.deductCalls)
}
func TestOpenAIGatewayServiceRecordUsage_IncludesEndpointMetadata(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
rateRepo := &openAIUserGroupRateRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_endpoint_metadata",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 2,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 1002,
Group: &Group{RateMultiplier: 1},
},
User: &User{ID: 2002},
Account: &Account{ID: 3002},
InboundEndpoint: " /v1/chat/completions ",
UpstreamEndpoint: " /v1/responses ",
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.NotNil(t, usageRepo.lastLog.InboundEndpoint)
require.Equal(t, "/v1/chat/completions", *usageRepo.lastLog.InboundEndpoint)
require.NotNil(t, usageRepo.lastLog.UpstreamEndpoint)
require.Equal(t, "/v1/responses", *usageRepo.lastLog.UpstreamEndpoint)
}
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateOnResolverError(t *testing.T) {
groupID := int64(12)
groupRate := 1.6

View File

@@ -480,6 +480,7 @@ func classifyOpenAIWSReconnectReason(err error) (string, bool) {
"upgrade_required",
"ws_unsupported",
"auth_failed",
"invalid_encrypted_content",
"previous_response_not_found":
return reason, false
}
@@ -530,6 +531,14 @@ func resolveOpenAIWSFallbackErrorResponse(err error) (statusCode int, errType st
}
switch reason {
case "invalid_encrypted_content":
if statusCode == 0 {
statusCode = http.StatusBadRequest
}
errType = "invalid_request_error"
if upstreamMessage == "" {
upstreamMessage = "encrypted content could not be verified"
}
case "previous_response_not_found":
if statusCode == 0 {
statusCode = http.StatusBadRequest
@@ -1303,7 +1312,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
return nil, err
}
if len(accounts) == 0 {
return nil, errors.New("no available accounts")
return nil, ErrNoAvailableAccounts
}
isExcluded := func(accountID int64) bool {
@@ -1373,7 +1382,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
if len(candidates) == 0 {
return nil, errors.New("no available accounts")
return nil, ErrNoAvailableAccounts
}
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
@@ -1480,7 +1489,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}, nil
}
return nil, errors.New("no available accounts")
return nil, ErrNoAvailableAccounts
}
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
@@ -1924,6 +1933,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
var wsErr error
wsLastFailureReason := ""
wsPrevResponseRecoveryTried := false
wsInvalidEncryptedContentRecoveryTried := false
recoverPrevResponseNotFound := func(attempt int) bool {
if wsPrevResponseRecoveryTried {
return false
@@ -1956,6 +1966,37 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
)
return true
}
recoverInvalidEncryptedContent := func(attempt int) bool {
if wsInvalidEncryptedContentRecoveryTried {
return false
}
removedReasoningItems := trimOpenAIEncryptedReasoningItems(wsReqBody)
if !removedReasoningItems {
logOpenAIWSModeInfo(
"reconnect_invalid_encrypted_content_recovery_skip account_id=%d attempt=%d reason=missing_encrypted_reasoning_items",
account.ID,
attempt,
)
return false
}
previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id")
hasFunctionCallOutput := HasFunctionCallOutput(wsReqBody)
if previousResponseID != "" && !hasFunctionCallOutput {
delete(wsReqBody, "previous_response_id")
}
wsInvalidEncryptedContentRecoveryTried = true
logOpenAIWSModeInfo(
"reconnect_invalid_encrypted_content_recovery account_id=%d attempt=%d action=drop_encrypted_reasoning_items retry=1 previous_response_id_present=%v previous_response_id=%s previous_response_id_kind=%s has_function_call_output=%v dropped_previous_response_id=%v",
account.ID,
attempt,
previousResponseID != "",
truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)),
hasFunctionCallOutput,
previousResponseID != "" && !hasFunctionCallOutput,
)
return true
}
retryBudget := s.openAIWSRetryTotalBudget()
retryStartedAt := time.Now()
wsRetryLoop:
@@ -1992,6 +2033,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
if reason == "previous_response_not_found" && recoverPrevResponseNotFound(attempt) {
continue
}
if reason == "invalid_encrypted_content" && recoverInvalidEncryptedContent(attempt) {
continue
}
if retryable && attempt < maxAttempts {
backoff := s.openAIWSRetryBackoff(attempt)
if retryBudget > 0 && time.Since(retryStartedAt)+backoff > retryBudget {
@@ -2075,126 +2119,143 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
return nil, wsErr
}
// Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
releaseUpstreamCtx()
if err != nil {
return nil, err
}
httpInvalidEncryptedContentRetryTried := false
for {
// Build upstream request
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
releaseUpstreamCtx()
if err != nil {
return nil, err
}
// Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// Send request
upstreamStart := time.Now()
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
if err != nil {
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// Handle error response
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
// Send request
upstreamStart := time.Now()
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
if err != nil {
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
// Handle error response
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamCode := extractUpstreamErrorCode(respBody)
if !httpInvalidEncryptedContentRetryTried && resp.StatusCode == http.StatusBadRequest && upstreamCode == "invalid_encrypted_content" {
if trimOpenAIEncryptedReasoningItems(reqBody) {
body, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("serialize invalid_encrypted_content retry body: %w", err)
}
setOpsUpstreamRequestBody(c, body)
httpInvalidEncryptedContentRetryTried = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Retrying non-WSv2 request once after invalid_encrypted_content (account: %s)", account.Name)
continue
}
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Skip non-WSv2 invalid_encrypted_content retry because encrypted reasoning items are missing (account: %s)", account.Name)
}
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleErrorResponse(ctx, resp, c, account, body)
}
defer func() { _ = resp.Body.Close() }()
// Handle normal response
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
if err != nil {
return nil, err
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
if err != nil {
return nil, err
}
}
return s.handleErrorResponse(ctx, resp, c, account, body)
}
// Handle normal response
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
if err != nil {
return nil, err
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if account.Type == AccountTypeOAuth {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
if err != nil {
return nil, err
if usage == nil {
usage = &OpenAIUsage{}
}
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
serviceTier := extractOpenAIServiceTier(reqBody)
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
ServiceTier: serviceTier,
ReasoningEffort: reasoningEffort,
Stream: reqStream,
OpenAIWSMode: false,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if account.Type == AccountTypeOAuth {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
if usage == nil {
usage = &OpenAIUsage{}
}
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
serviceTier := extractOpenAIServiceTier(reqBody)
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
ServiceTier: serviceTier,
ReasoningEffort: reasoningEffort,
Stream: reqStream,
OpenAIWSMode: false,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
@@ -3756,6 +3817,109 @@ func buildOpenAIResponsesURL(base string) string {
return normalized + "/v1/responses"
}
func trimOpenAIEncryptedReasoningItems(reqBody map[string]any) bool {
if len(reqBody) == 0 {
return false
}
inputValue, has := reqBody["input"]
if !has {
return false
}
switch input := inputValue.(type) {
case []any:
filtered := input[:0]
changed := false
for _, item := range input {
nextItem, itemChanged, keep := sanitizeEncryptedReasoningInputItem(item)
if itemChanged {
changed = true
}
if !keep {
continue
}
filtered = append(filtered, nextItem)
}
if !changed {
return false
}
if len(filtered) == 0 {
delete(reqBody, "input")
return true
}
reqBody["input"] = filtered
return true
case []map[string]any:
filtered := input[:0]
changed := false
for _, item := range input {
nextItem, itemChanged, keep := sanitizeEncryptedReasoningInputItem(item)
if itemChanged {
changed = true
}
if !keep {
continue
}
nextMap, ok := nextItem.(map[string]any)
if !ok {
filtered = append(filtered, item)
continue
}
filtered = append(filtered, nextMap)
}
if !changed {
return false
}
if len(filtered) == 0 {
delete(reqBody, "input")
return true
}
reqBody["input"] = filtered
return true
case map[string]any:
nextItem, changed, keep := sanitizeEncryptedReasoningInputItem(input)
if !changed {
return false
}
if !keep {
delete(reqBody, "input")
return true
}
nextMap, ok := nextItem.(map[string]any)
if !ok {
return false
}
reqBody["input"] = nextMap
return true
default:
return false
}
}
func sanitizeEncryptedReasoningInputItem(item any) (next any, changed bool, keep bool) {
inputItem, ok := item.(map[string]any)
if !ok {
return item, false, true
}
itemType, _ := inputItem["type"].(string)
if strings.TrimSpace(itemType) != "reasoning" {
return item, false, true
}
_, hasEncryptedContent := inputItem["encrypted_content"]
if !hasEncryptedContent {
return item, false, true
}
delete(inputItem, "encrypted_content")
if len(inputItem) == 1 {
return nil, true, false
}
return inputItem, true, true
}
func IsOpenAIResponsesCompactPathForTest(c *gin.Context) bool {
return isOpenAIResponsesCompactPath(c)
}
@@ -3864,6 +4028,8 @@ type OpenAIRecordUsageInput struct {
User *User
Account *Account
Subscription *UserSubscription
InboundEndpoint string
UpstreamEndpoint string
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string
@@ -3942,6 +4108,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
Model: billingModel,
ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort,
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
@@ -3961,7 +4129,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
}
// 添加 UserAgent
if input.UserAgent != "" {
usageLog.UserAgent = &input.UserAgent
@@ -4504,3 +4671,11 @@ func normalizeOpenAIReasoningEffort(raw string) string {
return ""
}
}
func optionalTrimmedStringPtr(raw string) *string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return nil
}
return &trimmed
}

View File

@@ -20,7 +20,7 @@ const (
openAILockWarnThresholdMs = 250
)
// OpenAITokenRuntimeMetrics 表示 OpenAI token 刷新与锁竞争保护指标快照。
// OpenAITokenRuntimeMetrics is a snapshot of refresh and lock contention metrics.
type OpenAITokenRuntimeMetrics struct {
RefreshRequests int64
RefreshSuccess int64
@@ -72,15 +72,18 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() {
m.lastObservedUnixMs.Store(time.Now().UnixMilli())
}
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
// OpenAITokenCache token cache interface.
type OpenAITokenCache = GeminiTokenCache
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
// OpenAITokenProvider manages access_token for OpenAI/Sora OAuth accounts.
type OpenAITokenProvider struct {
accountRepo AccountRepository
tokenCache OpenAITokenCache
openAIOAuthService *OpenAIOAuthService
metrics *openAITokenRuntimeMetricsStore
refreshAPI *OAuthRefreshAPI
executor OAuthRefreshExecutor
refreshPolicy ProviderRefreshPolicy
}
func NewOpenAITokenProvider(
@@ -93,9 +96,21 @@ func NewOpenAITokenProvider(
tokenCache: tokenCache,
openAIOAuthService: openAIOAuthService,
metrics: &openAITokenRuntimeMetricsStore{},
refreshPolicy: OpenAIProviderRefreshPolicy(),
}
}
// SetRefreshAPI injects unified OAuth refresh API and executor.
func (p *OpenAITokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
p.refreshAPI = api
p.executor = executor
}
// SetRefreshPolicy injects caller-side refresh policy.
func (p *OpenAITokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
p.refreshPolicy = policy
}
func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics {
if p == nil {
return OpenAITokenRuntimeMetrics{}
@@ -110,7 +125,7 @@ func (p *OpenAITokenProvider) ensureMetrics() {
}
}
// GetAccessToken 获取有效的 access_token
// GetAccessToken returns a valid access_token.
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
p.ensureMetrics()
if account == nil {
@@ -122,7 +137,7 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
cacheKey := OpenAITokenCacheKey(account)
// 1. 先尝试缓存
// 1) Try cache first.
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit", "account_id", account.ID)
@@ -134,114 +149,62 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
slog.Debug("openai_token_cache_miss", "account_id", account.ID)
// 2. 如果即将过期则刷新
// 2) Refresh if needed (pre-expiry skew).
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
p.metrics.refreshRequests.Add(1)
p.metrics.touchNow()
// Sora accounts skip OpenAI OAuth refresh and keep existing token path.
if account.Platform == PlatformSora {
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
refreshFailed = true
} else {
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew)
if err != nil {
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
}
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
p.metrics.refreshFailure.Add(1)
refreshFailed = true
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache {
p.metrics.lockContention.Add(1)
p.metrics.touchNow()
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
if waitErr != nil {
return "", waitErr
}
if strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
} else if result.Refreshed {
p.metrics.refreshSuccess.Add(1)
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
} else {
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
} else if needsRefresh && p.tokenCache != nil {
// Backward-compatible test path when refreshAPI is not injected.
p.metrics.refreshRequests.Add(1)
p.metrics.touchNow()
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if account.Platform == PlatformSora {
slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
refreshFailed = true
} else if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1)
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
p.metrics.refreshFailure.Add(1)
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
p.metrics.refreshSuccess.Add(1)
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if lockErr != nil {
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
p.metrics.lockAcquireFailure.Add(1)
p.metrics.touchNow()
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
// 检查 ctx 是否已取消
if ctx.Err() != nil {
return "", ctx.Err()
}
// 从数据库获取最新账户信息
if p.accountRepo != nil {
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
}
expiresAt = account.GetCredentialAsTime("expires_at")
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if account.Platform == PlatformSora {
slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID)
// Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
refreshFailed = true
} else if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1)
refreshFailed = true
} else {
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
p.metrics.refreshFailure.Add(1)
refreshFailed = true
} else {
p.metrics.refreshSuccess.Add(1)
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
slog.Warn("openai_token_lock_failed", "account_id", account.ID, "error", lockErr)
} else {
// 锁被其他 worker 持有:使用短轮询+jitter降低固定等待导致的尾延迟台阶。
p.metrics.lockContention.Add(1)
p.metrics.touchNow()
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
@@ -260,22 +223,23 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
// 3) Populate cache with TTL.
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("openai_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetOpenAIAccessToken()
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
if p.refreshPolicy.FailureTTL > 0 {
ttl = p.refreshPolicy.FailureTTL
} else {
ttl = time.Minute
}
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)

View File

@@ -3922,6 +3922,8 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri
return "ws_unsupported", true
case "websocket_connection_limit_reached":
return "ws_connection_limit_reached", true
case "invalid_encrypted_content":
return "invalid_encrypted_content", true
case "previous_response_not_found":
return "previous_response_not_found", true
}
@@ -3940,6 +3942,10 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri
if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") {
return "ws_connection_limit_reached", true
}
if strings.Contains(msg, "invalid_encrypted_content") ||
(strings.Contains(msg, "encrypted content") && strings.Contains(msg, "could not be verified")) {
return "invalid_encrypted_content", true
}
if strings.Contains(msg, "previous_response_not_found") ||
(strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) {
return "previous_response_not_found", true
@@ -3964,6 +3970,7 @@ func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
case strings.Contains(errType, "invalid_request"),
strings.Contains(code, "invalid_request"),
strings.Contains(code, "bad_request"),
code == "invalid_encrypted_content",
code == "previous_response_not_found":
return http.StatusBadRequest
case strings.Contains(errType, "authentication"),

View File

@@ -1,6 +1,7 @@
package service
import (
"bytes"
"context"
"encoding/json"
"io"
@@ -19,6 +20,47 @@ import (
"github.com/tidwall/gjson"
)
type httpUpstreamSequenceRecorder struct {
mu sync.Mutex
bodies [][]byte
reqs []*http.Request
responses []*http.Response
errs []error
callCount int
}
func (u *httpUpstreamSequenceRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
u.mu.Lock()
defer u.mu.Unlock()
idx := u.callCount
u.callCount++
u.reqs = append(u.reqs, req)
if req != nil && req.Body != nil {
b, _ := io.ReadAll(req.Body)
u.bodies = append(u.bodies, b)
_ = req.Body.Close()
req.Body = io.NopCloser(bytes.NewReader(b))
} else {
u.bodies = append(u.bodies, nil)
}
if idx < len(u.errs) && u.errs[idx] != nil {
return nil, u.errs[idx]
}
if idx < len(u.responses) {
return u.responses[idx], nil
}
if len(u.responses) == 0 {
return nil, nil
}
return u.responses[len(u.responses)-1], nil
}
func (u *httpUpstreamSequenceRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return u.Do(req, proxyURL, accountID, accountConcurrency)
}
func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t *testing.T) {
gin.SetMode(gin.TestMode)
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -143,6 +185,176 @@ func TestOpenAIGatewayService_Forward_HTTPIngressStaysHTTPWhenWSEnabled(t *testi
require.Equal(t, "client_protocol_http", reason)
}
func TestOpenAIGatewayService_Forward_HTTPIngressRetriesInvalidEncryptedContentOnce(t *testing.T) {
gin.SetMode(gin.TestMode)
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer wsFallbackServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
upstream := &httpUpstreamSequenceRecorder{
responses: []*http.Response{
{
StatusCode: http.StatusBadRequest,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"error":{"code":"invalid_encrypted_content","type":"invalid_request_error","message":"The encrypted content could not be verified."}}`,
)),
},
{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"id":"resp_http_retry_ok","usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
)),
},
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 102,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsFallbackServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_retry","input":[{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me"}]},{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发")
require.Equal(t, 2, upstream.callCount, "命中 invalid_encrypted_content 后应只在 HTTP 路径重试一次")
require.Len(t, upstream.bodies, 2)
firstBody := upstream.bodies[0]
secondBody := upstream.bodies[1]
require.False(t, gjson.GetBytes(firstBody, "previous_response_id").Exists(), "HTTP 首次请求仍应沿用原逻辑移除 previous_response_id")
require.True(t, gjson.GetBytes(firstBody, "input.0.encrypted_content").Exists(), "首次请求不应做发送前预清理")
require.Equal(t, "keep me", gjson.GetBytes(firstBody, "input.0.summary.0.text").String())
require.False(t, gjson.GetBytes(secondBody, "previous_response_id").Exists(), "HTTP 精确重试不应重新带回 previous_response_id")
require.False(t, gjson.GetBytes(secondBody, "input.0.encrypted_content").Exists(), "精确重试应移除 reasoning.encrypted_content")
require.Equal(t, "keep me", gjson.GetBytes(secondBody, "input.0.summary.0.text").String(), "精确重试应保留有效 reasoning summary")
require.Equal(t, "input_text", gjson.GetBytes(secondBody, "input.1.type").String(), "非 reasoning input 应保持原样")
decision, _ := c.Get("openai_ws_transport_decision")
reason, _ := c.Get("openai_ws_transport_reason")
require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision)
require.Equal(t, "client_protocol_http", reason)
}
func TestOpenAIGatewayService_Forward_HTTPIngressRetriesWrappedInvalidEncryptedContentOnce(t *testing.T) {
gin.SetMode(gin.TestMode)
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer wsFallbackServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
upstream := &httpUpstreamSequenceRecorder{
responses: []*http.Response{
{
StatusCode: http.StatusBadRequest,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"error":{"code":null,"message":"{\"error\":{\"message\":\"The encrypted content could not be verified.\",\"type\":\"invalid_request_error\",\"param\":null,\"code\":\"invalid_encrypted_content\"}}traceid: fb7ad1dbc7699c18f8a02f258f1af5ab","param":null,"type":"invalid_request_error"}}`,
)),
},
{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"x-request-id": []string{"req_http_retry_wrapped_ok"},
},
Body: io.NopCloser(strings.NewReader(
`{"id":"resp_http_retry_wrapped_ok","usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
)),
},
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 103,
Name: "openai-apikey-wrapped",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsFallbackServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_retry_wrapped","input":[{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me too"}]},{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发")
require.Equal(t, 2, upstream.callCount, "wrapped invalid_encrypted_content 也应只在 HTTP 路径重试一次")
require.Len(t, upstream.bodies, 2)
firstBody := upstream.bodies[0]
secondBody := upstream.bodies[1]
require.True(t, gjson.GetBytes(firstBody, "input.0.encrypted_content").Exists(), "首次请求不应做发送前预清理")
require.False(t, gjson.GetBytes(secondBody, "input.0.encrypted_content").Exists(), "wrapped exact retry 应移除 reasoning.encrypted_content")
require.Equal(t, "keep me too", gjson.GetBytes(secondBody, "input.0.summary.0.text").String(), "wrapped exact retry 应保留有效 reasoning summary")
decision, _ := c.Get("openai_ws_transport_decision")
reason, _ := c.Get("openai_ws_transport_reason")
require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision)
require.Equal(t, "client_protocol_http", reason)
}
func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -1218,3 +1430,460 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOn
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应包含 previous_response_id")
require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
}
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentRecoversOnce(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempt := wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
if attempt == 1 {
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "invalid_encrypted_content",
"type": "invalid_request_error",
"message": "The encrypted content could not be verified.",
},
})
return
}
_ = conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": "resp_ws_invalid_encrypted_content_recover_ok",
"model": "gpt-5.3-codex",
"usage": map[string]any{
"input_tokens": 1,
"output_tokens": 1,
"input_tokens_details": map[string]any{
"cached_tokens": 0,
},
},
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 95,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_encrypted","input":[{"type":"reasoning","encrypted_content":"gAAA"},{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_ws_invalid_encrypted_content_recover_ok", result.RequestID)
require.Nil(t, upstream.lastReq, "invalid_encrypted_content 不应回退 HTTP")
require.Equal(t, int32(2), wsAttempts.Load(), "invalid_encrypted_content 应触发一次清洗后重试")
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "resp_ws_invalid_encrypted_content_recover_ok", gjson.Get(rec.Body.String(), "id").String())
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 2)
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id")
require.True(t, gjson.GetBytes(requests[0], `input.0.encrypted_content`).Exists(), "首轮请求应保留 encrypted reasoning")
require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
require.False(t, gjson.GetBytes(requests[1], `input.0.encrypted_content`).Exists(), "恢复重试应移除 encrypted reasoning item")
require.Equal(t, "input_text", gjson.GetBytes(requests[1], `input.0.type`).String())
}
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentSkipsRecoveryWithoutReasoningItem(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "invalid_encrypted_content",
"type": "invalid_request_error",
"message": "The encrypted content could not be verified.",
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 96,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_encrypted","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "invalid_encrypted_content 不应回退 HTTP")
require.Equal(t, int32(1), wsAttempts.Load(), "缺少 reasoning encrypted item 时应跳过自动恢复重试")
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, strings.ToLower(rec.Body.String()), "encrypted content")
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 1)
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists())
require.False(t, gjson.GetBytes(requests[0], `input.0.encrypted_content`).Exists())
}
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentRecoversSingleObjectInputAndKeepsSummary(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempt := wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
if attempt == 1 {
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "invalid_encrypted_content",
"type": "invalid_request_error",
"message": "The encrypted content could not be verified.",
},
})
return
}
_ = conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": "resp_ws_invalid_encrypted_content_object_ok",
"model": "gpt-5.3-codex",
"usage": map[string]any{
"input_tokens": 1,
"output_tokens": 1,
"input_tokens_details": map[string]any{
"cached_tokens": 0,
},
},
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 97,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_encrypted","input":{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me"}]}}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_ws_invalid_encrypted_content_object_ok", result.RequestID)
require.Nil(t, upstream.lastReq, "invalid_encrypted_content 单对象 input 不应回退 HTTP")
require.Equal(t, int32(2), wsAttempts.Load(), "单对象 reasoning input 也应触发一次清洗后重试")
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 2)
require.True(t, gjson.GetBytes(requests[0], `input.encrypted_content`).Exists(), "首轮单对象应保留 encrypted_content")
require.True(t, gjson.GetBytes(requests[1], `input.summary.0.text`).Exists(), "恢复重试应保留 reasoning summary")
require.False(t, gjson.GetBytes(requests[1], `input.encrypted_content`).Exists(), "恢复重试只应移除 encrypted_content")
require.Equal(t, "reasoning", gjson.GetBytes(requests[1], `input.type`).String())
require.False(t, gjson.GetBytes(requests[1], `previous_response_id`).Exists(), "恢复重试应移除 previous_response_id")
}
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentKeepsPreviousResponseIDForFunctionCallOutput(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempt := wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
if attempt == 1 {
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "invalid_encrypted_content",
"type": "invalid_request_error",
"message": "The encrypted content could not be verified.",
},
})
return
}
_ = conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": "resp_ws_invalid_encrypted_content_function_call_output_ok",
"model": "gpt-5.3-codex",
"usage": map[string]any{
"input_tokens": 1,
"output_tokens": 1,
"input_tokens_details": map[string]any{
"cached_tokens": 0,
},
},
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 98,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_function_call","input":[{"type":"reasoning","encrypted_content":"gAAA"},{"type":"function_call_output","call_id":"call_123","output":"ok"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_ws_invalid_encrypted_content_function_call_output_ok", result.RequestID)
require.Nil(t, upstream.lastReq, "function_call_output + invalid_encrypted_content 不应回退 HTTP")
require.Equal(t, int32(2), wsAttempts.Load(), "应只做一次保锚点的清洗后重试")
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 2)
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id")
require.True(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "function_call_output 恢复重试不应移除 previous_response_id")
require.False(t, gjson.GetBytes(requests[1], `input.0.encrypted_content`).Exists(), "恢复重试应移除 reasoning encrypted_content")
require.Equal(t, "function_call_output", gjson.GetBytes(requests[1], `input.0.type`).String(), "清洗后应保留 function_call_output 作为首个输入项")
require.Equal(t, "call_123", gjson.GetBytes(requests[1], `input.0.call_id`).String())
require.Equal(t, "ok", gjson.GetBytes(requests[1], `input.0.output`).String())
require.Equal(t, "resp_prev_function_call", gjson.GetBytes(requests[1], "previous_response_id").String())
}

View File

@@ -23,7 +23,7 @@ const (
opsAggDailyInterval = 1 * time.Hour
// Keep in sync with ops retention target (vNext default 30d).
opsAggBackfillWindow = 30 * 24 * time.Hour
opsAggBackfillWindow = 1 * time.Hour
// Recompute overlap to absorb late-arriving rows near boundaries.
opsAggHourlyOverlap = 2 * time.Hour
@@ -36,7 +36,7 @@ const (
// that may still receive late inserts.
opsAggSafeDelay = 5 * time.Minute
opsAggMaxQueryTimeout = 3 * time.Second
opsAggMaxQueryTimeout = 5 * time.Second
opsAggHourlyTimeout = 5 * time.Minute
opsAggDailyTimeout = 2 * time.Minute

View File

@@ -467,7 +467,7 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: selErr.Error()}
}
if selection == nil || selection.Account == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "no available accounts"}
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: ErrNoAvailableAccounts.Error()}
}
account := selection.Account

View File

@@ -368,11 +368,14 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
Aggregation: OpsAggregationSettings{
AggregationEnabled: false,
},
IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
AutoRefreshEnabled: false,
AutoRefreshIntervalSec: 30,
IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
IgnoreInsufficientBalanceErrors: false, // 默认不忽略,余额不足可能需要关注
DisplayOpenAITokenStats: false,
DisplayAlertEvents: true,
AutoRefreshEnabled: false,
AutoRefreshIntervalSec: 30,
}
}
@@ -438,7 +441,7 @@ func (s *OpsService) GetOpsAdvancedSettings(ctx context.Context) (*OpsAdvancedSe
return nil, err
}
cfg := &OpsAdvancedSettings{}
cfg := defaultOpsAdvancedSettings()
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
return defaultCfg, nil
}

View File

@@ -0,0 +1,97 @@
package service
import (
"context"
"encoding/json"
"testing"
)
func TestGetOpsAdvancedSettings_DefaultHidesOpenAITokenStats(t *testing.T) {
repo := newRuntimeSettingRepoStub()
svc := &OpsService{settingRepo: repo}
cfg, err := svc.GetOpsAdvancedSettings(context.Background())
if err != nil {
t.Fatalf("GetOpsAdvancedSettings() error = %v", err)
}
if cfg.DisplayOpenAITokenStats {
t.Fatalf("DisplayOpenAITokenStats = true, want false by default")
}
if !cfg.DisplayAlertEvents {
t.Fatalf("DisplayAlertEvents = false, want true by default")
}
if repo.setCalls != 1 {
t.Fatalf("expected defaults to be persisted once, got %d", repo.setCalls)
}
}
func TestUpdateOpsAdvancedSettings_PersistsOpenAITokenStatsVisibility(t *testing.T) {
repo := newRuntimeSettingRepoStub()
svc := &OpsService{settingRepo: repo}
cfg := defaultOpsAdvancedSettings()
cfg.DisplayOpenAITokenStats = true
cfg.DisplayAlertEvents = false
updated, err := svc.UpdateOpsAdvancedSettings(context.Background(), cfg)
if err != nil {
t.Fatalf("UpdateOpsAdvancedSettings() error = %v", err)
}
if !updated.DisplayOpenAITokenStats {
t.Fatalf("DisplayOpenAITokenStats = false, want true")
}
if updated.DisplayAlertEvents {
t.Fatalf("DisplayAlertEvents = true, want false")
}
reloaded, err := svc.GetOpsAdvancedSettings(context.Background())
if err != nil {
t.Fatalf("GetOpsAdvancedSettings() after update error = %v", err)
}
if !reloaded.DisplayOpenAITokenStats {
t.Fatalf("reloaded DisplayOpenAITokenStats = false, want true")
}
if reloaded.DisplayAlertEvents {
t.Fatalf("reloaded DisplayAlertEvents = true, want false")
}
}
func TestGetOpsAdvancedSettings_BackfillsNewDisplayFlagsFromDefaults(t *testing.T) {
repo := newRuntimeSettingRepoStub()
svc := &OpsService{settingRepo: repo}
legacyCfg := map[string]any{
"data_retention": map[string]any{
"cleanup_enabled": false,
"cleanup_schedule": "0 2 * * *",
"error_log_retention_days": 30,
"minute_metrics_retention_days": 30,
"hourly_metrics_retention_days": 30,
},
"aggregation": map[string]any{
"aggregation_enabled": false,
},
"ignore_count_tokens_errors": true,
"ignore_context_canceled": true,
"ignore_no_available_accounts": false,
"ignore_invalid_api_key_errors": false,
"auto_refresh_enabled": false,
"auto_refresh_interval_seconds": 30,
}
raw, err := json.Marshal(legacyCfg)
if err != nil {
t.Fatalf("marshal legacy config: %v", err)
}
repo.values[SettingKeyOpsAdvancedSettings] = string(raw)
cfg, err := svc.GetOpsAdvancedSettings(context.Background())
if err != nil {
t.Fatalf("GetOpsAdvancedSettings() error = %v", err)
}
if cfg.DisplayOpenAITokenStats {
t.Fatalf("DisplayOpenAITokenStats = true, want false default backfill")
}
if !cfg.DisplayAlertEvents {
t.Fatalf("DisplayAlertEvents = false, want true default backfill")
}
}

View File

@@ -92,14 +92,17 @@ type OpsAlertRuntimeSettings struct {
// OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation).
type OpsAdvancedSettings struct {
DataRetention OpsDataRetentionSettings `json:"data_retention"`
Aggregation OpsAggregationSettings `json:"aggregation"`
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
DataRetention OpsDataRetentionSettings `json:"data_retention"`
Aggregation OpsAggregationSettings `json:"aggregation"`
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
IgnoreInsufficientBalanceErrors bool `json:"ignore_insufficient_balance_errors"`
DisplayOpenAITokenStats bool `json:"display_openai_token_stats"`
DisplayAlertEvents bool `json:"display_alert_events"`
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
}
type OpsDataRetentionSettings struct {

View File

@@ -149,8 +149,9 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
}
// 其他 400 错误(如参数问题)不处理,不禁用账号
case 401:
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
if account.Type == AccountTypeOAuth {
// OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为。
// Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制。
if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity {
// 1. 失效缓存
if s.tokenCacheInvalidator != nil {
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
@@ -182,7 +183,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
}
shouldDisable = true
} else {
// 非 OAuth 账号APIKey:保持原有 SetError 行为
// 非 OAuth / Antigravity OAuth:保持 SetError 行为
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg
@@ -199,11 +200,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
s.handleAuthError(ctx, account, msg)
shouldDisable = true
case 403:
// 禁止访问:停止调度,记录错误
msg := "Access forbidden (403): account may be suspended or lack permissions"
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
}
logger.LegacyPrintf(
"service.ratelimit",
"[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s",
@@ -215,8 +211,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
upstreamMsg,
truncateForLog(responseBody, 1024),
)
s.handleAuthError(ctx, account, msg)
shouldDisable = true
shouldDisable = s.handle403(ctx, account, upstreamMsg, responseBody)
case 429:
s.handle429(ctx, account, headers, responseBody)
shouldDisable = false
@@ -621,6 +616,62 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
}
// handle403 处理 403 Forbidden 错误
// Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用;
// 其他平台保持原有 SetError 行为。
func (s *RateLimitService) handle403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
if account.Platform == PlatformAntigravity {
return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody)
}
// 非 Antigravity 平台:保持原有行为
msg := "Access forbidden (403): account may be suspended or lack permissions"
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
return true
}
// handleAntigravity403 处理 Antigravity 平台的 403 错误
// validation需要验证→ 永久 SetError需人工去 Google 验证后恢复)
// violation违规封号→ 永久 SetError需人工处理
// generic通用禁止→ 永久 SetError
func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
fbType := classifyForbiddenType(string(responseBody))
switch fbType {
case forbiddenTypeValidation:
// VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复
msg := "Validation required (403): account needs Google verification"
if upstreamMsg != "" {
msg = "Validation required (403): " + upstreamMsg
}
if validationURL := extractValidationURL(string(responseBody)); validationURL != "" {
msg += " | validation_url: " + validationURL
}
s.handleAuthError(ctx, account, msg)
return true
case forbiddenTypeViolation:
// 违规封号: 永久禁用,需人工处理
msg := "Account violation (403): terms of service violation"
if upstreamMsg != "" {
msg = "Account violation (403): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
return true
default:
// 通用 403: 保持原有行为
msg := "Access forbidden (403): account may be suspended or lack permissions"
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
return true
}
}
// handleCustomErrorCode 处理自定义错误码,停止账号调度
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
@@ -1123,7 +1174,8 @@ func hasRecoverableRuntimeState(account *Account) bool {
if len(account.Extra) == 0 {
return false
}
return hasNonEmptyMapValue(account.Extra, "model_rate_limits") || hasNonEmptyMapValue(account.Extra, "antigravity_quota_scopes")
return hasNonEmptyMapValue(account.Extra, "model_rate_limits") ||
hasNonEmptyMapValue(account.Extra, "antigravity_quota_scopes")
}
func hasNonEmptyMapValue(extra map[string]any, key string) bool {
@@ -1213,7 +1265,8 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac
}
// 401 首次命中可临时不可调度(给 token 刷新窗口);
// 若历史上已因 401 进入过临时不可调度,则本次应升级为 error返回 false 交由默认错误逻辑处理)。
if statusCode == http.StatusUnauthorized {
// Antigravity 跳过:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制,无需升级逻辑。
if statusCode == http.StatusUnauthorized && account.Platform != PlatformAntigravity {
reason := account.TempUnschedulableReason
// 缓存可能没有 reason从 DB 回退读取
if reason == "" {

View File

@@ -27,34 +27,68 @@ func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, e
func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
// Scenario: cache account has empty TempUnschedulableReason (cache miss),
// but DB account has a previous 401 record → should escalate to ErrorPolicyNone.
repo := &dbFallbackRepoStub{
dbAccount: &Account{
ID: 20,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
},
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
// but DB account has a previous 401 record.
// Non-Antigravity: should escalate to ErrorPolicyNone (second 401 = permanent error).
// Antigravity: skips escalation logic (401 handled by applyErrorPolicy rules).
t.Run("gemini_escalates", func(t *testing.T) {
repo := &dbFallbackRepoStub{
dbAccount: &Account{
ID: 20,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
},
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 20,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: "", // cache miss — reason is empty
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
account := &Account{
ID: 20,
Type: AccountTypeOAuth,
Platform: PlatformGemini,
TempUnschedulableReason: "",
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
},
}
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone")
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyNone, result, "gemini 401 with DB fallback showing previous 401 should escalate")
})
t.Run("antigravity_stays_temp", func(t *testing.T) {
repo := &dbFallbackRepoStub{
dbAccount: &Account{
ID: 20,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
},
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 20,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: "",
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyTempUnscheduled, result, "antigravity 401 skips escalation, stays temp-unscheduled")
})
}
func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) {

View File

@@ -42,45 +42,56 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc
}
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
tests := []struct {
name string
platform string
}{
{name: "gemini", platform: PlatformGemini},
{name: "antigravity", platform: PlatformAntigravity},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: tt.platform,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": 401,
"keywords": []any{"unauthorized"},
"duration_minutes": 30,
"description": "custom rule",
},
t.Run("gemini", func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": 401,
"keywords": []any{"unauthorized"},
"duration_minutes": 30,
"description": "custom rule",
},
},
}
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Len(t, invalidator.accounts, 1)
})
}
require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Len(t, invalidator.accounts, 1)
})
t.Run("antigravity_401_uses_SetError", func(t *testing.T) {
// Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制,
// HandleUpstreamError 中走 SetError 路径。
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Empty(t, invalidator.accounts)
})
}
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {

Some files were not shown because too many files have changed in this diff Show More