Compare commits

..

204 Commits
v0.1.97 ... dev

Author SHA1 Message Date
Wesley Liddick
04b4d7de3a Merge pull request #978 from touwaeriol/worktree-fix/group
fix: filter soft-deleted users in group rate multiplier query
2026-03-13 22:15:17 +08:00
erio
b11cfdc11f fix: 分组专属倍率查询过滤软删除用户
GetByGroupID 查询 JOIN users 表时未过滤 deleted_at,
导致已删除的用户仍显示在分组专属倍率列表中。
2026-03-13 19:16:07 +08:00
erio
cdddbec629 chore: bump version to 0.1.96.1 2026-03-12 18:13:28 +08:00
erio
b3fe0506fb Merge branch 'release/custom-0.1.95' into release/custom-0.1.96 2026-03-12 18:12:47 +08:00
erio
8ca0e2772e feat: 分组管理页面新增专属倍率管理
- 后端新增 GET /admin/groups/:id/rate-multipliers API
- 前端新增 GroupRateMultipliersModal 组件,支持查看/添加/修改/删除用户专属倍率
- 分组列表操作列新增"专属倍率"按钮
- 修复 antigravity_gateway_service_test.go 参数不匹配的预存问题
2026-03-12 18:01:32 +08:00
erio
5f6e929d61 chore: bump version to 0.1.95.1 2026-03-11 03:24:05 +08:00
erio
4caa3c2701 merge: integrate upstream v0.1.95 with our customizations
Merge main (custom features) into release/custom-0.1.95 (upstream v0.1.95).
New upstream features: group subscription binding, multi-dimension quota (daily/weekly/total),
allow_messages_dispatch, default_mapped_model, recover state API.
Our customizations: simulate_claude_max_enabled, usage status detection, 403 validation handling.
2026-03-11 03:23:44 +08:00
erio
a9ecd7bcc6 chore: bump version to 0.1.91.1 2026-03-06 04:08:37 +08:00
erio
f89465fb39 Merge branch 'main' into release/custom-0.1.91
# Conflicts:
#	frontend/src/components/admin/account/AccountActionMenu.vue
#	frontend/src/views/admin/AccountsView.vue
2026-03-06 04:08:14 +08:00
erio
440c3f46a7 feat: add independent load_factor field for scheduling load calculation
- Separate load factor from concurrency: concurrency controls actual
  slot acquisition, load_factor controls load rate calculation
- Add EffectiveLoadFactor() method: LoadFactor > Concurrency > 1
- Add load_factor field to Create/Edit/BulkEdit account forms
- Fix RPM default value: auto-fill 15 when RPM enabled but not set
- Fix stale test compilation errors in server and handler packages
2026-03-06 03:42:24 +08:00
erio
c746964936 chore: bump version to 0.1.90.8 2026-03-06 01:16:40 +08:00
erio
b2d6879b3f refactor: unify post-usage billing logic and fix account quota calculation
- Extract postUsageBilling() to consolidate billing logic across
  GatewayService.RecordUsage, RecordUsageWithLongContext, and
  OpenAIGatewayService.RecordUsage, eliminating ~120 lines of
  duplicated code
- Fix account quota to use TotalCost × accountRateMultiplier
  (was using raw TotalCost, inconsistent with account cost stats)
- Fix RecordUsageWithLongContext API Key quota only updating in
  balance mode (now updates regardless of billing type)
- Fix WebSocket client disconnect detection on Windows by adding
  "an established connection was aborted" to known disconnect errors
2026-03-06 00:54:48 +08:00
erio
ce4095904e chore: bump version to 0.1.90.7 2026-03-05 22:56:20 +08:00
erio
0d2dcbff11 fix: route antigravity apikey account test to native protocol
Antigravity APIKey accounts were incorrectly routed to
testAntigravityAccountConnection which calls AntigravityTokenProvider,
but the token provider only handles OAuth and Upstream types, causing
"not an antigravity oauth account" error.

Extract routeAntigravityTest to route APIKey accounts to native
Claude/Gemini test paths based on model prefix, matching the
gateway_handler routing logic for normal requests.
2026-03-05 22:44:21 +08:00
erio
02dca78dbe refactor: extract QuotaLimitCard component for reuse in create and edit modals
- Extract quota limit card/toggle UI into QuotaLimitCard.vue component
- Use v-model pattern for clean parent-child data flow
- Integrate into both EditAccountModal and CreateAccountModal
- All apikey accounts (all platforms) now support quota limit on creation
- Bump version to 0.1.90.6
2026-03-05 22:13:56 +08:00
erio
8df299b767 feat: restyle API Key quota limit UI to card/toggle format
- Redesign quota limit section with card layout and toggle switch
- Add watch to clear quota value when toggle is disabled
- Add i18n keys for toggle labels and hints (zh/en)
- Bump version to 0.1.90.5
2026-03-05 22:01:40 +08:00
erio
95cf59b2f6 feat: add quota limit for API key accounts
- Add configurable spending limit (quota_limit) for apikey-type accounts
- Atomic quota accumulation via PostgreSQL JSONB operations on TotalCost
- Scheduler filters out over-quota accounts with outbox-triggered snapshot refresh
- Display quota usage ($used / $limit) in account capacity column
- Add "Reset Quota" action in account menu to reset usage to zero
- Editing account settings preserves quota_used (no accidental reset)
- Covers all 3 billing paths: Anthropic, Gemini, OpenAI RecordUsage

chore: bump version to 0.1.90.4
2026-03-05 21:48:37 +08:00
erio
a0f643da0e fix: keep 500 status for upstream errors, pass details in response body
Upstream errors like 401/429 must not be passed through as HTTP status
codes because the frontend routes on status (401 triggers JWT logout).
Keep HTTP 500 but include upstream error details in the message.
2026-03-05 21:03:08 +08:00
erio
99331a5285 fix: throttle Anthropic usage queries and pass through upstream HTTP errors
- Frontend: queue Anthropic OAuth/setup-token usage requests by proxy
  with random 1-1.5s interval to prevent upstream 429
- Backend: return ApplicationError with actual upstream status code
  instead of wrapping all errors as 500
- Handle component unmount to skip stale updates on page navigation
2026-03-05 19:12:49 +08:00
erio
9245e197a8 chore: bump version to 0.1.90.3 2026-03-04 21:50:33 +08:00
erio
5fa1d0978f fix: remove lite=1 default from account list to restore concurrency display
Upstream v0.1.90 added lite=1 as default parameter for the account list
API, which causes the backend to skip all runtime queries (concurrency,
session count, RPM, window cost). This resulted in all accounts showing
0 concurrency in the admin UI.
2026-03-04 21:50:19 +08:00
erio
0bb683a6f1 chore: bump version to 0.1.90.2 2026-03-04 21:33:39 +08:00
erio
cb3958bac3 fix: use detached context for concurrency batch query to prevent all-zero display
The upstream v0.1.90 changed GetAccountConcurrencyBatch from individual
Lua script calls (which swallowed per-account errors) to a Redis pipeline
approach that propagates errors from rdb.Time() or pipe.Exec(). When the
HTTP request context is cancelled (e.g., browser abort), the entire batch
fails and the handler silently shows all concurrency as 0.

Fix: use context.WithTimeout(context.Background(), 3s) for the Redis call
so HTTP request cancellation doesn't affect the read-only concurrency query.
2026-03-04 21:33:18 +08:00
erio
91ebf95efa fix: add dark theme support for "open in new tab" button
The backdrop-blur background was hardcoded to white, making the button
invisible/mismatched in dark theme. Added dark:bg-dark-800/80 variant.
2026-03-04 21:31:15 +08:00
erio
7895dd65b2 fix: remove duplicate groupExistenceBatchReader declaration in admin_service.go 2026-03-04 20:24:12 +08:00
erio
09f8894906 fix: update go.sum with missing tiktoken-go dependencies 2026-03-04 20:21:09 +08:00
erio
45fdef9da7 fix: remove duplicate IdempotencyRecords declarations in ent/migrate/schema.go 2026-03-04 20:20:40 +08:00
erio
704f256bd7 docs: update CLAUDE.md - migrate build to production server with limited-builder 2026-03-04 20:14:20 +08:00
erio
a6026e7ac4 Merge tag 'v0.1.90' into merge/upstream-v0.1.90
注册邮箱域名白名单策略上线,后台大数据场景性能大幅优化。

- 注册邮箱域名白名单:支持管理员配置允许注册的邮箱域名策略
- Keys 页面表单筛选:用户 /keys 页面支持按条件筛选 API Key
- Settings 页面分 Tab 拆分:管理后台设置页面按功能模块分 Tab 展示

- 后台大数据场景加载性能优化:仪表盘/用户/账号/Ops 页面大数据集加载显著提速
- Usage 大表分页优化:默认避免全量 COUNT(*),大幅降低分页查询耗时
- 消除重复的 normalizeAccountIDList,补充新增组件的单元测试
- 清理无用文件和过时文档,精简项目结构
- EmailVerifyView 硬编码英文字符串替换为 i18n 调用

- 修复 Anthropic 平台无限流重置时间的 429 误标记账号限流问题
- 修复自定义菜单页面管理员视角菜单不生效问题
- 修复 Ops 错误详情弹窗未展示真实上游 payload 的问题
- 修复充值/订阅菜单 icon 显示问题

# Conflicts:
#	.gitignore
#	backend/cmd/server/VERSION
#	backend/ent/group.go
#	backend/ent/runtime/runtime.go
#	backend/ent/schema/group.go
#	backend/go.sum
#	backend/internal/handler/admin/account_handler.go
#	backend/internal/handler/admin/dashboard_handler.go
#	backend/internal/pkg/usagestats/usage_log_types.go
#	backend/internal/repository/group_repo.go
#	backend/internal/repository/usage_log_repo.go
#	backend/internal/server/middleware/security_headers.go
#	backend/internal/server/router.go
#	backend/internal/service/account_usage_service.go
#	backend/internal/service/admin_service_bulk_update_test.go
#	backend/internal/service/dashboard_service.go
#	backend/internal/service/gateway_service.go
#	frontend/src/api/admin/dashboard.ts
#	frontend/src/components/account/BulkEditAccountModal.vue
#	frontend/src/components/charts/GroupDistributionChart.vue
#	frontend/src/components/layout/AppSidebar.vue
#	frontend/src/i18n/locales/en.ts
#	frontend/src/i18n/locales/zh.ts
#	frontend/src/views/admin/GroupsView.vue
#	frontend/src/views/admin/SettingsView.vue
#	frontend/src/views/admin/UsageView.vue
#	frontend/src/views/user/PurchaseSubscriptionView.vue
2026-03-04 19:58:38 +08:00
erio
1e03b2974a chore: bump version to 0.1.87.18 2026-03-02 01:23:12 +08:00
erio
daa7c783b9 fix(csp): add timeout ctx, preserve cache on error, validate scheme in extractOrigin 2026-03-02 01:15:29 +08:00
erio
8a82a2a648 feat(csp): auto-inject purchase_subscription_url origin into frame-src 2026-03-02 00:19:25 +08:00
erio
c3ac68af2a chore: bump version to 0.1.87.16 2026-03-01 20:14:48 +08:00
erio
ec3897b981 docs: add PR description format to CLAUDE.md and sync to AGENTS.md 2026-03-01 19:59:03 +08:00
erio
62486cee37 chore: bump version to 0.1.87.15 2026-03-01 19:49:12 +08:00
erio
7c5746ffbc feat(usage): add group usage distribution chart alongside model distribution
- Add GroupStat type to usagestats package
- Add GetGroupStatsWithFilters to UsageLogRepository interface and implement with LEFT JOIN groups
- Add GetGroupStats dashboard API endpoint (GET /admin/dashboard/groups)
- Add GroupDistributionChart.vue component mirroring ModelDistributionChart
- Rearrange UsageView layout: model + group in one row, token trend full-width below
- All filters (user, api_key, account, group, model, date range) apply to group stats
2026-03-01 19:49:01 +08:00
erio
47f7b0213b chore: bump version to 0.1.87.14 2026-03-01 18:16:41 +08:00
erio
8df42f7aab refactor(docs): move integration doc to docs/ and add download link in settings
- Move ADMIN_PAYMENT_INTEGRATION_API.md → docs/ADMIN_PAYMENT_INTEGRATION_API.md
- Update README.md reference path
- Add payment integration doc download link in admin settings UI (Purchase section)
- Add i18n keys: integrationDoc / integrationDocHint (zh + en)
2026-03-01 18:14:43 +08:00
erio
d666e05a6d refactor(purchase): use URL/searchParams only for purchase query merge 2026-03-01 18:14:43 +08:00
erio
c37edf2de5 docs+ui: add bilingual payment integration doc and rename purchase entry to recharge/subscription 2026-03-01 18:14:43 +08:00
erio
bb9af2465e feat(frontend): append purchase query params and make integration doc bilingual 2026-03-01 18:14:14 +08:00
erio
7af00864b3 feat(admin): add create-and-redeem API and payment integration docs 2026-03-01 18:13:43 +08:00
erio
81903e87e3 chore: bump version to 0.1.87.13 2026-03-01 16:37:12 +08:00
erio
0b96c7a65e fix(auth): replace submit turnstile widget with VerifyTurnstileForRegister
Port upstream's VerifyTurnstileForRegister which skips the duplicate
Turnstile check when email verify flow is already completed, instead of
requiring a second Turnstile widget on the verify page.
2026-03-01 16:34:14 +08:00
erio
34ccfe45ea chore: bump version to 0.1.87.12 2026-03-01 15:59:29 +08:00
erio
d4231150a9 fix: add groupExistenceBatchReader interface and update test stub for release branch compatibility 2026-03-01 15:53:45 +08:00
erio
2268e93aec fix: remove unused preload/snapshot functions and fix gofmt 2026-03-01 15:49:26 +08:00
erio
c976916b6d fix: remove dead code in BulkUpdateAccounts group binding loop 2026-03-01 15:49:15 +08:00
erio
0bbd003a18 fix: use i18n for mixed-channel warning messages and improve bulk pre-check
- BulkUpdate handler: add structured details to 409 response
- BulkUpdateAccounts: simplify to global pre-check before any DB write;
  remove per-account snapshot tracking which is no longer needed
- MixedChannelError.Error(): restore English message for API compatibility
- BulkEditAccountModal: use t() with details for both pre-check and 409
  fallback paths instead of displaying raw backend strings
- Update test to verify pre-check blocks on existing group conflicts
2026-03-01 15:49:05 +08:00
erio
336a844712 chore: bump version to 0.1.87.11 2026-03-01 13:02:24 +08:00
erio
51e7f262bd fix(auth): add submit Turnstile widget in email verify flow
- 邮箱验证码流程中,提交注册前要求重新完成 Turnstile 验证
- 修复发送验证码后 token 被清空导致注册时缺少 turnstile_token 的问题
- submit/resend 两个 widget 通过 showResendTurnstile 互斥显示
2026-03-01 13:02:12 +08:00
erio
99663a3f20 fix: update mixed channel warning message 2026-03-01 12:57:56 +08:00
erio
4d88248091 chore: bump version to 0.1.87.10 2026-03-01 12:46:42 +08:00
erio
c143367e15 fix: handle mixed channel warning for multi-platform bulk edit
Previously, preCheckMixedChannelRisk() skipped when selectedPlatforms
had more than one entry, and the catch block in submitBulkUpdate had no
409 handling — so multi-platform conflicts just showed a generic error.

- Rename canPreCheck(): only call pre-check API for single-platform
  antigravity/anthropic selections (API requires a single platform param)
- Pass `built` into preCheckMixedChannelRisk() so pendingUpdatesForConfirm
  is set before returning false
- submitBulkUpdate: add 409 mixed_channel_warning catch as fallback for
  multi-platform case, saving baseUpdates for retry
- Remove needsMixedChannelCheck() gate on confirm_mixed_channel_risk flag;
  use mixedChannelConfirmed alone so multi-platform retry also works
2026-03-01 12:46:31 +08:00
erio
ca44389baa fix: bulk edit mixed channel warning not showing confirmation dialog
The response interceptor in client.ts transforms errors into plain
objects {status, code, message}, but catch blocks were checking
error.response?.status (AxiosError format) which never matched.

- Add error field passthrough in client.ts interceptor
- Refactor BulkEditAccountModal to use pre-check API (checkMixedChannelRisk)
  before submit, matching the single edit flow
- Fix EditAccountModal catch blocks to use interceptor error format
- Add bulk-update mixed channel unit tests
2026-03-01 12:46:31 +08:00
erio
05c2a65ef0 chore: bump version to 0.1.87.9 2026-03-01 12:37:07 +08:00
erio
3f21d204a7 refactor(purchase): use URL/searchParams only for purchase query merge 2026-03-01 02:19:13 +08:00
erio
c848f950ce docs+ui: add bilingual payment integration doc and rename purchase entry to recharge/subscription 2026-03-01 02:19:07 +08:00
erio
87bd765a57 chore: bump version to 0.1.87.8 2026-03-01 01:02:53 +08:00
erio
49325a769e feat(frontend): append purchase query params and make integration doc bilingual 2026-03-01 01:02:03 +08:00
erio
238d86f502 feat(admin): add create-and-redeem API and payment integration docs 2026-03-01 00:41:38 +08:00
erio
7064063230 fix: wrap handleSubmit in form onSubmit to fix TS type mismatch 2026-02-28 19:53:25 +08:00
erio
b1de4352a8 fix: update mixed channel test assertions for Chinese message 2026-02-28 19:41:55 +08:00
erio
7bf5c1cbcb chore: bump version to 0.1.87.7 2026-02-28 19:32:24 +08:00
erio
411e24146d feat: bulk update accounts pre-check mixed channel risk with confirm dialog
- Move mixed channel check before any DB writes in BulkUpdateAccounts
- Return 409 from BulkUpdate handler for MixedChannelError
- Add ConfirmDialog to BulkEditAccountModal for mixed channel warning
- Update mixed channel warning message to Chinese
2026-02-28 19:31:57 +08:00
erio
c7392fc80b fix: make purchase iframe fully fill container 2026-02-28 18:58:19 +08:00
erio
c37c68a341 feat: append auth token to purchase iframe url 2026-02-28 16:02:55 +08:00
erio
9230d3cbc9 fix: streamline purchase embed layout with floating open button 2026-02-28 15:26:16 +08:00
erio
19925e22d9 chore: bump version to 0.1.87.3 2026-02-28 14:09:02 +08:00
erio
65e0c1b258 fix: pass theme to purchase iframe and optimize embed layout 2026-02-28 14:08:53 +08:00
erio
94bdde32bb chore: bump version to 0.1.87.2 2026-02-28 10:37:42 +08:00
erio
14c80d26c6 fix: unify purchase url for iframe and new-tab 2026-02-28 10:37:25 +08:00
erio
9555a99d1c chore: bump version to 0.1.87.1 2026-02-27 21:24:03 +08:00
erio
0e69895603 Merge branch 'main' into release/custom-0.1.87
# Conflicts:
#	frontend/src/components/keys/UseKeyModal.vue
2026-02-27 21:20:22 +08:00
erio
cc3cf1d70a chore: bump version to 0.1.86.10 2026-02-27 21:02:24 +08:00
erio
3382d496e3 refactor: remove unused detectClaudeMaxCacheBillingOutcomeForUsage function 2026-02-27 20:56:33 +08:00
erio
e0b4b00dc1 chore: bump version to 0.1.86.9 2026-02-27 20:45:52 +08:00
erio
81d896bf78 fix: sync Antigravity ForwardResult.Usage with client response simulation
Apply Claude Max cache billing to usage before returning ForwardResult
in Antigravity Forward, ensuring RecordUsage gets the same simulated
usage that clients see. Restore apply+fallback in RecordUsage for
consistency across GatewayService and Antigravity paths.
2026-02-27 20:42:53 +08:00
erio
ec576fdbde chore: bump version to 0.1.86.8 2026-02-27 19:59:51 +08:00
erio
741eae59bb refactor: decouple claude max cache simulation from RecordUsage
Extract setupClaudeMaxStreamingHook and applyClaudeMaxNonStreamingRewrite
facade functions to helpers file. RecordUsage now uses detect-only (no
mutation), client response rewriting handled at Forward layer.
2026-02-27 19:59:36 +08:00
erio
505494b378 fix: update 2K image price placeholder from 0.134 to 0.201 2026-02-27 17:18:24 +08:00
erio
c1033c12bd chore: bump version to 0.1.86.6 2026-02-27 16:38:35 +08:00
erio
b789333b68 test: stabilize subscription progress day assertion 2026-02-27 16:35:37 +08:00
erio
61ef73cb12 refactor: isolate claude max response usage simulation by group toggle 2026-02-27 16:14:07 +08:00
erio
e71be7e0f1 fix: update image pricing tests for 2K tier and refactor claude max billing policy
- Update 5 test assertions to match new 2K default price ($0.201 = base * 1.5)
- Refactor claude max cache billing policy into reusable functions
2026-02-27 15:13:05 +08:00
erio
0302c03864 fix: add 2K image pricing at 1.5x base price 2026-02-27 15:00:18 +08:00
erio
d21fe54d55 chore: bump version to 0.1.86.5 2026-02-27 12:55:05 +08:00
erio
a70d3ff82d fix: update antigravity user-agent version to 1.19.6 2026-02-27 12:36:50 +08:00
erio
574359f1df chore: bump version to 0.1.86.4 2026-02-27 12:24:02 +08:00
erio
6da2f54e50 refactor: decouple claude max cache policy and add tokenizer 2026-02-27 12:18:22 +08:00
erio
886464b2e9 Merge branch 'feature/claude-max-simulation-review' into release/custom-0.1.86
# Conflicts:
#	backend/cmd/server/VERSION
2026-02-27 09:58:01 +08:00
erio
396044e354 fix: gofmt alignment in constants.go 2026-02-27 09:36:23 +08:00
erio
3d15202124 chore: bump version to 0.1.86.3 2026-02-27 09:30:58 +08:00
erio
756b09b6b8 feat: replace gemini-3-pro-image with gemini-3.1-flash-image
- Add migration 060 to update model_mapping for all antigravity accounts
- Remove gemini-3-pro-image and gemini-3-pro-image-preview mappings
- Add gemini-3.1-flash-image and gemini-3.1-flash-image-preview mappings
- Update frontend usage window to show GImage for new model
- Update isImageGenerationModel to support new model
2026-02-27 09:30:44 +08:00
erio
f4d3fadd6f chore: bump version to 0.1.86.2 2026-02-27 01:55:25 +08:00
erio
1fb6e9e830 feat: add claude max usage simulation with group switch 2026-02-27 01:54:54 +08:00
erio
78ac6a7a29 docs: 添加 Star 环境部署信息并创建 AGENTS.md
- CLAUDE.md 中添加 Star 环境记录(端口 8086、数据库 star、Redis DB 4)
- 复制 CLAUDE.md 为 AGENTS.md
2026-02-27 00:59:11 +08:00
erio
efe8810dff fix: remove duplicate model_mapping field in GeminiCredentials
The field was defined twice (line 563 and 585) due to merge conflict
with upstream v0.1.86, causing TypeScript compilation error TS2300.
2026-02-25 22:10:57 +08:00
erio
5c07e11473 refactor: remove unused UserAgent constant
The UserAgent constant was never referenced; all HTTP requests
use GetUserAgent() which reads from defaultUserAgentVersion
(configurable via ANTIGRAVITY_USER_AGENT_VERSION env var).
2026-02-25 20:08:44 +08:00
erio
d552ad7673 fix: 修复合并后的重复代码和闭包签名不匹配
- 删除 account_handler.go 中重复的 CheckMixedChannel 函数
- 修复 gateway_handler.go 闭包调用传参不匹配(上游已改为闭包捕获)
- 恢复 docker-compose.yml 中缺失的 postgres_data volume 定义
2026-02-25 19:24:56 +08:00
erio
496173da1f merge: 合并上游 v0.1.86 到 main 分支 2026-02-25 19:02:10 +08:00
eriol touwa
1cdaf33272 Merge pull request #3 from touwaeriol/compact-model-ratelimit-badges
ui: 减小模型限流徽章区域间距
2026-02-25 18:54:06 +08:00
erio
e2b3969492 ui: 减小模型限流徽章区域间距,与用量窗口列高度对齐 2026-02-25 18:51:49 +08:00
erio
8661bf8837 chore: bump version to 0.1.84.8 2026-02-24 14:46:06 +08:00
erio
292fa7a6d2 fix: 修复 Antigravity Claude 4.6 模型用量窗口不显示
模型名列表只包含 claude-sonnet-4-5 和 claude-opus-4-5-thinking,
升级到 4.6 后后端返回 claude-sonnet-4-6/claude-opus-4-6-thinking,
前端匹配不到导致 Claude 用量条不显示。
2026-02-24 14:45:55 +08:00
erio
39d7300a8e chore: bump version to 0.1.84.7 2026-02-22 23:28:19 +08:00
erio
0ad20c9489 fix: 修复 Antigravity 账号 intercept_warmup_requests 配置无法保存
提取 applyInterceptWarmup 纯函数统一所有调用点:
- 修复 upstream 创建时遗漏写入 intercept_warmup_requests
- 修复 apikey 编辑时缺少 else 清除逻辑
- 添加前后端单元测试
- 修复 vitest.config.ts mergeConfig 兼容性问题
2026-02-22 23:28:11 +08:00
erio
7eb3b23ddf chore: bump version to 0.1.84.6 2026-02-22 19:30:50 +08:00
erio
5b06542193 feat: 更新 Antigravity UserAgent 版本到 1.18.4 2026-02-22 19:30:46 +08:00
erio
38dca4f787 chore: bump version to 0.1.84.5 2026-02-20 12:22:13 +08:00
erio
737d1ecf5b feat: 新增 gemini-3.1-pro-high/low 模型支持
- 后端默认映射新增 gemini-3.1-pro-high、gemini-3.1-pro-low、gemini-3.1-pro-preview
- 前端模型列表新增 gemini-3.1 系列
- 数据库迁移全量覆写 model_mapping
2026-02-20 12:22:02 +08:00
erio
f819cef6d5 chore: bump version to 0.1.84.4 2026-02-20 00:20:07 +08:00
erio
facae2a6db feat: 新增 claude-sonnet-4-6 模型支持
- 后端默认映射、身份注入、迁移文件
- 前端模型列表和快捷映射按钮
2026-02-20 00:16:46 +08:00
liuxiongfeng
8b021c099d chore: bump version to 0.1.84.3 2026-02-19 22:47:32 +08:00
liuxiongfeng
8a625188ce Merge release/custom-0.1.83 into release/custom-0.1.84 2026-02-18 22:38:45 +08:00
liuxiongfeng
c0cfa6acde chore: bump antigravity user-agent to 1.16.5, version to 0.1.83.2 2026-02-18 14:14:29 +08:00
liuxiongfeng
0913cfc082 chore: bump version to 0.1.83.1 2026-02-14 22:03:26 +08:00
liuxiongfeng
59e8465325 Merge release/custom-0.1.81 into release/custom-0.1.83
- Resolved conflict in .github/workflows/security-scan.yml: use upstream .gosec.json config
- Resolved conflict in deploy/docker-compose.yml: keep Redis, remove PostgreSQL (use external DB)
2026-02-14 22:03:09 +08:00
liuxiongfeng
be6e8ff77b fix(ui): extend mixed channel check to both Antigravity and Anthropic accounts
Problem:
- When creating Anthropic Console accounts to groups with Antigravity accounts,
  the mixed channel warning was not triggered
- Only Antigravity accounts triggered the confirmation dialog

Solution:
- Renamed isAntigravityAccount to needsMixedChannelCheck
- Extended check to include both 'antigravity' and 'anthropic' platforms
- Updated all 8 call sites in CreateAccountModal and EditAccountModal

Changes:
- frontend/src/components/account/CreateAccountModal.vue: 4 updates
- frontend/src/components/account/EditAccountModal.vue: 4 updates

Testing:
- Frontend compilation: passed
- Backend unit tests: passed
- OAuth flow: confirmation parameter correctly passed in Step 2
- Backend compatibility: verified with checkMixedChannelRisk logic

Impact:
- Both Antigravity and Anthropic accounts now show mixed channel warning
- OAuth accounts correctly pass confirm_mixed_channel_risk parameter
- No breaking changes to existing functionality
2026-02-14 21:29:36 +08:00
liuxiongfeng
ebb85cf843 chore: bump version to 0.1.81.9 2026-02-13 04:15:12 +08:00
liuxiongfeng
ef959bc3c6 chore: bump version to 0.1.81.8 2026-02-13 03:12:31 +08:00
liuxiongfeng
8cb7356bbc feat: add mixed-channel precheck flow for antigravity accounts 2026-02-13 03:12:09 +08:00
liuxiongfeng
496545188a chore: bump version to 0.1.81.7 2026-02-13 00:26:37 +08:00
liuxiongfeng
739c80227a fix(ui): reset mixed channel warning on close 2026-02-13 00:23:15 +08:00
liuxiongfeng
6218eefd61 refactor(ui): extract mixed channel warning handler 2026-02-13 00:22:14 +08:00
liuxiongfeng
5715587baf chore: bump version to 0.1.81.6 2026-02-12 23:49:49 +08:00
liuxiongfeng
ae770a625b fix(ui): mixed channel confirm for upstream create 2026-02-12 23:49:41 +08:00
liuxiongfeng
428ee065d3 chore: bump version to 0.1.81.5 2026-02-12 20:09:57 +08:00
liuxiongfeng
320ca28f90 fix: antigravity 429 fallback uses final model key 2026-02-12 19:30:44 +08:00
liuxiongfeng
5e518f5fbd feat: allow antigravity warmup intercept toggle
- Show warmup-intercept toggle for antigravity accounts in admin UI\n- Add unit tests verifying antigravity accounts are intercepted on /v1/messages
2026-02-12 18:56:55 +08:00
liuxiongfeng
24dcba1d72 chore: gofmt antigravity gateway tests 2026-02-12 18:56:31 +08:00
liuxiongfeng
1abc688cad chore: bump version to 0.1.81.4 2026-02-12 02:33:33 +08:00
liuxiongfeng
34936189d8 fix(antigravity): 固定按映射模型计费并补充回归测试
当账号配置了 model_mapping 时,确保计费使用映射后的实际模型,
而非用户请求的原始模型名,避免计费不准确。
2026-02-12 02:33:20 +08:00
liuxiongfeng
daf7bf3e8b docs: Admin API Key 改为从 .env 文件读取,避免明文传递 2026-02-11 23:50:06 +08:00
liuxiongfeng
3a9f1c5796 chore: bump version to 0.1.81.3 2026-02-11 23:37:37 +08:00
liuxiongfeng
bb1e205516 Merge branch 'develop' into release/custom-0.1.81 2026-02-11 23:37:21 +08:00
liuxiongfeng
9af4a55176 fix: gofmt 格式修复 gateway_cache_integration_test.go 2026-02-11 23:35:26 +08:00
liuxiongfeng
51e903c34e Revert "fix: 并发/排队面板支持 platform/group 过滤"
This reverts commit 86e600aa52.
2026-02-11 23:26:20 +08:00
liuxiongfeng
7f03319646 Revert "fix: 并发/排队面板支持 platform/group 过滤"
This reverts commit 86e600aa52.
2026-02-11 23:06:59 +08:00
liuxiongfeng
0c33d18a4d chore: bump version to 0.1.81.2 2026-02-11 22:31:27 +08:00
liuxiongfeng
a747c63b8e feat: add gemini model mapping whitelist for apikey and bulk edit 2026-02-11 22:31:27 +08:00
liuxiongfeng
a03c361b04 feat: add gemini model mapping whitelist for apikey and bulk edit 2026-02-11 22:31:07 +08:00
liuxiongfeng
807d0018ef docs: 补充账号更新/批量更新 API 文档,添加 API 操作流程规范
- 新增 1.11 更新单个账号(PUT)和 1.12 批量更新账号(POST)接口文档
- 添加 API 操作流程规范:遇到新需求先探索接口、更新文档、再执行操作
2026-02-11 22:17:15 +08:00
liuxiongfeng
850e267763 chore: bump version to 0.1.81.1 2026-02-11 20:49:31 +08:00
liuxiongfeng
c75ae56f10 Merge branch 'develop' into release/custom-0.1.81 2026-02-11 20:49:18 +08:00
liuxiongfeng
caaed775aa ui: 模型限流标签每行最多显示3个,超出自动换行 2026-02-11 18:51:51 +08:00
liuxiongfeng
d11b295729 chore: bump version to 0.1.80.1 2026-02-11 16:21:22 +08:00
liuxiongfeng
91d0059f8d Merge branch 'develop' into release/custom-0.1.80
# Conflicts:
#	backend/internal/pkg/antigravity/client.go
#	backend/internal/service/antigravity_oauth_service.go
#	backend/internal/service/antigravity_oauth_service_test.go
#	backend/internal/service/antigravity_token_provider.go
2026-02-11 16:20:56 +08:00
liuxiongfeng
44693d0dfb chore: bump version to 0.1.79.7 2026-02-11 14:39:29 +08:00
liuxiongfeng
c722212e12 fix: distinguish client disconnection from upstream retry failure
When client disconnects during upstream request, the error was
incorrectly reported as "Upstream request failed after retries".
Now checks context cancellation first and returns
"Client disconnected before upstream response" instead.
2026-02-11 14:39:14 +08:00
liuxiongfeng
f12da65962 docs: add Admin API reference to CLAUDE.md and AGENTS.md 2026-02-11 14:39:02 +08:00
liuxiongfeng
92745f7534 chore: bump version to 0.1.79.6 2026-02-11 13:08:49 +08:00
liuxiongfeng
f2917aeaf8 fix: use safe type assertion for errcheck lint compliance 2026-02-11 13:06:41 +08:00
liuxiongfeng
a1e2ffd586 refactor: optimize project_id fill to lightweight approach
- Replace heavy RefreshAccountToken with lightweight tryFillProjectID
  (loadCodeAssist → onboardUser → fallback), consistent with
  Antigravity-Manager's behavior
- Add sync.Map cooldown/dedup (60s) to prevent repeated fill attempts
- Add fallback project_id "bamboo-precept-lgxtn" matching AM
- Extract mergeCredentials helper to eliminate duplication
- Use slog structured logging instead of log.Printf
- Fix time.Sleep in OnboardUser to context-aware select
- Fix strings.NewReader(string(bodyBytes)) → bytes.NewReader(bodyBytes)
- Remove redundant tc := tc in test (Go 1.22+)
- Add nil guard in persistProjectID for test safety
2026-02-11 13:00:31 +08:00
liuxiongfeng
78a9705fad fix: resolve ineffassign lint error in onboard project_id logic 2026-02-11 12:45:52 +08:00
liuxiongfeng
5b6da04a02 docs: 更新 CI 检查流程,强调本地先执行全部检查 2026-02-11 12:43:45 +08:00
liuxiongfeng
4661c2f90f chore: bump version to 0.1.79.5 2026-02-11 12:35:26 +08:00
liuxiongfeng
90e4328885 chore: remove patch file from tracking 2026-02-11 12:25:15 +08:00
liuxiongfeng
130112a84a fix: 补齐 Antigravity OAuth 账号 project_id 获取逻辑
部分账号 loadCodeAssist 不会立即返回 cloudaicompanionProject,
导致转发时 project 字段为空,上游返回 400 "Invalid project resource name projects/"。

- 新增 OnboardUser API:当 loadCodeAssist 未返回 project_id 时,
  通过 onboardUser 完成账号初始化并获取 project_id
- token 刷新时增加 onboard 兜底逻辑
- GetAccessToken 按需补齐:转发时发现 project_id 为空立即触发刷新
- 新增 resolveDefaultTierID 单元测试
2026-02-11 12:25:04 +08:00
liuxiongfeng
b368bb6ea1 chore: bump version to 0.1.79.4 2026-02-11 04:55:17 +08:00
liuxiongfeng
9ecb6211d6 docs: update CLAUDE.md and .gitignore 2026-02-11 04:54:11 +08:00
liuxiongfeng
79fba9c8d3 refactor: consolidate failover logic into FailoverState
- Merge FailoverRetry/FailoverSwitch into single FailoverContinue action
- Extract HandleSelectionExhausted into FailoverState (was duplicated 3×)
- Move helper functions (needForceCacheBilling, sleepWithContext) into failover_loop.go
- Inline sleepFailoverDelay, replace sleepAntigravitySingleAccountBackoff with constant
- Delete gateway_handler_single_account_retry_test.go (tested removed function)
- Add 6 test cases for HandleSelectionExhausted
2026-02-11 04:54:05 +08:00
liuxiongfeng
37c76a93ab chore: bump version to 0.1.79.3 2026-02-11 01:56:52 +08:00
liuxiongfeng
86e600aa52 fix: 并发/排队面板支持 platform/group 过滤
- 添加 platformFilter/groupIdFilter props 变化监听器,过滤条件变化时
  立即重新加载数据(修复选择平台后显示"暂无数据"的问题)
- 全栈为 getUserConcurrencyStats 添加 platform/group_id 过滤支持:
  前端 API → Handler 解析 query params → Service 层过滤逻辑
- Service 层通过账号的 group 关联反查用户的 AllowedGroups,
  与 GetConcurrencyStats 的过滤模式保持一致
2026-02-11 01:41:34 +08:00
liuxiongfeng
6a1c28b70e docs: 优化 beta 部署流程,构建服务器共用仓库
- beta 和正式共用构建服务器 /root/sub2api 仓库,通过镜像标签区分
- 首次部署和更新 beta 流程统一使用共用仓库 + 分支切换
- 各步骤增加检查点和验证说明
2026-02-11 00:41:05 +08:00
liuxiongfeng
9375f1809c docs: 部署步骤增加前置检查和验证环节
- 各步骤增加检查点说明,确保每步执行成功后再继续
- 步骤 0 强调推送成功确认和未提交改动检查
- 步骤 1 增加版本号验证
- 步骤 2 增加常见构建问题排查(buildx 版本、磁盘空间)
- 步骤 4 增加生产服务器代码同步步骤
- 新增「构建服务器首次初始化」章节
2026-02-11 00:34:48 +08:00
liuxiongfeng
f176150c93 Merge branch 'release/custom-0.1.79'
# Conflicts:
#	AGENTS.md
#	CLAUDE.md
#	backend/cmd/server/VERSION
#	backend/internal/handler/gateway_handler.go
#	backend/internal/handler/gemini_v1beta_handler.go
#	backend/internal/service/antigravity_gateway_service.go
#	backend/internal/service/antigravity_rate_limit_test.go
#	backend/internal/service/antigravity_single_account_retry_test.go
#	backend/internal/service/antigravity_smart_retry_test.go
2026-02-11 00:32:11 +08:00
liuxiongfeng
30d25084f0 chore: bump version to 0.1.79.2 2026-02-11 00:22:19 +08:00
liuxiongfeng
91ad94d941 docs: 部署流程改为构建服务器与生产服务器分离
镜像构建从生产服务器(clicodeplus)迁移到构建服务器(us-asaki-root),
通过 docker save/load 管道传输,避免编译时资源占用影响线上服务。
2026-02-11 00:20:35 +08:00
liuxiongfeng
57a778dccf chore: bump version to 0.1.79.1 2026-02-10 23:48:23 +08:00
liuxiongfeng
f702c66659 fix: resolve gofmt alignment issue in failover_loop_test.go
Move inline comments to separate lines to avoid gofmt
consecutive-line comment alignment requirements.
2026-02-10 23:40:37 +08:00
liuxiongfeng
a095468850 fix: correct import path in failover_loop.go and failover_loop_test.go
Use github.com/Wei-Shaw/sub2api/internal/service instead of
sub2api/internal/service to match the module path in go.mod.
2026-02-10 23:35:21 +08:00
liuxiongfeng
f9b6a20995 Merge tag 'v0.1.79' into develop
增强错误处理与重试机制,新增 MODEL_CAPACITY_EXHAUSTED 同账号固定间隔重试、瞬态错误同账号重试优先于故障转移,并大幅优化错误匹配性能。

- MODEL_CAPACITY_EXHAUSTED (503) 使用固定 1s 间隔重试最多 60 次,不切换账号
- 瞬态错误(Google 400、空流响应)同账号重试 2 次后再触发故障转移
- 空流响应触发 failover 自动换号重试,不再直接返回 502
- Google "Invalid project resource name" 400 错误触发 failover 并临时封禁账号 1 小时
- 错误透传规则新增 skip_monitoring 选项,匹配的错误不记录到运维监控日志
- Antigravity 转发支持 daily/prod 单 URL 切换

- 错误匹配性能优化:延迟/限制 body ToLower,预计算规则关键词和平台集合
- MODEL_CAPACITY_EXHAUSTED 全局去重,避免并发请求重复重试
- 503 重试 body 读取限制从 2MB 降至 8KB
- time.After 替换为 time.NewTimer,防止 context 取消时 timer 泄漏
- 临时封禁冷却时间从 30 分钟缩短至 1 分钟(同账号重试耗尽后)

- 修复错误透传规则 skip_monitoring 未生效的问题
- 修复 CI 检查失败(gofmt、errcheck、staticcheck)

# Conflicts:
#	backend/internal/service/error_passthrough_runtime_test.go
2026-02-10 23:25:51 +08:00
liuxiongfeng
f2770da880 refactor: extract failover error handling into shared HandleFailoverError
- Extract duplicated failover error handling from gateway_handler.go (Gemini-compat & Claude paths) and gemini_v1beta_handler.go into shared failover_loop.go
- Introduce TempUnscheduler interface for testability (GatewayService implicitly satisfies it)
- Add comprehensive unit tests for HandleFailoverError (32 test cases covering all paths)
- Fix golangci-lint issues: errcheck in test type assertion, staticcheck QF1003 if/else→switch
2026-02-10 23:13:37 +08:00
liuxiongfeng
d269659e61 chore: bump version to 0.1.78.2 2026-02-10 21:28:52 +08:00
liuxiongfeng
c4d6715443 chore: squash merge customizations from develop-old-0.1.77
- 定制文档: CLAUDE.md, AGENTS.md
- UI定制: 微信客服按钮, 首页改造, 移除GitHub链接
- 部署运维: docker-compose.yml, 压测脚本
- CI/gitignore 小改动
2026-02-10 20:59:54 +08:00
erio
fc4a1c5433 Merge branch 'release/custom-0.1.78'
# Conflicts:
#	backend/cmd/server/VERSION
2026-02-10 11:41:29 +08:00
erio
6bdd580b3f chore: bump version to 0.1.78.1 2026-02-10 11:40:36 +08:00
erio
9cf4882f4c Merge tag 'v0.1.78' into develop
Resolve conflict in antigravity_gateway_service.go by keeping both
retry strategies:
- MODEL_CAPACITY_EXHAUSTED: handleModelCapacityExhaustedRetry (ours)
- Single-account 503 long delay: handleSingleAccountRetryInPlace (upstream)

Update tests to reflect that MODEL_CAPACITY_EXHAUSTED always goes
through capacity retry regardless of single-account mode.
2026-02-10 11:32:56 +08:00
erio
406dad998d chore: bump version to 0.1.77.2 2026-02-10 10:59:34 +08:00
erio
8b0db22c18 Merge branch 'develop' into release/custom-0.1.77 2026-02-10 10:58:52 +08:00
erio
f06048eccf fix: simplify MODEL_CAPACITY_EXHAUSTED to single retry for all cases
Both short (<20s) and long (>=20s/missing) retryDelay now retry once:
- Short: wait actual retryDelay, retry once
- Long/missing: wait 20s, retry once
- Still capacity exhausted: switch account
- Different error: let upper layer handle
2026-02-10 04:05:20 +08:00
erio
05f5a8b61d fix: use switch statement for staticcheck QF1003 compliance 2026-02-10 03:59:39 +08:00
erio
662625a091 feat: optimize MODEL_CAPACITY_EXHAUSTED retry and remove extra failover retries
- MODEL_CAPACITY_EXHAUSTED now uses independent retry strategy:
  - retryDelay < 20s: wait actual retryDelay then retry once
  - retryDelay >= 20s or missing: retry up to 5 times at 20s intervals
  - Still capacity exhausted after retries: switch account (failover)
  - Different error during retry (e.g. 429): handle by actual error code
  - No model rate limit set (capacity != rate limit)

- Remove Antigravity extra failover retries feature:
  Same-account retry mechanism (cherry-picked) makes it redundant.
  Removed: antigravityExtraRetries config, sleepFixedDelay, skip-non-antigravity logic.
2026-02-10 03:47:40 +08:00
Edric Li
6328e69441 feat: same-account retry before failover for transient errors
For retryable transient errors (Google 400 "invalid project resource name"
and empty stream responses), retry on the same account up to 2 times
(with 500ms delay) before switching to another account.

- Add RetryableOnSameAccount field to UpstreamFailoverError
- Add same-account retry loop in both Gemini and Claude/OpenAI handler paths
- Move temp-unschedule from service layer to handler layer (only after
  all same-account retries exhausted)
- Reduce temp-unschedule cooldown from 30 minutes to 1 minute
2026-02-10 03:27:19 +08:00
Edric Li
425dfb80d9 feat: failover and temp-unschedule on empty stream response
- Empty stream responses now return UpstreamFailoverError instead of
  plain 502, triggering automatic account switching (up to 10 retries)
- Add tempUnscheduleEmptyResponse: accounts returning empty responses
  are temp-unscheduled for 30 minutes
- Apply to both Claude and Gemini non-streaming paths
- Align googleConfigErrorCooldown from 60m to 30m for consistency
2026-02-10 03:27:02 +08:00
Edric Li
4c1fd570f0 feat: failover and temp-unschedule on Google "Invalid project resource name" 400
Google 后端间歇性返回 400 "Invalid project resource name" 错误,
此前该错误直接透传给客户端且不触发账号切换,导致请求失败。

- 在 Antigravity 和 Gemini 两个平台的所有转发路径中,
  精确匹配该错误消息后触发 failover 自动换号重试
- 命中后将账号临时封禁 1 小时,避免反复调度到同一故障账号
- 提取共享函数 isGoogleProjectConfigError / tempUnscheduleGoogleConfigError
  消除跨 Service 的代码重复
2026-02-10 03:26:51 +08:00
erio
345f853b5d chore: bump version to 0.1.77.1 2026-02-09 22:27:47 +08:00
erio
100a70f87c Merge remote-tracking branch 'upstream/main' into develop 2026-02-09 22:27:07 +08:00
erio
18b591bc3b feat: Antigravity extra failover retries after default retries exhausted
When default failover retries are exhausted, continue retrying with
Antigravity accounts only (up to 10 times, configurable via
GATEWAY_ANTIGRAVITY_EXTRA_RETRIES). Each extra retry uses a fixed
500ms delay. Non-Antigravity accounts are skipped during the extra
retry phase. Applied to all three endpoints: Gemini compat, Claude,
and Gemini native API paths.
2026-02-09 22:13:44 +08:00
erio
6a52b24369 Merge branch 'develop'
# Conflicts:
#	backend/cmd/server/VERSION
2026-02-09 20:13:54 +08:00
erio
228aca9523 Merge branch 'fix/gemini-error-policy-before-retry' into develop
# Conflicts:
#	backend/cmd/server/VERSION
2026-02-09 20:08:31 +08:00
erio
7e4637cd70 fix: support clearing model-level rate limits from action menu and temp-unsched reset 2026-02-09 20:08:00 +08:00
erio
3e3c015efa fix: Gemini error policy check should precede retry logic 2026-02-09 19:22:32 +08:00
erio
30c30b1712 fix: skip rate limiting when custom error codes don't match upstream status
Add ShouldHandleErrorCode guard at the entry of handleGeminiUpstreamError
and AntigravityGatewayService.handleUpstreamError so that accounts with
custom error codes (e.g. [599]) are not rate-limited when the upstream
returns a non-matching status (e.g. 429).
2026-02-09 18:53:52 +08:00
erio
e666356483 fix: resolve merge conflict in OpsConcurrencyCard.vue 2026-02-09 18:12:39 +08:00
erio
3710bc883b feat: ErrorPolicySkipped returns 500 instead of upstream status code
When custom error codes are enabled and the upstream error code is NOT
in the configured list, return HTTP 500 to the client instead of
transparently forwarding the original status code. This matches the
frontend description: "other errors will return 500".

Also adds integration test TestCustomErrorCode599 verifying that 429,
500, 503, 401, 403 all return 500 without triggering SetRateLimited
or SetError.
2026-02-09 18:10:39 +08:00
erio
64f60d15b0 fix: pass platform prop to GroupBadge in GroupSelector for consistent colors
chore: bump version to 0.1.76.2
2026-02-09 14:11:41 +08:00
erio
bc4a044337 Merge release/custom-0.1.76 into main 2026-02-09 13:44:17 +08:00
erio
cb233bfa66 fix: resolve merge conflict marker in OpsConcurrencyCard.vue 2026-02-09 13:20:14 +08:00
erio
d46059a735 chore: bump version to 0.1.76.1 2026-02-09 13:13:14 +08:00
erio
da2fbd9924 Merge tag 'v0.1.76' into develop
重构速率限制从 scope 级别到 model 级别,新增分组拖拽排序功能,优化 Antigravity 智能重试和错误处理策略。

- 分组拖拽排序:管理后台支持拖拽调整分组顺序
- 账号调度防雷群:相同优先级的账号在调度时随机打乱,防止并发请求集中到同一账号
- 上游账号路由:AccountTypeUpstream 正确路由到 ForwardUpstream 流程
- 客户端断开检测:流式响应时检测客户端断开,继续消费上游响应以确保计费准确
- Antigravity 失败切换延迟:账号切换时增加线性延迟,避免瞬时切换造成雪崩
- Gemini 错误策略集成:统一 Antigravity 错误策略,支持 Gemini 账号自定义错误码

- 重构速率限制:从 scope 级别(claude/gemini_text/gemini_image)改为 model 级别限速,简化账号选择算法
- 重构会话存储:使用扁平缓存替代 Trie 结构,简化 digest session 存储
- 速率限制延迟:使用上游返回的 retryDelay 而非固定 30s 默认值

- 修复 Gemini 原生请求格式解析,确保 session hash 正确生成
- 修复不同用户发送相同消息时 sessionHash 冲突问题
- 修复 thoughtSignature 清理逻辑,现在适用于所有客户端而非仅 CLI
- 修复粘性会话失败切换时的缓存计费豁免问题

# Conflicts:
#	backend/cmd/server/VERSION
#	backend/internal/service/gateway_multiplatform_test.go
#	backend/internal/service/gemini_multiplatform_test.go
#	frontend/src/components/account/AccountStatusIndicator.vue
#	frontend/src/views/admin/ops/components/OpsConcurrencyCard.vue
2026-02-09 12:40:26 +08:00
erio
084e0adb34 feat: squash merge all changes from develop-0.1.75
Squash of 124 commits from the legacy develop branch (develop-0.1.75)
onto a clean v0.1.75 upstream base, to simplify future upstream merges.

Key changes included:
- Refactor scope-level rate limiting to model-level rate limiting
- Antigravity gateway service improvements (smart retry, error policy)
- Digest session store (flat cache replacing Trie-based store)
- Client disconnect detection during streaming
- Gemini messages compatibility service enhancements
- Scheduler shuffle for thundering herd prevention
- Session hash generation improvements
- Frontend customizations (WeChat service, HomeView, etc.)
- Ops monitoring scope cleanup
2026-02-09 12:32:35 +08:00
shaw
3bddbb6afe chore: update version 2026-02-09 09:42:29 +08:00
159 changed files with 5826 additions and 9799 deletions

View File

@@ -17,6 +17,7 @@ jobs:
go-version-file: backend/go.mod
check-latest: false
cache: true
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.26.1'
@@ -36,6 +37,7 @@ jobs:
go-version-file: backend/go.mod
check-latest: false
cache: true
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.26.1'

10
.gitignore vendored
View File

@@ -78,6 +78,7 @@ Desktop.ini
# ===================
tmp/
temp/
logs/
*.tmp
*.temp
*.log
@@ -128,8 +129,15 @@ deploy/docker-compose.override.yml
vite.config.js
docs/*
.serena/
# ===================
# 压测工具
# ===================
tools/loadtest/
# Antigravity Manager
Antigravity-Manager/
antigravity_projectid_fix.patch
.codex/
frontend/coverage/
aicodex
output/

1392
AGENTS.md Normal file

File diff suppressed because it is too large Load Diff

1337
CLAUDE.md Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -39,16 +39,6 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
- **Concurrency Control** - Per-user and per-account concurrency limits
- **Rate Limiting** - Configurable request and token rate limits
- **Admin Dashboard** - Web interface for monitoring and management
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
## Ecosystem
Community projects that extend or integrate with Sub2API:
| Project | Description | Features |
|---------|-------------|----------|
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | Self-service payment system | Self-service top-up and subscription purchase; supports YiPay protocol, WeChat Pay, Alipay, Stripe; embeddable via iframe |
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native |
## Tech Stack

View File

@@ -39,16 +39,6 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
- **并发控制** - 用户级和账号级并发限制
- **速率限制** - 可配置的请求和 Token 速率限制
- **管理后台** - Web 界面进行监控和管理
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
## 生态项目
围绕 Sub2API 的社区扩展与集成项目:
| 项目 | 说明 | 功能 |
|------|------|------|
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | 自助支付系统 | 用户自助充值、自助订阅购买兼容易支付协议、微信官方支付、支付宝官方支付、Stripe支持 iframe 嵌入管理后台 |
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用iOS/Android/Web支持用户管理、账号管理、监控看板、多后端切换基于 Expo + React Native 构建 |
## 技术栈

View File

@@ -1 +1 @@
0.1.88
0.1.96.1

View File

@@ -41,9 +41,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
// Server layer ProviderSet
server.ProviderSet,
// Privacy client factory for OpenAI training opt-out
providePrivacyClientFactory,
// BuildInfo provider
provideServiceBuildInfo,
@@ -56,10 +53,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
return nil, nil
}
func providePrivacyClientFactory() service.PrivacyClientFactory {
return repository.CreatePrivacyReqClient
}
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
return service.BuildInfo{
Version: buildInfo.Version,

View File

@@ -81,7 +81,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemHandler := handler.NewRedeemHandler(redeemService)
@@ -105,8 +104,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
privacyClientFactory := providePrivacyClientFactory()
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
@@ -164,9 +162,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
@@ -228,7 +226,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
@@ -247,10 +245,6 @@ type Application struct {
Cleanup func()
}
func providePrivacyClientFactory() service.PrivacyClientFactory {
return repository.CreatePrivacyReqClient
}
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
return service.BuildInfo{
Version: buildInfo.Version,

View File

@@ -62,26 +62,28 @@ type Group struct {
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
// 是否仅允许 Claude Code 客户端
// allow Claude Code client only
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// Claude Code 请求降级使用的分组 ID
// fallback group for non-Claude-Code requests
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
// 无效请求兜底使用的分组 ID
// fallback group for invalid request
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
// 模型路由配置:模型模式 -> 优先账号ID列表
// model routing config: pattern -> account ids
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
// 是否启用模型路由配置
// whether model routing is enabled
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
// 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)
// whether MCP XML prompt injection is enabled
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
// 支持的模型系列:claude, gemini_text, gemini_image
// supported model scopes: claude, gemini_text, gemini_image
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
// 分组显示排序,数值越小越靠前
// group display order, lower comes first
SortOrder int `json:"sort_order,omitempty"`
// 是否允许 /v1/messages 调度到此 OpenAI 分组
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
// 默认映射模型 ID当账号级映射找不到时使用此值
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
// simulate claude usage as claude-max style (1h cache write)
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"`
@@ -190,7 +192,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case group.FieldModelRouting, group.FieldSupportedModelScopes:
values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch:
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldSimulateClaudeMaxEnabled:
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
values[i] = new(sql.NullFloat64)
@@ -431,6 +433,12 @@ func (_m *Group) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.DefaultMappedModel = value.String
}
case group.FieldSimulateClaudeMaxEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field simulate_claude_max_enabled", values[i])
} else if value.Valid {
_m.SimulateClaudeMaxEnabled = value.Bool
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -630,6 +638,9 @@ func (_m *Group) String() string {
builder.WriteString(", ")
builder.WriteString("default_mapped_model=")
builder.WriteString(_m.DefaultMappedModel)
builder.WriteString(", ")
builder.WriteString("simulate_claude_max_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.SimulateClaudeMaxEnabled))
builder.WriteByte(')')
return builder.String()
}

View File

@@ -79,6 +79,8 @@ const (
FieldAllowMessagesDispatch = "allow_messages_dispatch"
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
FieldDefaultMappedModel = "default_mapped_model"
// FieldSimulateClaudeMaxEnabled holds the string denoting the simulate_claude_max_enabled field in the database.
FieldSimulateClaudeMaxEnabled = "simulate_claude_max_enabled"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -186,6 +188,7 @@ var Columns = []string{
FieldSortOrder,
FieldAllowMessagesDispatch,
FieldDefaultMappedModel,
FieldSimulateClaudeMaxEnabled,
}
var (
@@ -259,6 +262,8 @@ var (
DefaultDefaultMappedModel string
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
DefaultMappedModelValidator func(string) error
// DefaultSimulateClaudeMaxEnabled holds the default value on creation for the "simulate_claude_max_enabled" field.
DefaultSimulateClaudeMaxEnabled bool
)
// OrderOption defines the ordering options for the Group queries.
@@ -419,6 +424,11 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
}
// BySimulateClaudeMaxEnabled orders the results by the simulate_claude_max_enabled field.
func BySimulateClaudeMaxEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSimulateClaudeMaxEnabled, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {

View File

@@ -205,6 +205,11 @@ func DefaultMappedModel(v string) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
}
// SimulateClaudeMaxEnabled applies equality check predicate on the "simulate_claude_max_enabled" field. It's identical to SimulateClaudeMaxEnabledEQ.
func SimulateClaudeMaxEnabled(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
@@ -1555,6 +1560,16 @@ func DefaultMappedModelContainsFold(v string) predicate.Group {
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
}
// SimulateClaudeMaxEnabledEQ applies the EQ predicate on the "simulate_claude_max_enabled" field.
func SimulateClaudeMaxEnabledEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
}
// SimulateClaudeMaxEnabledNEQ applies the NEQ predicate on the "simulate_claude_max_enabled" field.
func SimulateClaudeMaxEnabledNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSimulateClaudeMaxEnabled, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) {

View File

@@ -452,6 +452,20 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
return _c
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (_c *GroupCreate) SetSimulateClaudeMaxEnabled(v bool) *GroupCreate {
_c.mutation.SetSimulateClaudeMaxEnabled(v)
return _c
}
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupCreate {
if v != nil {
_c.SetSimulateClaudeMaxEnabled(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -649,6 +663,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultDefaultMappedModel
_c.mutation.SetDefaultMappedModel(v)
}
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
v := group.DefaultSimulateClaudeMaxEnabled
_c.mutation.SetSimulateClaudeMaxEnabled(v)
}
return nil
}
@@ -730,6 +748,9 @@ func (_c *GroupCreate) check() error {
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
}
}
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
return &ValidationError{Name: "simulate_claude_max_enabled", err: errors.New(`ent: missing required field "Group.simulate_claude_max_enabled"`)}
}
return nil
}
@@ -885,6 +906,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
_node.DefaultMappedModel = value
}
if value, ok := _c.mutation.SimulateClaudeMaxEnabled(); ok {
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
_node.SimulateClaudeMaxEnabled = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1599,6 +1624,18 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
return u
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (u *GroupUpsert) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsert {
u.Set(group.FieldSimulateClaudeMaxEnabled, v)
return u
}
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSimulateClaudeMaxEnabled() *GroupUpsert {
u.SetExcluded(group.FieldSimulateClaudeMaxEnabled)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -2295,6 +2332,20 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
})
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (u *GroupUpsertOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSimulateClaudeMaxEnabled(v)
})
}
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSimulateClaudeMaxEnabled() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSimulateClaudeMaxEnabled()
})
}
// Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -3157,6 +3208,20 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
})
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (u *GroupUpsertBulk) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSimulateClaudeMaxEnabled(v)
})
}
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSimulateClaudeMaxEnabled() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSimulateClaudeMaxEnabled()
})
}
// Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -653,6 +653,20 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
return _u
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (_u *GroupUpdate) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdate {
_u.mutation.SetSimulateClaudeMaxEnabled(v)
return _u
}
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdate {
if v != nil {
_u.SetSimulateClaudeMaxEnabled(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1149,6 +1163,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.DefaultMappedModel(); ok {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
}
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -2081,6 +2098,20 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO
return _u
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (_u *GroupUpdateOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdateOne {
_u.mutation.SetSimulateClaudeMaxEnabled(v)
return _u
}
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdateOne {
if v != nil {
_u.SetSimulateClaudeMaxEnabled(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -2607,6 +2638,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if value, ok := _u.mutation.DefaultMappedModel(); ok {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
}
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,

View File

@@ -410,6 +410,7 @@ var (
{Name: "sort_order", Type: field.TypeInt, Default: 0},
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
{Name: "simulate_claude_max_enabled", Type: field.TypeBool, Default: false},
}
// GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{

View File

@@ -8252,6 +8252,7 @@ type GroupMutation struct {
addsort_order *int
allow_messages_dispatch *bool
default_mapped_model *string
simulate_claude_max_enabled *bool
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -10068,6 +10069,42 @@ func (m *GroupMutation) ResetDefaultMappedModel() {
m.default_mapped_model = nil
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (m *GroupMutation) SetSimulateClaudeMaxEnabled(b bool) {
m.simulate_claude_max_enabled = &b
}
// SimulateClaudeMaxEnabled returns the value of the "simulate_claude_max_enabled" field in the mutation.
func (m *GroupMutation) SimulateClaudeMaxEnabled() (r bool, exists bool) {
v := m.simulate_claude_max_enabled
if v == nil {
return
}
return *v, true
}
// OldSimulateClaudeMaxEnabled returns the old "simulate_claude_max_enabled" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldSimulateClaudeMaxEnabled(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldSimulateClaudeMaxEnabled is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldSimulateClaudeMaxEnabled requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldSimulateClaudeMaxEnabled: %w", err)
}
return oldValue.SimulateClaudeMaxEnabled, nil
}
// ResetSimulateClaudeMaxEnabled resets all changes to the "simulate_claude_max_enabled" field.
func (m *GroupMutation) ResetSimulateClaudeMaxEnabled() {
m.simulate_claude_max_enabled = nil
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -10426,7 +10463,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 32)
fields := make([]string, 0, 33)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}
@@ -10523,6 +10560,9 @@ func (m *GroupMutation) Fields() []string {
if m.default_mapped_model != nil {
fields = append(fields, group.FieldDefaultMappedModel)
}
if m.simulate_claude_max_enabled != nil {
fields = append(fields, group.FieldSimulateClaudeMaxEnabled)
}
return fields
}
@@ -10595,6 +10635,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.AllowMessagesDispatch()
case group.FieldDefaultMappedModel:
return m.DefaultMappedModel()
case group.FieldSimulateClaudeMaxEnabled:
return m.SimulateClaudeMaxEnabled()
}
return nil, false
}
@@ -10668,6 +10710,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldAllowMessagesDispatch(ctx)
case group.FieldDefaultMappedModel:
return m.OldDefaultMappedModel(ctx)
case group.FieldSimulateClaudeMaxEnabled:
return m.OldSimulateClaudeMaxEnabled(ctx)
}
return nil, fmt.Errorf("unknown Group field %s", name)
}
@@ -10901,6 +10945,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetDefaultMappedModel(v)
return nil
case group.FieldSimulateClaudeMaxEnabled:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetSimulateClaudeMaxEnabled(v)
return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
@@ -11334,6 +11385,9 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldDefaultMappedModel:
m.ResetDefaultMappedModel()
return nil
case group.FieldSimulateClaudeMaxEnabled:
m.ResetSimulateClaudeMaxEnabled()
return nil
}
return fmt.Errorf("unknown Group field %s", name)
}

View File

@@ -463,6 +463,10 @@ func init() {
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
// groupDescSimulateClaudeMaxEnabled is the schema descriptor for simulate_claude_max_enabled field.
groupDescSimulateClaudeMaxEnabled := groupFields[29].Descriptor()
// group.DefaultSimulateClaudeMaxEnabled holds the default value on creation for the simulate_claude_max_enabled field.
group.DefaultSimulateClaudeMaxEnabled = groupDescSimulateClaudeMaxEnabled.Default.(bool)
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
_ = idempotencyrecordMixinFields0

View File

@@ -33,8 +33,6 @@ func (Group) Mixin() []ent.Mixin {
func (Group) Fields() []ent.Field {
return []ent.Field{
// 唯一约束通过部分索引实现WHERE deleted_at IS NULL支持软删除后重用
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
field.String("name").
MaxLen(100).
NotEmpty(),
@@ -51,7 +49,6 @@ func (Group) Fields() []ent.Field {
MaxLen(20).
Default(domain.StatusActive),
// Subscription-related fields (added by migration 003)
field.String("platform").
MaxLen(50).
Default(domain.PlatformAnthropic),
@@ -73,7 +70,6 @@ func (Group) Fields() []ent.Field {
field.Int("default_validity_days").
Default(30),
// 图片生成计费配置antigravity 和 gemini 平台使用)
field.Float("image_price_1k").
Optional().
Nillable().
@@ -87,7 +83,6 @@ func (Group) Fields() []ent.Field {
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
// Sora 按次计费配置(阶段 1
field.Float("sora_image_price_360").
Optional().
Nillable().
@@ -109,45 +104,38 @@ func (Group) Fields() []ent.Field {
field.Int64("sora_storage_quota_bytes").
Default(0),
// Claude Code 客户端限制 (added by migration 029)
field.Bool("claude_code_only").
Default(false).
Comment("是否仅允许 Claude Code 客户端"),
Comment("allow Claude Code client only"),
field.Int64("fallback_group_id").
Optional().
Nillable().
Comment("Claude Code 请求降级使用的分组 ID"),
Comment("fallback group for non-Claude-Code requests"),
field.Int64("fallback_group_id_on_invalid_request").
Optional().
Nillable().
Comment("无效请求兜底使用的分组 ID"),
Comment("fallback group for invalid request"),
// 模型路由配置 (added by migration 040)
field.JSON("model_routing", map[string][]int64{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("模型路由配置:模型模式 -> 优先账号ID列表"),
// 模型路由开关 (added by migration 041)
Comment("model routing config: pattern -> account ids"),
field.Bool("model_routing_enabled").
Default(false).
Comment("是否启用模型路由配置"),
Comment("whether model routing is enabled"),
// MCP XML 协议注入开关 (added by migration 042)
field.Bool("mcp_xml_inject").
Default(true).
Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"),
Comment("whether MCP XML prompt injection is enabled"),
// 支持的模型系列 (added by migration 046)
field.JSON("supported_model_scopes", []string{}).
Default([]string{"claude", "gemini_text", "gemini_image"}).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("支持的模型系列:claude, gemini_text, gemini_image"),
Comment("supported model scopes: claude, gemini_text, gemini_image"),
// 分组排序 (added by migration 052)
field.Int("sort_order").
Default(0).
Comment("分组显示排序,数值越小越靠前"),
Comment("group display order, lower comes first"),
// OpenAI Messages 调度配置 (added by migration 069)
field.Bool("allow_messages_dispatch").
@@ -157,6 +145,9 @@ func (Group) Fields() []ent.Field {
MaxLen(100).
Default("").
Comment("默认映射模型 ID当账号级映射找不到时使用此值"),
field.Bool("simulate_claude_max_enabled").
Default(false).
Comment("simulate claude usage as claude-max style (1h cache write)"),
}
}
@@ -172,14 +163,11 @@ func (Group) Edges() []ent.Edge {
edge.From("allowed_users", User.Type).
Ref("allowed_groups").
Through("user_allowed_groups", UserAllowedGroup.Type),
// 注意fallback_group_id 直接作为字段使用,不定义 edge
// 这样允许多个分组指向同一个降级分组M2O 关系)
}
}
func (Group) Indexes() []ent.Index {
return []ent.Index{
// name 字段已在 Fields() 中声明 Unique(),无需重复索引
index.Fields("status"),
index.Fields("platform"),
index.Fields("subscription_type"),

View File

@@ -7,7 +7,7 @@ require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/DouDOU-start/go-sora2api v1.1.0
github.com/alitto/pond/v2 v2.6.2
github.com/aws/aws-sdk-go-v2 v1.41.3
github.com/aws/aws-sdk-go-v2 v1.41.2
github.com/aws/aws-sdk-go-v2/config v1.32.10
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
@@ -66,7 +66,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
github.com/aws/smithy-go v1.24.2 // indirect
github.com/aws/smithy-go v1.24.1 // indirect
github.com/bdandy/go-errors v1.2.2 // indirect
github.com/bdandy/go-socks4 v1.2.3 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
@@ -87,6 +87,7 @@ require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
@@ -137,6 +138,8 @@ require (
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
github.com/pkoukk/tiktoken-go-loader v0.0.2 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/quic-go/qpack v0.6.0 // indirect

View File

@@ -24,8 +24,6 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
@@ -62,8 +60,6 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
@@ -128,6 +124,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
@@ -184,6 +182,7 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
@@ -203,6 +202,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -285,6 +286,10 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -337,6 +342,8 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@@ -345,8 +352,6 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
@@ -436,11 +441,11 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=

View File

@@ -934,10 +934,9 @@ type DashboardAggregationConfig struct {
// DashboardAggregationRetentionConfig 预聚合保留窗口
type DashboardAggregationRetentionConfig struct {
UsageLogsDays int `mapstructure:"usage_logs_days"`
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
HourlyDays int `mapstructure:"hourly_days"`
DailyDays int `mapstructure:"daily_days"`
UsageLogsDays int `mapstructure:"usage_logs_days"`
HourlyDays int `mapstructure:"hourly_days"`
DailyDays int `mapstructure:"daily_days"`
}
// UsageCleanupConfig 使用记录清理任务配置
@@ -1302,7 +1301,6 @@ func setDefaults() {
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
@@ -1760,12 +1758,6 @@ func (c *Config) Validate() error {
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
}
if c.DashboardAgg.Retention.HourlyDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
}
@@ -1788,14 +1780,6 @@ func (c *Config) Validate() error {
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
}
if c.DashboardAgg.Retention.HourlyDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
}

View File

@@ -441,9 +441,6 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
}
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
}
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
}
@@ -1019,23 +1016,6 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
wantErr: "dashboard_aggregation.retention.usage_logs_days",
},
{
name: "dashboard aggregation dedup retention",
mutate: func(c *Config) {
c.DashboardAgg.Enabled = true
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
},
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
},
{
name: "dashboard aggregation dedup retention smaller than usage logs",
mutate: func(c *Config) {
c.DashboardAgg.Enabled = true
c.DashboardAgg.Retention.UsageLogsDays = 30
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
},
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
},
{
name: "dashboard aggregation disabled interval",
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },

View File

@@ -27,12 +27,10 @@ const (
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock
AccountTypeBedrockAPIKey = "bedrock-apikey" // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
)
// Redeem type constants
@@ -115,27 +113,3 @@ var DefaultAntigravityModelMapping = map[string]string{
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview",
}
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
// 注意:此处的 "us." 前缀仅为默认值ResolveBedrockModelID 会根据账号配置的
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
var DefaultBedrockModelMapping = map[string]string{
// Claude Opus
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
"claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0",
"claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0",
"claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0",
// Claude Sonnet
"claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6",
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
"claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0",
// Claude Haiku
"claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
"claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
}

View File

@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
type UpdateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -865,9 +865,6 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv
}
}
// OpenAI OAuth: 刷新成功后检查并设置 privacy_mode
h.adminService.EnsureOpenAIPrivacy(ctx, updatedAccount)
return updatedAccount, "", nil
}
@@ -1341,6 +1338,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
})
return
}

View File

@@ -111,7 +111,7 @@ func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
require.Contains(t, resp["message"], "claude-max")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
@@ -140,7 +140,7 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
require.Contains(t, resp["message"], "claude-max")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)

View File

@@ -179,14 +179,6 @@ func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) (
return nil, nil
}
func (s *stubAdminService) ClearGroupRateMultipliers(_ context.Context, _ int64) error {
return nil
}
func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int64, _ []service.GroupRateMultiplierInput) error {
return nil
}
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil
}
@@ -441,9 +433,5 @@ func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) erro
return nil
}
func (s *stubAdminService) EnsureOpenAIPrivacy(ctx context.Context, account *service.Account) string {
return ""
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -466,60 +466,9 @@ type BatchUsersUsageRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
}
var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
func parseRankingLimit(raw string) int {
limit, err := strconv.Atoi(strings.TrimSpace(raw))
if err != nil || limit <= 0 {
return 12
}
if limit > 50 {
return 50
}
return limit
}
// GetUserSpendingRanking handles getting user spending ranking data.
// GET /api/v1/admin/dashboard/users-ranking
func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
limit := parseRankingLimit(c.DefaultQuery("limit", "12"))
keyRaw, _ := json.Marshal(struct {
Start string `json:"start"`
End string `json:"end"`
Limit int `json:"limit"`
}{
Start: startTime.UTC().Format(time.RFC3339),
End: endTime.UTC().Format(time.RFC3339),
Limit: limit,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit)
if err != nil {
response.Error(c, 500, "Failed to get user spending ranking")
return
}
payload := gin.H{
"ranking": ranking.Ranking,
"total_actual_cost": ranking.TotalActualCost,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
}
dashboardUsersRankingCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
// GetBatchUsersUsage handles getting usage stats for multiple users
// POST /api/v1/admin/dashboard/users-usage
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {

View File

@@ -19,9 +19,6 @@ type dashboardUsageRepoCapture struct {
trendStream *bool
modelRequestType *int16
modelStream *bool
rankingLimit int
ranking []usagestats.UserSpendingRankingItem
rankingTotal float64
}
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
@@ -52,18 +49,6 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
return []usagestats.ModelStat{}, nil
}
func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
ctx context.Context,
startTime, endTime time.Time,
limit int,
) (*usagestats.UserSpendingRankingResponse, error) {
s.rankingLimit = limit
return &usagestats.UserSpendingRankingResponse{
Ranking: s.ranking,
TotalActualCost: s.rankingTotal,
}, nil
}
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
@@ -71,7 +56,6 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng
router := gin.New()
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
router.GET("/admin/dashboard/models", handler.GetModelStats)
router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking)
return router
}
@@ -146,30 +130,3 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
repo := &dashboardUsageRepoCapture{
ranking: []usagestats.UserSpendingRankingItem{
{UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300},
},
rankingTotal: 88.8,
}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 50, repo.rankingLimit)
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
require.Equal(t, http.StatusOK, rec2.Code)
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
}

View File

@@ -46,9 +46,10 @@ type CreateGroupRequest struct {
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
// Sora 存储配额
@@ -84,9 +85,10 @@ type UpdateGroupRequest struct {
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"`
// Sora 存储配额
@@ -207,6 +209,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
AllowMessagesDispatch: req.AllowMessagesDispatch,
@@ -260,6 +263,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
AllowMessagesDispatch: req.AllowMessagesDispatch,
@@ -356,51 +360,6 @@ func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
response.Success(c, entries)
}
// ClearGroupRateMultipliers handles clearing all rate multipliers for a group
// DELETE /api/v1/admin/groups/:id/rate-multipliers
func (h *GroupHandler) ClearGroupRateMultipliers(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
if err := h.adminService.ClearGroupRateMultipliers(c.Request.Context(), groupID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Rate multipliers cleared successfully"})
}
// BatchSetGroupRateMultipliersRequest represents batch set rate multipliers request
type BatchSetGroupRateMultipliersRequest struct {
Entries []service.GroupRateMultiplierInput `json:"entries" binding:"required"`
}
// BatchSetGroupRateMultipliers handles batch setting rate multipliers for a group
// PUT /api/v1/admin/groups/:id/rate-multipliers
func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
var req BatchSetGroupRateMultipliersRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.adminService.BatchSetGroupRateMultipliers(c.Request.Context(), groupID, req.Entries); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
}
// UpdateSortOrderRequest represents the request to update group sort orders
type UpdateSortOrderRequest struct {
Updates []struct {

View File

@@ -289,7 +289,6 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
Platform: platform,
Type: "oauth",
Credentials: credentials,
Extra: nil,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,

View File

@@ -41,15 +41,12 @@ type GenerateRedeemCodesRequest struct {
}
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance
type CreateAndRedeemCodeRequest struct {
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance向后兼容
Value float64 `json:"value" binding:"required,gt=0"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
GroupID *int64 `json:"group_id"` // subscription 类型必填
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0
Notes string `json:"notes"`
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value" binding:"required,gt=0"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
Notes string `json:"notes"`
}
// List handles listing all redeem codes with pagination
@@ -139,22 +136,6 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
return
}
req.Code = strings.TrimSpace(req.Code)
// 向后兼容:旧版调用方(如 Sub2ApiPay不传 type 字段,默认当作 balance 充值处理。
// 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。
if req.Type == "" {
req.Type = "balance"
}
if req.Type == "subscription" {
if req.GroupID == nil {
response.BadRequest(c, "group_id is required for subscription type")
return
}
if req.ValidityDays <= 0 {
response.BadRequest(c, "validity_days must be greater than 0 for subscription type")
return
}
}
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
existing, err := h.redeemService.GetByCode(ctx, req.Code)
@@ -166,13 +147,11 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
}
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
Code: req.Code,
Type: req.Type,
Value: req.Value,
Status: service.StatusUnused,
Notes: req.Notes,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
Code: req.Code,
Type: req.Type,
Value: req.Value,
Status: service.StatusUnused,
Notes: req.Notes,
})
if createErr != nil {
// Unique code race: if code now exists, use idempotent semantics by used_by.

View File

@@ -1,135 +0,0 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal)
// RedeemService so that CreateAndRedeem's nil guard passes and we can test the
// parameter-validation layer that runs before any service call.
func newCreateAndRedeemHandler() *RedeemHandler {
return &RedeemHandler{
adminService: newStubAdminService(),
redeemService: &service.RedeemService{}, // non-nil to pass nil guard
}
}
// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response
// status code. For cases that pass validation and proceed into the service layer,
// a panic may occur (because RedeemService internals are nil); this is expected
// and treated as "validation passed" (returns 0 to indicate panic).
func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
jsonBytes, err := json.Marshal(body)
require.NoError(t, err)
c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes))
c.Request.Header.Set("Content-Type", "application/json")
defer func() {
if r := recover(); r != nil {
// Panic means we passed validation and entered service layer (expected for minimal stub).
code = 0
}
}()
handler.CreateAndRedeem(c)
return w.Code
}
func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) {
// 不传 type 字段时应默认 balance不触发 subscription 校验。
// 验证通过后进入 service 层会 panic返回 0说明默认值生效。
h := newCreateAndRedeemHandler()
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-balance-default",
"value": 10.0,
"user_id": 1,
})
assert.NotEqual(t, http.StatusBadRequest, code,
"omitting type should default to balance and pass validation")
}
func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
h := newCreateAndRedeemHandler()
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-sub-no-group",
"type": "subscription",
"value": 29.9,
"user_id": 1,
"validity_days": 30,
// group_id 缺失
})
assert.Equal(t, http.StatusBadRequest, code)
}
func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) {
groupID := int64(5)
h := newCreateAndRedeemHandler()
cases := []struct {
name string
validityDays int
}{
{"zero", 0},
{"negative", -1},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-sub-bad-days-" + tc.name,
"type": "subscription",
"value": 29.9,
"user_id": 1,
"group_id": groupID,
"validity_days": tc.validityDays,
})
assert.Equal(t, http.StatusBadRequest, code)
})
}
}
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
groupID := int64(5)
h := newCreateAndRedeemHandler()
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-sub-valid",
"type": "subscription",
"value": 29.9,
"user_id": 1,
"group_id": groupID,
"validity_days": 31,
})
assert.NotEqual(t, http.StatusBadRequest, code,
"valid subscription params should pass validation")
}
func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
h := newCreateAndRedeemHandler()
// balance 类型不传 group_id 和 validity_days不应报 400
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-balance-no-extras",
"type": "balance",
"value": 50.0,
"user_id": 1,
})
assert.NotEqual(t, http.StatusBadRequest, code,
"balance type should not require group_id or validity_days")
}

View File

@@ -218,12 +218,11 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
// ResetSubscriptionQuotaRequest represents the reset quota request
type ResetSubscriptionQuotaRequest struct {
Daily bool `json:"daily"`
Weekly bool `json:"weekly"`
Monthly bool `json:"monthly"`
Daily bool `json:"daily"`
Weekly bool `json:"weekly"`
}
// ResetQuota resets daily, weekly, and/or monthly usage for a subscription.
// ResetQuota resets daily and/or weekly usage for a subscription.
// POST /api/v1/admin/subscriptions/:id/reset-quota
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
@@ -236,11 +235,11 @@ func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !req.Daily && !req.Weekly && !req.Monthly {
response.BadRequest(c, "At least one of 'daily', 'weekly', or 'monthly' must be true")
if !req.Daily && !req.Weekly {
response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true")
return
}
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly, req.Monthly)
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -135,14 +135,15 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
return nil
}
out := &AdminGroup{
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
SortOrder: g.SortOrder,
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel,
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
SortOrder: g.SortOrder,
}
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))

View File

@@ -117,6 +117,8 @@ type AdminGroup struct {
// MCP XML 协议注入(仅 antigravity 平台使用)
MCPXMLInject bool `json:"mcp_xml_inject"`
// Claude usage 模拟开关(仅管理员可见)
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
DefaultMappedModel string `json:"default_mapped_model"`

View File

@@ -434,21 +434,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
Result: result,
ParsedRequest: parsedReq,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),
@@ -632,6 +631,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// ===== 用户消息串行队列 END =====
// 转发请求 - 根据账号平台分流
c.Set("parsed_request", parsedReq)
var result *service.ForwardResult
requestCtx := c.Request.Context()
if fs.SwitchCount > 0 {
@@ -738,21 +738,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: account,
Subscription: currentSubscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
Result: result,
ParsedRequest: parsedReq,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: account,
Subscription: currentSubscription,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),

View File

@@ -139,7 +139,6 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // accountRepo (not used: scheduler snapshot hit)
&fakeGroupRepo{group: group},
nil, // usageLogRepo
nil, // usageBillingRepo
nil, // userRepo
nil, // userSubRepo
nil, // userGroupRateRepo

View File

@@ -503,7 +503,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
requestPayloadHash := service.HashUsageRequestPayload(body)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
@@ -513,7 +512,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fs.ForceCacheBilling,

View File

@@ -352,20 +352,18 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
@@ -734,19 +732,17 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.messages"),
@@ -1235,15 +1231,14 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
h.submitUsageRecordTask(func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),

View File

@@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService用于测试 SelectAccountForModel
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
accountRepo, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
)
}

View File

@@ -399,19 +399,17 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
}); err != nil {
logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"),

View File

@@ -343,9 +343,6 @@ func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, e
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, nil
}
@@ -434,7 +431,6 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
nil,
nil,
nil,
nil,
testutil.StubGatewayCache{},
cfg,
nil,

View File

@@ -189,5 +189,6 @@ var DefaultStopSequences = []string{
"<|user|>",
"<|endoftext|>",
"<|end_of_turn|>",
"[DONE]",
"\n\nHuman:",
}

View File

@@ -18,6 +18,9 @@ const (
BlockTypeFunction
)
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
type UsageMapHook func(usageMap map[string]any)
// StreamingProcessor 流式响应处理器
type StreamingProcessor struct {
blockType BlockType
@@ -30,6 +33,7 @@ type StreamingProcessor struct {
originalModel string
webSearchQueries []string
groundingChunks []GeminiGroundingChunk
usageMapHook UsageMapHook
// 累计 usage
inputTokens int
@@ -45,6 +49,25 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
}
}
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
p.usageMapHook = fn
}
func usageToMap(u ClaudeUsage) map[string]any {
m := map[string]any{
"input_tokens": u.InputTokens,
"output_tokens": u.OutputTokens,
}
if u.CacheCreationInputTokens > 0 {
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
}
if u.CacheReadInputTokens > 0 {
m["cache_read_input_tokens"] = u.CacheReadInputTokens
}
return m
}
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
func (p *StreamingProcessor) ProcessLine(line string) []byte {
line = strings.TrimSpace(line)
@@ -168,6 +191,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
responseID = "msg_" + generateRandomID()
}
var usageValue any = usage
if p.usageMapHook != nil {
usageMap := usageToMap(usage)
p.usageMapHook(usageMap)
usageValue = usageMap
}
message := map[string]any{
"id": responseID,
"type": "message",
@@ -176,7 +206,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
"model": p.originalModel,
"stop_reason": nil,
"stop_sequence": nil,
"usage": usage,
"usage": usageValue,
}
event := map[string]any{
@@ -487,13 +517,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
CacheReadInputTokens: p.cacheReadTokens,
}
var usageValue any = usage
if p.usageMapHook != nil {
usageMap := usageToMap(usage)
p.usageMapHook(usageMap)
usageValue = usageMap
}
deltaEvent := map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": stopReason,
"stop_sequence": nil,
},
"usage": usage,
"usage": usageValue,
}
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))

View File

@@ -96,28 +96,12 @@ type UserUsageTrendPoint struct {
Date string `json:"date"`
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserSpendingRankingItem represents a user spending ranking row.
type UserSpendingRankingItem struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
ActualCost float64 `json:"actual_cost"` // 实际扣除
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}
// UserSpendingRankingResponse represents ranking rows plus total spend for the time range.
type UserSpendingRankingResponse struct {
Ranking []UserSpendingRankingItem `json:"ranking"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// APIKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint struct {
Date string `json:"date"`

View File

@@ -164,6 +164,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldModelRoutingEnabled,
group.FieldModelRouting,
group.FieldMcpXMLInject,
group.FieldSimulateClaudeMaxEnabled,
group.FieldSupportedModelScopes,
group.FieldAllowMessagesDispatch,
group.FieldDefaultMappedModel,
@@ -645,6 +646,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.McpXMLInject,
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
SupportedModelScopes: g.SupportedModelScopes,
SortOrder: g.SortOrder,
AllowMessagesDispatch: g.AllowMessagesDispatch,

View File

@@ -17,9 +17,6 @@ type dashboardAggregationRepository struct {
sql sqlExecutor
}
const usageLogsCleanupBatchSize = 10000
const usageBillingDedupCleanupBatchSize = 10000
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
if sqlDB == nil {
@@ -45,9 +42,6 @@ func isPostgresDriver(db *sql.DB) bool {
}
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
if r == nil || r.sql == nil {
return nil
}
loc := timezone.Location()
startLocal := start.In(loc)
endLocal := end.In(loc)
@@ -67,22 +61,6 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
dayEnd = dayEnd.Add(24 * time.Hour)
}
if db, ok := r.sql.(*sql.DB); ok {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
}
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
@@ -217,58 +195,8 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c
if isPartitioned {
return r.dropUsageLogsPartitions(ctx, cutoff)
}
for {
res, err := r.sql.ExecContext(ctx, `
WITH victims AS (
SELECT ctid
FROM usage_logs
WHERE created_at < $1
LIMIT $2
)
DELETE FROM usage_logs
WHERE ctid IN (SELECT ctid FROM victims)
`, cutoff.UTC(), usageLogsCleanupBatchSize)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected < usageLogsCleanupBatchSize {
return nil
}
}
}
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
for {
res, err := r.sql.ExecContext(ctx, `
WITH victims AS (
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
FROM usage_billing_dedup
WHERE created_at < $1
LIMIT $2
), archived AS (
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
SELECT request_id, api_key_id, request_fingerprint, created_at
FROM victims
ON CONFLICT (request_id, api_key_id) DO NOTHING
)
DELETE FROM usage_billing_dedup
WHERE ctid IN (SELECT ctid FROM victims)
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected < usageBillingDedupCleanupBatchSize {
return nil
}
}
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
return err
}
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {

View File

@@ -262,42 +262,6 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se
SetKey(k.Key).
SetName(k.Name).
SetStatus(k.Status)
if k.Quota != 0 {
create.SetQuota(k.Quota)
}
if k.QuotaUsed != 0 {
create.SetQuotaUsed(k.QuotaUsed)
}
if k.RateLimit5h != 0 {
create.SetRateLimit5h(k.RateLimit5h)
}
if k.RateLimit1d != 0 {
create.SetRateLimit1d(k.RateLimit1d)
}
if k.RateLimit7d != 0 {
create.SetRateLimit7d(k.RateLimit7d)
}
if k.Usage5h != 0 {
create.SetUsage5h(k.Usage5h)
}
if k.Usage1d != 0 {
create.SetUsage1d(k.Usage1d)
}
if k.Usage7d != 0 {
create.SetUsage7d(k.Usage7d)
}
if k.Window5hStart != nil {
create.SetWindow5hStart(*k.Window5hStart)
}
if k.Window1dStart != nil {
create.SetWindow1dStart(*k.Window1dStart)
}
if k.Window7dStart != nil {
create.SetWindow7dStart(*k.Window7dStart)
}
if k.ExpiresAt != nil {
create.SetExpiresAt(*k.ExpiresAt)
}
if k.GroupID != nil {
create.SetGroupID(*k.GroupID)
}

View File

@@ -61,7 +61,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetDefaultMappedModel(groupIn.DefaultMappedModel)
SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
@@ -129,7 +130,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetDefaultMappedModel(groupIn.DefaultMappedModel)
SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
// 显式处理可空字段nil 需要 clear非 nil 需要 set。
if groupIn.DailyLimitUSD != nil {

View File

@@ -45,20 +45,6 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
// usage_billing_dedup: billing idempotency narrow table
var usageBillingDedupRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
var usageBillingDedupArchiveRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
// settings table should exist
var settingsRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
@@ -89,23 +75,6 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
var exists bool
err := tx.QueryRowContext(context.Background(), `
SELECT EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = $1
AND indexname = $2
)
`, table, index).Scan(&exists)
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
require.True(t, exists, "expected index %s on %s", index, table)
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper()

View File

@@ -73,14 +73,3 @@ func buildReqClientKey(opts reqClientOptions) string {
opts.ForceHTTP2,
)
}
// CreatePrivacyReqClient creates an HTTP client for OpenAI privacy settings API
// This is exported for use by OpenAIPrivacyService
// Uses Chrome TLS fingerprint impersonation to bypass Cloudflare checks
func CreatePrivacyReqClient(proxyURL string) (*req.Client, error) {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
Impersonate: true, // Enable Chrome TLS fingerprint impersonation
})
}

View File

@@ -1,308 +0,0 @@
package repository
import (
"context"
"database/sql"
"errors"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type usageBillingRepository struct {
db *sql.DB
}
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
return &usageBillingRepository{db: sqlDB}
}
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
if cmd == nil {
return &service.UsageBillingApplyResult{}, nil
}
if r == nil || r.db == nil {
return nil, errors.New("usage billing repository db is nil")
}
cmd.Normalize()
if cmd.RequestID == "" {
return nil, service.ErrUsageBillingRequestIDRequired
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
if err != nil {
return nil, err
}
if !applied {
return &service.UsageBillingApplyResult{Applied: false}, nil
}
result := &service.UsageBillingApplyResult{Applied: true}
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
tx = nil
return result, nil
}
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
var id int64
err := tx.QueryRowContext(ctx, `
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
VALUES ($1, $2, $3)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
if errors.Is(err, sql.ErrNoRows) {
var existingFingerprint string
if err := tx.QueryRowContext(ctx, `
SELECT request_fingerprint
FROM usage_billing_dedup
WHERE request_id = $1 AND api_key_id = $2
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
return false, err
}
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
return false, service.ErrUsageBillingRequestConflict
}
return false, nil
}
if err != nil {
return false, err
}
var archivedFingerprint string
err = tx.QueryRowContext(ctx, `
SELECT request_fingerprint
FROM usage_billing_dedup_archive
WHERE request_id = $1 AND api_key_id = $2
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
if err == nil {
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
return false, service.ErrUsageBillingRequestConflict
}
return false, nil
}
if !errors.Is(err, sql.ErrNoRows) {
return false, err
}
return true, nil
}
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
return err
}
}
if cmd.BalanceCost > 0 {
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
return err
}
}
if cmd.APIKeyQuotaCost > 0 {
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
if err != nil {
return err
}
result.APIKeyQuotaExhausted = exhausted
}
if cmd.APIKeyRateLimitCost > 0 {
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
return err
}
}
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
return err
}
}
return nil
}
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
const updateSQL = `
UPDATE user_subscriptions us
SET
daily_usage_usd = us.daily_usage_usd + $1,
weekly_usage_usd = us.weekly_usage_usd + $1,
monthly_usage_usd = us.monthly_usage_usd + $1,
updated_at = NOW()
FROM groups g
WHERE us.id = $2
AND us.deleted_at IS NULL
AND us.group_id = g.id
AND g.deleted_at IS NULL
`
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
return service.ErrSubscriptionNotFound
}
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
res, err := tx.ExecContext(ctx, `
UPDATE users
SET balance = balance - $1,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
`, amount, userID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
return service.ErrUserNotFound
}
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
var exhausted bool
err := tx.QueryRowContext(ctx, `
UPDATE api_keys
SET quota_used = quota_used + $1,
status = CASE
WHEN quota > 0
AND status = $3
AND quota_used < quota
AND quota_used + $1 >= quota
THEN $4
ELSE status
END,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
if errors.Is(err, sql.ErrNoRows) {
return false, service.ErrAPIKeyNotFound
}
if err != nil {
return false, err
}
return exhausted, nil
}
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
res, err := tx.ExecContext(ctx, `
UPDATE api_keys SET
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
`, cost, apiKeyID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAPIKeyNotFound
}
return nil
}
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
rows, err := tx.QueryContext(ctx,
`UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb)
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_daily_used',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
)
ELSE '{}'::jsonb END
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_weekly_used',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
)
ELSE '{}'::jsonb END
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING
COALESCE((extra->>'quota_used')::numeric, 0),
COALESCE((extra->>'quota_limit')::numeric, 0)`,
amount, accountID)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
var newUsed, limit float64
if rows.Next() {
if err := rows.Scan(&newUsed, &limit); err != nil {
return err
}
} else {
if err := rows.Err(); err != nil {
return err
}
return service.ErrAccountNotFound
}
if err := rows.Err(); err != nil {
return err
}
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
return err
}
}
return nil
}

View File

@@ -1,279 +0,0 @@
//go:build integration
package repository
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-" + uuid.NewString(),
Name: "billing",
Quota: 1,
})
account := mustCreateAccount(t, client, &service.Account{
Name: "usage-billing-account-" + uuid.NewString(),
Type: service.AccountTypeAPIKey,
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: account.ID,
AccountType: service.AccountTypeAPIKey,
BalanceCost: 1.25,
APIKeyQuotaCost: 1.25,
APIKeyRateLimitCost: 1.25,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.NotNil(t, result1)
require.True(t, result1.Applied)
require.True(t, result1.APIKeyQuotaExhausted)
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.NotNil(t, result2)
require.False(t, result2.Applied)
var balance float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
require.InDelta(t, 98.75, balance, 0.000001)
var quotaUsed float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan(&quotaUsed))
require.InDelta(t, 1.25, quotaUsed, 0.000001)
var usage5h float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h))
require.InDelta(t, 1.25, usage5h, 0.000001)
var status string
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status))
require.Equal(t, service.StatusAPIKeyQuotaExhausted, status)
var dedupCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount))
require.Equal(t, 1, dedupCount)
}
func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
})
group := mustCreateGroup(t, client, &service.Group{
Name: "usage-billing-group-" + uuid.NewString(),
Platform: service.PlatformAnthropic,
SubscriptionType: service.SubscriptionTypeSubscription,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
GroupID: &group.ID,
Key: "sk-usage-billing-sub-" + uuid.NewString(),
Name: "billing-sub",
})
subscription := mustCreateSubscription(t, client, &service.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: 0,
SubscriptionID: &subscription.ID,
SubscriptionCost: 2.5,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.True(t, result1.Applied)
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.False(t, result2.Applied)
var dailyUsage float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage))
require.InDelta(t, 2.5, dailyUsage, 0.000001)
}
func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-conflict-" + uuid.NewString(),
Name: "billing-conflict",
})
requestID := uuid.NewString()
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 1.25,
})
require.NoError(t, err)
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 2.50,
})
require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict)
}
func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-account-" + uuid.NewString(),
Name: "billing-account",
})
account := mustCreateAccount(t, client, &service.Account{
Name: "usage-billing-account-quota-" + uuid.NewString(),
Type: service.AccountTypeAPIKey,
Extra: map[string]any{
"quota_limit": 100.0,
},
})
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: uuid.NewString(),
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: account.ID,
AccountType: service.AccountTypeAPIKey,
AccountQuotaCost: 3.5,
})
require.NoError(t, err)
var quotaUsed float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan(&quotaUsed))
require.InDelta(t, 3.5, quotaUsed, 0.000001)
}
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
ctx := context.Background()
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
oldRequestID := "dedup-old-" + uuid.NewString()
newRequestID := "dedup-new-" + uuid.NewString()
oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400)
newCreatedAt := time.Now().UTC().Add(-time.Hour)
_, err := integrationDB.ExecContext(ctx, `
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at)
VALUES ($1, 1, $2, $3), ($4, 1, $5, $6)
`,
oldRequestID, strings.Repeat("a", 64), oldCreatedAt,
newRequestID, strings.Repeat("b", 64), newCreatedAt,
)
require.NoError(t, err)
require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
var oldCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount))
require.Equal(t, 0, oldCount)
var newCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount))
require.Equal(t, 1, newCount)
var archivedCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount))
require.Equal(t, 1, archivedCount)
}
func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-archive-" + uuid.NewString(),
Name: "billing-archive",
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 1.25,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.True(t, result1.Applied)
_, err = integrationDB.ExecContext(ctx, `
UPDATE usage_billing_dedup
SET created_at = $1
WHERE request_id = $2 AND api_key_id = $3
`, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID)
require.NoError(t, err)
require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.False(t, result2.Applied)
var balance float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
require.InDelta(t, 98.75, balance, 0.000001)
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,8 +4,6 @@ package repository
import (
"context"
"fmt"
"sync"
"testing"
"time"
@@ -16,7 +14,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@@ -87,367 +84,6 @@ func (s *UsageLogRepoSuite) TestCreate() {
s.Require().NotZero(log.ID)
}
func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()})
const total = 16
results := make([]bool, total)
errs := make([]error, total)
logs := make([]*service.UsageLog, total)
var wg sync.WaitGroup
wg.Add(total)
for i := 0; i < total; i++ {
i := i
logs[i] = &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10 + i,
OutputTokens: 20 + i,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
go func() {
defer wg.Done()
results[i], errs[i] = repo.Create(ctx, logs[i])
}()
}
wg.Wait()
for i := 0; i < total; i++ {
require.NoError(t, errs[i])
require.True(t, results[i])
require.NotZero(t, logs[i].ID)
}
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count))
require.Equal(t, total, count)
}
func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()})
requestID := uuid.NewString()
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
inserted1, err1 := repo.Create(ctx, log1)
inserted2, err2 := repo.Create(ctx, log2)
require.NoError(t, err1)
require.NoError(t, err2)
require.True(t, inserted1)
require.False(t, inserted2)
require.Equal(t, log1.ID, log2.ID)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
require.Equal(t, 1, count)
}
func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()})
requestID := uuid.NewString()
const total = 8
batch := make([]usageLogCreateRequest, 0, total)
logs := make([]*service.UsageLog, 0, total)
for i := 0; i < total; i++ {
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10 + i,
OutputTokens: 20 + i,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
logs = append(logs, log)
batch = append(batch, usageLogCreateRequest{
log: log,
prepared: prepareUsageLogInsert(log),
resultCh: make(chan usageLogCreateResult, 1),
})
}
repo.flushCreateBatch(integrationDB, batch)
insertedCount := 0
var firstID int64
for idx, req := range batch {
res := <-req.resultCh
require.NoError(t, res.err)
if res.inserted {
insertedCount++
}
require.NotZero(t, logs[idx].ID)
if idx == 0 {
firstID = logs[idx].ID
} else {
require.Equal(t, firstID, logs[idx].ID)
}
}
require.Equal(t, 1, insertedCount)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
require.Equal(t, 1, count)
}
func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()})
requestID := uuid.NewString()
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
require.NoError(t, repo.CreateBestEffort(ctx, log1))
require.NoError(t, repo.CreateBestEffort(ctx, log2))
require.Eventually(t, func() bool {
var count int
err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)
return err == nil && count == 1
}, 3*time.Second, 20*time.Millisecond)
}
func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1)
repo.bestEffortBatchCh <- usageLogBestEffortRequest{}
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()})
err := repo.CreateBestEffort(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.Error(t, err)
require.True(t, service.IsUsageLogCreateDropped(err))
}
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()})
ctx, cancel := context.WithCancel(context.Background())
cancel()
inserted, err := repo.Create(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.False(t, inserted)
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
}
func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
repo.createBatchCh <- usageLogCreateRequest{}
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()})
inserted, err := repo.Create(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.False(t, inserted)
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
}
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()})
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error, 1)
go func() {
_, err := repo.createBatched(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
errCh <- err
}()
req := <-repo.createBatchCh
require.NotNil(t, req.shared)
cancel()
err := <-errCh
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)})
}
func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()})
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
req := usageLogCreateRequest{
log: log,
prepared: prepareUsageLogInsert(log),
shared: &usageLogCreateShared{},
resultCh: make(chan usageLogCreateResult, 1),
}
req.shared.state.Store(usageLogCreateStateCanceled)
repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req})
res := <-req.resultCh
require.False(t, res.inserted)
require.Error(t, res.err)
require.True(t, service.IsUsageLogCreateNotPersisted(res.err))
}
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})

View File

@@ -248,35 +248,6 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}).
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0).
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0).
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0)
mock.ExpectQuery("WITH user_spend AS \\(").
WithArgs(start, end, 12).
WillReturnRows(rows)
got, err := repo.GetUserSpendingRanking(context.Background(), start, end, 12)
require.NoError(t, err)
require.Equal(t, &usagestats.UserSpendingRankingResponse{
Ranking: []usagestats.UserSpendingRankingItem{
{UserID: 2, Email: "beta@example.com", ActualCost: 12.5, Requests: 9, Tokens: 900},
{UserID: 1, Email: "alpha@example.com", ActualCost: 12.5, Requests: 8, Tokens: 800},
{UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300},
},
TotalActualCost: 40.0,
}, got)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
tests := []struct {
name string

View File

@@ -3,11 +3,8 @@
package repository
import (
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
@@ -42,26 +39,3 @@ func TestSafeDateFormat(t *testing.T) {
})
}
}
func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) {
log := &service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-batch-no-update",
Model: "gpt-5",
InputTokens: 10,
OutputTokens: 5,
TotalCost: 1.2,
ActualCost: 1.2,
CreatedAt: time.Now().UTC(),
}
prepared := prepareUsageLogInsert(log)
query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{
usageLogBatchKey(log.RequestID, log.APIKeyID): prepared,
})
require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING")
require.NotContains(t, strings.ToUpper(query), "DO UPDATE")
}

View File

@@ -98,9 +98,9 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
// GetByGroupID 获取指定分组下所有用户的专属倍率
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
query := `
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
SELECT ugr.user_id, u.email, ugr.rate_multiplier
FROM user_group_rate_multipliers ugr
JOIN users u ON u.id = ugr.user_id
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
WHERE ugr.group_id = $1
ORDER BY ugr.user_id
`
@@ -113,7 +113,7 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
var result []service.UserGroupRateEntry
for rows.Next() {
var entry service.UserGroupRateEntry
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil {
if err := rows.Scan(&entry.UserID, &entry.UserEmail, &entry.RateMultiplier); err != nil {
return nil, err
}
result = append(result, entry)
@@ -193,31 +193,6 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
return nil
}
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil {
return err
}
if len(entries) == 0 {
return nil
}
userIDs := make([]int64, len(entries))
rates := make([]float64, len(entries))
for i, e := range entries {
userIDs[i] = e.UserID
rates[i] = e.RateMultiplier
}
now := time.Now()
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
SELECT data.user_id, $1::bigint, data.rate_multiplier, $2::timestamptz, $2::timestamptz
FROM unnest($3::bigint[], $4::double precision[]) AS data(user_id, rate_multiplier)
ON CONFLICT (user_id, group_id)
DO UPDATE SET rate_multiplier = EXCLUDED.rate_multiplier, updated_at = EXCLUDED.updated_at
`, groupID, now, pq.Array(userIDs), pq.Array(rates))
return err
}
// DeleteByGroupID 删除指定分组的所有用户专属倍率
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)

View File

@@ -62,7 +62,6 @@ var ProviderSet = wire.NewSet(
NewAnnouncementRepository,
NewAnnouncementReadRepository,
NewUsageLogRepository,
NewUsageBillingRepository,
NewIdempotencyRepository,
NewUsageCleanupRepository,
NewDashboardAggregationRepository,

View File

@@ -645,7 +645,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
@@ -1635,10 +1635,6 @@ func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, end
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
logs := r.userLogs[userID]
if len(logs) == 0 {

View File

@@ -192,7 +192,6 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
@@ -230,8 +229,6 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers)
groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
}

View File

@@ -412,7 +412,6 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
}
// Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整)
return nil
}
if len(rawMapping) == 0 {
@@ -765,14 +764,6 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
return false
}
func (a *Account) IsBedrock() bool {
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeBedrock || a.Type == AccountTypeBedrockAPIKey)
}
func (a *Account) IsBedrockAPIKey() bool {
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrockAPIKey
}
func (a *Account) IsOpenAI() bool {
return a.Platform == PlatformOpenAI
}

View File

@@ -207,14 +207,14 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
testModelID = claude.DefaultTestModel
}
// API Key 账号测试连接时也需要应用通配符模型映射。
// For API Key accounts with model mapping, map the model
if account.Type == "apikey" {
testModelID = account.GetMappedModel(testModelID)
}
// Bedrock accounts use a separate test path
if account.IsBedrock() {
return s.testBedrockAccountConnection(c, ctx, account, testModelID)
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
}
}
// Determine authentication method and API URL
@@ -312,109 +312,6 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.processClaudeStream(c, resp.Body)
}
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
region := bedrockRuntimeRegion(account)
resolvedModelID, ok := ResolveBedrockModelID(account, testModelID)
if !ok {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Bedrock model: %s", testModelID))
}
testModelID = resolvedModelID
// Set SSE headers (test UI expects SSE)
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
// Create a minimal Bedrock-compatible payload (no stream, no cache_control)
bedrockPayload := map[string]any{
"anthropic_version": "bedrock-2023-05-31",
"messages": []map[string]any{
{
"role": "user",
"content": []map[string]any{
{
"type": "text",
"text": "hi",
},
},
},
},
"max_tokens": 256,
"temperature": 1,
}
bedrockBody, _ := json.Marshal(bedrockPayload)
// Use non-streaming endpoint (response is standard Claude JSON)
apiURL := BuildBedrockURL(region, testModelID, false)
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bedrockBody))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
req.Header.Set("Content-Type", "application/json")
// Sign or set auth based on account type
if account.IsBedrockAPIKey() {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
req.Header.Set("Authorization", "Bearer "+apiKey)
} else {
signer, err := NewBedrockSignerFromAccount(account)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Bedrock signer: %s", err.Error()))
}
if err := signer.SignRequest(ctx, req, bedrockBody); err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to sign request: %s", err.Error()))
}
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, false)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}
// Bedrock non-streaming response is standard Claude JSON, extract the text
var result struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
}
if err := json.Unmarshal(body, &result); err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error()))
}
text := ""
if len(result.Content) > 0 {
text = result.Content[0].Text
}
if text == "" {
text = "(empty response)"
}
s.sendEvent(c, TestEvent{Type: "content", Text: text})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
// testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context()

View File

@@ -47,7 +47,6 @@ type UsageLogRepository interface {
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)

View File

@@ -43,8 +43,6 @@ type AdminService interface {
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// API Key management (admin)
@@ -60,8 +58,6 @@ type AdminService interface {
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*Account, error)
SetAccountError(ctx context.Context, id int64, errorMsg string) error
// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号 privacy_mode未设置则尝试关闭训练数据共享并持久化。
EnsureOpenAIPrivacy(ctx context.Context, account *Account) string
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
@@ -143,9 +139,10 @@ type CreateGroupInput struct {
// 无效请求兜底分组 ID仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由
MCPXMLInject *bool
ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由
MCPXMLInject *bool
SimulateClaudeMaxEnabled *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
// Sora 存储配额
@@ -182,9 +179,10 @@ type UpdateGroupInput struct {
// 无效请求兜底分组 ID仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由
MCPXMLInject *bool
ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由
MCPXMLInject *bool
SimulateClaudeMaxEnabled *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
// Sora 存储配额
@@ -368,6 +366,10 @@ type ProxyExitInfoProber interface {
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
}
type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
type proxyQualityTarget struct {
Target string
URL string
@@ -438,17 +440,12 @@ type adminServiceImpl struct {
settingService *SettingService
defaultSubAssigner DefaultSubscriptionAssigner
userSubRepo UserSubscriptionRepository
privacyClientFactory PrivacyClientFactory
}
type userGroupRateBatchReader interface {
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
}
type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAdminService creates a new AdminService
func NewAdminService(
userRepo UserRepository,
@@ -467,7 +464,6 @@ func NewAdminService(
settingService *SettingService,
defaultSubAssigner DefaultSubscriptionAssigner,
userSubRepo UserSubscriptionRepository,
privacyClientFactory PrivacyClientFactory,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
@@ -486,7 +482,6 @@ func NewAdminService(
settingService: settingService,
defaultSubAssigner: defaultSubAssigner,
userSubRepo: userSubRepo,
privacyClientFactory: privacyClientFactory,
}
}
@@ -868,6 +863,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
if input.MCPXMLInject != nil {
mcpXMLInject = *input.MCPXMLInject
}
simulateClaudeMaxEnabled := false
if input.SimulateClaudeMaxEnabled != nil {
if platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled {
return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups")
}
simulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled
}
// 如果指定了复制账号的源分组,先获取账号 ID 列表
var accountIDsToCopy []int64
@@ -924,6 +926,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject,
SimulateClaudeMaxEnabled: simulateClaudeMaxEnabled,
SupportedModelScopes: input.SupportedModelScopes,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
AllowMessagesDispatch: input.AllowMessagesDispatch,
@@ -1135,6 +1138,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.MCPXMLInject != nil {
group.MCPXMLInject = *input.MCPXMLInject
}
if input.SimulateClaudeMaxEnabled != nil {
if group.Platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled {
return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups")
}
group.SimulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled
}
if group.Platform != PlatformAnthropic {
group.SimulateClaudeMaxEnabled = false
}
// 支持的模型系列(仅 antigravity 平台使用)
if input.SupportedModelScopes != nil {
@@ -1259,20 +1271,6 @@ func (s *adminServiceImpl) GetGroupRateMultipliers(ctx context.Context, groupID
return s.userGroupRateRepo.GetByGroupID(ctx, groupID)
}
func (s *adminServiceImpl) ClearGroupRateMultipliers(ctx context.Context, groupID int64) error {
if s.userGroupRateRepo == nil {
return nil
}
return s.userGroupRateRepo.DeleteByGroupID(ctx, groupID)
}
func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
if s.userGroupRateRepo == nil {
return nil
}
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
}
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return s.groupRepo.UpdateSortOrders(ctx, updates)
}
@@ -2531,39 +2529,3 @@ func (e *MixedChannelError) Error() string {
func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error {
return s.accountRepo.ResetQuotaUsed(ctx, id)
}
// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号是否已设置 privacy_mode
// 未设置则调用 disableOpenAITraining 并持久化到 Extra返回设置的 mode 值。
func (s *adminServiceImpl) EnsureOpenAIPrivacy(ctx context.Context, account *Account) string {
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return ""
}
if s.privacyClientFactory == nil {
return ""
}
if account.Extra != nil {
if _, ok := account.Extra["privacy_mode"]; ok {
return ""
}
}
token, _ := account.Credentials["access_token"].(string)
if token == "" {
return ""
}
var proxyURL string
if account.ProxyID != nil {
if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil {
proxyURL = p.URL()
}
}
mode := disableOpenAITraining(ctx, s.privacyClientFactory, token, proxyURL)
if mode == "" {
return ""
}
_ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode})
return mode
}

View File

@@ -43,6 +43,16 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i
return nil
}
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
if err, ok := s.listByGroupErr[groupID]; ok {
return nil, err
}
if rows, ok := s.listByGroupData[groupID]; ok {
return rows, nil
}
return nil, nil
}
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
s.getByIDsCalled = true
s.getByIDsIDs = append([]int64{}, ids...)
@@ -63,16 +73,6 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac
return nil, errors.New("account not found")
}
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
if err, ok := s.listByGroupErr[groupID]; ok {
return nil, err
}
if rows, ok := s.listByGroupData[groupID]; ok {
return rows, nil
}
return nil, nil
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}

View File

@@ -1,176 +0,0 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
// userGroupRateRepoStubForGroupRate implements UserGroupRateRepository for group rate tests.
type userGroupRateRepoStubForGroupRate struct {
getByGroupIDData map[int64][]UserGroupRateEntry
getByGroupIDErr error
deletedGroupIDs []int64
deleteByGroupErr error
syncedGroupID int64
syncedEntries []GroupRateMultiplierInput
syncGroupErr error
}
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
panic("unexpected GetByUserID call")
}
func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context, _, _ int64) (*float64, error) {
panic("unexpected GetByUserAndGroup call")
}
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
if s.getByGroupIDErr != nil {
return nil, s.getByGroupIDErr
}
return s.getByGroupIDData[groupID], nil
}
func (s *userGroupRateRepoStubForGroupRate) SyncUserGroupRates(_ context.Context, _ int64, _ map[int64]*float64) error {
panic("unexpected SyncUserGroupRates call")
}
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
s.syncedGroupID = groupID
s.syncedEntries = entries
return s.syncGroupErr
}
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
return s.deleteByGroupErr
}
func (s *userGroupRateRepoStubForGroupRate) DeleteByUserID(_ context.Context, _ int64) error {
panic("unexpected DeleteByUserID call")
}
func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
t.Run("returns entries for group", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDData: map[int64][]UserGroupRateEntry{
10: {
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
},
},
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
require.NoError(t, err)
require.Len(t, entries, 2)
require.Equal(t, int64(1), entries[0].UserID)
require.Equal(t, "alice", entries[0].UserName)
require.Equal(t, 1.5, entries[0].RateMultiplier)
require.Equal(t, int64(2), entries[1].UserID)
require.Equal(t, 0.8, entries[1].RateMultiplier)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
svc := &adminServiceImpl{userGroupRateRepo: nil}
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
require.NoError(t, err)
require.Nil(t, entries)
})
t.Run("returns empty slice for group with no entries", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDData: map[int64][]UserGroupRateEntry{},
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
entries, err := svc.GetGroupRateMultipliers(context.Background(), 99)
require.NoError(t, err)
require.Nil(t, entries)
})
t.Run("propagates repo error", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDErr: errors.New("db error"),
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
_, err := svc.GetGroupRateMultipliers(context.Background(), 10)
require.Error(t, err)
require.Contains(t, err.Error(), "db error")
})
}
func TestAdminService_ClearGroupRateMultipliers(t *testing.T) {
t.Run("deletes by group ID", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
require.NoError(t, err)
require.Equal(t, []int64{42}, repo.deletedGroupIDs)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
svc := &adminServiceImpl{userGroupRateRepo: nil}
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
require.NoError(t, err)
})
t.Run("propagates repo error", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
deleteByGroupErr: errors.New("delete failed"),
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
require.Error(t, err)
require.Contains(t, err.Error(), "delete failed")
})
}
func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
t.Run("syncs entries to repo", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
entries := []GroupRateMultiplierInput{
{UserID: 1, RateMultiplier: 1.5},
{UserID: 2, RateMultiplier: 0.8},
}
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, entries)
require.NoError(t, err)
require.Equal(t, int64(10), repo.syncedGroupID)
require.Equal(t, entries, repo.syncedEntries)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
svc := &adminServiceImpl{userGroupRateRepo: nil}
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, nil)
require.NoError(t, err)
})
t.Run("propagates repo error", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
syncGroupErr: errors.New("sync failed"),
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, []GroupRateMultiplierInput{
{UserID: 1, RateMultiplier: 1.0},
})
require.Error(t, err)
require.Contains(t, err.Error(), "sync failed")
})
}

View File

@@ -785,3 +785,57 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
require.NotNil(t, repo.updated)
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_CreateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
enabled := true
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "openai-group",
Platform: PlatformOpenAI,
SimulateClaudeMaxEnabled: &enabled,
})
require.Error(t, err)
require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups")
require.Nil(t, repo.created)
}
func TestAdminService_UpdateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "openai-group",
Platform: PlatformOpenAI,
Status: StatusActive,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
enabled := true
_, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
SimulateClaudeMaxEnabled: &enabled,
})
require.Error(t, err)
require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups")
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_ClearsSimulateClaudeMaxWhenPlatformChanges(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "anthropic-group",
Platform: PlatformAnthropic,
Status: StatusActive,
SimulateClaudeMaxEnabled: true,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
Platform: PlatformOpenAI,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.False(t, repo.updated.SimulateClaudeMaxEnabled)
}

View File

@@ -68,15 +68,11 @@ func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context
panic("unexpected SyncUserGroupRates call")
}
func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, _ int64) ([]UserGroupRateEntry, error) {
func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
panic("unexpected GetByGroupID call")
}
func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.Context, _ int64, _ []GroupRateMultiplierInput) error {
panic("unexpected SyncGroupRateMultipliers call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
panic("unexpected DeleteByGroupID call")
}

View File

@@ -1673,7 +1673,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
var clientDisconnect bool
if claudeReq.Stream {
// 客户端要求流式,直接透传转换
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel, account.ID)
if err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
return nil, err
@@ -1683,7 +1683,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
clientDisconnect = streamRes.clientDisconnect
} else {
// 客户端要求非流式,收集流式响应后转换返回
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel, account.ID)
if err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
return nil, err
@@ -1692,6 +1692,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
firstTokenMs = streamRes.firstTokenMs
}
// Claude Max cache billing: 同步 ForwardResult.Usage 与客户端响应体一致
applyClaudeMaxCacheBillingPolicyToUsage(usage, parsedRequestFromGinContext(c), claudeMaxGroupFromGinContext(c), originalModel, account.ID)
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
@@ -3595,7 +3598,7 @@ func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int,
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
// 用于处理客户端非流式请求但上游只支持流式的情况
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
@@ -3753,6 +3756,9 @@ returnResponse:
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
// Claude Max cache billing simulation (non-streaming)
claudeResp = applyClaudeMaxNonStreamingRewrite(c, claudeResp, agUsage, originalModel, accountID)
c.Data(http.StatusOK, "application/json", claudeResp)
// 转换为 service.ClaudeUsage
@@ -3767,7 +3773,7 @@ returnResponse:
}
// handleClaudeStreamingResponse 处理 Claude 流式响应Gemini SSE → Claude SSE 转换)
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
@@ -3780,6 +3786,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
processor := antigravity.NewStreamingProcessor(originalModel)
setupClaudeMaxStreamingHook(c, processor, originalModel, accountID)
var firstTokenMs *int
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body)

View File

@@ -922,7 +922,7 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
_ = pr.Close()
require.NoError(t, err)
@@ -999,7 +999,7 @@ func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) {
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro")
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro", 0)
_ = pr.Close()
require.NoError(t, err)
@@ -1202,7 +1202,7 @@ func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
_ = pr.Close()
require.NoError(t, err)
@@ -1234,7 +1234,7 @@ func TestHandleClaudeStreamingResponse_EmptyStream(t *testing.T) {
fmt.Fprintln(pw, "")
}()
_, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
_, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
_ = pr.Close()
// 应当返回 UpstreamFailoverError 而非 nil以便上层触发 failover
@@ -1266,7 +1266,7 @@ func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
require.NoError(t, err)
require.NotNil(t, result)

View File

@@ -59,9 +59,10 @@ type APIKeyAuthGroupSnapshot struct {
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty.
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject bool `json:"mcp_xml_inject"`
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject bool `json:"mcp_xml_inject"`
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`

View File

@@ -244,6 +244,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
ModelRouting: apiKey.Group.ModelRouting,
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
MCPXMLInject: apiKey.Group.MCPXMLInject,
SimulateClaudeMaxEnabled: apiKey.Group.SimulateClaudeMaxEnabled,
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
@@ -303,6 +304,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
ModelRouting: snapshot.Group.ModelRouting,
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
MCPXMLInject: snapshot.Group.MCPXMLInject,
SimulateClaudeMaxEnabled: snapshot.Group.SimulateClaudeMaxEnabled,
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
DefaultMappedModel: snapshot.Group.DefaultMappedModel,

View File

@@ -1,607 +0,0 @@
package service
import (
"encoding/json"
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const defaultBedrockRegion = "us-east-1"
var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."}
// BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀
// 参考: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
func BedrockCrossRegionPrefix(region string) string {
switch {
case strings.HasPrefix(region, "us-gov"):
return "us-gov" // GovCloud 使用独立的 us-gov 前缀
case strings.HasPrefix(region, "us-"):
return "us"
case strings.HasPrefix(region, "eu-"):
return "eu"
case region == "ap-northeast-1":
return "jp" // 日本区域使用独立的 jp 前缀AWS 官方定义)
case region == "ap-southeast-2":
return "au" // 澳大利亚区域使用独立的 au 前缀AWS 官方定义)
case strings.HasPrefix(region, "ap-"):
return "apac" // 其余亚太区域使用通用 apac 前缀
case strings.HasPrefix(region, "ca-"):
return "us" // 加拿大区域使用 us 前缀的跨区域推理
case strings.HasPrefix(region, "sa-"):
return "us" // 南美区域使用 us 前缀的跨区域推理
default:
return "us"
}
}
// AdjustBedrockModelRegionPrefix 将模型 ID 的区域前缀替换为与当前 AWS Region 匹配的前缀
// 例如 region=eu-west-1 时,"us.anthropic.claude-opus-4-6-v1" → "eu.anthropic.claude-opus-4-6-v1"
// 特殊值 region="global" 强制使用 global. 前缀
func AdjustBedrockModelRegionPrefix(modelID, region string) string {
var targetPrefix string
if region == "global" {
targetPrefix = "global"
} else {
targetPrefix = BedrockCrossRegionPrefix(region)
}
for _, p := range bedrockCrossRegionPrefixes {
if strings.HasPrefix(modelID, p) {
if p == targetPrefix+"." {
return modelID // 前缀已匹配,无需替换
}
return targetPrefix + "." + modelID[len(p):]
}
}
// 模型 ID 没有已知区域前缀(如 "anthropic.claude-..."),不做修改
return modelID
}
func bedrockRuntimeRegion(account *Account) string {
if account == nil {
return defaultBedrockRegion
}
if region := account.GetCredential("aws_region"); region != "" {
return region
}
return defaultBedrockRegion
}
func shouldForceBedrockGlobal(account *Account) bool {
return account != nil && account.GetCredential("aws_force_global") == "true"
}
func isRegionalBedrockModelID(modelID string) bool {
for _, prefix := range bedrockCrossRegionPrefixes {
if strings.HasPrefix(modelID, prefix) {
return true
}
}
return false
}
func isLikelyBedrockModelID(modelID string) bool {
lower := strings.ToLower(strings.TrimSpace(modelID))
if lower == "" {
return false
}
if strings.HasPrefix(lower, "arn:") {
return true
}
for _, prefix := range []string{
"anthropic.",
"amazon.",
"meta.",
"mistral.",
"cohere.",
"ai21.",
"deepseek.",
"stability.",
"writer.",
"nova.",
} {
if strings.HasPrefix(lower, prefix) {
return true
}
}
return isRegionalBedrockModelID(lower)
}
func normalizeBedrockModelID(modelID string) (normalized string, shouldAdjustRegion bool, ok bool) {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
return "", false, false
}
if mapped, exists := domain.DefaultBedrockModelMapping[modelID]; exists {
return mapped, true, true
}
if isRegionalBedrockModelID(modelID) {
return modelID, true, true
}
if isLikelyBedrockModelID(modelID) {
return modelID, false, true
}
return "", false, false
}
// ResolveBedrockModelID resolves a requested Claude model into a Bedrock model ID.
// It applies account model_mapping first, then default Bedrock aliases, and finally
// adjusts Anthropic cross-region prefixes to match the account region.
func ResolveBedrockModelID(account *Account, requestedModel string) (string, bool) {
if account == nil {
return "", false
}
mappedModel := account.GetMappedModel(requestedModel)
modelID, shouldAdjustRegion, ok := normalizeBedrockModelID(mappedModel)
if !ok {
return "", false
}
if shouldAdjustRegion {
targetRegion := bedrockRuntimeRegion(account)
if shouldForceBedrockGlobal(account) {
targetRegion = "global"
}
modelID = AdjustBedrockModelRegionPrefix(modelID, targetRegion)
}
return modelID, true
}
// BuildBedrockURL 构建 Bedrock InvokeModel 的 URL
// stream=true 时使用 invoke-with-response-stream 端点
// modelID 中的特殊字符会被 URL 编码(与 litellm 的 urllib.parse.quote(safe="") 对齐)
func BuildBedrockURL(region, modelID string, stream bool) string {
if region == "" {
region = defaultBedrockRegion
}
encodedModelID := url.PathEscape(modelID)
// url.PathEscape 不编码冒号RFC 允许 path 中出现 ":"
// 但 AWS Bedrock 期望模型 ID 中的冒号被编码为 %3A
encodedModelID = strings.ReplaceAll(encodedModelID, ":", "%3A")
if stream {
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", region, encodedModelID)
}
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke", region, encodedModelID)
}
// PrepareBedrockRequestBody 处理请求体以适配 Bedrock API
// 1. 注入 anthropic_version
// 2. 注入 anthropic_beta从客户端 anthropic-beta 头解析)
// 3. 移除 Bedrock 不支持的字段model, stream, output_format, output_config
// 4. 移除工具定义中的 custom 字段Claude Code 会发送 custom: {defer_loading: true}
// 5. 清理 cache_control 中 Bedrock 不支持的字段scope, ttl
func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) {
betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID)
return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens)
}
// PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens.
func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) {
var err error
// 注入 anthropic_versionBedrock 要求)
body, err = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31")
if err != nil {
return nil, fmt.Errorf("inject anthropic_version: %w", err)
}
// 注入 anthropic_betaBedrock Invoke 通过请求体传递 beta 头,而非 HTTP 头)
// 1. 从客户端 anthropic-beta header 解析
// 2. 根据请求体内容自动补齐必要的 beta token
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() + _get_tool_search_beta_header_for_bedrock()
if len(betaTokens) > 0 {
body, err = sjson.SetBytes(body, "anthropic_beta", betaTokens)
if err != nil {
return nil, fmt.Errorf("inject anthropic_beta: %w", err)
}
}
// 移除 model 字段Bedrock 通过 URL 指定模型)
body, err = sjson.DeleteBytes(body, "model")
if err != nil {
return nil, fmt.Errorf("remove model field: %w", err)
}
// 移除 stream 字段Bedrock 通过不同端点控制流式,不接受请求体中的 stream 字段)
body, err = sjson.DeleteBytes(body, "stream")
if err != nil {
return nil, fmt.Errorf("remove stream field: %w", err)
}
// 转换 output_formatBedrock Invoke 不支持此字段,但可将 schema 内联到最后一条 user message
// 参考 litellm: _convert_output_format_to_inline_schema()
body = convertOutputFormatToInlineSchema(body)
// 移除 output_config 字段Bedrock Invoke 不支持)
body, err = sjson.DeleteBytes(body, "output_config")
if err != nil {
return nil, fmt.Errorf("remove output_config field: %w", err)
}
// 移除工具定义中的 custom 字段
// Claude Code (v2.1.69+) 在 tool 定义中发送 custom: {defer_loading: true}
// Anthropic API 接受但 Bedrock 会拒绝并报 "Extra inputs are not permitted"
body = removeCustomFieldFromTools(body)
// 清理 cache_control 中 Bedrock 不支持的字段
body = sanitizeBedrockCacheControl(body, modelID)
return body, nil
}
// ResolveBedrockBetaTokens computes the final Bedrock beta token list before policy filtering.
func ResolveBedrockBetaTokens(betaHeader string, body []byte, modelID string) []string {
betaTokens := parseAnthropicBetaHeader(betaHeader)
betaTokens = autoInjectBedrockBetaTokens(betaTokens, body, modelID)
return filterBedrockBetaTokens(betaTokens)
}
// convertOutputFormatToInlineSchema 将 output_format 中的 JSON schema 内联到最后一条 user message
// Bedrock Invoke 不支持 output_format 参数litellm 的做法是将 schema 追加到用户消息中
// 参考: litellm AmazonAnthropicClaudeMessagesConfig._convert_output_format_to_inline_schema()
func convertOutputFormatToInlineSchema(body []byte) []byte {
outputFormat := gjson.GetBytes(body, "output_format")
if !outputFormat.Exists() || !outputFormat.IsObject() {
return body
}
// 先从请求体中移除 output_format
body, _ = sjson.DeleteBytes(body, "output_format")
schema := outputFormat.Get("schema")
if !schema.Exists() {
return body
}
// 找到最后一条 user message
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
msgArr := messages.Array()
lastUserIdx := -1
for i := len(msgArr) - 1; i >= 0; i-- {
if msgArr[i].Get("role").String() == "user" {
lastUserIdx = i
break
}
}
if lastUserIdx < 0 {
return body
}
// 将 schema 序列化为 JSON 文本追加到该 message 的 content 数组
schemaJSON, err := json.Marshal(json.RawMessage(schema.Raw))
if err != nil {
return body
}
content := msgArr[lastUserIdx].Get("content")
basePath := fmt.Sprintf("messages.%d.content", lastUserIdx)
if content.IsArray() {
// 追加一个 text block 到 content 数组末尾
idx := len(content.Array())
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.type", basePath, idx), "text")
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.text", basePath, idx), string(schemaJSON))
} else if content.Type == gjson.String {
// content 是纯字符串,转换为数组格式
originalText := content.String()
body, _ = sjson.SetBytes(body, basePath, []map[string]string{
{"type": "text", "text": originalText},
{"type": "text", "text": string(schemaJSON)},
})
}
return body
}
// removeCustomFieldFromTools 移除 tools 数组中每个工具定义的 custom 字段
func removeCustomFieldFromTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if !tools.Exists() || !tools.IsArray() {
return body
}
var err error
for i := range tools.Array() {
body, err = sjson.DeleteBytes(body, fmt.Sprintf("tools.%d.custom", i))
if err != nil {
// 删除失败不影响整体流程,跳过
continue
}
}
return body
}
// claudeVersionRe 匹配 Claude 模型 ID 中的版本号部分
// 支持 claude-{tier}-{major}-{minor} 和 claude-{tier}-{major}.{minor} 格式
var claudeVersionRe = regexp.MustCompile(`claude-(?:haiku|sonnet|opus)-(\d+)[-.](\d+)`)
// isBedrockClaude45OrNewer 判断 Bedrock 模型 ID 是否为 Claude 4.5 或更新版本
// Claude 4.5+ 支持 cache_control 中的 ttl 字段("5m" 和 "1h"
func isBedrockClaude45OrNewer(modelID string) bool {
lower := strings.ToLower(modelID)
matches := claudeVersionRe.FindStringSubmatch(lower)
if matches == nil {
return false
}
major, _ := strconv.Atoi(matches[1])
minor, _ := strconv.Atoi(matches[2])
return major > 4 || (major == 4 && minor >= 5)
}
// sanitizeBedrockCacheControl 清理 system 和 messages 中 cache_control 里
// Bedrock 不支持的字段:
// - scopeBedrock 不支持(如 "global" 跨请求缓存)
// - ttl仅 Claude 4.5+ 支持 "5m" 和 "1h",旧模型需要移除
func sanitizeBedrockCacheControl(body []byte, modelID string) []byte {
isClaude45 := isBedrockClaude45OrNewer(modelID)
// 清理 system 数组中的 cache_control
systemArr := gjson.GetBytes(body, "system")
if systemArr.Exists() && systemArr.IsArray() {
for i, item := range systemArr.Array() {
if !item.IsObject() {
continue
}
cc := item.Get("cache_control")
if !cc.Exists() || !cc.IsObject() {
continue
}
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("system.%d.cache_control", i), cc, isClaude45)
}
}
// 清理 messages 中的 cache_control
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
for mi, msg := range messages.Array() {
if !msg.IsObject() {
continue
}
content := msg.Get("content")
if !content.Exists() || !content.IsArray() {
continue
}
for ci, block := range content.Array() {
if !block.IsObject() {
continue
}
cc := block.Get("cache_control")
if !cc.Exists() || !cc.IsObject() {
continue
}
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("messages.%d.content.%d.cache_control", mi, ci), cc, isClaude45)
}
}
return body
}
// deleteCacheControlUnsupportedFields 删除给定 cache_control 路径下 Bedrock 不支持的字段
func deleteCacheControlUnsupportedFields(body []byte, basePath string, cc gjson.Result, isClaude45 bool) []byte {
// Bedrock 不支持 scope如 "global"
if cc.Get("scope").Exists() {
body, _ = sjson.DeleteBytes(body, basePath+".scope")
}
// ttl仅 Claude 4.5+ 支持 "5m" 和 "1h",其余情况移除
ttl := cc.Get("ttl")
if ttl.Exists() {
shouldRemove := true
if isClaude45 {
v := ttl.String()
if v == "5m" || v == "1h" {
shouldRemove = false
}
}
if shouldRemove {
body, _ = sjson.DeleteBytes(body, basePath+".ttl")
}
}
return body
}
// parseAnthropicBetaHeader 解析 anthropic-beta 头的逗号分隔字符串为 token 列表
func parseAnthropicBetaHeader(header string) []string {
header = strings.TrimSpace(header)
if header == "" {
return nil
}
if strings.HasPrefix(header, "[") && strings.HasSuffix(header, "]") {
var parsed []any
if err := json.Unmarshal([]byte(header), &parsed); err == nil {
tokens := make([]string, 0, len(parsed))
for _, item := range parsed {
token := strings.TrimSpace(fmt.Sprint(item))
if token != "" {
tokens = append(tokens, token)
}
}
return tokens
}
}
var tokens []string
for _, part := range strings.Split(header, ",") {
t := strings.TrimSpace(part)
if t != "" {
tokens = append(tokens, t)
}
}
return tokens
}
// bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单
// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json)
// 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单
var bedrockSupportedBetaTokens = map[string]bool{
"computer-use-2025-01-24": true,
"computer-use-2025-11-24": true,
"context-1m-2025-08-07": true,
"context-management-2025-06-27": true,
"compact-2026-01-12": true,
"interleaved-thinking-2025-05-14": true,
"tool-search-tool-2025-10-19": true,
"tool-examples-2025-10-29": true,
}
// bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则
// Anthropic 直接 API 使用通用头Bedrock Invoke 需要特定的替代头
var bedrockBetaTokenTransforms = map[string]string{
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
}
// autoInjectBedrockBetaTokens 根据请求体内容自动补齐必要的 beta token
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() 和
// AmazonAnthropicClaudeMessagesConfig._get_tool_search_beta_header_for_bedrock()
//
// 客户端(特别是非 Claude Code 客户端)可能只在 body 中启用了功能而不在 header 中带对应 beta token
// 这里通过检测请求体特征自动补齐,确保 Bedrock Invoke 不会因缺少必要 beta 头而 400。
func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) []string {
seen := make(map[string]bool, len(tokens))
for _, t := range tokens {
seen[t] = true
}
inject := func(token string) {
if !seen[token] {
tokens = append(tokens, token)
seen[token] = true
}
}
// 检测 thinking / interleaved thinking
// 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta
if gjson.GetBytes(body, "thinking").Exists() {
inject("interleaved-thinking-2025-05-14")
}
// 检测 computer_use 工具
// tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta
tools := gjson.GetBytes(body, "tools")
if tools.Exists() && tools.IsArray() {
toolSearchUsed := false
programmaticToolCallingUsed := false
inputExamplesUsed := false
for _, tool := range tools.Array() {
toolType := tool.Get("type").String()
if strings.HasPrefix(toolType, "computer_20") {
inject("computer-use-2025-11-24")
}
if isBedrockToolSearchType(toolType) {
toolSearchUsed = true
}
if hasCodeExecutionAllowedCallers(tool) {
programmaticToolCallingUsed = true
}
if hasInputExamples(tool) {
inputExamplesUsed = true
}
}
if programmaticToolCallingUsed || inputExamplesUsed {
// programmatic tool calling 和 input examples 需要 advanced-tool-use
// 后续 filterBedrockBetaTokens 会将其转换为 Bedrock 特定的 tool-search-tool
inject("advanced-tool-use-2025-11-20")
}
if toolSearchUsed && bedrockModelSupportsToolSearch(modelID) {
// 纯 tool search无 programmatic/inputExamples时直接注入 Bedrock 特定头,
// 跳过 advanced-tool-use → tool-search-tool 的转换步骤(与 litellm 对齐)
if !programmaticToolCallingUsed && !inputExamplesUsed {
inject("tool-search-tool-2025-10-19")
} else {
inject("advanced-tool-use-2025-11-20")
}
}
}
return tokens
}
func isBedrockToolSearchType(toolType string) bool {
return toolType == "tool_search_tool_regex_20251119" || toolType == "tool_search_tool_bm25_20251119"
}
func hasCodeExecutionAllowedCallers(tool gjson.Result) bool {
allowedCallers := tool.Get("allowed_callers")
if containsStringInJSONArray(allowedCallers, "code_execution_20250825") {
return true
}
return containsStringInJSONArray(tool.Get("function.allowed_callers"), "code_execution_20250825")
}
func hasInputExamples(tool gjson.Result) bool {
if arr := tool.Get("input_examples"); arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 {
return true
}
arr := tool.Get("function.input_examples")
return arr.Exists() && arr.IsArray() && len(arr.Array()) > 0
}
func containsStringInJSONArray(result gjson.Result, target string) bool {
if !result.Exists() || !result.IsArray() {
return false
}
for _, item := range result.Array() {
if item.String() == target {
return true
}
}
return false
}
// bedrockModelSupportsToolSearch 判断 Bedrock 模型是否支持 tool search
// 目前仅 Claude Opus/Sonnet 4.5+ 支持Haiku 不支持
func bedrockModelSupportsToolSearch(modelID string) bool {
lower := strings.ToLower(modelID)
matches := claudeVersionRe.FindStringSubmatch(lower)
if matches == nil {
return false
}
// Haiku 不支持 tool search
if strings.Contains(lower, "haiku") {
return false
}
major, _ := strconv.Atoi(matches[1])
minor, _ := strconv.Atoi(matches[2])
return major > 4 || (major == 4 && minor >= 5)
}
// filterBedrockBetaTokens 过滤并转换 beta token 列表,仅保留 Bedrock Invoke 支持的 token
// 1. 应用转换规则(如 advanced-tool-use → tool-search-tool
// 2. 过滤掉 Bedrock 不支持的 token如 output-128k, files-api, structured-outputs 等)
// 3. 自动关联 tool-examples当 tool-search-tool 存在时)
func filterBedrockBetaTokens(tokens []string) []string {
seen := make(map[string]bool, len(tokens))
var result []string
for _, t := range tokens {
// 应用转换规则
if replacement, ok := bedrockBetaTokenTransforms[t]; ok {
t = replacement
}
// 只保留白名单中的 token且去重
if bedrockSupportedBetaTokens[t] && !seen[t] {
result = append(result, t)
seen[t] = true
}
}
// 自动关联: tool-search-tool 存在时,确保 tool-examples 也存在
if seen["tool-search-tool-2025-10-19"] && !seen["tool-examples-2025-10-29"] {
result = append(result, "tool-examples-2025-10-29")
}
return result
}

View File

@@ -1,659 +0,0 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestPrepareBedrockRequestBody_BasicFields(t *testing.T) {
input := `{"model":"claude-opus-4-6","stream":true,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
// anthropic_version 应被注入
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
// model 和 stream 应被移除
assert.False(t, gjson.GetBytes(result, "model").Exists())
assert.False(t, gjson.GetBytes(result, "stream").Exists())
// max_tokens 应保留
assert.Equal(t, int64(1024), gjson.GetBytes(result, "max_tokens").Int())
}
func TestPrepareBedrockRequestBody_OutputFormatInlineSchema(t *testing.T) {
t.Run("schema inlined into last user message array content", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}},"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
// schema 应内联到最后一条 user message 的 content 数组末尾
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
require.Len(t, contentArr, 2)
assert.Equal(t, "text", contentArr[1].Get("type").String())
assert.Contains(t, contentArr[1].Get("text").String(), `"name":"string"`)
})
t.Run("schema inlined into string content", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"result":"number"}},"messages":[{"role":"user","content":"compute this"}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
require.Len(t, contentArr, 2)
assert.Equal(t, "compute this", contentArr[0].Get("text").String())
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"number"`)
})
t.Run("no schema field just removes output_format", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json"},"messages":[{"role":"user","content":"hi"}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
})
t.Run("no messages just removes output_format", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
})
}
func TestPrepareBedrockRequestBody_RemoveOutputConfig(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_config":{"max_tokens":100},"messages":[]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
}
func TestRemoveCustomFieldFromTools(t *testing.T) {
input := `{
"tools": [
{"name":"tool1","custom":{"defer_loading":true},"description":"desc1"},
{"name":"tool2","description":"desc2"},
{"name":"tool3","custom":{"defer_loading":true,"other":123},"description":"desc3"}
]
}`
result := removeCustomFieldFromTools([]byte(input))
tools := gjson.GetBytes(result, "tools").Array()
require.Len(t, tools, 3)
// custom 应被移除
assert.False(t, tools[0].Get("custom").Exists())
// name/description 应保留
assert.Equal(t, "tool1", tools[0].Get("name").String())
assert.Equal(t, "desc1", tools[0].Get("description").String())
// 没有 custom 的工具不受影响
assert.Equal(t, "tool2", tools[1].Get("name").String())
// 第三个工具的 custom 也应被移除
assert.False(t, tools[2].Get("custom").Exists())
assert.Equal(t, "tool3", tools[2].Get("name").String())
}
func TestRemoveCustomFieldFromTools_NoTools(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}]}`
result := removeCustomFieldFromTools([]byte(input))
// 无 tools 时不改变原始数据
assert.JSONEq(t, input, string(result))
}
func TestSanitizeBedrockCacheControl_RemoveScope(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","scope":"global"}}],
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"global"}}]}]
}`
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
// scope 应被移除
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
assert.False(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.scope").Exists())
// type 应保留
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "messages.0.content.0.cache_control.type").String())
}
func TestSanitizeBedrockCacheControl_TTL_OldModel(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
}`
// 旧模型Claude 3.5)不支持 ttl
result := sanitizeBedrockCacheControl([]byte(input), "anthropic.claude-3-5-sonnet-20241022-v2:0")
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
}
func TestSanitizeBedrockCacheControl_TTL_Claude45_Supported(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
}`
// Claude 4.5+ 支持 "5m" 和 "1h"
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
assert.True(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
}
func TestSanitizeBedrockCacheControl_TTL_Claude45_UnsupportedValue(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"10m"}}]
}`
// Claude 4.5 不支持 "10m"
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
}
func TestSanitizeBedrockCacheControl_TTL_Claude46(t *testing.T) {
input := `{
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
}`
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
assert.True(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").Exists())
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
}
func TestSanitizeBedrockCacheControl_NoCacheControl(t *testing.T) {
input := `{"system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
// 无 cache_control 时不改变原始数据
assert.JSONEq(t, input, string(result))
}
func TestIsBedrockClaude45OrNewer(t *testing.T) {
tests := []struct {
modelID string
expect bool
}{
{"us.anthropic.claude-opus-4-6-v1", true},
{"us.anthropic.claude-sonnet-4-6", true},
{"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true},
{"us.anthropic.claude-opus-4-5-20251101-v1:0", true},
{"us.anthropic.claude-haiku-4-5-20251001-v1:0", true},
{"anthropic.claude-3-5-sonnet-20241022-v2:0", false},
{"anthropic.claude-3-opus-20240229-v1:0", false},
{"anthropic.claude-3-haiku-20240307-v1:0", false},
// 未来版本应自动支持
{"us.anthropic.claude-sonnet-5-0-v1", true},
{"us.anthropic.claude-opus-4-7-v1", true},
// 旧版本
{"anthropic.claude-opus-4-1-v1", false},
{"anthropic.claude-sonnet-4-0-v1", false},
// 非 Claude 模型
{"amazon.nova-pro-v1", false},
{"meta.llama3-70b", false},
}
for _, tt := range tests {
t.Run(tt.modelID, func(t *testing.T) {
assert.Equal(t, tt.expect, isBedrockClaude45OrNewer(tt.modelID))
})
}
}
func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) {
// 模拟一个完整的 Claude Code 请求
input := `{
"model": "claude-opus-4-6",
"stream": true,
"max_tokens": 16384,
"output_format": {"type": "json", "schema": {"result": "string"}},
"output_config": {"max_tokens": 100},
"system": [{"type": "text", "text": "You are helpful", "cache_control": {"type": "ephemeral", "scope": "global", "ttl": "5m"}}],
"messages": [
{"role": "user", "content": [{"type": "text", "text": "hello", "cache_control": {"type": "ephemeral", "ttl": "1h"}}]}
],
"tools": [
{"name": "bash", "description": "Run bash", "custom": {"defer_loading": true}, "input_schema": {"type": "object"}},
{"name": "read", "description": "Read file", "input_schema": {"type": "object"}}
]
}`
betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12"
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader)
require.NoError(t, err)
// 基本字段
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
assert.False(t, gjson.GetBytes(result, "model").Exists())
assert.False(t, gjson.GetBytes(result, "stream").Exists())
assert.Equal(t, int64(16384), gjson.GetBytes(result, "max_tokens").Int())
// anthropic_beta 应包含所有 beta tokens
betaArr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, betaArr, 3)
assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String())
assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String())
assert.Equal(t, "compact-2026-01-12", betaArr[2].String())
// output_format 应被移除schema 内联到最后一条 user message
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
// content 数组:原始 text block + 内联 schema block
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
require.Len(t, contentArr, 2)
assert.Equal(t, "hello", contentArr[0].Get("text").String())
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"string"`)
// tools 中的 custom 应被移除
assert.False(t, gjson.GetBytes(result, "tools.0.custom").Exists())
assert.Equal(t, "bash", gjson.GetBytes(result, "tools.0.name").String())
assert.Equal(t, "read", gjson.GetBytes(result, "tools.1.name").String())
// cache_control: scope 应被移除ttl 在 Claude 4.6 上保留合法值
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
}
func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
t.Run("empty beta header", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
})
t.Run("single beta token", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 1)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
})
t.Run("multiple beta tokens with spaces", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
})
t.Run("json array beta header", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`)
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
})
}
func TestParseAnthropicBetaHeader(t *testing.T) {
assert.Nil(t, parseAnthropicBetaHeader(""))
assert.Equal(t, []string{"a"}, parseAnthropicBetaHeader("a"))
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a,b"))
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a , b "))
assert.Equal(t, []string{"a", "b", "c"}, parseAnthropicBetaHeader("a,b,c"))
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader(`["a","b"]`))
}
func TestFilterBedrockBetaTokens(t *testing.T) {
t.Run("supported tokens pass through", func(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, tokens, result)
})
t.Run("unsupported tokens are filtered out", func(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
})
t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) {
tokens := []string{"advanced-tool-use-2025-11-20"}
result := filterBedrockBetaTokens(tokens)
assert.Contains(t, result, "tool-search-tool-2025-10-19")
// tool-examples 自动关联
assert.Contains(t, result, "tool-examples-2025-10-29")
})
t.Run("tool-search-tool auto-associates tool-examples", func(t *testing.T) {
tokens := []string{"tool-search-tool-2025-10-19"}
result := filterBedrockBetaTokens(tokens)
assert.Contains(t, result, "tool-search-tool-2025-10-19")
assert.Contains(t, result, "tool-examples-2025-10-29")
})
t.Run("no duplication when tool-examples already present", func(t *testing.T) {
tokens := []string{"tool-search-tool-2025-10-19", "tool-examples-2025-10-29"}
result := filterBedrockBetaTokens(tokens)
count := 0
for _, t := range result {
if t == "tool-examples-2025-10-29" {
count++
}
}
assert.Equal(t, 1, count)
})
t.Run("empty input returns nil", func(t *testing.T) {
result := filterBedrockBetaTokens(nil)
assert.Nil(t, result)
})
t.Run("all unsupported returns nil", func(t *testing.T) {
result := filterBedrockBetaTokens([]string{"output-128k-2025-02-19", "effort-2025-11-24"})
assert.Nil(t, result)
})
t.Run("duplicate tokens are deduplicated", func(t *testing.T) {
tokens := []string{"context-1m-2025-08-07", "context-1m-2025-08-07"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
})
}
func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
t.Run("unsupported beta tokens are filtered", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
"interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 1)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
})
t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
"advanced-tool-use-2025-11-20")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "tool-search-tool-2025-10-19", arr[0].String())
assert.Equal(t, "tool-examples-2025-10-29", arr[1].String())
})
}
func TestBedrockCrossRegionPrefix(t *testing.T) {
tests := []struct {
region string
expect string
}{
// US regions
{"us-east-1", "us"},
{"us-east-2", "us"},
{"us-west-1", "us"},
{"us-west-2", "us"},
// GovCloud
{"us-gov-east-1", "us-gov"},
{"us-gov-west-1", "us-gov"},
// EU regions
{"eu-west-1", "eu"},
{"eu-west-2", "eu"},
{"eu-west-3", "eu"},
{"eu-central-1", "eu"},
{"eu-central-2", "eu"},
{"eu-north-1", "eu"},
{"eu-south-1", "eu"},
// APAC regions
{"ap-northeast-1", "jp"},
{"ap-northeast-2", "apac"},
{"ap-southeast-1", "apac"},
{"ap-southeast-2", "au"},
{"ap-south-1", "apac"},
// Canada / South America fallback to us
{"ca-central-1", "us"},
{"sa-east-1", "us"},
// Unknown defaults to us
{"me-south-1", "us"},
}
for _, tt := range tests {
t.Run(tt.region, func(t *testing.T) {
assert.Equal(t, tt.expect, BedrockCrossRegionPrefix(tt.region))
})
}
}
func TestResolveBedrockModelID(t *testing.T) {
t.Run("default alias resolves and adjusts region", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "eu-west-1",
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-5")
require.True(t, ok)
assert.Equal(t, "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", modelID)
})
t.Run("custom alias mapping reuses default bedrock mapping", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "ap-southeast-2",
"model_mapping": map[string]any{
"claude-*": "claude-opus-4-6",
},
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-opus-4-6-thinking")
require.True(t, ok)
assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID)
})
t.Run("force global rewrites anthropic regional model id", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
"aws_force_global": "true",
"model_mapping": map[string]any{
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
},
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-6")
require.True(t, ok)
assert.Equal(t, "global.anthropic.claude-sonnet-4-6", modelID)
})
t.Run("direct bedrock model id passes through", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
},
}
modelID, ok := ResolveBedrockModelID(account, "anthropic.claude-haiku-4-5-20251001-v1:0")
require.True(t, ok)
assert.Equal(t, "anthropic.claude-haiku-4-5-20251001-v1:0", modelID)
})
t.Run("unsupported alias returns false", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
},
}
_, ok := ResolveBedrockModelID(account, "claude-3-5-sonnet-20241022")
assert.False(t, ok)
})
}
func TestAutoInjectBedrockBetaTokens(t *testing.T) {
t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
})
t.Run("no duplicate when already present", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1")
count := 0
for _, t := range result {
if t == "interleaved-thinking-2025-05-14" {
count++
}
}
assert.Equal(t, 1, count)
})
t.Run("inject computer-use when computer tool present", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"computer_20250124","name":"computer","display_width_px":1024}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "computer-use-2025-11-24")
})
t.Run("inject advanced-tool-use for programmatic tool calling", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("inject advanced-tool-use for input examples", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","input_examples":[{"cmd":"ls"}]}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("inject tool-search-tool directly for pure tool search (no programmatic/inputExamples)", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
// 纯 tool search 场景直接注入 Bedrock 特定头,不走 advanced-tool-use 转换
assert.Contains(t, result, "tool-search-tool-2025-10-19")
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("inject advanced-tool-use when tool search combined with programmatic calling", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"},{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
// 混合场景使用 advanced-tool-use后续由 filter 转换为 tool-search-tool
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("do not inject tool-search beta for unsupported models", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "anthropic.claude-3-5-sonnet-20241022-v2:0")
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
assert.NotContains(t, result, "tool-search-tool-2025-10-19")
})
t.Run("no injection for regular tools", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","description":"run bash","input_schema":{"type":"object"}}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Empty(t, result)
})
t.Run("no injection when no features detected", func(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Empty(t, result)
})
t.Run("preserves existing tokens", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled"},"messages":[{"role":"user","content":"hi"}]}`)
existing := []string{"context-1m-2025-08-07", "compact-2026-01-12"}
result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "context-1m-2025-08-07")
assert.Contains(t, result, "compact-2026-01-12")
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
})
}
func TestResolveBedrockBetaTokens(t *testing.T) {
t.Run("body-only tool features resolve to final bedrock tokens", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
result := ResolveBedrockBetaTokens("", body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "tool-search-tool-2025-10-19")
assert.Contains(t, result, "tool-examples-2025-10-29")
})
t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1")
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
})
}
func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) {
t.Run("thinking in body auto-injects beta without header", func(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
found := false
for _, v := range arr {
if v.String() == "interleaved-thinking-2025-05-14" {
found = true
}
}
assert.True(t, found, "interleaved-thinking should be auto-injected")
})
t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
names := make([]string, len(arr))
for i, v := range arr {
names[i] = v.String()
}
assert.Contains(t, names, "context-1m-2025-08-07")
assert.Contains(t, names, "interleaved-thinking-2025-05-14")
})
}
func TestAdjustBedrockModelRegionPrefix(t *testing.T) {
tests := []struct {
name string
modelID string
region string
expect string
}{
// US region — no change needed
{"us region keeps us prefix", "us.anthropic.claude-opus-4-6-v1", "us-east-1", "us.anthropic.claude-opus-4-6-v1"},
// EU region — replace us → eu
{"eu region replaces prefix", "us.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
{"eu region sonnet", "us.anthropic.claude-sonnet-4-6", "eu-central-1", "eu.anthropic.claude-sonnet-4-6"},
// APAC region — jp and au have dedicated prefixes per AWS docs
{"jp region (ap-northeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-northeast-1", "jp.anthropic.claude-sonnet-4-5-20250929-v1:0"},
{"au region (ap-southeast-2)", "us.anthropic.claude-haiku-4-5-20251001-v1:0", "ap-southeast-2", "au.anthropic.claude-haiku-4-5-20251001-v1:0"},
{"apac region (ap-southeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-southeast-1", "apac.anthropic.claude-sonnet-4-5-20250929-v1:0"},
// eu → us (user manually set eu prefix, moved to us region)
{"eu to us", "eu.anthropic.claude-opus-4-6-v1", "us-west-2", "us.anthropic.claude-opus-4-6-v1"},
// global prefix — replace to match region
{"global to eu", "global.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
// No known prefix — leave unchanged
{"no prefix unchanged", "anthropic.claude-3-5-sonnet-20241022-v2:0", "eu-west-1", "anthropic.claude-3-5-sonnet-20241022-v2:0"},
// GovCloud — uses independent us-gov prefix
{"govcloud from us", "us.anthropic.claude-opus-4-6-v1", "us-gov-east-1", "us-gov.anthropic.claude-opus-4-6-v1"},
{"govcloud already correct", "us-gov.anthropic.claude-opus-4-6-v1", "us-gov-west-1", "us-gov.anthropic.claude-opus-4-6-v1"},
// Force global (special region value)
{"force global from us", "us.anthropic.claude-opus-4-6-v1", "global", "global.anthropic.claude-opus-4-6-v1"},
{"force global from eu", "eu.anthropic.claude-sonnet-4-6", "global", "global.anthropic.claude-sonnet-4-6"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expect, AdjustBedrockModelRegionPrefix(tt.modelID, tt.region))
})
}
}

View File

@@ -1,67 +0,0 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
)
// BedrockSigner 使用 AWS SigV4 对 Bedrock 请求签名
type BedrockSigner struct {
credentials aws.Credentials
region string
signer *v4.Signer
}
// NewBedrockSigner 创建 BedrockSigner
func NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region string) *BedrockSigner {
return &BedrockSigner{
credentials: aws.Credentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
},
region: region,
signer: v4.NewSigner(),
}
}
// NewBedrockSignerFromAccount 从 Account 凭证创建 BedrockSigner
func NewBedrockSignerFromAccount(account *Account) (*BedrockSigner, error) {
accessKeyID := account.GetCredential("aws_access_key_id")
if accessKeyID == "" {
return nil, fmt.Errorf("aws_access_key_id not found in credentials")
}
secretAccessKey := account.GetCredential("aws_secret_access_key")
if secretAccessKey == "" {
return nil, fmt.Errorf("aws_secret_access_key not found in credentials")
}
region := account.GetCredential("aws_region")
if region == "" {
region = defaultBedrockRegion
}
sessionToken := account.GetCredential("aws_session_token") // 可选
return NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region), nil
}
// SignRequest 对 HTTP 请求进行 SigV4 签名
// 重要约束调用此方法前req 应只包含 AWS 相关的 header如 Content-Type、Accept
// 非 AWS header如 anthropic-beta会参与签名计算如果 Bedrock 服务端不识别这些 header
// 签名验证可能失败。litellm 通过 _filter_headers_for_aws_signature 实现头过滤,
// 当前实现中 buildUpstreamRequestBedrock 仅设置了 Content-Type 和 Accept因此是安全的。
func (s *BedrockSigner) SignRequest(ctx context.Context, req *http.Request, body []byte) error {
payloadHash := sha256Hash(body)
return s.signer.SignHTTP(ctx, s.credentials, req, payloadHash, "bedrock", s.region, time.Now())
}
func sha256Hash(data []byte) string {
h := sha256.Sum256(data)
return hex.EncodeToString(h[:])
}

View File

@@ -1,35 +0,0 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewBedrockSignerFromAccount_DefaultRegion(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_access_key_id": "test-akid",
"aws_secret_access_key": "test-secret",
},
}
signer, err := NewBedrockSignerFromAccount(account)
require.NoError(t, err)
require.NotNil(t, signer)
assert.Equal(t, defaultBedrockRegion, signer.region)
}
func TestFilterBetaTokens(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "tool-search-tool-2025-10-19"}
filterSet := map[string]struct{}{
"tool-search-tool-2025-10-19": {},
}
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, filterBetaTokens(tokens, filterSet))
assert.Equal(t, tokens, filterBetaTokens(tokens, nil))
assert.Nil(t, filterBetaTokens(nil, filterSet))
}

View File

@@ -1,414 +0,0 @@
package service
import (
"bufio"
"context"
"encoding/base64"
"errors"
"fmt"
"hash/crc32"
"io"
"net/http"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// handleBedrockStreamingResponse 处理 Bedrock InvokeModelWithResponseStream 的 EventStream 响应
// Bedrock 返回 AWS EventStream 二进制格式,每个事件的 payload 中 chunk.bytes 是 base64 编码的
// Claude SSE 事件 JSON。本方法解码后转换为标准 SSE 格式写入客户端。
func (s *GatewayService) handleBedrockStreamingResponse(
ctx context.Context,
resp *http.Response,
c *gin.Context,
account *Account,
startTime time.Time,
model string,
) (*streamingResult, error) {
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if v := resp.Header.Get("x-amzn-requestid"); v != "" {
c.Header("x-request-id", v)
}
usage := &ClaudeUsage{}
var firstTokenMs *int
clientDisconnected := false
// Bedrock EventStream 使用 application/vnd.amazon.eventstream 二进制格式。
// 每个帧结构total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4)
// 但更实用的方式是使用行扫描找 JSON chunks因为 Bedrock 的响应在二进制帧中。
// 我们使用 EventStream decoder 来正确解析。
decoder := newBedrockEventStreamDecoder(resp.Body)
type decodeEvent struct {
payload []byte
err error
}
events := make(chan decodeEvent, 16)
done := make(chan struct{})
sendEvent := func(ev decodeEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt atomic.Int64
lastReadAt.Store(time.Now().UnixNano())
go func() {
defer close(events)
for {
payload, err := decoder.Decode()
if err != nil {
if err == io.EOF {
return
}
_ = sendEvent(decodeEvent{err: err})
return
}
lastReadAt.Store(time.Now().UnixNano())
if !sendEvent(decodeEvent{payload: payload}) {
return
}
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
for {
select {
case ev, ok := <-events:
if !ok {
if !clientDisconnected {
flusher.Flush()
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
}
if ev.err != nil {
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("bedrock stream read error: %w", ev.err)
}
// payload 是 JSON提取 chunk.bytesbase64 编码的 Claude SSE 事件数据)
sseData := extractBedrockChunkData(ev.payload)
if sseData == nil {
continue
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
// 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式
// 同时移除该字段避免透传给客户端
sseData = transformBedrockInvocationMetrics(sseData)
// 解析 SSE 事件数据提取 usage
s.parseSSEUsagePassthrough(string(sseData), usage)
// 确定 SSE event type
eventType := gjson.GetBytes(sseData, "type").String()
// 写入标准 SSE 格式
if !clientDisconnected {
var writeErr error
if eventType != "" {
_, writeErr = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, sseData)
} else {
_, writeErr = fmt.Fprintf(w, "data: %s\n\n", sseData)
}
if writeErr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "[Bedrock] Client disconnected during streaming, continue draining for usage: account=%d", account.ID)
} else {
flusher.Flush()
}
}
case <-intervalCh:
lastRead := time.Unix(0, lastReadAt.Load())
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
logger.LegacyPrintf("service.gateway", "[Bedrock] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
}
// extractBedrockChunkData 从 Bedrock EventStream payload 中提取 Claude SSE 事件数据
// Bedrock payload 格式:{"bytes":"<base64-encoded-json>"}
func extractBedrockChunkData(payload []byte) []byte {
b64 := gjson.GetBytes(payload, "bytes").String()
if b64 == "" {
return nil
}
decoded, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return nil
}
return decoded
}
// transformBedrockInvocationMetrics 将 Bedrock 特有的 amazon-bedrock-invocationMetrics
// 转换为标准 Anthropic usage 格式,并从 SSE 数据中移除该字段。
//
// Bedrock Invoke 返回的 message_delta 事件可能包含:
//
// {"type":"message_delta","delta":{...},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}
//
// 转换为:
//
// {"type":"message_delta","delta":{...},"usage":{"input_tokens":150,"output_tokens":42}}
func transformBedrockInvocationMetrics(data []byte) []byte {
metrics := gjson.GetBytes(data, "amazon-bedrock-invocationMetrics")
if !metrics.Exists() || !metrics.IsObject() {
return data
}
// 移除 Bedrock 特有字段
data, _ = sjson.DeleteBytes(data, "amazon-bedrock-invocationMetrics")
// 如果已有标准 usage 字段,不覆盖
if gjson.GetBytes(data, "usage").Exists() {
return data
}
// 转换 camelCase → snake_case 写入 usage
inputTokens := metrics.Get("inputTokenCount")
outputTokens := metrics.Get("outputTokenCount")
if inputTokens.Exists() {
data, _ = sjson.SetBytes(data, "usage.input_tokens", inputTokens.Int())
}
if outputTokens.Exists() {
data, _ = sjson.SetBytes(data, "usage.output_tokens", outputTokens.Int())
}
return data
}
// bedrockEventStreamDecoder 解码 AWS EventStream 二进制帧
// EventStream 帧格式:
//
// [total_byte_length: 4 bytes]
// [headers_byte_length: 4 bytes]
// [prelude_crc: 4 bytes]
// [headers: variable]
// [payload: variable]
// [message_crc: 4 bytes]
type bedrockEventStreamDecoder struct {
reader *bufio.Reader
}
func newBedrockEventStreamDecoder(r io.Reader) *bedrockEventStreamDecoder {
return &bedrockEventStreamDecoder{
reader: bufio.NewReaderSize(r, 64*1024),
}
}
// Decode 读取下一个 EventStream 帧并返回 chunk 类型事件的 payload
func (d *bedrockEventStreamDecoder) Decode() ([]byte, error) {
for {
// 读取 prelude: total_length(4) + headers_length(4) + prelude_crc(4) = 12 bytes
prelude := make([]byte, 12)
if _, err := io.ReadFull(d.reader, prelude); err != nil {
return nil, err
}
// 验证 prelude CRCAWS EventStream 使用标准 CRC32 / IEEE
preludeCRC := bedrockReadUint32(prelude[8:12])
if crc32.Checksum(prelude[0:8], crc32IEEETable) != preludeCRC {
return nil, fmt.Errorf("eventstream prelude CRC mismatch")
}
totalLength := bedrockReadUint32(prelude[0:4])
headersLength := bedrockReadUint32(prelude[4:8])
if totalLength < 16 { // minimum: 12 prelude + 4 message_crc
return nil, fmt.Errorf("invalid eventstream frame: total_length=%d", totalLength)
}
// 读取 headers + payload + message_crc
remaining := int(totalLength) - 12
if remaining <= 0 {
continue
}
data := make([]byte, remaining)
if _, err := io.ReadFull(d.reader, data); err != nil {
return nil, err
}
// 验证 message CRC覆盖 prelude + headers + payload
messageCRC := bedrockReadUint32(data[len(data)-4:])
h := crc32.New(crc32IEEETable)
_, _ = h.Write(prelude)
_, _ = h.Write(data[:len(data)-4])
if h.Sum32() != messageCRC {
return nil, fmt.Errorf("eventstream message CRC mismatch")
}
// 解析 headers
headers := data[:headersLength]
payload := data[headersLength : len(data)-4] // 去掉 message_crc
// 从 headers 中提取 :event-type
eventType := extractEventStreamHeaderValue(headers, ":event-type")
// 只处理 chunk 事件
if eventType == "chunk" {
// payload 是完整的 JSON包含 bytes 字段
return payload, nil
}
// 检查异常事件
exceptionType := extractEventStreamHeaderValue(headers, ":exception-type")
if exceptionType != "" {
return nil, fmt.Errorf("bedrock exception: %s: %s", exceptionType, string(payload))
}
messageType := extractEventStreamHeaderValue(headers, ":message-type")
if messageType == "exception" || messageType == "error" {
return nil, fmt.Errorf("bedrock error: %s", string(payload))
}
// 跳过其他事件类型(如 initial-response
}
}
// extractEventStreamHeaderValue 从 EventStream headers 二进制数据中提取指定 header 的字符串值
// EventStream header 格式:
//
// [name_length: 1 byte][name: variable][value_type: 1 byte][value: variable]
//
// value_type = 7 表示 string 类型,前 2 bytes 为长度
func extractEventStreamHeaderValue(headers []byte, targetName string) string {
pos := 0
for pos < len(headers) {
if pos >= len(headers) {
break
}
nameLen := int(headers[pos])
pos++
if pos+nameLen > len(headers) {
break
}
name := string(headers[pos : pos+nameLen])
pos += nameLen
if pos >= len(headers) {
break
}
valueType := headers[pos]
pos++
switch valueType {
case 7: // string
if pos+2 > len(headers) {
return ""
}
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
pos += 2
if pos+valueLen > len(headers) {
return ""
}
value := string(headers[pos : pos+valueLen])
pos += valueLen
if name == targetName {
return value
}
case 0: // bool true
if name == targetName {
return "true"
}
case 1: // bool false
if name == targetName {
return "false"
}
case 2: // byte
pos++
if name == targetName {
return ""
}
case 3: // short
pos += 2
if name == targetName {
return ""
}
case 4: // int
pos += 4
if name == targetName {
return ""
}
case 5: // long
pos += 8
if name == targetName {
return ""
}
case 6: // bytes
if pos+2 > len(headers) {
return ""
}
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
pos += 2 + valueLen
case 8: // timestamp
pos += 8
case 9: // uuid
pos += 16
default:
return "" // 未知类型,无法继续解析
}
}
return ""
}
// crc32IEEETable is the CRC32 / IEEE table used by AWS EventStream.
var crc32IEEETable = crc32.MakeTable(crc32.IEEE)
func bedrockReadUint32(b []byte) uint32 {
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
}
func bedrockReadUint16(b []byte) uint16 {
return uint16(b[0])<<8 | uint16(b[1])
}

View File

@@ -1,261 +0,0 @@
package service
import (
"bytes"
"encoding/base64"
"encoding/binary"
"hash/crc32"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestExtractBedrockChunkData(t *testing.T) {
t.Run("valid base64 payload", func(t *testing.T) {
original := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`
b64 := base64.StdEncoding.EncodeToString([]byte(original))
payload := []byte(`{"bytes":"` + b64 + `"}`)
result := extractBedrockChunkData(payload)
require.NotNil(t, result)
assert.JSONEq(t, original, string(result))
})
t.Run("empty bytes field", func(t *testing.T) {
result := extractBedrockChunkData([]byte(`{"bytes":""}`))
assert.Nil(t, result)
})
t.Run("no bytes field", func(t *testing.T) {
result := extractBedrockChunkData([]byte(`{"other":"value"}`))
assert.Nil(t, result)
})
t.Run("invalid base64", func(t *testing.T) {
result := extractBedrockChunkData([]byte(`{"bytes":"not-valid-base64!!!"}`))
assert.Nil(t, result)
})
}
func TestTransformBedrockInvocationMetrics(t *testing.T) {
t.Run("converts metrics to usage", func(t *testing.T) {
input := `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
result := transformBedrockInvocationMetrics([]byte(input))
// amazon-bedrock-invocationMetrics should be removed
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
// usage should be set
assert.Equal(t, int64(150), gjson.GetBytes(result, "usage.input_tokens").Int())
assert.Equal(t, int64(42), gjson.GetBytes(result, "usage.output_tokens").Int())
// original fields preserved
assert.Equal(t, "message_delta", gjson.GetBytes(result, "type").String())
assert.Equal(t, "end_turn", gjson.GetBytes(result, "delta.stop_reason").String())
})
t.Run("no metrics present", func(t *testing.T) {
input := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}`
result := transformBedrockInvocationMetrics([]byte(input))
assert.JSONEq(t, input, string(result))
})
t.Run("does not overwrite existing usage", func(t *testing.T) {
input := `{"type":"message_delta","usage":{"output_tokens":100},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
result := transformBedrockInvocationMetrics([]byte(input))
// metrics removed but existing usage preserved
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
assert.Equal(t, int64(100), gjson.GetBytes(result, "usage.output_tokens").Int())
})
}
func TestExtractEventStreamHeaderValue(t *testing.T) {
// Build a header with :event-type = "chunk" (string type = 7)
buildStringHeader := func(name, value string) []byte {
var buf bytes.Buffer
// name length (1 byte)
_ = buf.WriteByte(byte(len(name)))
// name
_, _ = buf.WriteString(name)
// value type (7 = string)
_ = buf.WriteByte(7)
// value length (2 bytes, big-endian)
_ = binary.Write(&buf, binary.BigEndian, uint16(len(value)))
// value
_, _ = buf.WriteString(value)
return buf.Bytes()
}
t.Run("find string header", func(t *testing.T) {
headers := buildStringHeader(":event-type", "chunk")
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
})
t.Run("header not found", func(t *testing.T) {
headers := buildStringHeader(":event-type", "chunk")
assert.Equal(t, "", extractEventStreamHeaderValue(headers, ":message-type"))
})
t.Run("multiple headers", func(t *testing.T) {
var buf bytes.Buffer
_, _ = buf.Write(buildStringHeader(":content-type", "application/json"))
_, _ = buf.Write(buildStringHeader(":event-type", "chunk"))
_, _ = buf.Write(buildStringHeader(":message-type", "event"))
headers := buf.Bytes()
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
assert.Equal(t, "application/json", extractEventStreamHeaderValue(headers, ":content-type"))
assert.Equal(t, "event", extractEventStreamHeaderValue(headers, ":message-type"))
})
t.Run("empty headers", func(t *testing.T) {
assert.Equal(t, "", extractEventStreamHeaderValue([]byte{}, ":event-type"))
})
}
func TestBedrockEventStreamDecoder(t *testing.T) {
crc32IeeeTab := crc32.MakeTable(crc32.IEEE)
// Build a valid EventStream frame with correct CRC32/IEEE checksums.
buildFrame := func(eventType string, payload []byte) []byte {
// Build headers
var headersBuf bytes.Buffer
// :event-type header
_ = headersBuf.WriteByte(byte(len(":event-type")))
_, _ = headersBuf.WriteString(":event-type")
_ = headersBuf.WriteByte(7) // string type
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len(eventType)))
_, _ = headersBuf.WriteString(eventType)
// :message-type header
_ = headersBuf.WriteByte(byte(len(":message-type")))
_, _ = headersBuf.WriteString(":message-type")
_ = headersBuf.WriteByte(7)
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("event")))
_, _ = headersBuf.WriteString("event")
headers := headersBuf.Bytes()
headersLen := uint32(len(headers))
// total = 12 (prelude) + headers + payload + 4 (message_crc)
totalLen := uint32(12 + len(headers) + len(payload) + 4)
// Prelude: total_length(4) + headers_length(4)
var preludeBuf bytes.Buffer
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
preludeBytes := preludeBuf.Bytes()
preludeCRC := crc32.Checksum(preludeBytes, crc32IeeeTab)
// Build frame: prelude + prelude_crc + headers + payload
var frame bytes.Buffer
_, _ = frame.Write(preludeBytes)
_ = binary.Write(&frame, binary.BigEndian, preludeCRC)
_, _ = frame.Write(headers)
_, _ = frame.Write(payload)
// Message CRC covers everything before itself
messageCRC := crc32.Checksum(frame.Bytes(), crc32IeeeTab)
_ = binary.Write(&frame, binary.BigEndian, messageCRC)
return frame.Bytes()
}
t.Run("decode chunk event", func(t *testing.T) {
payload := []byte(`{"bytes":"dGVzdA=="}`) // base64("test")
frame := buildFrame("chunk", payload)
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
result, err := decoder.Decode()
require.NoError(t, err)
assert.Equal(t, payload, result)
})
t.Run("skip non-chunk events", func(t *testing.T) {
// Write initial-response followed by chunk
var buf bytes.Buffer
_, _ = buf.Write(buildFrame("initial-response", []byte(`{}`)))
chunkPayload := []byte(`{"bytes":"aGVsbG8="}`)
_, _ = buf.Write(buildFrame("chunk", chunkPayload))
decoder := newBedrockEventStreamDecoder(&buf)
result, err := decoder.Decode()
require.NoError(t, err)
assert.Equal(t, chunkPayload, result)
})
t.Run("EOF on empty input", func(t *testing.T) {
decoder := newBedrockEventStreamDecoder(bytes.NewReader(nil))
_, err := decoder.Decode()
assert.Equal(t, io.EOF, err)
})
t.Run("corrupted prelude CRC", func(t *testing.T) {
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
// Corrupt the prelude CRC (bytes 8-11)
frame[8] ^= 0xFF
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
_, err := decoder.Decode()
require.Error(t, err)
assert.Contains(t, err.Error(), "prelude CRC mismatch")
})
t.Run("corrupted message CRC", func(t *testing.T) {
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
// Corrupt the message CRC (last 4 bytes)
frame[len(frame)-1] ^= 0xFF
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
_, err := decoder.Decode()
require.Error(t, err)
assert.Contains(t, err.Error(), "message CRC mismatch")
})
t.Run("castagnoli encoded frame is rejected", func(t *testing.T) {
castagnoliTab := crc32.MakeTable(crc32.Castagnoli)
payload := []byte(`{"bytes":"dGVzdA=="}`)
var headersBuf bytes.Buffer
_ = headersBuf.WriteByte(byte(len(":event-type")))
_, _ = headersBuf.WriteString(":event-type")
_ = headersBuf.WriteByte(7)
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("chunk")))
_, _ = headersBuf.WriteString("chunk")
headers := headersBuf.Bytes()
headersLen := uint32(len(headers))
totalLen := uint32(12 + len(headers) + len(payload) + 4)
var preludeBuf bytes.Buffer
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
preludeBytes := preludeBuf.Bytes()
var frame bytes.Buffer
_, _ = frame.Write(preludeBytes)
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(preludeBytes, castagnoliTab))
_, _ = frame.Write(headers)
_, _ = frame.Write(payload)
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(frame.Bytes(), castagnoliTab))
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame.Bytes()))
_, err := decoder.Decode()
require.Error(t, err)
assert.Contains(t, err.Error(), "prelude CRC mismatch")
})
}
func TestBuildBedrockURL(t *testing.T) {
t.Run("stream URL with colon in model ID", func(t *testing.T) {
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-opus-4-5-20251101-v1:0", true)
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/invoke-with-response-stream", url)
})
t.Run("non-stream URL with colon in model ID", func(t *testing.T) {
url := BuildBedrockURL("eu-west-1", "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", false)
assert.Equal(t, "https://bedrock-runtime.eu-west-1.amazonaws.com/model/eu.anthropic.claude-sonnet-4-5-20250929-v1%3A0/invoke", url)
})
t.Run("model ID without colon", func(t *testing.T) {
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-sonnet-4-6", true)
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-sonnet-4-6/invoke-with-response-stream", url)
})
}

View File

@@ -0,0 +1,450 @@
package service
import (
"encoding/json"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/tidwall/gjson"
)
type claudeMaxCacheBillingOutcome struct {
Simulated bool
}
func applyClaudeMaxCacheBillingPolicyToUsage(usage *ClaudeUsage, parsed *ParsedRequest, group *Group, model string, accountID int64) claudeMaxCacheBillingOutcome {
var out claudeMaxCacheBillingOutcome
if usage == nil || !shouldApplyClaudeMaxBillingRulesForUsage(group, model, parsed) {
return out
}
resolvedModel := strings.TrimSpace(model)
if resolvedModel == "" && parsed != nil {
resolvedModel = strings.TrimSpace(parsed.Model)
}
if hasCacheCreationTokens(*usage) {
// Upstream already returned cache creation usage; keep original usage.
return out
}
if !shouldSimulateClaudeMaxUsageForUsage(*usage, parsed) {
return out
}
beforeInputTokens := usage.InputTokens
out.Simulated = safelyProjectUsageToClaudeMax1H(usage, parsed)
if out.Simulated {
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d",
resolvedModel,
accountID,
beforeInputTokens,
usage.InputTokens,
usage.CacheCreation1hTokens,
)
}
return out
}
func isClaudeFamilyModel(model string) bool {
normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model)))
if normalized == "" {
return false
}
return strings.Contains(normalized, "claude-")
}
func shouldApplyClaudeMaxBillingRules(input *RecordUsageInput) bool {
if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil {
return false
}
return shouldApplyClaudeMaxBillingRulesForUsage(input.APIKey.Group, input.Result.Model, input.ParsedRequest)
}
func shouldApplyClaudeMaxBillingRulesForUsage(group *Group, model string, parsed *ParsedRequest) bool {
if group == nil {
return false
}
if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic {
return false
}
resolvedModel := model
if resolvedModel == "" && parsed != nil {
resolvedModel = parsed.Model
}
if !isClaudeFamilyModel(resolvedModel) {
return false
}
return true
}
func hasCacheCreationTokens(usage ClaudeUsage) bool {
return usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0
}
func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool {
if input == nil || input.Result == nil {
return false
}
if !shouldApplyClaudeMaxBillingRules(input) {
return false
}
return shouldSimulateClaudeMaxUsageForUsage(input.Result.Usage, input.ParsedRequest)
}
func shouldSimulateClaudeMaxUsageForUsage(usage ClaudeUsage, parsed *ParsedRequest) bool {
if usage.InputTokens <= 0 {
return false
}
if hasCacheCreationTokens(usage) {
return false
}
if !hasClaudeCacheSignals(parsed) {
return false
}
return true
}
func safelyProjectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) (changed bool) {
defer func() {
if r := recover(); r != nil {
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage skipped: panic=%v", r)
changed = false
}
}()
return projectUsageToClaudeMax1H(usage, parsed)
}
func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool {
if usage == nil {
return false
}
totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
if totalWindowTokens <= 1 {
return false
}
simulatedInputTokens := computeClaudeMaxProjectedInputTokens(totalWindowTokens, parsed)
if simulatedInputTokens <= 0 {
simulatedInputTokens = 1
}
if simulatedInputTokens >= totalWindowTokens {
simulatedInputTokens = totalWindowTokens - 1
}
cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens
if usage.InputTokens == simulatedInputTokens &&
usage.CacheCreation5mTokens == 0 &&
usage.CacheCreation1hTokens == cacheCreation1hTokens &&
usage.CacheCreationInputTokens == cacheCreation1hTokens {
return false
}
usage.InputTokens = simulatedInputTokens
usage.CacheCreation5mTokens = 0
usage.CacheCreation1hTokens = cacheCreation1hTokens
usage.CacheCreationInputTokens = cacheCreation1hTokens
return true
}
type claudeCacheProjection struct {
HasBreakpoint bool
BreakpointCount int
TotalEstimatedTokens int
TailEstimatedTokens int
}
func computeClaudeMaxProjectedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int {
if totalWindowTokens <= 1 {
return totalWindowTokens
}
projection := analyzeClaudeCacheProjection(parsed)
if !projection.HasBreakpoint || projection.TotalEstimatedTokens <= 0 || projection.TailEstimatedTokens <= 0 {
return totalWindowTokens
}
totalEstimate := int64(projection.TotalEstimatedTokens)
tailEstimate := int64(projection.TailEstimatedTokens)
if tailEstimate > totalEstimate {
tailEstimate = totalEstimate
}
scaled := (int64(totalWindowTokens)*tailEstimate + totalEstimate/2) / totalEstimate
if scaled <= 0 {
scaled = 1
}
if scaled >= int64(totalWindowTokens) {
scaled = int64(totalWindowTokens - 1)
}
return int(scaled)
}
func hasClaudeCacheSignals(parsed *ParsedRequest) bool {
if parsed == nil {
return false
}
if hasTopLevelEphemeralCacheControl(parsed) {
return true
}
return countExplicitCacheBreakpoints(parsed) > 0
}
func hasTopLevelEphemeralCacheControl(parsed *ParsedRequest) bool {
if parsed == nil || len(parsed.Body) == 0 {
return false
}
cacheType := strings.TrimSpace(gjson.GetBytes(parsed.Body, "cache_control.type").String())
return strings.EqualFold(cacheType, "ephemeral")
}
func analyzeClaudeCacheProjection(parsed *ParsedRequest) claudeCacheProjection {
var projection claudeCacheProjection
if parsed == nil {
return projection
}
total := 0
lastBreakpointAt := -1
switch system := parsed.System.(type) {
case string:
total += claudeMaxMessageOverheadTokens + estimateClaudeTextTokens(system)
case []any:
for _, raw := range system {
block, ok := raw.(map[string]any)
if !ok {
total += claudeMaxUnknownContentTokens
continue
}
total += estimateClaudeBlockTokens(block)
if hasEphemeralCacheControl(block) {
lastBreakpointAt = total
projection.BreakpointCount++
projection.HasBreakpoint = true
}
}
}
for _, rawMsg := range parsed.Messages {
total += claudeMaxMessageOverheadTokens
msg, ok := rawMsg.(map[string]any)
if !ok {
total += claudeMaxUnknownContentTokens
continue
}
content, exists := msg["content"]
if !exists {
continue
}
msgTokens, msgLastBreak, msgBreakCount := estimateClaudeContentTokens(content)
total += msgTokens
if msgBreakCount > 0 {
lastBreakpointAt = total - msgTokens + msgLastBreak
projection.BreakpointCount += msgBreakCount
projection.HasBreakpoint = true
}
}
if total <= 0 {
total = 1
}
projection.TotalEstimatedTokens = total
if projection.HasBreakpoint && lastBreakpointAt >= 0 {
tail := total - lastBreakpointAt
if tail <= 0 {
tail = 1
}
projection.TailEstimatedTokens = tail
return projection
}
if hasTopLevelEphemeralCacheControl(parsed) {
tail := estimateLastUserMessageTokens(parsed)
if tail <= 0 {
tail = 1
}
projection.HasBreakpoint = true
projection.BreakpointCount = 1
projection.TailEstimatedTokens = tail
}
return projection
}
func countExplicitCacheBreakpoints(parsed *ParsedRequest) int {
if parsed == nil {
return 0
}
total := 0
if system, ok := parsed.System.([]any); ok {
for _, raw := range system {
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
total++
}
}
}
for _, rawMsg := range parsed.Messages {
msg, ok := rawMsg.(map[string]any)
if !ok {
continue
}
content, ok := msg["content"].([]any)
if !ok {
continue
}
for _, raw := range content {
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
total++
}
}
}
return total
}
func hasEphemeralCacheControl(block map[string]any) bool {
if block == nil {
return false
}
raw, ok := block["cache_control"]
if !ok || raw == nil {
return false
}
switch cc := raw.(type) {
case map[string]any:
cacheType, _ := cc["type"].(string)
return strings.EqualFold(strings.TrimSpace(cacheType), "ephemeral")
case map[string]string:
return strings.EqualFold(strings.TrimSpace(cc["type"]), "ephemeral")
default:
return false
}
}
func estimateClaudeContentTokens(content any) (tokens int, lastBreakAt int, breakpointCount int) {
switch value := content.(type) {
case string:
return estimateClaudeTextTokens(value), -1, 0
case []any:
total := 0
lastBreak := -1
breaks := 0
for _, raw := range value {
block, ok := raw.(map[string]any)
if !ok {
total += claudeMaxUnknownContentTokens
continue
}
total += estimateClaudeBlockTokens(block)
if hasEphemeralCacheControl(block) {
lastBreak = total
breaks++
}
}
return total, lastBreak, breaks
default:
return estimateStructuredTokens(value), -1, 0
}
}
func estimateClaudeBlockTokens(block map[string]any) int {
if block == nil {
return claudeMaxUnknownContentTokens
}
tokens := claudeMaxBlockOverheadTokens
blockType, _ := block["type"].(string)
switch blockType {
case "text":
if text, ok := block["text"].(string); ok {
tokens += estimateClaudeTextTokens(text)
}
case "tool_result":
if content, ok := block["content"]; ok {
nested, _, _ := estimateClaudeContentTokens(content)
tokens += nested
}
case "tool_use":
if name, ok := block["name"].(string); ok {
tokens += estimateClaudeTextTokens(name)
}
if input, ok := block["input"]; ok {
tokens += estimateStructuredTokens(input)
}
default:
if text, ok := block["text"].(string); ok {
tokens += estimateClaudeTextTokens(text)
} else if content, ok := block["content"]; ok {
nested, _, _ := estimateClaudeContentTokens(content)
tokens += nested
}
}
if tokens <= claudeMaxBlockOverheadTokens {
tokens += claudeMaxUnknownContentTokens
}
return tokens
}
func estimateLastUserMessageTokens(parsed *ParsedRequest) int {
if parsed == nil || len(parsed.Messages) == 0 {
return 0
}
for i := len(parsed.Messages) - 1; i >= 0; i-- {
msg, ok := parsed.Messages[i].(map[string]any)
if !ok {
continue
}
role, _ := msg["role"].(string)
if !strings.EqualFold(strings.TrimSpace(role), "user") {
continue
}
tokens, _, _ := estimateClaudeContentTokens(msg["content"])
return claudeMaxMessageOverheadTokens + tokens
}
return 0
}
func estimateStructuredTokens(v any) int {
if v == nil {
return 0
}
raw, err := json.Marshal(v)
if err != nil {
return claudeMaxUnknownContentTokens
}
return estimateClaudeTextTokens(string(raw))
}
func estimateClaudeTextTokens(text string) int {
if tokens, ok := estimateTokensByThirdPartyTokenizer(text); ok {
return tokens
}
return estimateClaudeTextTokensHeuristic(text)
}
func estimateClaudeTextTokensHeuristic(text string) int {
normalized := strings.Join(strings.Fields(strings.TrimSpace(text)), " ")
if normalized == "" {
return 0
}
asciiChars := 0
nonASCIIChars := 0
for _, r := range normalized {
if r <= 127 {
asciiChars++
} else {
nonASCIIChars++
}
}
tokens := nonASCIIChars
if asciiChars > 0 {
tokens += (asciiChars + 3) / 4
}
if words := len(strings.Fields(normalized)); words > tokens {
tokens = words
}
if tokens <= 0 {
return 1
}
return tokens
}

View File

@@ -0,0 +1,156 @@
package service
import (
"strings"
"testing"
)
func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) {
usage := &ClaudeUsage{
InputTokens: 1200,
CacheCreationInputTokens: 0,
CacheCreation5mTokens: 0,
CacheCreation1hTokens: 0,
}
parsed := &ParsedRequest{
Model: "claude-sonnet-4-5",
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": strings.Repeat("cached context ", 200),
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "summarize quickly",
},
},
},
},
}
changed := projectUsageToClaudeMax1H(usage, parsed)
if !changed {
t.Fatalf("expected usage to be projected")
}
total := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
if total != 1200 {
t.Fatalf("total tokens changed: got=%d want=%d", total, 1200)
}
if usage.CacheCreation5mTokens != 0 {
t.Fatalf("cache_creation_5m should be 0, got=%d", usage.CacheCreation5mTokens)
}
if usage.InputTokens <= 0 || usage.InputTokens >= 1200 {
t.Fatalf("simulated input out of range, got=%d", usage.InputTokens)
}
if usage.InputTokens > 100 {
t.Fatalf("simulated input should stay near cache breakpoint tail, got=%d", usage.InputTokens)
}
if usage.CacheCreation1hTokens <= 0 {
t.Fatalf("cache_creation_1h should be > 0, got=%d", usage.CacheCreation1hTokens)
}
if usage.CacheCreationInputTokens != usage.CacheCreation1hTokens {
t.Fatalf("cache_creation_input_tokens mismatch: got=%d want=%d", usage.CacheCreationInputTokens, usage.CacheCreation1hTokens)
}
}
func TestComputeClaudeMaxProjectedInputTokens_Deterministic(t *testing.T) {
parsed := &ParsedRequest{
Model: "claude-opus-4-5",
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "build context",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "what is failing now",
},
},
},
},
}
got1 := computeClaudeMaxProjectedInputTokens(4096, parsed)
got2 := computeClaudeMaxProjectedInputTokens(4096, parsed)
if got1 != got2 {
t.Fatalf("non-deterministic input tokens: %d != %d", got1, got2)
}
}
func TestShouldSimulateClaudeMaxUsage(t *testing.T) {
group := &Group{
Platform: PlatformAnthropic,
SimulateClaudeMaxEnabled: true,
}
input := &RecordUsageInput{
Result: &ForwardResult{
Model: "claude-sonnet-4-5",
Usage: ClaudeUsage{
InputTokens: 3000,
CacheCreationInputTokens: 0,
CacheCreation5mTokens: 0,
CacheCreation1hTokens: 0,
},
},
ParsedRequest: &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "cached",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "tail",
},
},
},
},
},
APIKey: &APIKey{Group: group},
}
if !shouldSimulateClaudeMaxUsage(input) {
t.Fatalf("expected simulate=true for claude group with cache signal")
}
input.ParsedRequest = &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "no cache signal"},
},
}
if shouldSimulateClaudeMaxUsage(input) {
t.Fatalf("expected simulate=false when request has no cache signal")
}
input.ParsedRequest = &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "cached",
"cache_control": map[string]any{"type": "ephemeral"},
},
},
},
},
}
input.Result.Usage.CacheCreationInputTokens = 100
if shouldSimulateClaudeMaxUsage(input) {
t.Fatalf("expected simulate=false when cache creation already exists")
}
}

View File

@@ -0,0 +1,41 @@
package service
import (
"sync"
tiktoken "github.com/pkoukk/tiktoken-go"
tiktokenloader "github.com/pkoukk/tiktoken-go-loader"
)
var (
claudeTokenizerOnce sync.Once
claudeTokenizer *tiktoken.Tiktoken
)
func getClaudeTokenizer() *tiktoken.Tiktoken {
claudeTokenizerOnce.Do(func() {
// Use offline loader to avoid runtime dictionary download.
tiktoken.SetBpeLoader(tiktokenloader.NewOfflineLoader())
// Use a high-capacity tokenizer as the default approximation for Claude payloads.
enc, err := tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
if err != nil {
enc, err = tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
}
if err == nil {
claudeTokenizer = enc
}
})
return claudeTokenizer
}
func estimateTokensByThirdPartyTokenizer(text string) (int, bool) {
enc := getClaudeTokenizer()
if enc == nil {
return 0, false
}
tokens := len(enc.EncodeOrdinary(text))
if tokens <= 0 {
return 0, false
}
return tokens, true
}

View File

@@ -343,8 +343,9 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts.
// Uses a detached context with timeout to prevent HTTP request cancellation from
// causing the entire batch to fail (which would show all concurrency as 0).
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return map[int64]int{}, nil
@@ -356,5 +357,11 @@ func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, acc
}
return result, nil
}
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
// Use a detached context so that a cancelled HTTP request doesn't cause
// the Redis pipeline to fail and return all-zero concurrency counts.
redisCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
return s.cache.GetAccountConcurrencyBatch(redisCtx, accountIDs)
}

View File

@@ -35,7 +35,6 @@ type DashboardAggregationRepository interface {
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
}
@@ -297,7 +296,6 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays)
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
if aggErr != nil {
@@ -307,11 +305,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
if usageErr != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
}
dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff)
if dedupErr != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr)
}
if aggErr == nil && usageErr == nil && dedupErr == nil {
if aggErr == nil && usageErr == nil {
s.lastRetentionCleanup.Store(now)
}
}

View File

@@ -12,18 +12,12 @@ import (
type dashboardAggregationRepoTestStub struct {
aggregateCalls int
recomputeCalls int
cleanupUsageCalls int
cleanupDedupCalls int
ensurePartitionCalls int
lastStart time.Time
lastEnd time.Time
watermark time.Time
aggregateErr error
cleanupAggregatesErr error
cleanupUsageErr error
cleanupDedupErr error
ensurePartitionErr error
}
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
@@ -34,7 +28,6 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
}
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
s.recomputeCalls++
return s.AggregateRange(ctx, start, end)
}
@@ -51,18 +44,11 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context
}
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
s.cleanupUsageCalls++
return s.cleanupUsageErr
}
func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
s.cleanupDedupCalls++
return s.cleanupDedupErr
}
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
s.ensurePartitionCalls++
return s.ensurePartitionErr
return nil
}
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
@@ -104,50 +90,6 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
require.Nil(t, svc.lastRetentionCleanup.Load())
require.Equal(t, 1, repo.cleanupUsageCalls)
require.Equal(t, 1, repo.cleanupDedupCalls)
}
func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
require.Nil(t, svc.lastRetentionCleanup.Load())
require.Equal(t, 1, repo.cleanupDedupCalls)
}
func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Enabled: true,
IntervalSeconds: 60,
LookbackSeconds: 120,
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
UsageBillingDedupDays: 2,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.runScheduledAggregation()
require.Equal(t, 1, repo.ensurePartitionCalls)
require.Equal(t, 1, repo.aggregateCalls)
}
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {

View File

@@ -327,14 +327,6 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
return trend, nil
}
func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
ranking, err := s.usageRepo.GetUserSpendingRanking(ctx, startTime, endTime, limit)
if err != nil {
return nil, fmt.Errorf("get user spending ranking: %w", err)
}
return ranking, nil
}
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
if err != nil {

View File

@@ -124,10 +124,6 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut
return nil
}
func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
}

View File

@@ -29,12 +29,10 @@ const (
// Account type constants
const (
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号inference only scope
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock
AccountTypeBedrockAPIKey = domain.AccountTypeBedrockAPIKey // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号inference only scope
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
)
// Redeem type constants

View File

@@ -220,7 +220,7 @@ func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
v, exists := c.Get(OpsSkipPassthroughKey)
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
boolVal, ok := v.(bool)
assert.True(t, ok, "value should be bool")
assert.True(t, ok, "value should be a bool")
assert.True(t, boolVal)
}

View File

@@ -136,18 +136,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
},
}
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &GatewayService{
cfg: cfg,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
deferredService: &DeferredService{},
billingCacheService: nil,
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
deferredService: &DeferredService{},
billingCacheService: nil,
}
account := &Account{
@@ -223,16 +221,14 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
},
}
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &GatewayService{
cfg: cfg,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
account := &Account{
@@ -731,39 +727,6 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf
require.Equal(t, 5, result.usage.OutputTokens)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
},
rateLimitService: &RateLimitService{},
}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`,
"",
`data: {"type":"message_delta","usage":{"output_tokens":5}}`,
"",
}, "\n"))),
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219")
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
@@ -1111,8 +1074,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi
_ = pr.Close()
<-done
require.Error(t, err)
require.Contains(t, err.Error(), "stream usage incomplete after timeout")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.Equal(t, 9, result.usage.InputTokens)
@@ -1141,8 +1103,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219")
require.Error(t, err)
require.Contains(t, err.Error(), "stream usage incomplete")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
}
@@ -1172,8 +1133,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219")
require.Error(t, err)
require.Contains(t, err.Error(), "stream usage incomplete after disconnect")
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.Equal(t, 8, result.usage.InputTokens)

View File

@@ -0,0 +1,196 @@
package service
import (
"context"
"encoding/json"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/tidwall/sjson"
)
type claudeMaxResponseRewriteContext struct {
Parsed *ParsedRequest
Group *Group
}
type claudeMaxResponseRewriteContextKeyType struct{}
var claudeMaxResponseRewriteContextKey = claudeMaxResponseRewriteContextKeyType{}
func withClaudeMaxResponseRewriteContext(ctx context.Context, c *gin.Context, parsed *ParsedRequest) context.Context {
if ctx == nil {
ctx = context.Background()
}
value := claudeMaxResponseRewriteContext{
Parsed: parsed,
Group: claudeMaxGroupFromGinContext(c),
}
return context.WithValue(ctx, claudeMaxResponseRewriteContextKey, value)
}
func claudeMaxResponseRewriteContextFromContext(ctx context.Context) claudeMaxResponseRewriteContext {
if ctx == nil {
return claudeMaxResponseRewriteContext{}
}
value, _ := ctx.Value(claudeMaxResponseRewriteContextKey).(claudeMaxResponseRewriteContext)
return value
}
func claudeMaxGroupFromGinContext(c *gin.Context) *Group {
if c == nil {
return nil
}
raw, exists := c.Get("api_key")
if !exists {
return nil
}
apiKey, ok := raw.(*APIKey)
if !ok || apiKey == nil {
return nil
}
return apiKey.Group
}
func parsedRequestFromGinContext(c *gin.Context) *ParsedRequest {
if c == nil {
return nil
}
raw, exists := c.Get("parsed_request")
if !exists {
return nil
}
parsed, _ := raw.(*ParsedRequest)
return parsed
}
func applyClaudeMaxSimulationToUsage(ctx context.Context, usage *ClaudeUsage, model string, accountID int64) claudeMaxCacheBillingOutcome {
var out claudeMaxCacheBillingOutcome
if usage == nil {
return out
}
rewriteCtx := claudeMaxResponseRewriteContextFromContext(ctx)
return applyClaudeMaxCacheBillingPolicyToUsage(usage, rewriteCtx.Parsed, rewriteCtx.Group, model, accountID)
}
func applyClaudeMaxSimulationToUsageJSONMap(ctx context.Context, usageObj map[string]any, model string, accountID int64) claudeMaxCacheBillingOutcome {
var out claudeMaxCacheBillingOutcome
if usageObj == nil {
return out
}
usage := claudeUsageFromJSONMap(usageObj)
out = applyClaudeMaxSimulationToUsage(ctx, &usage, model, accountID)
if out.Simulated {
rewriteClaudeUsageJSONMap(usageObj, usage)
}
return out
}
func rewriteClaudeUsageJSONBytes(body []byte, usage ClaudeUsage) []byte {
updated := body
var err error
updated, err = sjson.SetBytes(updated, "usage.input_tokens", usage.InputTokens)
if err != nil {
return body
}
updated, err = sjson.SetBytes(updated, "usage.cache_creation_input_tokens", usage.CacheCreationInputTokens)
if err != nil {
return body
}
updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_5m_input_tokens", usage.CacheCreation5mTokens)
if err != nil {
return body
}
updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_1h_input_tokens", usage.CacheCreation1hTokens)
if err != nil {
return body
}
return updated
}
func claudeUsageFromJSONMap(usageObj map[string]any) ClaudeUsage {
var usage ClaudeUsage
if usageObj == nil {
return usage
}
usage.InputTokens = usageIntFromAny(usageObj["input_tokens"])
usage.OutputTokens = usageIntFromAny(usageObj["output_tokens"])
usage.CacheCreationInputTokens = usageIntFromAny(usageObj["cache_creation_input_tokens"])
usage.CacheReadInputTokens = usageIntFromAny(usageObj["cache_read_input_tokens"])
if ccObj, ok := usageObj["cache_creation"].(map[string]any); ok {
usage.CacheCreation5mTokens = usageIntFromAny(ccObj["ephemeral_5m_input_tokens"])
usage.CacheCreation1hTokens = usageIntFromAny(ccObj["ephemeral_1h_input_tokens"])
}
return usage
}
func rewriteClaudeUsageJSONMap(usageObj map[string]any, usage ClaudeUsage) {
if usageObj == nil {
return
}
usageObj["input_tokens"] = usage.InputTokens
usageObj["cache_creation_input_tokens"] = usage.CacheCreationInputTokens
ccObj, _ := usageObj["cache_creation"].(map[string]any)
if ccObj == nil {
ccObj = make(map[string]any, 2)
usageObj["cache_creation"] = ccObj
}
ccObj["ephemeral_5m_input_tokens"] = usage.CacheCreation5mTokens
ccObj["ephemeral_1h_input_tokens"] = usage.CacheCreation1hTokens
}
func usageIntFromAny(v any) int {
switch value := v.(type) {
case int:
return value
case int64:
return int(value)
case float64:
return int(value)
case json.Number:
if n, err := value.Int64(); err == nil {
return int(n)
}
}
return 0
}
// setupClaudeMaxStreamingHook 为 Antigravity 流式路径设置 SSE usage 改写 hook。
func setupClaudeMaxStreamingHook(c *gin.Context, processor *antigravity.StreamingProcessor, originalModel string, accountID int64) {
group := claudeMaxGroupFromGinContext(c)
parsed := parsedRequestFromGinContext(c)
if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) {
return
}
processor.SetUsageMapHook(func(usageMap map[string]any) {
svcUsage := claudeUsageFromJSONMap(usageMap)
outcome := applyClaudeMaxCacheBillingPolicyToUsage(&svcUsage, parsed, group, originalModel, accountID)
if outcome.Simulated {
rewriteClaudeUsageJSONMap(usageMap, svcUsage)
}
})
}
// applyClaudeMaxNonStreamingRewrite 为 Antigravity 非流式路径改写响应体中的 usage。
func applyClaudeMaxNonStreamingRewrite(c *gin.Context, claudeResp []byte, agUsage *antigravity.ClaudeUsage, originalModel string, accountID int64) []byte {
group := claudeMaxGroupFromGinContext(c)
parsed := parsedRequestFromGinContext(c)
if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) {
return claudeResp
}
svcUsage := &ClaudeUsage{
InputTokens: agUsage.InputTokens,
OutputTokens: agUsage.OutputTokens,
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
CacheReadInputTokens: agUsage.CacheReadInputTokens,
}
outcome := applyClaudeMaxCacheBillingPolicyToUsage(svcUsage, parsed, group, originalModel, accountID)
if outcome.Simulated {
return rewriteClaudeUsageJSONBytes(claudeResp, *svcUsage)
}
return claudeResp
}

View File

@@ -0,0 +1,199 @@
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type usageLogRepoRecordUsageStub struct {
UsageLogRepository
last *UsageLog
inserted bool
err error
}
func (s *usageLogRepoRecordUsageStub) Create(_ context.Context, log *UsageLog) (bool, error) {
copied := *log
s.last = &copied
return s.inserted, s.err
}
func newGatewayServiceForRecordUsageTest(repo UsageLogRepository) *GatewayService {
return &GatewayService{
usageLogRepo: repo,
billingService: NewBillingService(&config.Config{}, nil),
cfg: &config.Config{RunMode: config.RunModeSimple},
deferredService: &DeferredService{},
}
}
func TestRecordUsage_SimulateClaudeMaxEnabled_ProjectsUsageAndSkipsTTLOverride(t *testing.T) {
repo := &usageLogRepoRecordUsageStub{inserted: true}
svc := newGatewayServiceForRecordUsageTest(repo)
groupID := int64(11)
input := &RecordUsageInput{
Result: &ForwardResult{
RequestID: "req-sim-1",
Model: "claude-sonnet-4",
Duration: time.Second,
Usage: ClaudeUsage{
InputTokens: 160,
},
},
ParsedRequest: &ParsedRequest{
Model: "claude-sonnet-4",
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "long cached context for prior turns",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "please summarize the logs and provide root cause analysis",
},
},
},
},
},
APIKey: &APIKey{
ID: 1,
GroupID: &groupID,
Group: &Group{
ID: groupID,
Platform: PlatformAnthropic,
RateMultiplier: 1,
SimulateClaudeMaxEnabled: true,
},
},
User: &User{ID: 2},
Account: &Account{
ID: 3,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"cache_ttl_override_enabled": true,
"cache_ttl_override_target": "5m",
},
},
}
err := svc.RecordUsage(context.Background(), input)
require.NoError(t, err)
require.NotNil(t, repo.last)
log := repo.last
require.Equal(t, 80, log.InputTokens)
require.Equal(t, 80, log.CacheCreationTokens)
require.Equal(t, 0, log.CacheCreation5mTokens)
require.Equal(t, 80, log.CacheCreation1hTokens)
require.False(t, log.CacheTTLOverridden, "simulate outcome should skip account ttl override")
}
func TestRecordUsage_SimulateClaudeMaxDisabled_AppliesTTLOverride(t *testing.T) {
repo := &usageLogRepoRecordUsageStub{inserted: true}
svc := newGatewayServiceForRecordUsageTest(repo)
groupID := int64(12)
input := &RecordUsageInput{
Result: &ForwardResult{
RequestID: "req-sim-2",
Model: "claude-sonnet-4",
Duration: time.Second,
Usage: ClaudeUsage{
InputTokens: 40,
CacheCreationInputTokens: 120,
CacheCreation1hTokens: 120,
},
},
APIKey: &APIKey{
ID: 2,
GroupID: &groupID,
Group: &Group{
ID: groupID,
Platform: PlatformAnthropic,
RateMultiplier: 1,
SimulateClaudeMaxEnabled: false,
},
},
User: &User{ID: 3},
Account: &Account{
ID: 4,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"cache_ttl_override_enabled": true,
"cache_ttl_override_target": "5m",
},
},
}
err := svc.RecordUsage(context.Background(), input)
require.NoError(t, err)
require.NotNil(t, repo.last)
log := repo.last
require.Equal(t, 120, log.CacheCreationTokens)
require.Equal(t, 120, log.CacheCreation5mTokens)
require.Equal(t, 0, log.CacheCreation1hTokens)
require.True(t, log.CacheTTLOverridden)
}
func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationBypassesSimulation(t *testing.T) {
repo := &usageLogRepoRecordUsageStub{inserted: true}
svc := newGatewayServiceForRecordUsageTest(repo)
groupID := int64(13)
input := &RecordUsageInput{
Result: &ForwardResult{
RequestID: "req-sim-3",
Model: "claude-sonnet-4",
Duration: time.Second,
Usage: ClaudeUsage{
InputTokens: 20,
CacheCreationInputTokens: 120,
CacheCreation5mTokens: 120,
},
},
APIKey: &APIKey{
ID: 3,
GroupID: &groupID,
Group: &Group{
ID: groupID,
Platform: PlatformAnthropic,
RateMultiplier: 1,
SimulateClaudeMaxEnabled: true,
},
},
User: &User{ID: 4},
Account: &Account{
ID: 5,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"cache_ttl_override_enabled": true,
"cache_ttl_override_target": "5m",
},
},
}
err := svc.RecordUsage(context.Background(), input)
require.NoError(t, err)
require.NotNil(t, repo.last)
log := repo.last
require.Equal(t, 20, log.InputTokens)
require.Equal(t, 120, log.CacheCreation5mTokens)
require.Equal(t, 0, log.CacheCreation1hTokens)
require.Equal(t, 120, log.CacheCreationTokens)
require.False(t, log.CacheTTLOverridden, "existing cache_creation with SimulateClaudeMax enabled should skip account ttl override")
}

View File

@@ -1,371 +0,0 @@
//go:build unit
package service
import (
"context"
"errors"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.1
return NewGatewayService(
nil,
nil,
usageRepo,
nil,
userRepo,
subRepo,
nil,
nil,
cfg,
nil,
nil,
NewBillingService(cfg, nil),
nil,
&BillingCacheService{},
nil,
nil,
&DeferredService{},
nil,
nil,
nil,
nil,
nil,
)
}
func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
svc.usageBillingRepo = billingRepo
return svc
}
type openAIRecordUsageBestEffortLogRepoStub struct {
UsageLogRepository
bestEffortErr error
createErr error
bestEffortCalls int
createCalls int
lastLog *UsageLog
lastCtxErr error
}
func (s *openAIRecordUsageBestEffortLogRepoStub) CreateBestEffort(ctx context.Context, log *UsageLog) error {
s.bestEffortCalls++
s.lastLog = log
s.lastCtxErr = ctx.Err()
return s.bestEffortErr
}
func (s *openAIRecordUsageBestEffortLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
s.createCalls++
s.lastLog = log
s.lastCtxErr = ctx.Err()
return false, s.createErr
}
func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsage(reqCtx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_detached_ctx",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 501,
Quota: 100,
},
User: &User{ID: 601},
Account: &Account{ID: 701},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.NoError(t, userRepo.lastCtxErr)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.NoError(t, quotaSvc.lastQuotaCtxErr)
}
func TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
payloadHash := HashUsageRequestPayload([]byte(`{"messages":[{"role":"user","content":"hello"}]}`))
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_payload_hash",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
RequestPayloadHash: payloadHash,
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
}
func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-123")
err := svc.RecordUsage(ctx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_payload_fallback",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
}
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_not_persisted",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 503,
Quota: 100,
},
User: &User{ID: 603},
Account: &Account{ID: 703},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.Equal(t, 1, quotaSvc.quotaCalls)
}
func TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsageWithLongContext(reqCtx, &RecordUsageLongContextInput{
Result: &ForwardResult{
RequestID: "gateway_long_context_detached_ctx",
Usage: ClaudeUsage{
InputTokens: 12,
OutputTokens: 8,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 502,
Quota: 100,
},
User: &User{ID: 602},
Account: &Account{ID: 702},
LongContextThreshold: 200000,
LongContextMultiplier: 2,
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.NoError(t, userRepo.lastCtxErr)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.NoError(t, quotaSvc.lastQuotaCtxErr)
}
func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "gateway-local-fallback")
err := svc.RecordUsage(ctx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 504},
User: &User{ID: 604},
Account: &Account{ID: 704},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID)
}
func TestGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-stable-123")
ctx = context.WithValue(ctx, ctxkey.RequestID, "req-local-ignored")
err := svc.RecordUsage(ctx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "upstream-volatile-456",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 506},
User: &User{ID: 606},
Account: &Account{ID: 706},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "client:client-stable-123", billingRepo.lastCmd.RequestID)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "client:client-stable-123", usageRepo.lastLog.RequestID)
}
func TestGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 507},
User: &User{ID: 607},
Account: &Account{ID: 707},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:"))
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID)
}
func TestGatewayServiceRecordUsage_DroppedUsageLogDoesNotSyncFallback(t *testing.T) {
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{
bestEffortErr: MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")),
}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_drop_usage_log",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 508},
User: &User{ID: 608},
Account: &Account{ID: 708},
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.bestEffortCalls)
require.Equal(t, 0, usageRepo.createCalls)
}
func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo)
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_billing_fail",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 505},
User: &User{ID: 605},
Account: &Account{ID: 705},
})
require.Error(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 0, usageRepo.calls)
}

View File

@@ -0,0 +1,170 @@
package service
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestHandleNonStreamingResponse_UsageAlignedWithClaudeMaxSimulation(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := &GatewayService{
cfg: &config.Config{},
rateLimitService: &RateLimitService{},
}
account := &Account{
ID: 11,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"cache_ttl_override_enabled": true,
"cache_ttl_override_target": "5m",
},
}
group := &Group{
ID: 99,
Platform: PlatformAnthropic,
SimulateClaudeMaxEnabled: true,
}
parsed := &ParsedRequest{
Model: "claude-sonnet-4",
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "long cached context",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "new user question",
},
},
},
},
}
upstreamBody := []byte(`{"id":"msg_1","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: ioNopCloserBytes(upstreamBody),
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil))
c.Set("api_key", &APIKey{Group: group})
requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed)
usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4")
require.NoError(t, err)
require.NotNil(t, usage)
var rendered struct {
Usage ClaudeUsage `json:"usage"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rendered))
rendered.Usage.CacheCreation5mTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_5m_input_tokens").Int())
rendered.Usage.CacheCreation1hTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_1h_input_tokens").Int())
require.Equal(t, rendered.Usage.InputTokens, usage.InputTokens)
require.Equal(t, rendered.Usage.OutputTokens, usage.OutputTokens)
require.Equal(t, rendered.Usage.CacheCreationInputTokens, usage.CacheCreationInputTokens)
require.Equal(t, rendered.Usage.CacheCreation5mTokens, usage.CacheCreation5mTokens)
require.Equal(t, rendered.Usage.CacheCreation1hTokens, usage.CacheCreation1hTokens)
require.Equal(t, rendered.Usage.CacheReadInputTokens, usage.CacheReadInputTokens)
require.Greater(t, usage.CacheCreation1hTokens, 0)
require.Equal(t, 0, usage.CacheCreation5mTokens)
require.Less(t, usage.InputTokens, 120)
}
func TestHandleNonStreamingResponse_ClaudeMaxDisabled_NoSimulationIntercept(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := &GatewayService{
cfg: &config.Config{},
rateLimitService: &RateLimitService{},
}
account := &Account{
ID: 12,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"cache_ttl_override_enabled": true,
"cache_ttl_override_target": "5m",
},
}
group := &Group{
ID: 100,
Platform: PlatformAnthropic,
SimulateClaudeMaxEnabled: false,
}
parsed := &ParsedRequest{
Model: "claude-sonnet-4",
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "long cached context",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "new user question",
},
},
},
},
}
upstreamBody := []byte(`{"id":"msg_2","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: ioNopCloserBytes(upstreamBody),
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil))
c.Set("api_key", &APIKey{Group: group})
requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed)
usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4")
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, 120, usage.InputTokens)
require.Equal(t, 0, usage.CacheCreationInputTokens)
require.Equal(t, 0, usage.CacheCreation5mTokens)
require.Equal(t, 0, usage.CacheCreation1hTokens)
}
func ioNopCloserBytes(b []byte) *readCloserFromBytes {
return &readCloserFromBytes{Reader: bytes.NewReader(b)}
}
type readCloserFromBytes struct {
*bytes.Reader
}
func (r *readCloserFromBytes) Close() error {
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,267 +0,0 @@
package service
import (
"context"
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
type betaPolicySettingRepoStub struct {
values map[string]string
}
func (s *betaPolicySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *betaPolicySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", ErrSettingNotFound
}
func (s *betaPolicySettingRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *betaPolicySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *betaPolicySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *betaPolicySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *betaPolicySettingRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestResolveBedrockBetaTokensForRequest_BlocksOnOriginalAnthropicToken(t *testing.T) {
settings := &BetaPolicySettings{
Rules: []BetaPolicyRule{
{
BetaToken: "advanced-tool-use-2025-11-20",
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "advanced tool use is blocked",
},
},
}
raw, err := json.Marshal(settings)
if err != nil {
t.Fatalf("marshal settings: %v", err)
}
svc := &GatewayService{
settingService: NewSettingService(
&betaPolicySettingRepoStub{values: map[string]string{
SettingKeyBetaPolicySettings: string(raw),
}},
&config.Config{},
),
}
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
_, err = svc.resolveBedrockBetaTokensForRequest(
context.Background(),
account,
"advanced-tool-use-2025-11-20",
[]byte(`{"messages":[{"role":"user","content":"hi"}]}`),
"us.anthropic.claude-opus-4-6-v1",
)
if err == nil {
t.Fatal("expected raw advanced-tool-use token to be blocked before Bedrock transform")
}
if err.Error() != "advanced tool use is blocked" {
t.Fatalf("unexpected error: %v", err)
}
}
func TestResolveBedrockBetaTokensForRequest_FiltersAfterBedrockTransform(t *testing.T) {
settings := &BetaPolicySettings{
Rules: []BetaPolicyRule{
{
BetaToken: "tool-search-tool-2025-10-19",
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
},
},
}
raw, err := json.Marshal(settings)
if err != nil {
t.Fatalf("marshal settings: %v", err)
}
svc := &GatewayService{
settingService: NewSettingService(
&betaPolicySettingRepoStub{values: map[string]string{
SettingKeyBetaPolicySettings: string(raw),
}},
&config.Config{},
),
}
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
betaTokens, err := svc.resolveBedrockBetaTokensForRequest(
context.Background(),
account,
"advanced-tool-use-2025-11-20",
[]byte(`{"messages":[{"role":"user","content":"hi"}]}`),
"us.anthropic.claude-opus-4-6-v1",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
for _, token := range betaTokens {
if token == "tool-search-tool-2025-10-19" {
t.Fatalf("expected transformed Bedrock token to be filtered")
}
}
}
// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking 验证:
// 管理员 block 了 interleaved-thinking客户端不在 header 中带该 token
// 但请求体包含 thinking 字段 → 自动注入后应被 block。
func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking(t *testing.T) {
settings := &BetaPolicySettings{
Rules: []BetaPolicyRule{
{
BetaToken: "interleaved-thinking-2025-05-14",
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "thinking is blocked",
},
},
}
raw, err := json.Marshal(settings)
if err != nil {
t.Fatalf("marshal settings: %v", err)
}
svc := &GatewayService{
settingService: NewSettingService(
&betaPolicySettingRepoStub{values: map[string]string{
SettingKeyBetaPolicySettings: string(raw),
}},
&config.Config{},
),
}
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
// header 中不带 beta token但 body 中有 thinking 字段
_, err = svc.resolveBedrockBetaTokensForRequest(
context.Background(),
account,
"", // 空 header
[]byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`),
"us.anthropic.claude-opus-4-6-v1",
)
if err == nil {
t.Fatal("expected body-injected interleaved-thinking to be blocked")
}
if err.Error() != "thinking is blocked" {
t.Fatalf("unexpected error: %v", err)
}
}
// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch 验证:
// 管理员 block 了 tool-search-tool客户端不在 header 中带 beta token
// 但请求体包含 tool search 工具 → 自动注入后应被 block。
func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch(t *testing.T) {
settings := &BetaPolicySettings{
Rules: []BetaPolicyRule{
{
BetaToken: "tool-search-tool-2025-10-19",
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "tool search is blocked",
},
},
}
raw, err := json.Marshal(settings)
if err != nil {
t.Fatalf("marshal settings: %v", err)
}
svc := &GatewayService{
settingService: NewSettingService(
&betaPolicySettingRepoStub{values: map[string]string{
SettingKeyBetaPolicySettings: string(raw),
}},
&config.Config{},
),
}
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
// header 中不带 beta token但 body 中有 tool_search_tool 工具
_, err = svc.resolveBedrockBetaTokensForRequest(
context.Background(),
account,
"",
[]byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`),
"us.anthropic.claude-sonnet-4-6",
)
if err == nil {
t.Fatal("expected body-injected tool-search-tool to be blocked")
}
if err.Error() != "tool search is blocked" {
t.Fatalf("unexpected error: %v", err)
}
}
// TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches 验证:
// body 自动注入的 token 如果没有对应的 block 规则,应正常通过。
func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *testing.T) {
settings := &BetaPolicySettings{
Rules: []BetaPolicyRule{
{
BetaToken: "computer-use-2025-11-24",
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "computer use is blocked",
},
},
}
raw, err := json.Marshal(settings)
if err != nil {
t.Fatalf("marshal settings: %v", err)
}
svc := &GatewayService{
settingService: NewSettingService(
&betaPolicySettingRepoStub{values: map[string]string{
SettingKeyBetaPolicySettings: string(raw),
}},
&config.Config{},
),
}
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
// body 中有 thinking会注入 interleaved-thinking但 block 规则只针对 computer-use
tokens, err := svc.resolveBedrockBetaTokensForRequest(
context.Background(),
account,
"",
[]byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`),
"us.anthropic.claude-opus-4-6-v1",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
found := false
for _, token := range tokens {
if token == "interleaved-thinking-2025-05-14" {
found = true
}
}
if !found {
t.Fatal("expected interleaved-thinking token to be present")
}
}

View File

@@ -1,48 +0,0 @@
package service
import "testing"
func TestGatewayServiceIsModelSupportedByAccount_BedrockDefaultMappingRestrictsModels(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
},
}
if !svc.isModelSupportedByAccount(account, "claude-sonnet-4-5") {
t.Fatalf("expected default Bedrock alias to be supported")
}
if svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022") {
t.Fatalf("expected unsupported alias to be rejected for Bedrock account")
}
}
func TestGatewayServiceIsModelSupportedByAccount_BedrockCustomMappingStillActsAsAllowlist(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "eu-west-1",
"model_mapping": map[string]any{
"claude-sonnet-*": "claude-sonnet-4-6",
},
},
}
if !svc.isModelSupportedByAccount(account, "claude-sonnet-4-6") {
t.Fatalf("expected matched custom mapping to be supported")
}
if !svc.isModelSupportedByAccount(account, "claude-opus-4-6") {
t.Fatalf("expected default Bedrock alias fallback to remain supported")
}
if svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022") {
t.Fatalf("expected unsupported model to still be rejected")
}
}

View File

@@ -181,8 +181,7 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) {
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
_ = pr.Close()
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NoError(t, err)
require.NotNil(t, result)
}

View File

@@ -50,6 +50,9 @@ type Group struct {
// MCP XML 协议注入开关(仅 antigravity 平台使用)
MCPXMLInject bool
// Claude usage 模拟开关:将无写缓存 usage 模拟为 claude-max 风格
SimulateClaudeMaxEnabled bool
// 支持的模型系列(仅 antigravity 平台使用)
// 可选值: claude, gemini_text, gemini_image
SupportedModelScopes []string

View File

@@ -129,41 +129,6 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
}
}
// 兼容遗留的 functions 和 function_call转换为 tools 和 tool_choice
if functionsRaw, ok := reqBody["functions"]; ok {
if functions, k := functionsRaw.([]any); k {
tools := make([]any, 0, len(functions))
for _, f := range functions {
tools = append(tools, map[string]any{
"type": "function",
"function": f,
})
}
reqBody["tools"] = tools
}
delete(reqBody, "functions")
result.Modified = true
}
if fcRaw, ok := reqBody["function_call"]; ok {
if fcStr, ok := fcRaw.(string); ok {
// e.g. "auto", "none"
reqBody["tool_choice"] = fcStr
} else if fcObj, ok := fcRaw.(map[string]any); ok {
// e.g. {"name": "my_func"}
if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" {
reqBody["tool_choice"] = map[string]any{
"type": "function",
"function": map[string]any{
"name": name,
},
}
}
}
delete(reqBody, "function_call")
result.Modified = true
}
if normalizeCodexTools(reqBody) {
result.Modified = true
}
@@ -338,18 +303,6 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
continue
}
typ, _ := m["type"].(string)
// 修复 OpenAI 上游的最新校验:"Expected an ID that begins with 'fc'"
fixIDPrefix := func(id string) string {
if id == "" || strings.HasPrefix(id, "fc") {
return id
}
if strings.HasPrefix(id, "call_") {
return "fc" + strings.TrimPrefix(id, "call_")
}
return "fc_" + id
}
if typ == "item_reference" {
if !preserveReferences {
continue
@@ -358,9 +311,6 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
for key, value := range m {
newItem[key] = value
}
if id, ok := newItem["id"].(string); ok && id != "" {
newItem["id"] = fixIDPrefix(id)
}
filtered = append(filtered, newItem)
continue
}
@@ -380,20 +330,10 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
if isCodexToolCallItemType(typ) {
callID, ok := m["call_id"].(string)
if !ok || strings.TrimSpace(callID) == "" {
if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" {
if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" {
callID = id
ensureCopy()
newItem["call_id"] = callID
}
}
if callID != "" {
fixedCallID := fixIDPrefix(callID)
if fixedCallID != callID {
ensureCopy()
newItem["call_id"] = fixedCallID
newItem["call_id"] = id
}
}
}
@@ -404,14 +344,6 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
if !isCodexToolCallItemType(typ) {
delete(newItem, "call_id")
}
} else {
if id, ok := newItem["id"].(string); ok && id != "" {
fixedID := fixIDPrefix(id)
if fixedID != id {
ensureCopy()
newItem["id"] = fixedID
}
}
}
filtered = append(filtered, newItem)

Some files were not shown because too many files have changed in this diff Show More