Compare commits

..

51 Commits
dev ... v0.1.98

Author SHA1 Message Date
Wesley Liddick
2e3e8687e1 Merge pull request #993 from xvhuan/fix/codex-responses-id-prefix-hotfix-20260314
fix: 止血 Codex/Responses 原生 input id 被误改成 fc_*
2026-03-14 13:54:26 +08:00
ius
ca42a45802 fix: stop rewriting native responses input ids 2026-03-14 13:47:01 +08:00
Wesley Liddick
9350ecb62b Merge pull request #987 from Ethan0x0000/feat-chatcompletions2repsonses-fix
fix: chat compatibility model fallback and reasoning_content output
2026-03-14 13:41:47 +08:00
Wesley Liddick
a4a026e8da Merge pull request #990 from LvyuanW/admin-openai-available-models-fix
fix: respect OpenAI OAuth model mapping in admin available models
2026-03-14 13:33:18 +08:00
Wesley Liddick
342fd03e72 Merge pull request #986 from LvyuanW/openai-model-mapping-fix
fix: honor account model mapping before group fallback
2026-03-14 13:32:26 +08:00
Ethan0x0000
e3f1fd9b63 fix: handle strings.Builder write errors in assistant parsing 2026-03-14 13:12:17 +08:00
Wang Lvyuan
a377e99088 fix: remove unused wildcard mapping helper 2026-03-14 12:56:34 +08:00
Wang Lvyuan
1d3d7a3033 fix: respect OpenAI model mapping in admin available models 2026-03-14 12:45:10 +08:00
Wesley Liddick
e7086cb3a3 Merge pull request #988 from LvyuanW/scheduler-snapshot-sync
fix: sync scheduler snapshot on account updates
2026-03-14 12:38:48 +08:00
shaw
4f2a97073e chore: update docs 2026-03-14 12:37:36 +08:00
Wesley Liddick
7407e3b45d Merge pull request #984 from touwaeriol/docs/ecosystem-projects
docs: add iframe integration feature and ecosystem projects section
2026-03-14 12:31:25 +08:00
Wang Lvyuan
01ef7340aa Merge remote-tracking branch 'origin/main' into openai-model-mapping-fix 2026-03-14 12:27:08 +08:00
Wang Lvyuan
1c960d22c1 fix: sync scheduler snapshot on account updates 2026-03-14 12:21:28 +08:00
Ethan0x0000
ece0606fed fix: consolidate chat-completions compatibility fixes
- apply default mapped model only when scheduling fallback is actually used

- preserve reasoning in OpenAI-compatible output via reasoning_content and avoid invalid input function_call ids
2026-03-14 12:12:08 +08:00
Wesley Liddick
e6d59216d4 Merge pull request #975 from Ylarod/aws-bedrock
sub2api: add bedrock support
2026-03-14 10:52:24 +08:00
Wang Lvyuan
4e8615f276 fix: honor account model mapping before group fallback 2026-03-14 10:47:31 +08:00
erio
91e4d95660 docs: add iframe integration feature and ecosystem projects section 2026-03-14 03:16:07 +08:00
Wesley Liddick
4588258d80 Merge pull request #960 from 0xObjc/codex/user-spending-ranking
feat(admin): add user spending ranking dashboard view
2026-03-13 23:06:30 +08:00
Wesley Liddick
c12e48f966 Merge pull request #949 from kunish/fix/remove-done-stop-sequence
fix: remove SSE termination marker from DefaultStopSequences
2026-03-13 22:56:29 +08:00
Wesley Liddick
ec8f50a658 Merge pull request #951 from wanXcode/fix/dashboard-user-trend-label
fix(dashboard): prefer username over email prefix in recent usage chart
2026-03-13 22:56:13 +08:00
Wesley Liddick
99c9191784 Merge pull request #974 from 0xObjc/codex/next-pr-base-20260313
fix(admin): default dashboard date range to today
2026-03-13 22:55:14 +08:00
shaw
6bb02d141f chore: remove accidentally committed PR diagnostic report 2026-03-13 22:52:05 +08:00
Wesley Liddick
07bb2a5f3f Merge pull request #952 from xvhuan/feat/billing-ledger-decouple-usage-log-20260312
feat: 解耦计费正确性与 usage_logs 批量写压
2026-03-13 22:46:09 +08:00
Wesley Liddick
417861a48e Merge pull request #956 from share-wey/main
chore: codex transform fixes and feature compatibility
2026-03-13 22:36:29 +08:00
Wesley Liddick
b7e878de64 Merge pull request #980 from touwaeriol/feat/redeem-subscription-support
feat(redeem): support subscription type in create-and-redeem API
2026-03-13 22:15:33 +08:00
erio
05edb5514b feat(redeem): support subscription type in create-and-redeem API
Add group_id and validity_days fields to CreateAndRedeemCodeRequest,
enabling subscription-type redemption codes to be created and redeemed
in a single API call.

- Type defaults to "balance" when omitted for backward compatibility
- Subscription type requires group_id (non-nil) and validity_days (>0)
- Existing balance/concurrency callers are unaffected
2026-03-13 21:26:46 +08:00
Ylarod
e90ec847b6 fix lint 2026-03-13 19:15:27 +08:00
Connie Borer
7e288acc90 Merge branch 'Wei-Shaw:main' into main 2026-03-13 17:28:14 +08:00
Peter
27ff222cfb fix(admin): default dashboard date range to today 2026-03-13 17:02:54 +08:00
Ylarod
11f7b83522 sub2api: add bedrock support 2026-03-13 17:00:16 +08:00
Wesley Liddick
1ee984478f Merge pull request #957 from touwaeriol/feat/group-rate-multipliers-modal
feat(groups): add rate multipliers management modal
2026-03-13 11:11:13 +08:00
Wesley Liddick
fd693dc526 Merge pull request #967 from StarryKira/fix/admin-reset-quota-monthly
fix: 管理员重置配额补全 monthly 字段并修复 ristretto 缓存异步问题 fix issue #964
2026-03-13 11:10:47 +08:00
haruka
e73531ce9b fix: 管理员重置配额补全 monthly 字段并修复 ristretto 缓存异步问题
- 后端 handler:ResetSubscriptionQuotaRequest 新增 Monthly 字段,
  验证逻辑扩展为 daily/weekly/monthly 至少一项为 true
- 后端 service:AdminResetQuota 新增 resetMonthly 参数,
  调用 ResetMonthlyUsage;重置后追加 subCacheL1.Wait(),
  保证 ristretto Del() 的异步删除立即生效,消除重置后
  /v1/usage 返回旧用量数据的竞态窗口
- 后端测试:更新存量测试用例匹配新签名,补充
  TestAdminResetQuota_ResetMonthlyOnly /
  TestAdminResetQuota_ResetMonthlyUsageError 两个新用例
- 前端 API:resetQuota options 类型新增 monthly: boolean
- 前端视图:confirmResetQuota 改为同时重置 daily/weekly/monthly
- i18n:中英文确认提示文案更新,提及每月配额

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-13 10:39:35 +08:00
Wesley Liddick
ecea13757b Merge pull request #955 from DaydreamCoding/feat/gpt-training-off
feat: GPT Private设置数据不用于训练
2026-03-13 09:12:33 +08:00
Peter
80d8d6c3bc feat(admin): add user spending ranking dashboard view 2026-03-13 03:43:03 +08:00
erio
d648811233 feat(groups): add rate multipliers management modal
Add a dedicated modal in group management for viewing, adding, editing,
and deleting per-user rate multipliers within a group.

Backend:
- GET /admin/groups/:id/rate-multipliers - list entries with user details
- PUT /admin/groups/:id/rate-multipliers - batch sync (full replace)
- DELETE /admin/groups/:id/rate-multipliers - clear all entries
- Repository: GetByGroupID, SyncGroupRateMultipliers methods on
  user_group_rate_multipliers table (same table as user-side rates)

Frontend:
- New GroupRateMultipliersModal component with:
  - User search and add with email autocomplete
  - Editable rate column with local edit mode (cancel/save)
  - Batch adjust: multiply all rates by a factor
  - Clear all (local operation, requires save to persist)
  - Pagination (10/20/50 per page)
  - Platform icon with brand colors in group info bar
  - Unsaved changes indicator with revert option
- Unit tests for all three backend endpoints
2026-03-12 23:37:36 +08:00
QTom
34695acb85 fix: 移除账号导入时同步调用 disableOpenAITraining,避免网络超时导致导入失败
privacy_mode 改为由 TokenRefreshService 在 token 刷新后异步补设。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-12 22:36:25 +08:00
QTom
a63de12182 feat: GPT 隐私模式 + no-train 前端展示优化 2026-03-12 21:24:01 +08:00
yexueduxing
f16910d616 chore: codex transform fixes and feature compatibility 2026-03-12 20:52:35 +08:00
ius
64b3f3cec1 test: relocate best-effort usage log stub 2026-03-12 18:43:37 +08:00
ius
6a685727d0 fix: harden usage billing idempotency and backpressure 2026-03-12 18:38:09 +08:00
ius
32d25f76fc fix: respect preconfigured usage log batch channels 2026-03-12 17:44:57 +08:00
wanXcode
69cafe8674 fix(dashboard): prefer username in user usage trend 2026-03-12 17:42:41 +08:00
ius
18ba8d9166 fix: stabilize repository integration paths 2026-03-12 17:42:41 +08:00
ius
e97fd7e81c test: align oauth passthrough stream expectations 2026-03-12 17:22:01 +08:00
kunish
cdb64b0d33 fix: remove SSE termination marker from DefaultStopSequences
The SSE stream termination marker string was incorrectly included in
DefaultStopSequences, causing Gemini to prematurely stop generating
output whenever the model produced text containing that marker.

The SSE-level protocol filtering in stream_transformer.go already
handles this marker correctly; it should not be a stop sequence for
the model's text generation.
2026-03-12 17:10:01 +08:00
ius
8d4d3b03bb fix: remove unused gateway usage helpers 2026-03-12 17:08:57 +08:00
ius
addefe79e1 fix: align docker health checks with runtime image 2026-03-12 17:03:21 +08:00
ius
b764d3b8f6 Merge remote-tracking branch 'origin/main' into feat/billing-ledger-decouple-usage-log-20260312 2026-03-12 16:53:28 +08:00
ius
611fd884bd feat: decouple billing correctness from usage log batching 2026-03-12 16:53:18 +08:00
ius
c9debc50b1 Batch usage log writes in repository 2026-03-11 20:29:48 +08:00
174 changed files with 10338 additions and 5884 deletions

View File

@@ -17,7 +17,6 @@ jobs:
go-version-file: backend/go.mod
check-latest: false
cache: true
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.26.1'
@@ -37,7 +36,6 @@ jobs:
go-version-file: backend/go.mod
check-latest: false
cache: true
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.26.1'

10
.gitignore vendored
View File

@@ -78,7 +78,6 @@ Desktop.ini
# ===================
tmp/
temp/
logs/
*.tmp
*.temp
*.log
@@ -129,15 +128,8 @@ deploy/docker-compose.override.yml
vite.config.js
docs/*
.serena/
# ===================
# 压测工具
# ===================
tools/loadtest/
# Antigravity Manager
Antigravity-Manager/
antigravity_projectid_fix.patch
.codex/
frontend/coverage/
aicodex
output/

1392
AGENTS.md

File diff suppressed because it is too large Load Diff

1337
CLAUDE.md

File diff suppressed because it is too large Load Diff

View File

@@ -39,6 +39,16 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
- **Concurrency Control** - Per-user and per-account concurrency limits
- **Rate Limiting** - Configurable request and token rate limits
- **Admin Dashboard** - Web interface for monitoring and management
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
## Ecosystem
Community projects that extend or integrate with Sub2API:
| Project | Description | Features |
|---------|-------------|----------|
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | Self-service payment system | Self-service top-up and subscription purchase; supports YiPay protocol, WeChat Pay, Alipay, Stripe; embeddable via iframe |
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native |
## Tech Stack

View File

@@ -39,6 +39,16 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
- **并发控制** - 用户级和账号级并发限制
- **速率限制** - 可配置的请求和 Token 速率限制
- **管理后台** - Web 界面进行监控和管理
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
## 生态项目
围绕 Sub2API 的社区扩展与集成项目:
| 项目 | 说明 | 功能 |
|------|------|------|
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | 自助支付系统 | 用户自助充值、自助订阅购买兼容易支付协议、微信官方支付、支付宝官方支付、Stripe支持 iframe 嵌入管理后台 |
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用iOS/Android/Web支持用户管理、账号管理、监控看板、多后端切换基于 Expo + React Native 构建 |
## 技术栈

View File

@@ -1 +1 @@
0.1.96.1
0.1.88

View File

@@ -41,6 +41,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
// Server layer ProviderSet
server.ProviderSet,
// Privacy client factory for OpenAI training opt-out
providePrivacyClientFactory,
// BuildInfo provider
provideServiceBuildInfo,
@@ -53,6 +56,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
return nil, nil
}
func providePrivacyClientFactory() service.PrivacyClientFactory {
return repository.CreatePrivacyReqClient
}
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
return service.BuildInfo{
Version: buildInfo.Version,

View File

@@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemHandler := handler.NewRedeemHandler(redeemService)
@@ -104,7 +105,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
privacyClientFactory := providePrivacyClientFactory()
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
@@ -162,9 +164,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
@@ -226,7 +228,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
@@ -245,6 +247,10 @@ type Application struct {
Cleanup func()
}
func providePrivacyClientFactory() service.PrivacyClientFactory {
return repository.CreatePrivacyReqClient
}
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
return service.BuildInfo{
Version: buildInfo.Version,

View File

@@ -62,28 +62,26 @@ type Group struct {
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
// allow Claude Code client only
// 是否仅允许 Claude Code 客户端
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// fallback group for non-Claude-Code requests
// Claude Code 请求降级使用的分组 ID
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
// fallback group for invalid request
// 无效请求兜底使用的分组 ID
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
// model routing config: pattern -> account ids
// 模型路由配置:模型模式 -> 优先账号ID列表
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
// whether model routing is enabled
// 是否启用模型路由配置
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
// whether MCP XML prompt injection is enabled
// 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
// supported model scopes: claude, gemini_text, gemini_image
// 支持的模型系列:claude, gemini_text, gemini_image
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
// group display order, lower comes first
// 分组显示排序,数值越小越靠前
SortOrder int `json:"sort_order,omitempty"`
// 是否允许 /v1/messages 调度到此 OpenAI 分组
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
// 默认映射模型 ID当账号级映射找不到时使用此值
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
// simulate claude usage as claude-max style (1h cache write)
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"`
@@ -192,7 +190,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case group.FieldModelRouting, group.FieldSupportedModelScopes:
values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldSimulateClaudeMaxEnabled:
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch:
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
values[i] = new(sql.NullFloat64)
@@ -433,12 +431,6 @@ func (_m *Group) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.DefaultMappedModel = value.String
}
case group.FieldSimulateClaudeMaxEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field simulate_claude_max_enabled", values[i])
} else if value.Valid {
_m.SimulateClaudeMaxEnabled = value.Bool
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -638,9 +630,6 @@ func (_m *Group) String() string {
builder.WriteString(", ")
builder.WriteString("default_mapped_model=")
builder.WriteString(_m.DefaultMappedModel)
builder.WriteString(", ")
builder.WriteString("simulate_claude_max_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.SimulateClaudeMaxEnabled))
builder.WriteByte(')')
return builder.String()
}

View File

@@ -79,8 +79,6 @@ const (
FieldAllowMessagesDispatch = "allow_messages_dispatch"
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
FieldDefaultMappedModel = "default_mapped_model"
// FieldSimulateClaudeMaxEnabled holds the string denoting the simulate_claude_max_enabled field in the database.
FieldSimulateClaudeMaxEnabled = "simulate_claude_max_enabled"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -188,7 +186,6 @@ var Columns = []string{
FieldSortOrder,
FieldAllowMessagesDispatch,
FieldDefaultMappedModel,
FieldSimulateClaudeMaxEnabled,
}
var (
@@ -262,8 +259,6 @@ var (
DefaultDefaultMappedModel string
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
DefaultMappedModelValidator func(string) error
// DefaultSimulateClaudeMaxEnabled holds the default value on creation for the "simulate_claude_max_enabled" field.
DefaultSimulateClaudeMaxEnabled bool
)
// OrderOption defines the ordering options for the Group queries.
@@ -424,11 +419,6 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
}
// BySimulateClaudeMaxEnabled orders the results by the simulate_claude_max_enabled field.
func BySimulateClaudeMaxEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSimulateClaudeMaxEnabled, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {

View File

@@ -205,11 +205,6 @@ func DefaultMappedModel(v string) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
}
// SimulateClaudeMaxEnabled applies equality check predicate on the "simulate_claude_max_enabled" field. It's identical to SimulateClaudeMaxEnabledEQ.
func SimulateClaudeMaxEnabled(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
@@ -1560,16 +1555,6 @@ func DefaultMappedModelContainsFold(v string) predicate.Group {
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
}
// SimulateClaudeMaxEnabledEQ applies the EQ predicate on the "simulate_claude_max_enabled" field.
func SimulateClaudeMaxEnabledEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
}
// SimulateClaudeMaxEnabledNEQ applies the NEQ predicate on the "simulate_claude_max_enabled" field.
func SimulateClaudeMaxEnabledNEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldNEQ(FieldSimulateClaudeMaxEnabled, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) {

View File

@@ -452,20 +452,6 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
return _c
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (_c *GroupCreate) SetSimulateClaudeMaxEnabled(v bool) *GroupCreate {
_c.mutation.SetSimulateClaudeMaxEnabled(v)
return _c
}
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
func (_c *GroupCreate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupCreate {
if v != nil {
_c.SetSimulateClaudeMaxEnabled(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -663,10 +649,6 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultDefaultMappedModel
_c.mutation.SetDefaultMappedModel(v)
}
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
v := group.DefaultSimulateClaudeMaxEnabled
_c.mutation.SetSimulateClaudeMaxEnabled(v)
}
return nil
}
@@ -748,9 +730,6 @@ func (_c *GroupCreate) check() error {
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
}
}
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
return &ValidationError{Name: "simulate_claude_max_enabled", err: errors.New(`ent: missing required field "Group.simulate_claude_max_enabled"`)}
}
return nil
}
@@ -906,10 +885,6 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
_node.DefaultMappedModel = value
}
if value, ok := _c.mutation.SimulateClaudeMaxEnabled(); ok {
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
_node.SimulateClaudeMaxEnabled = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1624,18 +1599,6 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
return u
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (u *GroupUpsert) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsert {
u.Set(group.FieldSimulateClaudeMaxEnabled, v)
return u
}
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
func (u *GroupUpsert) UpdateSimulateClaudeMaxEnabled() *GroupUpsert {
u.SetExcluded(group.FieldSimulateClaudeMaxEnabled)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -2332,20 +2295,6 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
})
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (u *GroupUpsertOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.SetSimulateClaudeMaxEnabled(v)
})
}
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
func (u *GroupUpsertOne) UpdateSimulateClaudeMaxEnabled() *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
s.UpdateSimulateClaudeMaxEnabled()
})
}
// Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -3208,20 +3157,6 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
})
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (u *GroupUpsertBulk) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.SetSimulateClaudeMaxEnabled(v)
})
}
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
func (u *GroupUpsertBulk) UpdateSimulateClaudeMaxEnabled() *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
s.UpdateSimulateClaudeMaxEnabled()
})
}
// Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -653,20 +653,6 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
return _u
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (_u *GroupUpdate) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdate {
_u.mutation.SetSimulateClaudeMaxEnabled(v)
return _u
}
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
func (_u *GroupUpdate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdate {
if v != nil {
_u.SetSimulateClaudeMaxEnabled(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1163,9 +1149,6 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.DefaultMappedModel(); ok {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
}
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -2098,20 +2081,6 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO
return _u
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (_u *GroupUpdateOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdateOne {
_u.mutation.SetSimulateClaudeMaxEnabled(v)
return _u
}
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
func (_u *GroupUpdateOne) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdateOne {
if v != nil {
_u.SetSimulateClaudeMaxEnabled(*v)
}
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -2638,9 +2607,6 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if value, ok := _u.mutation.DefaultMappedModel(); ok {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
}
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,

View File

@@ -410,7 +410,6 @@ var (
{Name: "sort_order", Type: field.TypeInt, Default: 0},
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
{Name: "simulate_claude_max_enabled", Type: field.TypeBool, Default: false},
}
// GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{

View File

@@ -8252,7 +8252,6 @@ type GroupMutation struct {
addsort_order *int
allow_messages_dispatch *bool
default_mapped_model *string
simulate_claude_max_enabled *bool
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -10069,42 +10068,6 @@ func (m *GroupMutation) ResetDefaultMappedModel() {
m.default_mapped_model = nil
}
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
func (m *GroupMutation) SetSimulateClaudeMaxEnabled(b bool) {
m.simulate_claude_max_enabled = &b
}
// SimulateClaudeMaxEnabled returns the value of the "simulate_claude_max_enabled" field in the mutation.
func (m *GroupMutation) SimulateClaudeMaxEnabled() (r bool, exists bool) {
v := m.simulate_claude_max_enabled
if v == nil {
return
}
return *v, true
}
// OldSimulateClaudeMaxEnabled returns the old "simulate_claude_max_enabled" field's value of the Group entity.
// If the Group object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *GroupMutation) OldSimulateClaudeMaxEnabled(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldSimulateClaudeMaxEnabled is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldSimulateClaudeMaxEnabled requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldSimulateClaudeMaxEnabled: %w", err)
}
return oldValue.SimulateClaudeMaxEnabled, nil
}
// ResetSimulateClaudeMaxEnabled resets all changes to the "simulate_claude_max_enabled" field.
func (m *GroupMutation) ResetSimulateClaudeMaxEnabled() {
m.simulate_claude_max_enabled = nil
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -10463,7 +10426,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
fields := make([]string, 0, 33)
fields := make([]string, 0, 32)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}
@@ -10560,9 +10523,6 @@ func (m *GroupMutation) Fields() []string {
if m.default_mapped_model != nil {
fields = append(fields, group.FieldDefaultMappedModel)
}
if m.simulate_claude_max_enabled != nil {
fields = append(fields, group.FieldSimulateClaudeMaxEnabled)
}
return fields
}
@@ -10635,8 +10595,6 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.AllowMessagesDispatch()
case group.FieldDefaultMappedModel:
return m.DefaultMappedModel()
case group.FieldSimulateClaudeMaxEnabled:
return m.SimulateClaudeMaxEnabled()
}
return nil, false
}
@@ -10710,8 +10668,6 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldAllowMessagesDispatch(ctx)
case group.FieldDefaultMappedModel:
return m.OldDefaultMappedModel(ctx)
case group.FieldSimulateClaudeMaxEnabled:
return m.OldSimulateClaudeMaxEnabled(ctx)
}
return nil, fmt.Errorf("unknown Group field %s", name)
}
@@ -10945,13 +10901,6 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetDefaultMappedModel(v)
return nil
case group.FieldSimulateClaudeMaxEnabled:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetSimulateClaudeMaxEnabled(v)
return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
@@ -11385,9 +11334,6 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldDefaultMappedModel:
m.ResetDefaultMappedModel()
return nil
case group.FieldSimulateClaudeMaxEnabled:
m.ResetSimulateClaudeMaxEnabled()
return nil
}
return fmt.Errorf("unknown Group field %s", name)
}

View File

@@ -463,10 +463,6 @@ func init() {
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
// groupDescSimulateClaudeMaxEnabled is the schema descriptor for simulate_claude_max_enabled field.
groupDescSimulateClaudeMaxEnabled := groupFields[29].Descriptor()
// group.DefaultSimulateClaudeMaxEnabled holds the default value on creation for the simulate_claude_max_enabled field.
group.DefaultSimulateClaudeMaxEnabled = groupDescSimulateClaudeMaxEnabled.Default.(bool)
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
_ = idempotencyrecordMixinFields0

View File

@@ -33,6 +33,8 @@ func (Group) Mixin() []ent.Mixin {
func (Group) Fields() []ent.Field {
return []ent.Field{
// 唯一约束通过部分索引实现WHERE deleted_at IS NULL支持软删除后重用
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
field.String("name").
MaxLen(100).
NotEmpty(),
@@ -49,6 +51,7 @@ func (Group) Fields() []ent.Field {
MaxLen(20).
Default(domain.StatusActive),
// Subscription-related fields (added by migration 003)
field.String("platform").
MaxLen(50).
Default(domain.PlatformAnthropic),
@@ -70,6 +73,7 @@ func (Group) Fields() []ent.Field {
field.Int("default_validity_days").
Default(30),
// 图片生成计费配置antigravity 和 gemini 平台使用)
field.Float("image_price_1k").
Optional().
Nillable().
@@ -83,6 +87,7 @@ func (Group) Fields() []ent.Field {
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
// Sora 按次计费配置(阶段 1
field.Float("sora_image_price_360").
Optional().
Nillable().
@@ -104,38 +109,45 @@ func (Group) Fields() []ent.Field {
field.Int64("sora_storage_quota_bytes").
Default(0),
// Claude Code 客户端限制 (added by migration 029)
field.Bool("claude_code_only").
Default(false).
Comment("allow Claude Code client only"),
Comment("是否仅允许 Claude Code 客户端"),
field.Int64("fallback_group_id").
Optional().
Nillable().
Comment("fallback group for non-Claude-Code requests"),
Comment("Claude Code 请求降级使用的分组 ID"),
field.Int64("fallback_group_id_on_invalid_request").
Optional().
Nillable().
Comment("fallback group for invalid request"),
Comment("无效请求兜底使用的分组 ID"),
// 模型路由配置 (added by migration 040)
field.JSON("model_routing", map[string][]int64{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("model routing config: pattern -> account ids"),
Comment("模型路由配置:模型模式 -> 优先账号ID列表"),
// 模型路由开关 (added by migration 041)
field.Bool("model_routing_enabled").
Default(false).
Comment("whether model routing is enabled"),
Comment("是否启用模型路由配置"),
// MCP XML 协议注入开关 (added by migration 042)
field.Bool("mcp_xml_inject").
Default(true).
Comment("whether MCP XML prompt injection is enabled"),
Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"),
// 支持的模型系列 (added by migration 046)
field.JSON("supported_model_scopes", []string{}).
Default([]string{"claude", "gemini_text", "gemini_image"}).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("supported model scopes: claude, gemini_text, gemini_image"),
Comment("支持的模型系列:claude, gemini_text, gemini_image"),
// 分组排序 (added by migration 052)
field.Int("sort_order").
Default(0).
Comment("group display order, lower comes first"),
Comment("分组显示排序,数值越小越靠前"),
// OpenAI Messages 调度配置 (added by migration 069)
field.Bool("allow_messages_dispatch").
@@ -145,9 +157,6 @@ func (Group) Fields() []ent.Field {
MaxLen(100).
Default("").
Comment("默认映射模型 ID当账号级映射找不到时使用此值"),
field.Bool("simulate_claude_max_enabled").
Default(false).
Comment("simulate claude usage as claude-max style (1h cache write)"),
}
}
@@ -163,11 +172,14 @@ func (Group) Edges() []ent.Edge {
edge.From("allowed_users", User.Type).
Ref("allowed_groups").
Through("user_allowed_groups", UserAllowedGroup.Type),
// 注意fallback_group_id 直接作为字段使用,不定义 edge
// 这样允许多个分组指向同一个降级分组M2O 关系)
}
}
func (Group) Indexes() []ent.Index {
return []ent.Index{
// name 字段已在 Fields() 中声明 Unique(),无需重复索引
index.Fields("status"),
index.Fields("platform"),
index.Fields("subscription_type"),

View File

@@ -7,7 +7,7 @@ require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/DouDOU-start/go-sora2api v1.1.0
github.com/alitto/pond/v2 v2.6.2
github.com/aws/aws-sdk-go-v2 v1.41.2
github.com/aws/aws-sdk-go-v2 v1.41.3
github.com/aws/aws-sdk-go-v2/config v1.32.10
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
@@ -66,7 +66,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
github.com/aws/smithy-go v1.24.1 // indirect
github.com/aws/smithy-go v1.24.2 // indirect
github.com/bdandy/go-errors v1.2.2 // indirect
github.com/bdandy/go-socks4 v1.2.3 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
@@ -87,7 +87,6 @@ require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/docker/docker v28.5.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
@@ -138,8 +137,6 @@ require (
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
github.com/pkoukk/tiktoken-go-loader v0.0.2 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/quic-go/qpack v0.6.0 // indirect

View File

@@ -24,6 +24,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
@@ -60,6 +62,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
@@ -124,8 +128,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
@@ -182,7 +184,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
@@ -202,8 +203,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -286,10 +285,6 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -342,8 +337,6 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@@ -352,6 +345,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
@@ -441,11 +436,11 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=

View File

@@ -934,9 +934,10 @@ type DashboardAggregationConfig struct {
// DashboardAggregationRetentionConfig 预聚合保留窗口
type DashboardAggregationRetentionConfig struct {
UsageLogsDays int `mapstructure:"usage_logs_days"`
HourlyDays int `mapstructure:"hourly_days"`
DailyDays int `mapstructure:"daily_days"`
UsageLogsDays int `mapstructure:"usage_logs_days"`
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
HourlyDays int `mapstructure:"hourly_days"`
DailyDays int `mapstructure:"daily_days"`
}
// UsageCleanupConfig 使用记录清理任务配置
@@ -1301,6 +1302,7 @@ func setDefaults() {
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
@@ -1758,6 +1760,12 @@ func (c *Config) Validate() error {
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
}
if c.DashboardAgg.Retention.HourlyDays <= 0 {
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
}
@@ -1780,6 +1788,14 @@ func (c *Config) Validate() error {
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
}
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
}
if c.DashboardAgg.Retention.HourlyDays < 0 {
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
}

View File

@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
}
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
}
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
}
@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
wantErr: "dashboard_aggregation.retention.usage_logs_days",
},
{
name: "dashboard aggregation dedup retention",
mutate: func(c *Config) {
c.DashboardAgg.Enabled = true
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
},
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
},
{
name: "dashboard aggregation dedup retention smaller than usage logs",
mutate: func(c *Config) {
c.DashboardAgg.Enabled = true
c.DashboardAgg.Retention.UsageLogsDays = 30
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
},
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
},
{
name: "dashboard aggregation disabled interval",
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },

View File

@@ -27,10 +27,12 @@ const (
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock
AccountTypeBedrockAPIKey = "bedrock-apikey" // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock
)
// Redeem type constants
@@ -113,3 +115,27 @@ var DefaultAntigravityModelMapping = map[string]string{
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview",
}
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
// 注意:此处的 "us." 前缀仅为默认值ResolveBedrockModelID 会根据账号配置的
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
var DefaultBedrockModelMapping = map[string]string{
// Claude Opus
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
"claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0",
"claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0",
"claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0",
// Claude Sonnet
"claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6",
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
"claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0",
// Claude Haiku
"claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
"claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
}

View File

@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
type UpdateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -865,6 +865,9 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv
}
}
// OpenAI OAuth: 刷新成功后检查并设置 privacy_mode
h.adminService.EnsureOpenAIPrivacy(ctx, updatedAccount)
return updatedAccount, "", nil
}
@@ -1338,12 +1341,6 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
})
return
}
@@ -1721,13 +1718,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
// Handle OpenAI accounts
if account.IsOpenAI() {
// For OAuth accounts: return default OpenAI models
if account.IsOAuth() {
// OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。
if account.IsOpenAIPassthroughEnabled() {
response.Success(c, openai.DefaultModels)
return
}
// For API Key accounts: check model_mapping
mapping := account.GetModelMapping()
if len(mapping) == 0 {
response.Success(c, openai.DefaultModels)

View File

@@ -0,0 +1,105 @@
package admin
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type availableModelsAdminService struct {
*stubAdminService
account service.Account
}
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
if s.account.ID == id {
acc := s.account
return &acc, nil
}
return s.stubAdminService.GetAccount(context.Background(), id)
}
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
return router
}
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 42,
Name: "openai-oauth",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.1",
},
},
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Len(t, resp.Data, 1)
require.Equal(t, "gpt-5", resp.Data[0].ID)
}
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 43,
Name: "openai-oauth-passthrough",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.1",
},
},
Extra: map[string]any{
"openai_passthrough": true,
},
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp.Data)
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
}

View File

@@ -111,7 +111,7 @@ func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "claude-max")
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
@@ -140,7 +140,7 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "claude-max")
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)

View File

@@ -179,6 +179,14 @@ func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) (
return nil, nil
}
func (s *stubAdminService) ClearGroupRateMultipliers(_ context.Context, _ int64) error {
return nil
}
func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int64, _ []service.GroupRateMultiplierInput) error {
return nil
}
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil
}
@@ -433,5 +441,9 @@ func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) erro
return nil
}
func (s *stubAdminService) EnsureOpenAIPrivacy(ctx context.Context, account *service.Account) string {
return ""
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -466,9 +466,60 @@ type BatchUsersUsageRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
}
var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
func parseRankingLimit(raw string) int {
limit, err := strconv.Atoi(strings.TrimSpace(raw))
if err != nil || limit <= 0 {
return 12
}
if limit > 50 {
return 50
}
return limit
}
// GetUserSpendingRanking handles getting user spending ranking data.
// GET /api/v1/admin/dashboard/users-ranking
func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
limit := parseRankingLimit(c.DefaultQuery("limit", "12"))
keyRaw, _ := json.Marshal(struct {
Start string `json:"start"`
End string `json:"end"`
Limit int `json:"limit"`
}{
Start: startTime.UTC().Format(time.RFC3339),
End: endTime.UTC().Format(time.RFC3339),
Limit: limit,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit)
if err != nil {
response.Error(c, 500, "Failed to get user spending ranking")
return
}
payload := gin.H{
"ranking": ranking.Ranking,
"total_actual_cost": ranking.TotalActualCost,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
}
dashboardUsersRankingCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
// GetBatchUsersUsage handles getting usage stats for multiple users
// POST /api/v1/admin/dashboard/users-usage
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {

View File

@@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct {
trendStream *bool
modelRequestType *int16
modelStream *bool
rankingLimit int
ranking []usagestats.UserSpendingRankingItem
rankingTotal float64
}
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
@@ -49,6 +52,18 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
return []usagestats.ModelStat{}, nil
}
func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
ctx context.Context,
startTime, endTime time.Time,
limit int,
) (*usagestats.UserSpendingRankingResponse, error) {
s.rankingLimit = limit
return &usagestats.UserSpendingRankingResponse{
Ranking: s.ranking,
TotalActualCost: s.rankingTotal,
}, nil
}
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
@@ -56,6 +71,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng
router := gin.New()
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
router.GET("/admin/dashboard/models", handler.GetModelStats)
router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking)
return router
}
@@ -130,3 +146,30 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
repo := &dashboardUsageRepoCapture{
ranking: []usagestats.UserSpendingRankingItem{
{UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300},
},
rankingTotal: 88.8,
}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 50, repo.rankingLimit)
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
require.Equal(t, http.StatusOK, rec2.Code)
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
}

View File

@@ -46,10 +46,9 @@ type CreateGroupRequest struct {
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
// Sora 存储配额
@@ -85,10 +84,9 @@ type UpdateGroupRequest struct {
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"`
// Sora 存储配额
@@ -209,7 +207,6 @@ func (h *GroupHandler) Create(c *gin.Context) {
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
AllowMessagesDispatch: req.AllowMessagesDispatch,
@@ -263,7 +260,6 @@ func (h *GroupHandler) Update(c *gin.Context) {
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
AllowMessagesDispatch: req.AllowMessagesDispatch,
@@ -360,6 +356,51 @@ func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
response.Success(c, entries)
}
// ClearGroupRateMultipliers handles clearing all rate multipliers for a group
// DELETE /api/v1/admin/groups/:id/rate-multipliers
func (h *GroupHandler) ClearGroupRateMultipliers(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
if err := h.adminService.ClearGroupRateMultipliers(c.Request.Context(), groupID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Rate multipliers cleared successfully"})
}
// BatchSetGroupRateMultipliersRequest represents batch set rate multipliers request
type BatchSetGroupRateMultipliersRequest struct {
Entries []service.GroupRateMultiplierInput `json:"entries" binding:"required"`
}
// BatchSetGroupRateMultipliers handles batch setting rate multipliers for a group
// PUT /api/v1/admin/groups/:id/rate-multipliers
func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
var req BatchSetGroupRateMultipliersRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.adminService.BatchSetGroupRateMultipliers(c.Request.Context(), groupID, req.Entries); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
}
// UpdateSortOrderRequest represents the request to update group sort orders
type UpdateSortOrderRequest struct {
Updates []struct {

View File

@@ -289,6 +289,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
Platform: platform,
Type: "oauth",
Credentials: credentials,
Extra: nil,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,

View File

@@ -41,12 +41,15 @@ type GenerateRedeemCodesRequest struct {
}
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance
type CreateAndRedeemCodeRequest struct {
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value" binding:"required,gt=0"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
Notes string `json:"notes"`
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance向后兼容
Value float64 `json:"value" binding:"required,gt=0"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
GroupID *int64 `json:"group_id"` // subscription 类型必填
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0
Notes string `json:"notes"`
}
// List handles listing all redeem codes with pagination
@@ -136,6 +139,22 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
return
}
req.Code = strings.TrimSpace(req.Code)
// 向后兼容:旧版调用方(如 Sub2ApiPay不传 type 字段,默认当作 balance 充值处理。
// 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。
if req.Type == "" {
req.Type = "balance"
}
if req.Type == "subscription" {
if req.GroupID == nil {
response.BadRequest(c, "group_id is required for subscription type")
return
}
if req.ValidityDays <= 0 {
response.BadRequest(c, "validity_days must be greater than 0 for subscription type")
return
}
}
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
existing, err := h.redeemService.GetByCode(ctx, req.Code)
@@ -147,11 +166,13 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
}
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
Code: req.Code,
Type: req.Type,
Value: req.Value,
Status: service.StatusUnused,
Notes: req.Notes,
Code: req.Code,
Type: req.Type,
Value: req.Value,
Status: service.StatusUnused,
Notes: req.Notes,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
})
if createErr != nil {
// Unique code race: if code now exists, use idempotent semantics by used_by.

View File

@@ -0,0 +1,135 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal)
// RedeemService so that CreateAndRedeem's nil guard passes and we can test the
// parameter-validation layer that runs before any service call.
func newCreateAndRedeemHandler() *RedeemHandler {
return &RedeemHandler{
adminService: newStubAdminService(),
redeemService: &service.RedeemService{}, // non-nil to pass nil guard
}
}
// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response
// status code. For cases that pass validation and proceed into the service layer,
// a panic may occur (because RedeemService internals are nil); this is expected
// and treated as "validation passed" (returns 0 to indicate panic).
func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
jsonBytes, err := json.Marshal(body)
require.NoError(t, err)
c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes))
c.Request.Header.Set("Content-Type", "application/json")
defer func() {
if r := recover(); r != nil {
// Panic means we passed validation and entered service layer (expected for minimal stub).
code = 0
}
}()
handler.CreateAndRedeem(c)
return w.Code
}
func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) {
// 不传 type 字段时应默认 balance不触发 subscription 校验。
// 验证通过后进入 service 层会 panic返回 0说明默认值生效。
h := newCreateAndRedeemHandler()
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-balance-default",
"value": 10.0,
"user_id": 1,
})
assert.NotEqual(t, http.StatusBadRequest, code,
"omitting type should default to balance and pass validation")
}
func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
h := newCreateAndRedeemHandler()
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-sub-no-group",
"type": "subscription",
"value": 29.9,
"user_id": 1,
"validity_days": 30,
// group_id 缺失
})
assert.Equal(t, http.StatusBadRequest, code)
}
func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) {
groupID := int64(5)
h := newCreateAndRedeemHandler()
cases := []struct {
name string
validityDays int
}{
{"zero", 0},
{"negative", -1},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-sub-bad-days-" + tc.name,
"type": "subscription",
"value": 29.9,
"user_id": 1,
"group_id": groupID,
"validity_days": tc.validityDays,
})
assert.Equal(t, http.StatusBadRequest, code)
})
}
}
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
groupID := int64(5)
h := newCreateAndRedeemHandler()
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-sub-valid",
"type": "subscription",
"value": 29.9,
"user_id": 1,
"group_id": groupID,
"validity_days": 31,
})
assert.NotEqual(t, http.StatusBadRequest, code,
"valid subscription params should pass validation")
}
func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
h := newCreateAndRedeemHandler()
// balance 类型不传 group_id 和 validity_days不应报 400
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-balance-no-extras",
"type": "balance",
"value": 50.0,
"user_id": 1,
})
assert.NotEqual(t, http.StatusBadRequest, code,
"balance type should not require group_id or validity_days")
}

View File

@@ -218,11 +218,12 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
// ResetSubscriptionQuotaRequest represents the reset quota request
type ResetSubscriptionQuotaRequest struct {
Daily bool `json:"daily"`
Weekly bool `json:"weekly"`
Daily bool `json:"daily"`
Weekly bool `json:"weekly"`
Monthly bool `json:"monthly"`
}
// ResetQuota resets daily and/or weekly usage for a subscription.
// ResetQuota resets daily, weekly, and/or monthly usage for a subscription.
// POST /api/v1/admin/subscriptions/:id/reset-quota
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
@@ -235,11 +236,11 @@ func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !req.Daily && !req.Weekly {
response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true")
if !req.Daily && !req.Weekly && !req.Monthly {
response.BadRequest(c, "At least one of 'daily', 'weekly', or 'monthly' must be true")
return
}
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly)
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly, req.Monthly)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -135,15 +135,14 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
return nil
}
out := &AdminGroup{
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel,
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
SortOrder: g.SortOrder,
Group: groupFromServiceBase(g),
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.MCPXMLInject,
DefaultMappedModel: g.DefaultMappedModel,
SupportedModelScopes: g.SupportedModelScopes,
AccountCount: g.AccountCount,
SortOrder: g.SortOrder,
}
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))

View File

@@ -117,8 +117,6 @@ type AdminGroup struct {
// MCP XML 协议注入(仅 antigravity 平台使用)
MCPXMLInject bool `json:"mcp_xml_inject"`
// Claude usage 模拟开关(仅管理员可见)
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
DefaultMappedModel string `json:"default_mapped_model"`

View File

@@ -434,20 +434,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ParsedRequest: parsedReq,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),
@@ -631,7 +632,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// ===== 用户消息串行队列 END =====
// 转发请求 - 根据账号平台分流
c.Set("parsed_request", parsedReq)
var result *service.ForwardResult
requestCtx := c.Request.Context()
if fs.SwitchCount > 0 {
@@ -738,20 +738,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ParsedRequest: parsedReq,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: account,
Subscription: currentSubscription,
UserAgent: userAgent,
IPAddress: clientIP,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: account,
Subscription: currentSubscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),

View File

@@ -139,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // accountRepo (not used: scheduler snapshot hit)
&fakeGroupRepo{group: group},
nil, // usageLogRepo
nil, // usageBillingRepo
nil, // userRepo
nil, // userSubRepo
nil, // userGroupRateRepo

View File

@@ -503,6 +503,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
requestPayloadHash := service.HashUsageRequestPayload(body)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
@@ -512,6 +513,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fs.ForceCacheBilling,

View File

@@ -181,13 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
defaultMappedModel := ""
if apiKey.Group != nil {
defaultMappedModel = apiKey.Group.DefaultMappedModel
}
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
defaultMappedModel = fallbackModel
}
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds()

View File

@@ -352,18 +352,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
@@ -653,14 +655,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
defaultMappedModel := ""
if apiKey.Group != nil {
defaultMappedModel = apiKey.Group.DefaultMappedModel
}
// 如果使用了降级模型调度,强制使用降级模型
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
defaultMappedModel = fallbackModel
}
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时,
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。
defaultMappedModel := c.GetString("openai_messages_fallback_model")
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
@@ -732,17 +729,19 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.messages"),
@@ -1231,14 +1230,15 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
h.submitUsageRecordTask(func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),

View File

@@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService用于测试 SelectAccountForModel
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil,
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
)
}

View File

@@ -399,17 +399,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
}); err != nil {
logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"),

View File

@@ -343,6 +343,9 @@ func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, e
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, nil
}
@@ -431,6 +434,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
nil,
nil,
nil,
nil,
testutil.StubGatewayCache{},
cfg,
nil,

View File

@@ -189,6 +189,5 @@ var DefaultStopSequences = []string{
"<|user|>",
"<|endoftext|>",
"<|end_of_turn|>",
"[DONE]",
"\n\nHuman:",
}

View File

@@ -18,9 +18,6 @@ const (
BlockTypeFunction
)
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
type UsageMapHook func(usageMap map[string]any)
// StreamingProcessor 流式响应处理器
type StreamingProcessor struct {
blockType BlockType
@@ -33,7 +30,6 @@ type StreamingProcessor struct {
originalModel string
webSearchQueries []string
groundingChunks []GeminiGroundingChunk
usageMapHook UsageMapHook
// 累计 usage
inputTokens int
@@ -49,25 +45,6 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
}
}
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
p.usageMapHook = fn
}
func usageToMap(u ClaudeUsage) map[string]any {
m := map[string]any{
"input_tokens": u.InputTokens,
"output_tokens": u.OutputTokens,
}
if u.CacheCreationInputTokens > 0 {
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
}
if u.CacheReadInputTokens > 0 {
m["cache_read_input_tokens"] = u.CacheReadInputTokens
}
return m
}
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
func (p *StreamingProcessor) ProcessLine(line string) []byte {
line = strings.TrimSpace(line)
@@ -191,13 +168,6 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
responseID = "msg_" + generateRandomID()
}
var usageValue any = usage
if p.usageMapHook != nil {
usageMap := usageToMap(usage)
p.usageMapHook(usageMap)
usageValue = usageMap
}
message := map[string]any{
"id": responseID,
"type": "message",
@@ -206,7 +176,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
"model": p.originalModel,
"stop_reason": nil,
"stop_sequence": nil,
"usage": usageValue,
"usage": usage,
}
event := map[string]any{
@@ -517,20 +487,13 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
CacheReadInputTokens: p.cacheReadTokens,
}
var usageValue any = usage
if p.usageMapHook != nil {
usageMap := usageToMap(usage)
p.usageMapHook(usageMap)
usageValue = usageMap
}
deltaEvent := map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": stopReason,
"stop_sequence": nil,
},
"usage": usageValue,
"usage": usage,
}
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))

View File

@@ -105,6 +105,7 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
assert.Equal(t, "assistant", items[1].Role)
assert.Equal(t, "function_call", items[2].Type)
assert.Equal(t, "fc_call_1", items[2].CallID)
assert.Empty(t, items[2].ID)
assert.Equal(t, "function_call_output", items[3].Type)
assert.Equal(t, "fc_call_1", items[3].CallID)
assert.Equal(t, "Sunny, 72°F", items[3].Output)

View File

@@ -277,7 +277,6 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
CallID: fcID,
Name: b.Name,
Arguments: args,
ID: fcID,
})
}

View File

@@ -99,6 +99,7 @@ func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
// Check function_call item
assert.Equal(t, "function_call", items[1].Type)
assert.Equal(t, "call_1", items[1].CallID)
assert.Empty(t, items[1].ID)
assert.Equal(t, "ping", items[1].Name)
// Check function_call_output item
@@ -252,6 +253,55 @@ func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T)
assert.Equal(t, "user", items[0].Role)
assert.Equal(t, "assistant", items[1].Role)
assert.Equal(t, "function_call", items[2].Type)
assert.Empty(t, items[2].ID)
}
func TestChatCompletionsToResponses_AssistantArrayContentPreserved(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Hi"`)},
{Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"A"},{"type":"text","text":"B"}]`)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 2)
assert.Equal(t, "assistant", items[1].Role)
var parts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
require.Len(t, parts, 1)
assert.Equal(t, "output_text", parts[0].Type)
assert.Equal(t, "AB", parts[0].Text)
}
func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Hi"`)},
{Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"internal plan"},{"type":"text","text":"final answer"}]`)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 2)
var parts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
require.Len(t, parts, 1)
assert.Equal(t, "output_text", parts[0].Type)
assert.Contains(t, parts[0].Text, "<thinking>internal plan</thinking>")
assert.Contains(t, parts[0].Text, "final answer")
}
// ---------------------------------------------------------------------------
@@ -344,8 +394,8 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
var content string
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
// Reasoning summary is prepended to text
assert.Equal(t, "I thought about it.The answer is 42.", content)
assert.Equal(t, "The answer is 42.", content)
assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
}
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
@@ -582,8 +632,35 @@ func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
Delta: "Thinking...",
}, state)
require.Len(t, chunks, 1)
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.ReasoningContent)
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.reasoning_summary_text.done",
}, state)
require.Len(t, chunks, 0)
}
func TestResponsesEventToChatChunks_ReasoningThenTextAutoCloseTag(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.SentRole = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.reasoning_summary_text.delta",
Delta: "plan",
}, state)
require.Len(t, chunks, 1)
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
assert.Equal(t, "plan", *chunks[0].Choices[0].Delta.ReasoningContent)
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.output_text.delta",
Delta: "answer",
}, state)
require.Len(t, chunks, 1)
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content)
assert.Equal(t, "answer", *chunks[0].Choices[0].Delta.Content)
}
func TestFinalizeResponsesChatStream(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package apicompat
import (
"encoding/json"
"fmt"
"strings"
)
// ChatCompletionsToResponses converts a Chat Completions request into a
@@ -174,8 +175,11 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
// Emit assistant message with output_text if content is non-empty.
if len(m.Content) > 0 {
var s string
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
s, err := parseAssistantContent(m.Content)
if err != nil {
return nil, err
}
if s != "" {
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
partsJSON, err := json.Marshal(parts)
if err != nil {
@@ -196,13 +200,82 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
CallID: tc.ID,
Name: tc.Function.Name,
Arguments: args,
ID: tc.ID,
})
}
return items, nil
}
// parseAssistantContent returns assistant content as plain text.
//
// Supported formats:
// - JSON string
// - JSON array of typed parts (e.g. [{"type":"text","text":"..."}])
//
// For structured thinking/reasoning parts, it preserves semantics by wrapping
// the text in explicit tags so downstream can still distinguish it from normal text.
func parseAssistantContent(raw json.RawMessage) (string, error) {
if len(raw) == 0 {
return "", nil
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return s, nil
}
var parts []map[string]any
if err := json.Unmarshal(raw, &parts); err != nil {
// Keep compatibility with prior behavior: unsupported assistant content
// formats are ignored instead of failing the whole request conversion.
return "", nil
}
var b strings.Builder
write := func(v string) error {
_, err := b.WriteString(v)
return err
}
for _, p := range parts {
typ, _ := p["type"].(string)
text, _ := p["text"].(string)
thinking, _ := p["thinking"].(string)
switch typ {
case "thinking", "reasoning":
if thinking != "" {
if err := write("<thinking>"); err != nil {
return "", err
}
if err := write(thinking); err != nil {
return "", err
}
if err := write("</thinking>"); err != nil {
return "", err
}
} else if text != "" {
if err := write("<thinking>"); err != nil {
return "", err
}
if err := write(text); err != nil {
return "", err
}
if err := write("</thinking>"); err != nil {
return "", err
}
}
default:
if text != "" {
if err := write(text); err != nil {
return "", err
}
}
}
}
return b.String(), nil
}
// chatToolToResponses converts a tool result message (role=tool) into a
// function_call_output item.
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {

View File

@@ -29,6 +29,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
}
var contentText string
var reasoningText string
var toolCalls []ChatToolCall
for _, item := range resp.Output {
@@ -51,7 +52,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
case "reasoning":
for _, s := range item.Summary {
if s.Type == "summary_text" && s.Text != "" {
contentText += s.Text
reasoningText += s.Text
}
}
case "web_search_call":
@@ -67,6 +68,9 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
raw, _ := json.Marshal(contentText)
msg.Content = raw
}
if reasoningText != "" {
msg.ReasoningContent = reasoningText
}
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
@@ -153,6 +157,8 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
return resToChatHandleFuncArgsDelta(evt, state)
case "response.reasoning_summary_text.delta":
return resToChatHandleReasoningDelta(evt, state)
case "response.reasoning_summary_text.done":
return nil
case "response.completed", "response.incomplete", "response.failed":
return resToChatHandleCompleted(evt, state)
default:
@@ -276,8 +282,8 @@ func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEv
if evt.Delta == "" {
return nil
}
content := evt.Delta
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
reasoning := evt.Delta
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ReasoningContent: &reasoning})}
}
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {

View File

@@ -361,11 +361,12 @@ type ChatStreamOptions struct {
// ChatMessage is a single message in the Chat Completions conversation.
type ChatMessage struct {
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
Content json.RawMessage `json:"content,omitempty"`
Name string `json:"name,omitempty"`
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
Content json.RawMessage `json:"content,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
Name string `json:"name,omitempty"`
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
// Legacy function calling
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
@@ -466,9 +467,10 @@ type ChatChunkChoice struct {
// ChatDelta carries incremental content in a streaming chunk.
type ChatDelta struct {
Role string `json:"role,omitempty"`
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
Role string `json:"role,omitempty"`
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
ReasoningContent *string `json:"reasoning_content,omitempty"`
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
}
// ---------------------------------------------------------------------------

View File

@@ -96,12 +96,28 @@ type UserUsageTrendPoint struct {
Date string `json:"date"`
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserSpendingRankingItem represents a user spending ranking row.
type UserSpendingRankingItem struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
ActualCost float64 `json:"actual_cost"` // 实际扣除
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}
// UserSpendingRankingResponse represents ranking rows plus total spend for the time range.
type UserSpendingRankingResponse struct {
Ranking []UserSpendingRankingItem `json:"ranking"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// APIKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint struct {
Date string `json:"date"`

View File

@@ -397,9 +397,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
}
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
r.syncSchedulerAccountSnapshot(ctx, account.ID)
}
// 普通账号编辑(如 model_mapping / credentials也需要立即刷新单账号快照
// 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。
r.syncSchedulerAccountSnapshot(ctx, account.ID)
return nil
}

View File

@@ -142,6 +142,35 @@ func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
}
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnCredentialsChange() {
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "sync-credentials-update",
Status: service.StatusActive,
Schedulable: true,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.1",
},
},
})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
account.Credentials = map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.2",
},
}
err := s.repo.Update(s.ctx, account)
s.Require().NoError(err, "Update")
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
mapping, ok := cacheRecorder.setAccounts[0].Credentials["model_mapping"].(map[string]any)
s.Require().True(ok)
s.Require().Equal("gpt-5.2", mapping["gpt-5"])
}
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})

View File

@@ -164,7 +164,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldModelRoutingEnabled,
group.FieldModelRouting,
group.FieldMcpXMLInject,
group.FieldSimulateClaudeMaxEnabled,
group.FieldSupportedModelScopes,
group.FieldAllowMessagesDispatch,
group.FieldDefaultMappedModel,
@@ -646,7 +645,6 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
MCPXMLInject: g.McpXMLInject,
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
SupportedModelScopes: g.SupportedModelScopes,
SortOrder: g.SortOrder,
AllowMessagesDispatch: g.AllowMessagesDispatch,

View File

@@ -17,6 +17,9 @@ type dashboardAggregationRepository struct {
sql sqlExecutor
}
const usageLogsCleanupBatchSize = 10000
const usageBillingDedupCleanupBatchSize = 10000
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
if sqlDB == nil {
@@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool {
}
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
if r == nil || r.sql == nil {
return nil
}
loc := timezone.Location()
startLocal := start.In(loc)
endLocal := end.In(loc)
@@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
dayEnd = dayEnd.Add(24 * time.Hour)
}
if db, ok := r.sql.(*sql.DB); ok {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
}
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
@@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c
if isPartitioned {
return r.dropUsageLogsPartitions(ctx, cutoff)
}
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
return err
for {
res, err := r.sql.ExecContext(ctx, `
WITH victims AS (
SELECT ctid
FROM usage_logs
WHERE created_at < $1
LIMIT $2
)
DELETE FROM usage_logs
WHERE ctid IN (SELECT ctid FROM victims)
`, cutoff.UTC(), usageLogsCleanupBatchSize)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected < usageLogsCleanupBatchSize {
return nil
}
}
}
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
for {
res, err := r.sql.ExecContext(ctx, `
WITH victims AS (
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
FROM usage_billing_dedup
WHERE created_at < $1
LIMIT $2
), archived AS (
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
SELECT request_id, api_key_id, request_fingerprint, created_at
FROM victims
ON CONFLICT (request_id, api_key_id) DO NOTHING
)
DELETE FROM usage_billing_dedup
WHERE ctid IN (SELECT ctid FROM victims)
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected < usageBillingDedupCleanupBatchSize {
return nil
}
}
}
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {

View File

@@ -262,6 +262,42 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se
SetKey(k.Key).
SetName(k.Name).
SetStatus(k.Status)
if k.Quota != 0 {
create.SetQuota(k.Quota)
}
if k.QuotaUsed != 0 {
create.SetQuotaUsed(k.QuotaUsed)
}
if k.RateLimit5h != 0 {
create.SetRateLimit5h(k.RateLimit5h)
}
if k.RateLimit1d != 0 {
create.SetRateLimit1d(k.RateLimit1d)
}
if k.RateLimit7d != 0 {
create.SetRateLimit7d(k.RateLimit7d)
}
if k.Usage5h != 0 {
create.SetUsage5h(k.Usage5h)
}
if k.Usage1d != 0 {
create.SetUsage1d(k.Usage1d)
}
if k.Usage7d != 0 {
create.SetUsage7d(k.Usage7d)
}
if k.Window5hStart != nil {
create.SetWindow5hStart(*k.Window5hStart)
}
if k.Window1dStart != nil {
create.SetWindow1dStart(*k.Window1dStart)
}
if k.Window7dStart != nil {
create.SetWindow7dStart(*k.Window7dStart)
}
if k.ExpiresAt != nil {
create.SetExpiresAt(*k.ExpiresAt)
}
if k.GroupID != nil {
create.SetGroupID(*k.GroupID)
}

View File

@@ -61,8 +61,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
SetDefaultMappedModel(groupIn.DefaultMappedModel)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
@@ -130,8 +129,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
SetDefaultMappedModel(groupIn.DefaultMappedModel)
// 显式处理可空字段nil 需要 clear非 nil 需要 set。
if groupIn.DailyLimitUSD != nil {

View File

@@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
// usage_billing_dedup: billing idempotency narrow table
var usageBillingDedupRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
var usageBillingDedupArchiveRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
// settings table should exist
var settingsRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
@@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
var exists bool
err := tx.QueryRowContext(context.Background(), `
SELECT EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = $1
AND indexname = $2
)
`, table, index).Scan(&exists)
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
require.True(t, exists, "expected index %s on %s", index, table)
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper()

View File

@@ -73,3 +73,14 @@ func buildReqClientKey(opts reqClientOptions) string {
opts.ForceHTTP2,
)
}
// CreatePrivacyReqClient creates an HTTP client for OpenAI privacy settings API
// This is exported for use by OpenAIPrivacyService
// Uses Chrome TLS fingerprint impersonation to bypass Cloudflare checks
func CreatePrivacyReqClient(proxyURL string) (*req.Client, error) {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
Impersonate: true, // Enable Chrome TLS fingerprint impersonation
})
}

View File

@@ -0,0 +1,308 @@
package repository
import (
"context"
"database/sql"
"errors"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type usageBillingRepository struct {
db *sql.DB
}
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
return &usageBillingRepository{db: sqlDB}
}
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
if cmd == nil {
return &service.UsageBillingApplyResult{}, nil
}
if r == nil || r.db == nil {
return nil, errors.New("usage billing repository db is nil")
}
cmd.Normalize()
if cmd.RequestID == "" {
return nil, service.ErrUsageBillingRequestIDRequired
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
if err != nil {
return nil, err
}
if !applied {
return &service.UsageBillingApplyResult{Applied: false}, nil
}
result := &service.UsageBillingApplyResult{Applied: true}
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
tx = nil
return result, nil
}
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
var id int64
err := tx.QueryRowContext(ctx, `
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
VALUES ($1, $2, $3)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
if errors.Is(err, sql.ErrNoRows) {
var existingFingerprint string
if err := tx.QueryRowContext(ctx, `
SELECT request_fingerprint
FROM usage_billing_dedup
WHERE request_id = $1 AND api_key_id = $2
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
return false, err
}
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
return false, service.ErrUsageBillingRequestConflict
}
return false, nil
}
if err != nil {
return false, err
}
var archivedFingerprint string
err = tx.QueryRowContext(ctx, `
SELECT request_fingerprint
FROM usage_billing_dedup_archive
WHERE request_id = $1 AND api_key_id = $2
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
if err == nil {
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
return false, service.ErrUsageBillingRequestConflict
}
return false, nil
}
if !errors.Is(err, sql.ErrNoRows) {
return false, err
}
return true, nil
}
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
return err
}
}
if cmd.BalanceCost > 0 {
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
return err
}
}
if cmd.APIKeyQuotaCost > 0 {
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
if err != nil {
return err
}
result.APIKeyQuotaExhausted = exhausted
}
if cmd.APIKeyRateLimitCost > 0 {
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
return err
}
}
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
return err
}
}
return nil
}
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
const updateSQL = `
UPDATE user_subscriptions us
SET
daily_usage_usd = us.daily_usage_usd + $1,
weekly_usage_usd = us.weekly_usage_usd + $1,
monthly_usage_usd = us.monthly_usage_usd + $1,
updated_at = NOW()
FROM groups g
WHERE us.id = $2
AND us.deleted_at IS NULL
AND us.group_id = g.id
AND g.deleted_at IS NULL
`
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
return service.ErrSubscriptionNotFound
}
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
res, err := tx.ExecContext(ctx, `
UPDATE users
SET balance = balance - $1,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
`, amount, userID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
return service.ErrUserNotFound
}
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
var exhausted bool
err := tx.QueryRowContext(ctx, `
UPDATE api_keys
SET quota_used = quota_used + $1,
status = CASE
WHEN quota > 0
AND status = $3
AND quota_used < quota
AND quota_used + $1 >= quota
THEN $4
ELSE status
END,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
if errors.Is(err, sql.ErrNoRows) {
return false, service.ErrAPIKeyNotFound
}
if err != nil {
return false, err
}
return exhausted, nil
}
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
res, err := tx.ExecContext(ctx, `
UPDATE api_keys SET
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
`, cost, apiKeyID)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAPIKeyNotFound
}
return nil
}
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
rows, err := tx.QueryContext(ctx,
`UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb)
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_daily_used',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
)
ELSE '{}'::jsonb END
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_weekly_used',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
)
ELSE '{}'::jsonb END
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING
COALESCE((extra->>'quota_used')::numeric, 0),
COALESCE((extra->>'quota_limit')::numeric, 0)`,
amount, accountID)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
var newUsed, limit float64
if rows.Next() {
if err := rows.Scan(&newUsed, &limit); err != nil {
return err
}
} else {
if err := rows.Err(); err != nil {
return err
}
return service.ErrAccountNotFound
}
if err := rows.Err(); err != nil {
return err
}
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
return err
}
}
return nil
}

View File

@@ -0,0 +1,279 @@
//go:build integration
package repository
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-" + uuid.NewString(),
Name: "billing",
Quota: 1,
})
account := mustCreateAccount(t, client, &service.Account{
Name: "usage-billing-account-" + uuid.NewString(),
Type: service.AccountTypeAPIKey,
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: account.ID,
AccountType: service.AccountTypeAPIKey,
BalanceCost: 1.25,
APIKeyQuotaCost: 1.25,
APIKeyRateLimitCost: 1.25,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.NotNil(t, result1)
require.True(t, result1.Applied)
require.True(t, result1.APIKeyQuotaExhausted)
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.NotNil(t, result2)
require.False(t, result2.Applied)
var balance float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
require.InDelta(t, 98.75, balance, 0.000001)
var quotaUsed float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan(&quotaUsed))
require.InDelta(t, 1.25, quotaUsed, 0.000001)
var usage5h float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h))
require.InDelta(t, 1.25, usage5h, 0.000001)
var status string
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status))
require.Equal(t, service.StatusAPIKeyQuotaExhausted, status)
var dedupCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount))
require.Equal(t, 1, dedupCount)
}
func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
})
group := mustCreateGroup(t, client, &service.Group{
Name: "usage-billing-group-" + uuid.NewString(),
Platform: service.PlatformAnthropic,
SubscriptionType: service.SubscriptionTypeSubscription,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
GroupID: &group.ID,
Key: "sk-usage-billing-sub-" + uuid.NewString(),
Name: "billing-sub",
})
subscription := mustCreateSubscription(t, client, &service.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: 0,
SubscriptionID: &subscription.ID,
SubscriptionCost: 2.5,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.True(t, result1.Applied)
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.False(t, result2.Applied)
var dailyUsage float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage))
require.InDelta(t, 2.5, dailyUsage, 0.000001)
}
func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-conflict-" + uuid.NewString(),
Name: "billing-conflict",
})
requestID := uuid.NewString()
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 1.25,
})
require.NoError(t, err)
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 2.50,
})
require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict)
}
func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-account-" + uuid.NewString(),
Name: "billing-account",
})
account := mustCreateAccount(t, client, &service.Account{
Name: "usage-billing-account-quota-" + uuid.NewString(),
Type: service.AccountTypeAPIKey,
Extra: map[string]any{
"quota_limit": 100.0,
},
})
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
RequestID: uuid.NewString(),
APIKeyID: apiKey.ID,
UserID: user.ID,
AccountID: account.ID,
AccountType: service.AccountTypeAPIKey,
AccountQuotaCost: 3.5,
})
require.NoError(t, err)
var quotaUsed float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan(&quotaUsed))
require.InDelta(t, 3.5, quotaUsed, 0.000001)
}
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
ctx := context.Background()
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
oldRequestID := "dedup-old-" + uuid.NewString()
newRequestID := "dedup-new-" + uuid.NewString()
oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400)
newCreatedAt := time.Now().UTC().Add(-time.Hour)
_, err := integrationDB.ExecContext(ctx, `
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at)
VALUES ($1, 1, $2, $3), ($4, 1, $5, $6)
`,
oldRequestID, strings.Repeat("a", 64), oldCreatedAt,
newRequestID, strings.Repeat("b", 64), newCreatedAt,
)
require.NoError(t, err)
require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
var oldCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount))
require.Equal(t, 0, oldCount)
var newCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount))
require.Equal(t, 1, newCount)
var archivedCount int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount))
require.Equal(t, 1, archivedCount)
}
func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUsageBillingRepository(client, integrationDB)
aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB)
user := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()),
PasswordHash: "hash",
Balance: 100,
})
apiKey := mustCreateApiKey(t, client, &service.APIKey{
UserID: user.ID,
Key: "sk-usage-billing-archive-" + uuid.NewString(),
Name: "billing-archive",
})
requestID := uuid.NewString()
cmd := &service.UsageBillingCommand{
RequestID: requestID,
APIKeyID: apiKey.ID,
UserID: user.ID,
BalanceCost: 1.25,
}
result1, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.True(t, result1.Applied)
_, err = integrationDB.ExecContext(ctx, `
UPDATE usage_billing_dedup
SET created_at = $1
WHERE request_id = $2 AND api_key_id = $3
`, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID)
require.NoError(t, err)
require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
result2, err := repo.Apply(ctx, cmd)
require.NoError(t, err)
require.False(t, result2.Applied)
var balance float64
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
require.InDelta(t, 98.75, balance, 0.000001)
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,8 @@ package repository
import (
"context"
"fmt"
"sync"
"testing"
"time"
@@ -14,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@@ -84,6 +87,367 @@ func (s *UsageLogRepoSuite) TestCreate() {
s.Require().NotZero(log.ID)
}
func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()})
const total = 16
results := make([]bool, total)
errs := make([]error, total)
logs := make([]*service.UsageLog, total)
var wg sync.WaitGroup
wg.Add(total)
for i := 0; i < total; i++ {
i := i
logs[i] = &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10 + i,
OutputTokens: 20 + i,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
go func() {
defer wg.Done()
results[i], errs[i] = repo.Create(ctx, logs[i])
}()
}
wg.Wait()
for i := 0; i < total; i++ {
require.NoError(t, errs[i])
require.True(t, results[i])
require.NotZero(t, logs[i].ID)
}
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count))
require.Equal(t, total, count)
}
func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()})
requestID := uuid.NewString()
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
inserted1, err1 := repo.Create(ctx, log1)
inserted2, err2 := repo.Create(ctx, log2)
require.NoError(t, err1)
require.NoError(t, err2)
require.True(t, inserted1)
require.False(t, inserted2)
require.Equal(t, log1.ID, log2.ID)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
require.Equal(t, 1, count)
}
func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()})
requestID := uuid.NewString()
const total = 8
batch := make([]usageLogCreateRequest, 0, total)
logs := make([]*service.UsageLog, 0, total)
for i := 0; i < total; i++ {
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10 + i,
OutputTokens: 20 + i,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
logs = append(logs, log)
batch = append(batch, usageLogCreateRequest{
log: log,
prepared: prepareUsageLogInsert(log),
resultCh: make(chan usageLogCreateResult, 1),
})
}
repo.flushCreateBatch(integrationDB, batch)
insertedCount := 0
var firstID int64
for idx, req := range batch {
res := <-req.resultCh
require.NoError(t, res.err)
if res.inserted {
insertedCount++
}
require.NotZero(t, logs[idx].ID)
if idx == 0 {
firstID = logs[idx].ID
} else {
require.Equal(t, firstID, logs[idx].ID)
}
}
require.Equal(t, 1, insertedCount)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
require.Equal(t, 1, count)
}
func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()})
requestID := uuid.NewString()
log1 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
log2 := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: requestID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
require.NoError(t, repo.CreateBestEffort(ctx, log1))
require.NoError(t, repo.CreateBestEffort(ctx, log2))
require.Eventually(t, func() bool {
var count int
err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)
return err == nil && count == 1
}, 3*time.Second, 20*time.Millisecond)
}
func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1)
repo.bestEffortBatchCh <- usageLogBestEffortRequest{}
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()})
err := repo.CreateBestEffort(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.Error(t, err)
require.True(t, service.IsUsageLogCreateDropped(err))
}
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()})
ctx, cancel := context.WithCancel(context.Background())
cancel()
inserted, err := repo.Create(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.False(t, inserted)
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
}
func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
repo.createBatchCh <- usageLogCreateRequest{}
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()})
inserted, err := repo.Create(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.False(t, inserted)
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
}
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()})
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error, 1)
go func() {
_, err := repo.createBatched(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
errCh <- err
}()
req := <-repo.createBatchCh
require.NotNil(t, req.shared)
cancel()
err := <-errCh
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)})
}
func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()})
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
}
req := usageLogCreateRequest{
log: log,
prepared: prepareUsageLogInsert(log),
shared: &usageLogCreateShared{},
resultCh: make(chan usageLogCreateResult, 1),
}
req.shared.state.Store(usageLogCreateStateCanceled)
repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req})
res := <-req.resultCh
require.False(t, res.inserted)
require.Error(t, res.err)
require.True(t, service.IsUsageLogCreateNotPersisted(res.err))
}
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})

View File

@@ -248,6 +248,35 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}).
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0).
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0).
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0)
mock.ExpectQuery("WITH user_spend AS \\(").
WithArgs(start, end, 12).
WillReturnRows(rows)
got, err := repo.GetUserSpendingRanking(context.Background(), start, end, 12)
require.NoError(t, err)
require.Equal(t, &usagestats.UserSpendingRankingResponse{
Ranking: []usagestats.UserSpendingRankingItem{
{UserID: 2, Email: "beta@example.com", ActualCost: 12.5, Requests: 9, Tokens: 900},
{UserID: 1, Email: "alpha@example.com", ActualCost: 12.5, Requests: 8, Tokens: 800},
{UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300},
},
TotalActualCost: 40.0,
}, got)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
tests := []struct {
name string

View File

@@ -3,8 +3,11 @@
package repository
import (
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
@@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) {
})
}
}
func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) {
log := &service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-batch-no-update",
Model: "gpt-5",
InputTokens: 10,
OutputTokens: 5,
TotalCost: 1.2,
ActualCost: 1.2,
CreatedAt: time.Now().UTC(),
}
prepared := prepareUsageLogInsert(log)
query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{
usageLogBatchKey(log.RequestID, log.APIKeyID): prepared,
})
require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING")
require.NotContains(t, strings.ToUpper(query), "DO UPDATE")
}

View File

@@ -98,9 +98,9 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
// GetByGroupID 获取指定分组下所有用户的专属倍率
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
query := `
SELECT ugr.user_id, u.email, ugr.rate_multiplier
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
FROM user_group_rate_multipliers ugr
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
JOIN users u ON u.id = ugr.user_id
WHERE ugr.group_id = $1
ORDER BY ugr.user_id
`
@@ -113,7 +113,7 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
var result []service.UserGroupRateEntry
for rows.Next() {
var entry service.UserGroupRateEntry
if err := rows.Scan(&entry.UserID, &entry.UserEmail, &entry.RateMultiplier); err != nil {
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil {
return nil, err
}
result = append(result, entry)
@@ -193,6 +193,31 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
return nil
}
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil {
return err
}
if len(entries) == 0 {
return nil
}
userIDs := make([]int64, len(entries))
rates := make([]float64, len(entries))
for i, e := range entries {
userIDs[i] = e.UserID
rates[i] = e.RateMultiplier
}
now := time.Now()
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
SELECT data.user_id, $1::bigint, data.rate_multiplier, $2::timestamptz, $2::timestamptz
FROM unnest($3::bigint[], $4::double precision[]) AS data(user_id, rate_multiplier)
ON CONFLICT (user_id, group_id)
DO UPDATE SET rate_multiplier = EXCLUDED.rate_multiplier, updated_at = EXCLUDED.updated_at
`, groupID, now, pq.Array(userIDs), pq.Array(rates))
return err
}
// DeleteByGroupID 删除指定分组的所有用户专属倍率
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)

View File

@@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet(
NewAnnouncementRepository,
NewAnnouncementReadRepository,
NewUsageLogRepository,
NewUsageBillingRepository,
NewIdempotencyRepository,
NewUsageCleanupRepository,
NewDashboardAggregationRepository,

View File

@@ -645,7 +645,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
@@ -1635,6 +1635,10 @@ func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, end
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
logs := r.userLogs[userID]
if len(logs) == 0 {

View File

@@ -192,6 +192,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
@@ -229,6 +230,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers)
groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
}

View File

@@ -412,6 +412,7 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
}
// Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整)
return nil
}
if len(rawMapping) == 0 {
@@ -521,16 +522,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
// 如果未配置 mapping返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string {
mappedModel, _ := a.ResolveMappedModel(requestedModel)
return mappedModel
}
// ResolveMappedModel 获取映射后的模型名,并返回是否命中了账号级映射。
// matched=true 表示命中了精确映射或通配符映射,即使映射结果与原模型名相同。
func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, matched bool) {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return requestedModel
return requestedModel, false
}
// 精确匹配优先
if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel
return mappedModel, true
}
// 通配符匹配(最长优先)
return matchWildcardMapping(mapping, requestedModel)
return matchWildcardMappingResult(mapping, requestedModel)
}
func (a *Account) GetBaseURL() string {
@@ -604,9 +612,7 @@ func matchWildcard(pattern, str string) bool {
return matchAntigravityWildcard(pattern, str)
}
// matchWildcardMapping 通配符映射匹配(最长优先)
// 如果没有匹配,返回原始字符串
func matchWildcardMapping(mapping map[string]string, requestedModel string) string {
func matchWildcardMappingResult(mapping map[string]string, requestedModel string) (string, bool) {
// 收集所有匹配的 pattern按长度降序排序最长优先
type patternMatch struct {
pattern string
@@ -621,7 +627,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
}
if len(matches) == 0 {
return requestedModel // 无匹配,返回原始模型名
return requestedModel, false // 无匹配,返回原始模型名
}
// 按 pattern 长度降序排序
@@ -632,7 +638,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
return matches[i].pattern < matches[j].pattern
})
return matches[0].target
return matches[0].target, true
}
func (a *Account) IsCustomErrorCodesEnabled() bool {
@@ -764,6 +770,14 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
return false
}
func (a *Account) IsBedrock() bool {
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeBedrock || a.Type == AccountTypeBedrockAPIKey)
}
func (a *Account) IsBedrockAPIKey() bool {
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrockAPIKey
}
func (a *Account) IsOpenAI() bool {
return a.Platform == PlatformOpenAI
}

View File

@@ -207,14 +207,14 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
testModelID = claude.DefaultTestModel
}
// For API Key accounts with model mapping, map the model
// API Key 账号测试连接时也需要应用通配符模型映射。
if account.Type == "apikey" {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
testModelID = mappedModel
}
}
testModelID = account.GetMappedModel(testModelID)
}
// Bedrock accounts use a separate test path
if account.IsBedrock() {
return s.testBedrockAccountConnection(c, ctx, account, testModelID)
}
// Determine authentication method and API URL
@@ -312,6 +312,109 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.processClaudeStream(c, resp.Body)
}
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
region := bedrockRuntimeRegion(account)
resolvedModelID, ok := ResolveBedrockModelID(account, testModelID)
if !ok {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Bedrock model: %s", testModelID))
}
testModelID = resolvedModelID
// Set SSE headers (test UI expects SSE)
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
// Create a minimal Bedrock-compatible payload (no stream, no cache_control)
bedrockPayload := map[string]any{
"anthropic_version": "bedrock-2023-05-31",
"messages": []map[string]any{
{
"role": "user",
"content": []map[string]any{
{
"type": "text",
"text": "hi",
},
},
},
},
"max_tokens": 256,
"temperature": 1,
}
bedrockBody, _ := json.Marshal(bedrockPayload)
// Use non-streaming endpoint (response is standard Claude JSON)
apiURL := BuildBedrockURL(region, testModelID, false)
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bedrockBody))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
req.Header.Set("Content-Type", "application/json")
// Sign or set auth based on account type
if account.IsBedrockAPIKey() {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
req.Header.Set("Authorization", "Bearer "+apiKey)
} else {
signer, err := NewBedrockSignerFromAccount(account)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Bedrock signer: %s", err.Error()))
}
if err := signer.SignRequest(ctx, req, bedrockBody); err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to sign request: %s", err.Error()))
}
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, false)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}
// Bedrock non-streaming response is standard Claude JSON, extract the text
var result struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
}
if err := json.Unmarshal(body, &result); err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error()))
}
text := ""
if len(result.Content) > 0 {
text = result.Content[0].Text
}
if text == "" {
text = "(empty response)"
}
s.sendEvent(c, TestEvent{Type: "content", Text: text})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
// testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context()

View File

@@ -47,6 +47,7 @@ type UsageLogRepository interface {
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)

View File

@@ -43,12 +43,13 @@ func TestMatchWildcard(t *testing.T) {
}
}
func TestMatchWildcardMapping(t *testing.T) {
func TestMatchWildcardMappingResult(t *testing.T) {
tests := []struct {
name string
mapping map[string]string
requestedModel string
expected string
matched bool
}{
// 精确匹配优先于通配符
{
@@ -59,6 +60,7 @@ func TestMatchWildcardMapping(t *testing.T) {
},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5-exact",
matched: true,
},
// 最长通配符优先
@@ -71,6 +73,7 @@ func TestMatchWildcardMapping(t *testing.T) {
},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-series",
matched: true,
},
// 单个通配符
@@ -81,6 +84,7 @@ func TestMatchWildcardMapping(t *testing.T) {
},
requestedModel: "claude-opus-4-5",
expected: "claude-mapped",
matched: true,
},
// 无匹配返回原始模型
@@ -91,6 +95,7 @@ func TestMatchWildcardMapping(t *testing.T) {
},
requestedModel: "gemini-3-flash",
expected: "gemini-3-flash",
matched: false,
},
// 空映射返回原始模型
@@ -99,6 +104,7 @@ func TestMatchWildcardMapping(t *testing.T) {
mapping: map[string]string{},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5",
matched: false,
},
// Gemini 模型映射
@@ -110,14 +116,15 @@ func TestMatchWildcardMapping(t *testing.T) {
},
requestedModel: "gemini-3-flash-preview",
expected: "gemini-3-pro-high",
matched: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchWildcardMapping(tt.mapping, tt.requestedModel)
if result != tt.expected {
t.Errorf("matchWildcardMapping(%v, %q) = %q, want %q", tt.mapping, tt.requestedModel, result, tt.expected)
result, matched := matchWildcardMappingResult(tt.mapping, tt.requestedModel)
if result != tt.expected || matched != tt.matched {
t.Errorf("matchWildcardMappingResult(%v, %q) = (%q, %v), want (%q, %v)", tt.mapping, tt.requestedModel, result, matched, tt.expected, tt.matched)
}
})
}
@@ -268,6 +275,69 @@ func TestAccountGetMappedModel(t *testing.T) {
}
}
func TestAccountResolveMappedModel(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
requestedModel string
expectedModel string
expectedMatch bool
}{
{
name: "no mapping reports unmatched",
credentials: nil,
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
{
name: "exact passthrough mapping still counts as matched",
credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: true,
},
{
name: "wildcard passthrough mapping still counts as matched",
credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-*": "gpt-5.4",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: true,
},
{
name: "missing mapping reports unmatched",
credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.2": "gpt-5.2",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Credentials: tt.credentials,
}
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
if mappedModel != tt.expectedModel || matched != tt.expectedMatch {
t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch)
}
})
}
}
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,

View File

@@ -43,6 +43,8 @@ type AdminService interface {
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// API Key management (admin)
@@ -58,6 +60,8 @@ type AdminService interface {
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*Account, error)
SetAccountError(ctx context.Context, id int64, errorMsg string) error
// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号 privacy_mode未设置则尝试关闭训练数据共享并持久化。
EnsureOpenAIPrivacy(ctx context.Context, account *Account) string
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
@@ -139,10 +143,9 @@ type CreateGroupInput struct {
// 无效请求兜底分组 ID仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由
MCPXMLInject *bool
SimulateClaudeMaxEnabled *bool
ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
// Sora 存储配额
@@ -179,10 +182,9 @@ type UpdateGroupInput struct {
// 无效请求兜底分组 ID仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由
MCPXMLInject *bool
SimulateClaudeMaxEnabled *bool
ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
// Sora 存储配额
@@ -366,10 +368,6 @@ type ProxyExitInfoProber interface {
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
}
type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
type proxyQualityTarget struct {
Target string
URL string
@@ -440,12 +438,17 @@ type adminServiceImpl struct {
settingService *SettingService
defaultSubAssigner DefaultSubscriptionAssigner
userSubRepo UserSubscriptionRepository
privacyClientFactory PrivacyClientFactory
}
type userGroupRateBatchReader interface {
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
}
type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAdminService creates a new AdminService
func NewAdminService(
userRepo UserRepository,
@@ -464,6 +467,7 @@ func NewAdminService(
settingService *SettingService,
defaultSubAssigner DefaultSubscriptionAssigner,
userSubRepo UserSubscriptionRepository,
privacyClientFactory PrivacyClientFactory,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
@@ -482,6 +486,7 @@ func NewAdminService(
settingService: settingService,
defaultSubAssigner: defaultSubAssigner,
userSubRepo: userSubRepo,
privacyClientFactory: privacyClientFactory,
}
}
@@ -863,13 +868,6 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
if input.MCPXMLInject != nil {
mcpXMLInject = *input.MCPXMLInject
}
simulateClaudeMaxEnabled := false
if input.SimulateClaudeMaxEnabled != nil {
if platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled {
return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups")
}
simulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled
}
// 如果指定了复制账号的源分组,先获取账号 ID 列表
var accountIDsToCopy []int64
@@ -926,7 +924,6 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject,
SimulateClaudeMaxEnabled: simulateClaudeMaxEnabled,
SupportedModelScopes: input.SupportedModelScopes,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
AllowMessagesDispatch: input.AllowMessagesDispatch,
@@ -1138,15 +1135,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.MCPXMLInject != nil {
group.MCPXMLInject = *input.MCPXMLInject
}
if input.SimulateClaudeMaxEnabled != nil {
if group.Platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled {
return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups")
}
group.SimulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled
}
if group.Platform != PlatformAnthropic {
group.SimulateClaudeMaxEnabled = false
}
// 支持的模型系列(仅 antigravity 平台使用)
if input.SupportedModelScopes != nil {
@@ -1271,6 +1259,20 @@ func (s *adminServiceImpl) GetGroupRateMultipliers(ctx context.Context, groupID
return s.userGroupRateRepo.GetByGroupID(ctx, groupID)
}
func (s *adminServiceImpl) ClearGroupRateMultipliers(ctx context.Context, groupID int64) error {
if s.userGroupRateRepo == nil {
return nil
}
return s.userGroupRateRepo.DeleteByGroupID(ctx, groupID)
}
func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
if s.userGroupRateRepo == nil {
return nil
}
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
}
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return s.groupRepo.UpdateSortOrders(ctx, updates)
}
@@ -2529,3 +2531,39 @@ func (e *MixedChannelError) Error() string {
func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error {
return s.accountRepo.ResetQuotaUsed(ctx, id)
}
// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号是否已设置 privacy_mode
// 未设置则调用 disableOpenAITraining 并持久化到 Extra返回设置的 mode 值。
func (s *adminServiceImpl) EnsureOpenAIPrivacy(ctx context.Context, account *Account) string {
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return ""
}
if s.privacyClientFactory == nil {
return ""
}
if account.Extra != nil {
if _, ok := account.Extra["privacy_mode"]; ok {
return ""
}
}
token, _ := account.Credentials["access_token"].(string)
if token == "" {
return ""
}
var proxyURL string
if account.ProxyID != nil {
if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil {
proxyURL = p.URL()
}
}
mode := disableOpenAITraining(ctx, s.privacyClientFactory, token, proxyURL)
if mode == "" {
return ""
}
_ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode})
return mode
}

View File

@@ -43,16 +43,6 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i
return nil
}
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
if err, ok := s.listByGroupErr[groupID]; ok {
return nil, err
}
if rows, ok := s.listByGroupData[groupID]; ok {
return rows, nil
}
return nil, nil
}
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
s.getByIDsCalled = true
s.getByIDsIDs = append([]int64{}, ids...)
@@ -73,6 +63,16 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac
return nil, errors.New("account not found")
}
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
if err, ok := s.listByGroupErr[groupID]; ok {
return nil, err
}
if rows, ok := s.listByGroupData[groupID]; ok {
return rows, nil
}
return nil, nil
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}

View File

@@ -0,0 +1,176 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
// userGroupRateRepoStubForGroupRate implements UserGroupRateRepository for group rate tests.
type userGroupRateRepoStubForGroupRate struct {
getByGroupIDData map[int64][]UserGroupRateEntry
getByGroupIDErr error
deletedGroupIDs []int64
deleteByGroupErr error
syncedGroupID int64
syncedEntries []GroupRateMultiplierInput
syncGroupErr error
}
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
panic("unexpected GetByUserID call")
}
func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context, _, _ int64) (*float64, error) {
panic("unexpected GetByUserAndGroup call")
}
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
if s.getByGroupIDErr != nil {
return nil, s.getByGroupIDErr
}
return s.getByGroupIDData[groupID], nil
}
func (s *userGroupRateRepoStubForGroupRate) SyncUserGroupRates(_ context.Context, _ int64, _ map[int64]*float64) error {
panic("unexpected SyncUserGroupRates call")
}
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
s.syncedGroupID = groupID
s.syncedEntries = entries
return s.syncGroupErr
}
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
return s.deleteByGroupErr
}
func (s *userGroupRateRepoStubForGroupRate) DeleteByUserID(_ context.Context, _ int64) error {
panic("unexpected DeleteByUserID call")
}
func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
t.Run("returns entries for group", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDData: map[int64][]UserGroupRateEntry{
10: {
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
},
},
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
require.NoError(t, err)
require.Len(t, entries, 2)
require.Equal(t, int64(1), entries[0].UserID)
require.Equal(t, "alice", entries[0].UserName)
require.Equal(t, 1.5, entries[0].RateMultiplier)
require.Equal(t, int64(2), entries[1].UserID)
require.Equal(t, 0.8, entries[1].RateMultiplier)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
svc := &adminServiceImpl{userGroupRateRepo: nil}
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
require.NoError(t, err)
require.Nil(t, entries)
})
t.Run("returns empty slice for group with no entries", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDData: map[int64][]UserGroupRateEntry{},
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
entries, err := svc.GetGroupRateMultipliers(context.Background(), 99)
require.NoError(t, err)
require.Nil(t, entries)
})
t.Run("propagates repo error", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDErr: errors.New("db error"),
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
_, err := svc.GetGroupRateMultipliers(context.Background(), 10)
require.Error(t, err)
require.Contains(t, err.Error(), "db error")
})
}
func TestAdminService_ClearGroupRateMultipliers(t *testing.T) {
t.Run("deletes by group ID", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
require.NoError(t, err)
require.Equal(t, []int64{42}, repo.deletedGroupIDs)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
svc := &adminServiceImpl{userGroupRateRepo: nil}
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
require.NoError(t, err)
})
t.Run("propagates repo error", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
deleteByGroupErr: errors.New("delete failed"),
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
require.Error(t, err)
require.Contains(t, err.Error(), "delete failed")
})
}
func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
t.Run("syncs entries to repo", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
entries := []GroupRateMultiplierInput{
{UserID: 1, RateMultiplier: 1.5},
{UserID: 2, RateMultiplier: 0.8},
}
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, entries)
require.NoError(t, err)
require.Equal(t, int64(10), repo.syncedGroupID)
require.Equal(t, entries, repo.syncedEntries)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
svc := &adminServiceImpl{userGroupRateRepo: nil}
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, nil)
require.NoError(t, err)
})
t.Run("propagates repo error", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
syncGroupErr: errors.New("sync failed"),
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, []GroupRateMultiplierInput{
{UserID: 1, RateMultiplier: 1.0},
})
require.Error(t, err)
require.Contains(t, err.Error(), "sync failed")
})
}

View File

@@ -785,57 +785,3 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
require.NotNil(t, repo.updated)
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_CreateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
enabled := true
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "openai-group",
Platform: PlatformOpenAI,
SimulateClaudeMaxEnabled: &enabled,
})
require.Error(t, err)
require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups")
require.Nil(t, repo.created)
}
func TestAdminService_UpdateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "openai-group",
Platform: PlatformOpenAI,
Status: StatusActive,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
enabled := true
_, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
SimulateClaudeMaxEnabled: &enabled,
})
require.Error(t, err)
require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups")
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_ClearsSimulateClaudeMaxWhenPlatformChanges(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "anthropic-group",
Platform: PlatformAnthropic,
Status: StatusActive,
SimulateClaudeMaxEnabled: true,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
Platform: PlatformOpenAI,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.False(t, repo.updated.SimulateClaudeMaxEnabled)
}

View File

@@ -68,11 +68,15 @@ func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context
panic("unexpected SyncUserGroupRates call")
}
func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, _ int64) ([]UserGroupRateEntry, error) {
panic("unexpected GetByGroupID call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.Context, _ int64, _ []GroupRateMultiplierInput) error {
panic("unexpected SyncGroupRateMultipliers call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
panic("unexpected DeleteByGroupID call")
}

View File

@@ -1673,7 +1673,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
var clientDisconnect bool
if claudeReq.Stream {
// 客户端要求流式,直接透传转换
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel, account.ID)
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
if err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
return nil, err
@@ -1683,7 +1683,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
clientDisconnect = streamRes.clientDisconnect
} else {
// 客户端要求非流式,收集流式响应后转换返回
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel, account.ID)
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
if err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
return nil, err
@@ -1692,9 +1692,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
firstTokenMs = streamRes.firstTokenMs
}
// Claude Max cache billing: 同步 ForwardResult.Usage 与客户端响应体一致
applyClaudeMaxCacheBillingPolicyToUsage(usage, parsedRequestFromGinContext(c), claudeMaxGroupFromGinContext(c), originalModel, account.ID)
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
@@ -3598,7 +3595,7 @@ func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int,
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
// 用于处理客户端非流式请求但上游只支持流式的情况
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
@@ -3756,9 +3753,6 @@ returnResponse:
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
// Claude Max cache billing simulation (non-streaming)
claudeResp = applyClaudeMaxNonStreamingRewrite(c, claudeResp, agUsage, originalModel, accountID)
c.Data(http.StatusOK, "application/json", claudeResp)
// 转换为 service.ClaudeUsage
@@ -3773,7 +3767,7 @@ returnResponse:
}
// handleClaudeStreamingResponse 处理 Claude 流式响应Gemini SSE → Claude SSE 转换)
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
@@ -3786,8 +3780,6 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
processor := antigravity.NewStreamingProcessor(originalModel)
setupClaudeMaxStreamingHook(c, processor, originalModel, accountID)
var firstTokenMs *int
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
scanner := bufio.NewScanner(resp.Body)

View File

@@ -922,7 +922,7 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
_ = pr.Close()
require.NoError(t, err)
@@ -999,7 +999,7 @@ func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) {
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro", 0)
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro")
_ = pr.Close()
require.NoError(t, err)
@@ -1202,7 +1202,7 @@ func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
fmt.Fprintln(pw, "")
}()
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
_ = pr.Close()
require.NoError(t, err)
@@ -1234,7 +1234,7 @@ func TestHandleClaudeStreamingResponse_EmptyStream(t *testing.T) {
fmt.Fprintln(pw, "")
}()
_, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
_, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
_ = pr.Close()
// 应当返回 UpstreamFailoverError 而非 nil以便上层触发 failover
@@ -1266,7 +1266,7 @@ func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
require.NoError(t, err)
require.NotNil(t, result)

View File

@@ -59,10 +59,9 @@ type APIKeyAuthGroupSnapshot struct {
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty.
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject bool `json:"mcp_xml_inject"`
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`

View File

@@ -244,7 +244,6 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
ModelRouting: apiKey.Group.ModelRouting,
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
MCPXMLInject: apiKey.Group.MCPXMLInject,
SimulateClaudeMaxEnabled: apiKey.Group.SimulateClaudeMaxEnabled,
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
@@ -304,7 +303,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
ModelRouting: snapshot.Group.ModelRouting,
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
MCPXMLInject: snapshot.Group.MCPXMLInject,
SimulateClaudeMaxEnabled: snapshot.Group.SimulateClaudeMaxEnabled,
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
DefaultMappedModel: snapshot.Group.DefaultMappedModel,

View File

@@ -0,0 +1,607 @@
package service
import (
"encoding/json"
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const defaultBedrockRegion = "us-east-1"
var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."}
// BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀
// 参考: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
func BedrockCrossRegionPrefix(region string) string {
switch {
case strings.HasPrefix(region, "us-gov"):
return "us-gov" // GovCloud 使用独立的 us-gov 前缀
case strings.HasPrefix(region, "us-"):
return "us"
case strings.HasPrefix(region, "eu-"):
return "eu"
case region == "ap-northeast-1":
return "jp" // 日本区域使用独立的 jp 前缀AWS 官方定义)
case region == "ap-southeast-2":
return "au" // 澳大利亚区域使用独立的 au 前缀AWS 官方定义)
case strings.HasPrefix(region, "ap-"):
return "apac" // 其余亚太区域使用通用 apac 前缀
case strings.HasPrefix(region, "ca-"):
return "us" // 加拿大区域使用 us 前缀的跨区域推理
case strings.HasPrefix(region, "sa-"):
return "us" // 南美区域使用 us 前缀的跨区域推理
default:
return "us"
}
}
// AdjustBedrockModelRegionPrefix 将模型 ID 的区域前缀替换为与当前 AWS Region 匹配的前缀
// 例如 region=eu-west-1 时,"us.anthropic.claude-opus-4-6-v1" → "eu.anthropic.claude-opus-4-6-v1"
// 特殊值 region="global" 强制使用 global. 前缀
func AdjustBedrockModelRegionPrefix(modelID, region string) string {
var targetPrefix string
if region == "global" {
targetPrefix = "global"
} else {
targetPrefix = BedrockCrossRegionPrefix(region)
}
for _, p := range bedrockCrossRegionPrefixes {
if strings.HasPrefix(modelID, p) {
if p == targetPrefix+"." {
return modelID // 前缀已匹配,无需替换
}
return targetPrefix + "." + modelID[len(p):]
}
}
// 模型 ID 没有已知区域前缀(如 "anthropic.claude-..."),不做修改
return modelID
}
func bedrockRuntimeRegion(account *Account) string {
if account == nil {
return defaultBedrockRegion
}
if region := account.GetCredential("aws_region"); region != "" {
return region
}
return defaultBedrockRegion
}
func shouldForceBedrockGlobal(account *Account) bool {
return account != nil && account.GetCredential("aws_force_global") == "true"
}
func isRegionalBedrockModelID(modelID string) bool {
for _, prefix := range bedrockCrossRegionPrefixes {
if strings.HasPrefix(modelID, prefix) {
return true
}
}
return false
}
func isLikelyBedrockModelID(modelID string) bool {
lower := strings.ToLower(strings.TrimSpace(modelID))
if lower == "" {
return false
}
if strings.HasPrefix(lower, "arn:") {
return true
}
for _, prefix := range []string{
"anthropic.",
"amazon.",
"meta.",
"mistral.",
"cohere.",
"ai21.",
"deepseek.",
"stability.",
"writer.",
"nova.",
} {
if strings.HasPrefix(lower, prefix) {
return true
}
}
return isRegionalBedrockModelID(lower)
}
func normalizeBedrockModelID(modelID string) (normalized string, shouldAdjustRegion bool, ok bool) {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
return "", false, false
}
if mapped, exists := domain.DefaultBedrockModelMapping[modelID]; exists {
return mapped, true, true
}
if isRegionalBedrockModelID(modelID) {
return modelID, true, true
}
if isLikelyBedrockModelID(modelID) {
return modelID, false, true
}
return "", false, false
}
// ResolveBedrockModelID resolves a requested Claude model into a Bedrock model ID.
// It applies account model_mapping first, then default Bedrock aliases, and finally
// adjusts Anthropic cross-region prefixes to match the account region.
func ResolveBedrockModelID(account *Account, requestedModel string) (string, bool) {
if account == nil {
return "", false
}
mappedModel := account.GetMappedModel(requestedModel)
modelID, shouldAdjustRegion, ok := normalizeBedrockModelID(mappedModel)
if !ok {
return "", false
}
if shouldAdjustRegion {
targetRegion := bedrockRuntimeRegion(account)
if shouldForceBedrockGlobal(account) {
targetRegion = "global"
}
modelID = AdjustBedrockModelRegionPrefix(modelID, targetRegion)
}
return modelID, true
}
// BuildBedrockURL 构建 Bedrock InvokeModel 的 URL
// stream=true 时使用 invoke-with-response-stream 端点
// modelID 中的特殊字符会被 URL 编码(与 litellm 的 urllib.parse.quote(safe="") 对齐)
func BuildBedrockURL(region, modelID string, stream bool) string {
if region == "" {
region = defaultBedrockRegion
}
encodedModelID := url.PathEscape(modelID)
// url.PathEscape 不编码冒号RFC 允许 path 中出现 ":"
// 但 AWS Bedrock 期望模型 ID 中的冒号被编码为 %3A
encodedModelID = strings.ReplaceAll(encodedModelID, ":", "%3A")
if stream {
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", region, encodedModelID)
}
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke", region, encodedModelID)
}
// PrepareBedrockRequestBody 处理请求体以适配 Bedrock API
// 1. 注入 anthropic_version
// 2. 注入 anthropic_beta从客户端 anthropic-beta 头解析)
// 3. 移除 Bedrock 不支持的字段model, stream, output_format, output_config
// 4. 移除工具定义中的 custom 字段Claude Code 会发送 custom: {defer_loading: true}
// 5. 清理 cache_control 中 Bedrock 不支持的字段scope, ttl
func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) {
betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID)
return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens)
}
// PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens.
func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) {
var err error
// 注入 anthropic_versionBedrock 要求)
body, err = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31")
if err != nil {
return nil, fmt.Errorf("inject anthropic_version: %w", err)
}
// 注入 anthropic_betaBedrock Invoke 通过请求体传递 beta 头,而非 HTTP 头)
// 1. 从客户端 anthropic-beta header 解析
// 2. 根据请求体内容自动补齐必要的 beta token
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() + _get_tool_search_beta_header_for_bedrock()
if len(betaTokens) > 0 {
body, err = sjson.SetBytes(body, "anthropic_beta", betaTokens)
if err != nil {
return nil, fmt.Errorf("inject anthropic_beta: %w", err)
}
}
// 移除 model 字段Bedrock 通过 URL 指定模型)
body, err = sjson.DeleteBytes(body, "model")
if err != nil {
return nil, fmt.Errorf("remove model field: %w", err)
}
// 移除 stream 字段Bedrock 通过不同端点控制流式,不接受请求体中的 stream 字段)
body, err = sjson.DeleteBytes(body, "stream")
if err != nil {
return nil, fmt.Errorf("remove stream field: %w", err)
}
// 转换 output_formatBedrock Invoke 不支持此字段,但可将 schema 内联到最后一条 user message
// 参考 litellm: _convert_output_format_to_inline_schema()
body = convertOutputFormatToInlineSchema(body)
// 移除 output_config 字段Bedrock Invoke 不支持)
body, err = sjson.DeleteBytes(body, "output_config")
if err != nil {
return nil, fmt.Errorf("remove output_config field: %w", err)
}
// 移除工具定义中的 custom 字段
// Claude Code (v2.1.69+) 在 tool 定义中发送 custom: {defer_loading: true}
// Anthropic API 接受但 Bedrock 会拒绝并报 "Extra inputs are not permitted"
body = removeCustomFieldFromTools(body)
// 清理 cache_control 中 Bedrock 不支持的字段
body = sanitizeBedrockCacheControl(body, modelID)
return body, nil
}
// ResolveBedrockBetaTokens computes the final Bedrock beta token list before policy filtering.
func ResolveBedrockBetaTokens(betaHeader string, body []byte, modelID string) []string {
betaTokens := parseAnthropicBetaHeader(betaHeader)
betaTokens = autoInjectBedrockBetaTokens(betaTokens, body, modelID)
return filterBedrockBetaTokens(betaTokens)
}
// convertOutputFormatToInlineSchema 将 output_format 中的 JSON schema 内联到最后一条 user message
// Bedrock Invoke 不支持 output_format 参数litellm 的做法是将 schema 追加到用户消息中
// 参考: litellm AmazonAnthropicClaudeMessagesConfig._convert_output_format_to_inline_schema()
func convertOutputFormatToInlineSchema(body []byte) []byte {
outputFormat := gjson.GetBytes(body, "output_format")
if !outputFormat.Exists() || !outputFormat.IsObject() {
return body
}
// 先从请求体中移除 output_format
body, _ = sjson.DeleteBytes(body, "output_format")
schema := outputFormat.Get("schema")
if !schema.Exists() {
return body
}
// 找到最后一条 user message
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
msgArr := messages.Array()
lastUserIdx := -1
for i := len(msgArr) - 1; i >= 0; i-- {
if msgArr[i].Get("role").String() == "user" {
lastUserIdx = i
break
}
}
if lastUserIdx < 0 {
return body
}
// 将 schema 序列化为 JSON 文本追加到该 message 的 content 数组
schemaJSON, err := json.Marshal(json.RawMessage(schema.Raw))
if err != nil {
return body
}
content := msgArr[lastUserIdx].Get("content")
basePath := fmt.Sprintf("messages.%d.content", lastUserIdx)
if content.IsArray() {
// 追加一个 text block 到 content 数组末尾
idx := len(content.Array())
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.type", basePath, idx), "text")
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.text", basePath, idx), string(schemaJSON))
} else if content.Type == gjson.String {
// content 是纯字符串,转换为数组格式
originalText := content.String()
body, _ = sjson.SetBytes(body, basePath, []map[string]string{
{"type": "text", "text": originalText},
{"type": "text", "text": string(schemaJSON)},
})
}
return body
}
// removeCustomFieldFromTools 移除 tools 数组中每个工具定义的 custom 字段
func removeCustomFieldFromTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if !tools.Exists() || !tools.IsArray() {
return body
}
var err error
for i := range tools.Array() {
body, err = sjson.DeleteBytes(body, fmt.Sprintf("tools.%d.custom", i))
if err != nil {
// 删除失败不影响整体流程,跳过
continue
}
}
return body
}
// claudeVersionRe 匹配 Claude 模型 ID 中的版本号部分
// 支持 claude-{tier}-{major}-{minor} 和 claude-{tier}-{major}.{minor} 格式
var claudeVersionRe = regexp.MustCompile(`claude-(?:haiku|sonnet|opus)-(\d+)[-.](\d+)`)
// isBedrockClaude45OrNewer 判断 Bedrock 模型 ID 是否为 Claude 4.5 或更新版本
// Claude 4.5+ 支持 cache_control 中的 ttl 字段("5m" 和 "1h"
func isBedrockClaude45OrNewer(modelID string) bool {
lower := strings.ToLower(modelID)
matches := claudeVersionRe.FindStringSubmatch(lower)
if matches == nil {
return false
}
major, _ := strconv.Atoi(matches[1])
minor, _ := strconv.Atoi(matches[2])
return major > 4 || (major == 4 && minor >= 5)
}
// sanitizeBedrockCacheControl 清理 system 和 messages 中 cache_control 里
// Bedrock 不支持的字段:
// - scopeBedrock 不支持(如 "global" 跨请求缓存)
// - ttl仅 Claude 4.5+ 支持 "5m" 和 "1h",旧模型需要移除
func sanitizeBedrockCacheControl(body []byte, modelID string) []byte {
isClaude45 := isBedrockClaude45OrNewer(modelID)
// 清理 system 数组中的 cache_control
systemArr := gjson.GetBytes(body, "system")
if systemArr.Exists() && systemArr.IsArray() {
for i, item := range systemArr.Array() {
if !item.IsObject() {
continue
}
cc := item.Get("cache_control")
if !cc.Exists() || !cc.IsObject() {
continue
}
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("system.%d.cache_control", i), cc, isClaude45)
}
}
// 清理 messages 中的 cache_control
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
for mi, msg := range messages.Array() {
if !msg.IsObject() {
continue
}
content := msg.Get("content")
if !content.Exists() || !content.IsArray() {
continue
}
for ci, block := range content.Array() {
if !block.IsObject() {
continue
}
cc := block.Get("cache_control")
if !cc.Exists() || !cc.IsObject() {
continue
}
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("messages.%d.content.%d.cache_control", mi, ci), cc, isClaude45)
}
}
return body
}
// deleteCacheControlUnsupportedFields 删除给定 cache_control 路径下 Bedrock 不支持的字段
func deleteCacheControlUnsupportedFields(body []byte, basePath string, cc gjson.Result, isClaude45 bool) []byte {
// Bedrock 不支持 scope如 "global"
if cc.Get("scope").Exists() {
body, _ = sjson.DeleteBytes(body, basePath+".scope")
}
// ttl仅 Claude 4.5+ 支持 "5m" 和 "1h",其余情况移除
ttl := cc.Get("ttl")
if ttl.Exists() {
shouldRemove := true
if isClaude45 {
v := ttl.String()
if v == "5m" || v == "1h" {
shouldRemove = false
}
}
if shouldRemove {
body, _ = sjson.DeleteBytes(body, basePath+".ttl")
}
}
return body
}
// parseAnthropicBetaHeader 解析 anthropic-beta 头的逗号分隔字符串为 token 列表
func parseAnthropicBetaHeader(header string) []string {
header = strings.TrimSpace(header)
if header == "" {
return nil
}
if strings.HasPrefix(header, "[") && strings.HasSuffix(header, "]") {
var parsed []any
if err := json.Unmarshal([]byte(header), &parsed); err == nil {
tokens := make([]string, 0, len(parsed))
for _, item := range parsed {
token := strings.TrimSpace(fmt.Sprint(item))
if token != "" {
tokens = append(tokens, token)
}
}
return tokens
}
}
var tokens []string
for _, part := range strings.Split(header, ",") {
t := strings.TrimSpace(part)
if t != "" {
tokens = append(tokens, t)
}
}
return tokens
}
// bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单
// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json)
// 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单
var bedrockSupportedBetaTokens = map[string]bool{
"computer-use-2025-01-24": true,
"computer-use-2025-11-24": true,
"context-1m-2025-08-07": true,
"context-management-2025-06-27": true,
"compact-2026-01-12": true,
"interleaved-thinking-2025-05-14": true,
"tool-search-tool-2025-10-19": true,
"tool-examples-2025-10-29": true,
}
// bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则
// Anthropic 直接 API 使用通用头Bedrock Invoke 需要特定的替代头
var bedrockBetaTokenTransforms = map[string]string{
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
}
// autoInjectBedrockBetaTokens 根据请求体内容自动补齐必要的 beta token
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() 和
// AmazonAnthropicClaudeMessagesConfig._get_tool_search_beta_header_for_bedrock()
//
// 客户端(特别是非 Claude Code 客户端)可能只在 body 中启用了功能而不在 header 中带对应 beta token
// 这里通过检测请求体特征自动补齐,确保 Bedrock Invoke 不会因缺少必要 beta 头而 400。
func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) []string {
seen := make(map[string]bool, len(tokens))
for _, t := range tokens {
seen[t] = true
}
inject := func(token string) {
if !seen[token] {
tokens = append(tokens, token)
seen[token] = true
}
}
// 检测 thinking / interleaved thinking
// 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta
if gjson.GetBytes(body, "thinking").Exists() {
inject("interleaved-thinking-2025-05-14")
}
// 检测 computer_use 工具
// tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta
tools := gjson.GetBytes(body, "tools")
if tools.Exists() && tools.IsArray() {
toolSearchUsed := false
programmaticToolCallingUsed := false
inputExamplesUsed := false
for _, tool := range tools.Array() {
toolType := tool.Get("type").String()
if strings.HasPrefix(toolType, "computer_20") {
inject("computer-use-2025-11-24")
}
if isBedrockToolSearchType(toolType) {
toolSearchUsed = true
}
if hasCodeExecutionAllowedCallers(tool) {
programmaticToolCallingUsed = true
}
if hasInputExamples(tool) {
inputExamplesUsed = true
}
}
if programmaticToolCallingUsed || inputExamplesUsed {
// programmatic tool calling 和 input examples 需要 advanced-tool-use
// 后续 filterBedrockBetaTokens 会将其转换为 Bedrock 特定的 tool-search-tool
inject("advanced-tool-use-2025-11-20")
}
if toolSearchUsed && bedrockModelSupportsToolSearch(modelID) {
// 纯 tool search无 programmatic/inputExamples时直接注入 Bedrock 特定头,
// 跳过 advanced-tool-use → tool-search-tool 的转换步骤(与 litellm 对齐)
if !programmaticToolCallingUsed && !inputExamplesUsed {
inject("tool-search-tool-2025-10-19")
} else {
inject("advanced-tool-use-2025-11-20")
}
}
}
return tokens
}
func isBedrockToolSearchType(toolType string) bool {
return toolType == "tool_search_tool_regex_20251119" || toolType == "tool_search_tool_bm25_20251119"
}
func hasCodeExecutionAllowedCallers(tool gjson.Result) bool {
allowedCallers := tool.Get("allowed_callers")
if containsStringInJSONArray(allowedCallers, "code_execution_20250825") {
return true
}
return containsStringInJSONArray(tool.Get("function.allowed_callers"), "code_execution_20250825")
}
func hasInputExamples(tool gjson.Result) bool {
if arr := tool.Get("input_examples"); arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 {
return true
}
arr := tool.Get("function.input_examples")
return arr.Exists() && arr.IsArray() && len(arr.Array()) > 0
}
func containsStringInJSONArray(result gjson.Result, target string) bool {
if !result.Exists() || !result.IsArray() {
return false
}
for _, item := range result.Array() {
if item.String() == target {
return true
}
}
return false
}
// bedrockModelSupportsToolSearch 判断 Bedrock 模型是否支持 tool search
// 目前仅 Claude Opus/Sonnet 4.5+ 支持Haiku 不支持
func bedrockModelSupportsToolSearch(modelID string) bool {
lower := strings.ToLower(modelID)
matches := claudeVersionRe.FindStringSubmatch(lower)
if matches == nil {
return false
}
// Haiku 不支持 tool search
if strings.Contains(lower, "haiku") {
return false
}
major, _ := strconv.Atoi(matches[1])
minor, _ := strconv.Atoi(matches[2])
return major > 4 || (major == 4 && minor >= 5)
}
// filterBedrockBetaTokens 过滤并转换 beta token 列表,仅保留 Bedrock Invoke 支持的 token
// 1. 应用转换规则(如 advanced-tool-use → tool-search-tool
// 2. 过滤掉 Bedrock 不支持的 token如 output-128k, files-api, structured-outputs 等)
// 3. 自动关联 tool-examples当 tool-search-tool 存在时)
func filterBedrockBetaTokens(tokens []string) []string {
seen := make(map[string]bool, len(tokens))
var result []string
for _, t := range tokens {
// 应用转换规则
if replacement, ok := bedrockBetaTokenTransforms[t]; ok {
t = replacement
}
// 只保留白名单中的 token且去重
if bedrockSupportedBetaTokens[t] && !seen[t] {
result = append(result, t)
seen[t] = true
}
}
// 自动关联: tool-search-tool 存在时,确保 tool-examples 也存在
if seen["tool-search-tool-2025-10-19"] && !seen["tool-examples-2025-10-29"] {
result = append(result, "tool-examples-2025-10-29")
}
return result
}

View File

@@ -0,0 +1,659 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestPrepareBedrockRequestBody_BasicFields(t *testing.T) {
input := `{"model":"claude-opus-4-6","stream":true,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
// anthropic_version 应被注入
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
// model 和 stream 应被移除
assert.False(t, gjson.GetBytes(result, "model").Exists())
assert.False(t, gjson.GetBytes(result, "stream").Exists())
// max_tokens 应保留
assert.Equal(t, int64(1024), gjson.GetBytes(result, "max_tokens").Int())
}
func TestPrepareBedrockRequestBody_OutputFormatInlineSchema(t *testing.T) {
t.Run("schema inlined into last user message array content", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}},"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
// schema 应内联到最后一条 user message 的 content 数组末尾
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
require.Len(t, contentArr, 2)
assert.Equal(t, "text", contentArr[1].Get("type").String())
assert.Contains(t, contentArr[1].Get("text").String(), `"name":"string"`)
})
t.Run("schema inlined into string content", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"result":"number"}},"messages":[{"role":"user","content":"compute this"}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
require.Len(t, contentArr, 2)
assert.Equal(t, "compute this", contentArr[0].Get("text").String())
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"number"`)
})
t.Run("no schema field just removes output_format", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json"},"messages":[{"role":"user","content":"hi"}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
})
t.Run("no messages just removes output_format", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
})
}
func TestPrepareBedrockRequestBody_RemoveOutputConfig(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_config":{"max_tokens":100},"messages":[]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
}
func TestRemoveCustomFieldFromTools(t *testing.T) {
input := `{
"tools": [
{"name":"tool1","custom":{"defer_loading":true},"description":"desc1"},
{"name":"tool2","description":"desc2"},
{"name":"tool3","custom":{"defer_loading":true,"other":123},"description":"desc3"}
]
}`
result := removeCustomFieldFromTools([]byte(input))
tools := gjson.GetBytes(result, "tools").Array()
require.Len(t, tools, 3)
// custom 应被移除
assert.False(t, tools[0].Get("custom").Exists())
// name/description 应保留
assert.Equal(t, "tool1", tools[0].Get("name").String())
assert.Equal(t, "desc1", tools[0].Get("description").String())
// 没有 custom 的工具不受影响
assert.Equal(t, "tool2", tools[1].Get("name").String())
// 第三个工具的 custom 也应被移除
assert.False(t, tools[2].Get("custom").Exists())
assert.Equal(t, "tool3", tools[2].Get("name").String())
}
func TestRemoveCustomFieldFromTools_NoTools(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}]}`
result := removeCustomFieldFromTools([]byte(input))
// 无 tools 时不改变原始数据
assert.JSONEq(t, input, string(result))
}
func TestSanitizeBedrockCacheControl_RemoveScope(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","scope":"global"}}],
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"global"}}]}]
}`
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
// scope 应被移除
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
assert.False(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.scope").Exists())
// type 应保留
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "messages.0.content.0.cache_control.type").String())
}
func TestSanitizeBedrockCacheControl_TTL_OldModel(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
}`
// 旧模型Claude 3.5)不支持 ttl
result := sanitizeBedrockCacheControl([]byte(input), "anthropic.claude-3-5-sonnet-20241022-v2:0")
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
}
func TestSanitizeBedrockCacheControl_TTL_Claude45_Supported(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
}`
// Claude 4.5+ 支持 "5m" 和 "1h"
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
assert.True(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
}
func TestSanitizeBedrockCacheControl_TTL_Claude45_UnsupportedValue(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"10m"}}]
}`
// Claude 4.5 不支持 "10m"
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
}
func TestSanitizeBedrockCacheControl_TTL_Claude46(t *testing.T) {
input := `{
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
}`
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
assert.True(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").Exists())
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
}
func TestSanitizeBedrockCacheControl_NoCacheControl(t *testing.T) {
input := `{"system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
// 无 cache_control 时不改变原始数据
assert.JSONEq(t, input, string(result))
}
func TestIsBedrockClaude45OrNewer(t *testing.T) {
tests := []struct {
modelID string
expect bool
}{
{"us.anthropic.claude-opus-4-6-v1", true},
{"us.anthropic.claude-sonnet-4-6", true},
{"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true},
{"us.anthropic.claude-opus-4-5-20251101-v1:0", true},
{"us.anthropic.claude-haiku-4-5-20251001-v1:0", true},
{"anthropic.claude-3-5-sonnet-20241022-v2:0", false},
{"anthropic.claude-3-opus-20240229-v1:0", false},
{"anthropic.claude-3-haiku-20240307-v1:0", false},
// 未来版本应自动支持
{"us.anthropic.claude-sonnet-5-0-v1", true},
{"us.anthropic.claude-opus-4-7-v1", true},
// 旧版本
{"anthropic.claude-opus-4-1-v1", false},
{"anthropic.claude-sonnet-4-0-v1", false},
// 非 Claude 模型
{"amazon.nova-pro-v1", false},
{"meta.llama3-70b", false},
}
for _, tt := range tests {
t.Run(tt.modelID, func(t *testing.T) {
assert.Equal(t, tt.expect, isBedrockClaude45OrNewer(tt.modelID))
})
}
}
func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) {
// 模拟一个完整的 Claude Code 请求
input := `{
"model": "claude-opus-4-6",
"stream": true,
"max_tokens": 16384,
"output_format": {"type": "json", "schema": {"result": "string"}},
"output_config": {"max_tokens": 100},
"system": [{"type": "text", "text": "You are helpful", "cache_control": {"type": "ephemeral", "scope": "global", "ttl": "5m"}}],
"messages": [
{"role": "user", "content": [{"type": "text", "text": "hello", "cache_control": {"type": "ephemeral", "ttl": "1h"}}]}
],
"tools": [
{"name": "bash", "description": "Run bash", "custom": {"defer_loading": true}, "input_schema": {"type": "object"}},
{"name": "read", "description": "Read file", "input_schema": {"type": "object"}}
]
}`
betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12"
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader)
require.NoError(t, err)
// 基本字段
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
assert.False(t, gjson.GetBytes(result, "model").Exists())
assert.False(t, gjson.GetBytes(result, "stream").Exists())
assert.Equal(t, int64(16384), gjson.GetBytes(result, "max_tokens").Int())
// anthropic_beta 应包含所有 beta tokens
betaArr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, betaArr, 3)
assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String())
assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String())
assert.Equal(t, "compact-2026-01-12", betaArr[2].String())
// output_format 应被移除schema 内联到最后一条 user message
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
// content 数组:原始 text block + 内联 schema block
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
require.Len(t, contentArr, 2)
assert.Equal(t, "hello", contentArr[0].Get("text").String())
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"string"`)
// tools 中的 custom 应被移除
assert.False(t, gjson.GetBytes(result, "tools.0.custom").Exists())
assert.Equal(t, "bash", gjson.GetBytes(result, "tools.0.name").String())
assert.Equal(t, "read", gjson.GetBytes(result, "tools.1.name").String())
// cache_control: scope 应被移除ttl 在 Claude 4.6 上保留合法值
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
}
func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
t.Run("empty beta header", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
})
t.Run("single beta token", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 1)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
})
t.Run("multiple beta tokens with spaces", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
})
t.Run("json array beta header", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`)
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
})
}
func TestParseAnthropicBetaHeader(t *testing.T) {
assert.Nil(t, parseAnthropicBetaHeader(""))
assert.Equal(t, []string{"a"}, parseAnthropicBetaHeader("a"))
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a,b"))
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a , b "))
assert.Equal(t, []string{"a", "b", "c"}, parseAnthropicBetaHeader("a,b,c"))
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader(`["a","b"]`))
}
func TestFilterBedrockBetaTokens(t *testing.T) {
t.Run("supported tokens pass through", func(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, tokens, result)
})
t.Run("unsupported tokens are filtered out", func(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
})
t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) {
tokens := []string{"advanced-tool-use-2025-11-20"}
result := filterBedrockBetaTokens(tokens)
assert.Contains(t, result, "tool-search-tool-2025-10-19")
// tool-examples 自动关联
assert.Contains(t, result, "tool-examples-2025-10-29")
})
t.Run("tool-search-tool auto-associates tool-examples", func(t *testing.T) {
tokens := []string{"tool-search-tool-2025-10-19"}
result := filterBedrockBetaTokens(tokens)
assert.Contains(t, result, "tool-search-tool-2025-10-19")
assert.Contains(t, result, "tool-examples-2025-10-29")
})
t.Run("no duplication when tool-examples already present", func(t *testing.T) {
tokens := []string{"tool-search-tool-2025-10-19", "tool-examples-2025-10-29"}
result := filterBedrockBetaTokens(tokens)
count := 0
for _, t := range result {
if t == "tool-examples-2025-10-29" {
count++
}
}
assert.Equal(t, 1, count)
})
t.Run("empty input returns nil", func(t *testing.T) {
result := filterBedrockBetaTokens(nil)
assert.Nil(t, result)
})
t.Run("all unsupported returns nil", func(t *testing.T) {
result := filterBedrockBetaTokens([]string{"output-128k-2025-02-19", "effort-2025-11-24"})
assert.Nil(t, result)
})
t.Run("duplicate tokens are deduplicated", func(t *testing.T) {
tokens := []string{"context-1m-2025-08-07", "context-1m-2025-08-07"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
})
}
func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
t.Run("unsupported beta tokens are filtered", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
"interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 1)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
})
t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
"advanced-tool-use-2025-11-20")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "tool-search-tool-2025-10-19", arr[0].String())
assert.Equal(t, "tool-examples-2025-10-29", arr[1].String())
})
}
func TestBedrockCrossRegionPrefix(t *testing.T) {
tests := []struct {
region string
expect string
}{
// US regions
{"us-east-1", "us"},
{"us-east-2", "us"},
{"us-west-1", "us"},
{"us-west-2", "us"},
// GovCloud
{"us-gov-east-1", "us-gov"},
{"us-gov-west-1", "us-gov"},
// EU regions
{"eu-west-1", "eu"},
{"eu-west-2", "eu"},
{"eu-west-3", "eu"},
{"eu-central-1", "eu"},
{"eu-central-2", "eu"},
{"eu-north-1", "eu"},
{"eu-south-1", "eu"},
// APAC regions
{"ap-northeast-1", "jp"},
{"ap-northeast-2", "apac"},
{"ap-southeast-1", "apac"},
{"ap-southeast-2", "au"},
{"ap-south-1", "apac"},
// Canada / South America fallback to us
{"ca-central-1", "us"},
{"sa-east-1", "us"},
// Unknown defaults to us
{"me-south-1", "us"},
}
for _, tt := range tests {
t.Run(tt.region, func(t *testing.T) {
assert.Equal(t, tt.expect, BedrockCrossRegionPrefix(tt.region))
})
}
}
func TestResolveBedrockModelID(t *testing.T) {
t.Run("default alias resolves and adjusts region", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "eu-west-1",
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-5")
require.True(t, ok)
assert.Equal(t, "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", modelID)
})
t.Run("custom alias mapping reuses default bedrock mapping", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "ap-southeast-2",
"model_mapping": map[string]any{
"claude-*": "claude-opus-4-6",
},
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-opus-4-6-thinking")
require.True(t, ok)
assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID)
})
t.Run("force global rewrites anthropic regional model id", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
"aws_force_global": "true",
"model_mapping": map[string]any{
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
},
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-6")
require.True(t, ok)
assert.Equal(t, "global.anthropic.claude-sonnet-4-6", modelID)
})
t.Run("direct bedrock model id passes through", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
},
}
modelID, ok := ResolveBedrockModelID(account, "anthropic.claude-haiku-4-5-20251001-v1:0")
require.True(t, ok)
assert.Equal(t, "anthropic.claude-haiku-4-5-20251001-v1:0", modelID)
})
t.Run("unsupported alias returns false", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
},
}
_, ok := ResolveBedrockModelID(account, "claude-3-5-sonnet-20241022")
assert.False(t, ok)
})
}
func TestAutoInjectBedrockBetaTokens(t *testing.T) {
t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
})
t.Run("no duplicate when already present", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1")
count := 0
for _, t := range result {
if t == "interleaved-thinking-2025-05-14" {
count++
}
}
assert.Equal(t, 1, count)
})
t.Run("inject computer-use when computer tool present", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"computer_20250124","name":"computer","display_width_px":1024}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "computer-use-2025-11-24")
})
t.Run("inject advanced-tool-use for programmatic tool calling", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("inject advanced-tool-use for input examples", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","input_examples":[{"cmd":"ls"}]}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("inject tool-search-tool directly for pure tool search (no programmatic/inputExamples)", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
// 纯 tool search 场景直接注入 Bedrock 特定头,不走 advanced-tool-use 转换
assert.Contains(t, result, "tool-search-tool-2025-10-19")
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("inject advanced-tool-use when tool search combined with programmatic calling", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"},{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
// 混合场景使用 advanced-tool-use后续由 filter 转换为 tool-search-tool
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("do not inject tool-search beta for unsupported models", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "anthropic.claude-3-5-sonnet-20241022-v2:0")
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
assert.NotContains(t, result, "tool-search-tool-2025-10-19")
})
t.Run("no injection for regular tools", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","description":"run bash","input_schema":{"type":"object"}}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Empty(t, result)
})
t.Run("no injection when no features detected", func(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Empty(t, result)
})
t.Run("preserves existing tokens", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled"},"messages":[{"role":"user","content":"hi"}]}`)
existing := []string{"context-1m-2025-08-07", "compact-2026-01-12"}
result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "context-1m-2025-08-07")
assert.Contains(t, result, "compact-2026-01-12")
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
})
}
func TestResolveBedrockBetaTokens(t *testing.T) {
t.Run("body-only tool features resolve to final bedrock tokens", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
result := ResolveBedrockBetaTokens("", body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "tool-search-tool-2025-10-19")
assert.Contains(t, result, "tool-examples-2025-10-29")
})
t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1")
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
})
}
func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) {
t.Run("thinking in body auto-injects beta without header", func(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
found := false
for _, v := range arr {
if v.String() == "interleaved-thinking-2025-05-14" {
found = true
}
}
assert.True(t, found, "interleaved-thinking should be auto-injected")
})
t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
names := make([]string, len(arr))
for i, v := range arr {
names[i] = v.String()
}
assert.Contains(t, names, "context-1m-2025-08-07")
assert.Contains(t, names, "interleaved-thinking-2025-05-14")
})
}
func TestAdjustBedrockModelRegionPrefix(t *testing.T) {
tests := []struct {
name string
modelID string
region string
expect string
}{
// US region — no change needed
{"us region keeps us prefix", "us.anthropic.claude-opus-4-6-v1", "us-east-1", "us.anthropic.claude-opus-4-6-v1"},
// EU region — replace us → eu
{"eu region replaces prefix", "us.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
{"eu region sonnet", "us.anthropic.claude-sonnet-4-6", "eu-central-1", "eu.anthropic.claude-sonnet-4-6"},
// APAC region — jp and au have dedicated prefixes per AWS docs
{"jp region (ap-northeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-northeast-1", "jp.anthropic.claude-sonnet-4-5-20250929-v1:0"},
{"au region (ap-southeast-2)", "us.anthropic.claude-haiku-4-5-20251001-v1:0", "ap-southeast-2", "au.anthropic.claude-haiku-4-5-20251001-v1:0"},
{"apac region (ap-southeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-southeast-1", "apac.anthropic.claude-sonnet-4-5-20250929-v1:0"},
// eu → us (user manually set eu prefix, moved to us region)
{"eu to us", "eu.anthropic.claude-opus-4-6-v1", "us-west-2", "us.anthropic.claude-opus-4-6-v1"},
// global prefix — replace to match region
{"global to eu", "global.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
// No known prefix — leave unchanged
{"no prefix unchanged", "anthropic.claude-3-5-sonnet-20241022-v2:0", "eu-west-1", "anthropic.claude-3-5-sonnet-20241022-v2:0"},
// GovCloud — uses independent us-gov prefix
{"govcloud from us", "us.anthropic.claude-opus-4-6-v1", "us-gov-east-1", "us-gov.anthropic.claude-opus-4-6-v1"},
{"govcloud already correct", "us-gov.anthropic.claude-opus-4-6-v1", "us-gov-west-1", "us-gov.anthropic.claude-opus-4-6-v1"},
// Force global (special region value)
{"force global from us", "us.anthropic.claude-opus-4-6-v1", "global", "global.anthropic.claude-opus-4-6-v1"},
{"force global from eu", "eu.anthropic.claude-sonnet-4-6", "global", "global.anthropic.claude-sonnet-4-6"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expect, AdjustBedrockModelRegionPrefix(tt.modelID, tt.region))
})
}
}

View File

@@ -0,0 +1,67 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
)
// BedrockSigner 使用 AWS SigV4 对 Bedrock 请求签名
type BedrockSigner struct {
credentials aws.Credentials
region string
signer *v4.Signer
}
// NewBedrockSigner 创建 BedrockSigner
func NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region string) *BedrockSigner {
return &BedrockSigner{
credentials: aws.Credentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
},
region: region,
signer: v4.NewSigner(),
}
}
// NewBedrockSignerFromAccount 从 Account 凭证创建 BedrockSigner
func NewBedrockSignerFromAccount(account *Account) (*BedrockSigner, error) {
accessKeyID := account.GetCredential("aws_access_key_id")
if accessKeyID == "" {
return nil, fmt.Errorf("aws_access_key_id not found in credentials")
}
secretAccessKey := account.GetCredential("aws_secret_access_key")
if secretAccessKey == "" {
return nil, fmt.Errorf("aws_secret_access_key not found in credentials")
}
region := account.GetCredential("aws_region")
if region == "" {
region = defaultBedrockRegion
}
sessionToken := account.GetCredential("aws_session_token") // 可选
return NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region), nil
}
// SignRequest 对 HTTP 请求进行 SigV4 签名
// 重要约束调用此方法前req 应只包含 AWS 相关的 header如 Content-Type、Accept
// 非 AWS header如 anthropic-beta会参与签名计算如果 Bedrock 服务端不识别这些 header
// 签名验证可能失败。litellm 通过 _filter_headers_for_aws_signature 实现头过滤,
// 当前实现中 buildUpstreamRequestBedrock 仅设置了 Content-Type 和 Accept因此是安全的。
func (s *BedrockSigner) SignRequest(ctx context.Context, req *http.Request, body []byte) error {
payloadHash := sha256Hash(body)
return s.signer.SignHTTP(ctx, s.credentials, req, payloadHash, "bedrock", s.region, time.Now())
}
func sha256Hash(data []byte) string {
h := sha256.Sum256(data)
return hex.EncodeToString(h[:])
}

View File

@@ -0,0 +1,35 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewBedrockSignerFromAccount_DefaultRegion(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_access_key_id": "test-akid",
"aws_secret_access_key": "test-secret",
},
}
signer, err := NewBedrockSignerFromAccount(account)
require.NoError(t, err)
require.NotNil(t, signer)
assert.Equal(t, defaultBedrockRegion, signer.region)
}
func TestFilterBetaTokens(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "tool-search-tool-2025-10-19"}
filterSet := map[string]struct{}{
"tool-search-tool-2025-10-19": {},
}
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, filterBetaTokens(tokens, filterSet))
assert.Equal(t, tokens, filterBetaTokens(tokens, nil))
assert.Nil(t, filterBetaTokens(nil, filterSet))
}

View File

@@ -0,0 +1,414 @@
package service
import (
"bufio"
"context"
"encoding/base64"
"errors"
"fmt"
"hash/crc32"
"io"
"net/http"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// handleBedrockStreamingResponse 处理 Bedrock InvokeModelWithResponseStream 的 EventStream 响应
// Bedrock 返回 AWS EventStream 二进制格式,每个事件的 payload 中 chunk.bytes 是 base64 编码的
// Claude SSE 事件 JSON。本方法解码后转换为标准 SSE 格式写入客户端。
func (s *GatewayService) handleBedrockStreamingResponse(
ctx context.Context,
resp *http.Response,
c *gin.Context,
account *Account,
startTime time.Time,
model string,
) (*streamingResult, error) {
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if v := resp.Header.Get("x-amzn-requestid"); v != "" {
c.Header("x-request-id", v)
}
usage := &ClaudeUsage{}
var firstTokenMs *int
clientDisconnected := false
// Bedrock EventStream 使用 application/vnd.amazon.eventstream 二进制格式。
// 每个帧结构total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4)
// 但更实用的方式是使用行扫描找 JSON chunks因为 Bedrock 的响应在二进制帧中。
// 我们使用 EventStream decoder 来正确解析。
decoder := newBedrockEventStreamDecoder(resp.Body)
type decodeEvent struct {
payload []byte
err error
}
events := make(chan decodeEvent, 16)
done := make(chan struct{})
sendEvent := func(ev decodeEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt atomic.Int64
lastReadAt.Store(time.Now().UnixNano())
go func() {
defer close(events)
for {
payload, err := decoder.Decode()
if err != nil {
if err == io.EOF {
return
}
_ = sendEvent(decodeEvent{err: err})
return
}
lastReadAt.Store(time.Now().UnixNano())
if !sendEvent(decodeEvent{payload: payload}) {
return
}
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
for {
select {
case ev, ok := <-events:
if !ok {
if !clientDisconnected {
flusher.Flush()
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
}
if ev.err != nil {
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("bedrock stream read error: %w", ev.err)
}
// payload 是 JSON提取 chunk.bytesbase64 编码的 Claude SSE 事件数据)
sseData := extractBedrockChunkData(ev.payload)
if sseData == nil {
continue
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
// 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式
// 同时移除该字段避免透传给客户端
sseData = transformBedrockInvocationMetrics(sseData)
// 解析 SSE 事件数据提取 usage
s.parseSSEUsagePassthrough(string(sseData), usage)
// 确定 SSE event type
eventType := gjson.GetBytes(sseData, "type").String()
// 写入标准 SSE 格式
if !clientDisconnected {
var writeErr error
if eventType != "" {
_, writeErr = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, sseData)
} else {
_, writeErr = fmt.Fprintf(w, "data: %s\n\n", sseData)
}
if writeErr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "[Bedrock] Client disconnected during streaming, continue draining for usage: account=%d", account.ID)
} else {
flusher.Flush()
}
}
case <-intervalCh:
lastRead := time.Unix(0, lastReadAt.Load())
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
logger.LegacyPrintf("service.gateway", "[Bedrock] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
}
// extractBedrockChunkData 从 Bedrock EventStream payload 中提取 Claude SSE 事件数据
// Bedrock payload 格式:{"bytes":"<base64-encoded-json>"}
func extractBedrockChunkData(payload []byte) []byte {
b64 := gjson.GetBytes(payload, "bytes").String()
if b64 == "" {
return nil
}
decoded, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return nil
}
return decoded
}
// transformBedrockInvocationMetrics 将 Bedrock 特有的 amazon-bedrock-invocationMetrics
// 转换为标准 Anthropic usage 格式,并从 SSE 数据中移除该字段。
//
// Bedrock Invoke 返回的 message_delta 事件可能包含:
//
// {"type":"message_delta","delta":{...},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}
//
// 转换为:
//
// {"type":"message_delta","delta":{...},"usage":{"input_tokens":150,"output_tokens":42}}
func transformBedrockInvocationMetrics(data []byte) []byte {
metrics := gjson.GetBytes(data, "amazon-bedrock-invocationMetrics")
if !metrics.Exists() || !metrics.IsObject() {
return data
}
// 移除 Bedrock 特有字段
data, _ = sjson.DeleteBytes(data, "amazon-bedrock-invocationMetrics")
// 如果已有标准 usage 字段,不覆盖
if gjson.GetBytes(data, "usage").Exists() {
return data
}
// 转换 camelCase → snake_case 写入 usage
inputTokens := metrics.Get("inputTokenCount")
outputTokens := metrics.Get("outputTokenCount")
if inputTokens.Exists() {
data, _ = sjson.SetBytes(data, "usage.input_tokens", inputTokens.Int())
}
if outputTokens.Exists() {
data, _ = sjson.SetBytes(data, "usage.output_tokens", outputTokens.Int())
}
return data
}
// bedrockEventStreamDecoder 解码 AWS EventStream 二进制帧
// EventStream 帧格式:
//
// [total_byte_length: 4 bytes]
// [headers_byte_length: 4 bytes]
// [prelude_crc: 4 bytes]
// [headers: variable]
// [payload: variable]
// [message_crc: 4 bytes]
type bedrockEventStreamDecoder struct {
reader *bufio.Reader
}
func newBedrockEventStreamDecoder(r io.Reader) *bedrockEventStreamDecoder {
return &bedrockEventStreamDecoder{
reader: bufio.NewReaderSize(r, 64*1024),
}
}
// Decode 读取下一个 EventStream 帧并返回 chunk 类型事件的 payload
func (d *bedrockEventStreamDecoder) Decode() ([]byte, error) {
for {
// 读取 prelude: total_length(4) + headers_length(4) + prelude_crc(4) = 12 bytes
prelude := make([]byte, 12)
if _, err := io.ReadFull(d.reader, prelude); err != nil {
return nil, err
}
// 验证 prelude CRCAWS EventStream 使用标准 CRC32 / IEEE
preludeCRC := bedrockReadUint32(prelude[8:12])
if crc32.Checksum(prelude[0:8], crc32IEEETable) != preludeCRC {
return nil, fmt.Errorf("eventstream prelude CRC mismatch")
}
totalLength := bedrockReadUint32(prelude[0:4])
headersLength := bedrockReadUint32(prelude[4:8])
if totalLength < 16 { // minimum: 12 prelude + 4 message_crc
return nil, fmt.Errorf("invalid eventstream frame: total_length=%d", totalLength)
}
// 读取 headers + payload + message_crc
remaining := int(totalLength) - 12
if remaining <= 0 {
continue
}
data := make([]byte, remaining)
if _, err := io.ReadFull(d.reader, data); err != nil {
return nil, err
}
// 验证 message CRC覆盖 prelude + headers + payload
messageCRC := bedrockReadUint32(data[len(data)-4:])
h := crc32.New(crc32IEEETable)
_, _ = h.Write(prelude)
_, _ = h.Write(data[:len(data)-4])
if h.Sum32() != messageCRC {
return nil, fmt.Errorf("eventstream message CRC mismatch")
}
// 解析 headers
headers := data[:headersLength]
payload := data[headersLength : len(data)-4] // 去掉 message_crc
// 从 headers 中提取 :event-type
eventType := extractEventStreamHeaderValue(headers, ":event-type")
// 只处理 chunk 事件
if eventType == "chunk" {
// payload 是完整的 JSON包含 bytes 字段
return payload, nil
}
// 检查异常事件
exceptionType := extractEventStreamHeaderValue(headers, ":exception-type")
if exceptionType != "" {
return nil, fmt.Errorf("bedrock exception: %s: %s", exceptionType, string(payload))
}
messageType := extractEventStreamHeaderValue(headers, ":message-type")
if messageType == "exception" || messageType == "error" {
return nil, fmt.Errorf("bedrock error: %s", string(payload))
}
// 跳过其他事件类型(如 initial-response
}
}
// extractEventStreamHeaderValue 从 EventStream headers 二进制数据中提取指定 header 的字符串值
// EventStream header 格式:
//
// [name_length: 1 byte][name: variable][value_type: 1 byte][value: variable]
//
// value_type = 7 表示 string 类型,前 2 bytes 为长度
func extractEventStreamHeaderValue(headers []byte, targetName string) string {
pos := 0
for pos < len(headers) {
if pos >= len(headers) {
break
}
nameLen := int(headers[pos])
pos++
if pos+nameLen > len(headers) {
break
}
name := string(headers[pos : pos+nameLen])
pos += nameLen
if pos >= len(headers) {
break
}
valueType := headers[pos]
pos++
switch valueType {
case 7: // string
if pos+2 > len(headers) {
return ""
}
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
pos += 2
if pos+valueLen > len(headers) {
return ""
}
value := string(headers[pos : pos+valueLen])
pos += valueLen
if name == targetName {
return value
}
case 0: // bool true
if name == targetName {
return "true"
}
case 1: // bool false
if name == targetName {
return "false"
}
case 2: // byte
pos++
if name == targetName {
return ""
}
case 3: // short
pos += 2
if name == targetName {
return ""
}
case 4: // int
pos += 4
if name == targetName {
return ""
}
case 5: // long
pos += 8
if name == targetName {
return ""
}
case 6: // bytes
if pos+2 > len(headers) {
return ""
}
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
pos += 2 + valueLen
case 8: // timestamp
pos += 8
case 9: // uuid
pos += 16
default:
return "" // 未知类型,无法继续解析
}
}
return ""
}
// crc32IEEETable is the CRC32 / IEEE table used by AWS EventStream.
var crc32IEEETable = crc32.MakeTable(crc32.IEEE)
func bedrockReadUint32(b []byte) uint32 {
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
}
func bedrockReadUint16(b []byte) uint16 {
return uint16(b[0])<<8 | uint16(b[1])
}

View File

@@ -0,0 +1,261 @@
package service
import (
"bytes"
"encoding/base64"
"encoding/binary"
"hash/crc32"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestExtractBedrockChunkData(t *testing.T) {
t.Run("valid base64 payload", func(t *testing.T) {
original := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`
b64 := base64.StdEncoding.EncodeToString([]byte(original))
payload := []byte(`{"bytes":"` + b64 + `"}`)
result := extractBedrockChunkData(payload)
require.NotNil(t, result)
assert.JSONEq(t, original, string(result))
})
t.Run("empty bytes field", func(t *testing.T) {
result := extractBedrockChunkData([]byte(`{"bytes":""}`))
assert.Nil(t, result)
})
t.Run("no bytes field", func(t *testing.T) {
result := extractBedrockChunkData([]byte(`{"other":"value"}`))
assert.Nil(t, result)
})
t.Run("invalid base64", func(t *testing.T) {
result := extractBedrockChunkData([]byte(`{"bytes":"not-valid-base64!!!"}`))
assert.Nil(t, result)
})
}
func TestTransformBedrockInvocationMetrics(t *testing.T) {
t.Run("converts metrics to usage", func(t *testing.T) {
input := `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
result := transformBedrockInvocationMetrics([]byte(input))
// amazon-bedrock-invocationMetrics should be removed
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
// usage should be set
assert.Equal(t, int64(150), gjson.GetBytes(result, "usage.input_tokens").Int())
assert.Equal(t, int64(42), gjson.GetBytes(result, "usage.output_tokens").Int())
// original fields preserved
assert.Equal(t, "message_delta", gjson.GetBytes(result, "type").String())
assert.Equal(t, "end_turn", gjson.GetBytes(result, "delta.stop_reason").String())
})
t.Run("no metrics present", func(t *testing.T) {
input := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}`
result := transformBedrockInvocationMetrics([]byte(input))
assert.JSONEq(t, input, string(result))
})
t.Run("does not overwrite existing usage", func(t *testing.T) {
input := `{"type":"message_delta","usage":{"output_tokens":100},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
result := transformBedrockInvocationMetrics([]byte(input))
// metrics removed but existing usage preserved
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
assert.Equal(t, int64(100), gjson.GetBytes(result, "usage.output_tokens").Int())
})
}
func TestExtractEventStreamHeaderValue(t *testing.T) {
// Build a header with :event-type = "chunk" (string type = 7)
buildStringHeader := func(name, value string) []byte {
var buf bytes.Buffer
// name length (1 byte)
_ = buf.WriteByte(byte(len(name)))
// name
_, _ = buf.WriteString(name)
// value type (7 = string)
_ = buf.WriteByte(7)
// value length (2 bytes, big-endian)
_ = binary.Write(&buf, binary.BigEndian, uint16(len(value)))
// value
_, _ = buf.WriteString(value)
return buf.Bytes()
}
t.Run("find string header", func(t *testing.T) {
headers := buildStringHeader(":event-type", "chunk")
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
})
t.Run("header not found", func(t *testing.T) {
headers := buildStringHeader(":event-type", "chunk")
assert.Equal(t, "", extractEventStreamHeaderValue(headers, ":message-type"))
})
t.Run("multiple headers", func(t *testing.T) {
var buf bytes.Buffer
_, _ = buf.Write(buildStringHeader(":content-type", "application/json"))
_, _ = buf.Write(buildStringHeader(":event-type", "chunk"))
_, _ = buf.Write(buildStringHeader(":message-type", "event"))
headers := buf.Bytes()
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
assert.Equal(t, "application/json", extractEventStreamHeaderValue(headers, ":content-type"))
assert.Equal(t, "event", extractEventStreamHeaderValue(headers, ":message-type"))
})
t.Run("empty headers", func(t *testing.T) {
assert.Equal(t, "", extractEventStreamHeaderValue([]byte{}, ":event-type"))
})
}
func TestBedrockEventStreamDecoder(t *testing.T) {
crc32IeeeTab := crc32.MakeTable(crc32.IEEE)
// Build a valid EventStream frame with correct CRC32/IEEE checksums.
buildFrame := func(eventType string, payload []byte) []byte {
// Build headers
var headersBuf bytes.Buffer
// :event-type header
_ = headersBuf.WriteByte(byte(len(":event-type")))
_, _ = headersBuf.WriteString(":event-type")
_ = headersBuf.WriteByte(7) // string type
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len(eventType)))
_, _ = headersBuf.WriteString(eventType)
// :message-type header
_ = headersBuf.WriteByte(byte(len(":message-type")))
_, _ = headersBuf.WriteString(":message-type")
_ = headersBuf.WriteByte(7)
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("event")))
_, _ = headersBuf.WriteString("event")
headers := headersBuf.Bytes()
headersLen := uint32(len(headers))
// total = 12 (prelude) + headers + payload + 4 (message_crc)
totalLen := uint32(12 + len(headers) + len(payload) + 4)
// Prelude: total_length(4) + headers_length(4)
var preludeBuf bytes.Buffer
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
preludeBytes := preludeBuf.Bytes()
preludeCRC := crc32.Checksum(preludeBytes, crc32IeeeTab)
// Build frame: prelude + prelude_crc + headers + payload
var frame bytes.Buffer
_, _ = frame.Write(preludeBytes)
_ = binary.Write(&frame, binary.BigEndian, preludeCRC)
_, _ = frame.Write(headers)
_, _ = frame.Write(payload)
// Message CRC covers everything before itself
messageCRC := crc32.Checksum(frame.Bytes(), crc32IeeeTab)
_ = binary.Write(&frame, binary.BigEndian, messageCRC)
return frame.Bytes()
}
t.Run("decode chunk event", func(t *testing.T) {
payload := []byte(`{"bytes":"dGVzdA=="}`) // base64("test")
frame := buildFrame("chunk", payload)
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
result, err := decoder.Decode()
require.NoError(t, err)
assert.Equal(t, payload, result)
})
t.Run("skip non-chunk events", func(t *testing.T) {
// Write initial-response followed by chunk
var buf bytes.Buffer
_, _ = buf.Write(buildFrame("initial-response", []byte(`{}`)))
chunkPayload := []byte(`{"bytes":"aGVsbG8="}`)
_, _ = buf.Write(buildFrame("chunk", chunkPayload))
decoder := newBedrockEventStreamDecoder(&buf)
result, err := decoder.Decode()
require.NoError(t, err)
assert.Equal(t, chunkPayload, result)
})
t.Run("EOF on empty input", func(t *testing.T) {
decoder := newBedrockEventStreamDecoder(bytes.NewReader(nil))
_, err := decoder.Decode()
assert.Equal(t, io.EOF, err)
})
t.Run("corrupted prelude CRC", func(t *testing.T) {
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
// Corrupt the prelude CRC (bytes 8-11)
frame[8] ^= 0xFF
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
_, err := decoder.Decode()
require.Error(t, err)
assert.Contains(t, err.Error(), "prelude CRC mismatch")
})
t.Run("corrupted message CRC", func(t *testing.T) {
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
// Corrupt the message CRC (last 4 bytes)
frame[len(frame)-1] ^= 0xFF
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
_, err := decoder.Decode()
require.Error(t, err)
assert.Contains(t, err.Error(), "message CRC mismatch")
})
t.Run("castagnoli encoded frame is rejected", func(t *testing.T) {
castagnoliTab := crc32.MakeTable(crc32.Castagnoli)
payload := []byte(`{"bytes":"dGVzdA=="}`)
var headersBuf bytes.Buffer
_ = headersBuf.WriteByte(byte(len(":event-type")))
_, _ = headersBuf.WriteString(":event-type")
_ = headersBuf.WriteByte(7)
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("chunk")))
_, _ = headersBuf.WriteString("chunk")
headers := headersBuf.Bytes()
headersLen := uint32(len(headers))
totalLen := uint32(12 + len(headers) + len(payload) + 4)
var preludeBuf bytes.Buffer
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
preludeBytes := preludeBuf.Bytes()
var frame bytes.Buffer
_, _ = frame.Write(preludeBytes)
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(preludeBytes, castagnoliTab))
_, _ = frame.Write(headers)
_, _ = frame.Write(payload)
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(frame.Bytes(), castagnoliTab))
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame.Bytes()))
_, err := decoder.Decode()
require.Error(t, err)
assert.Contains(t, err.Error(), "prelude CRC mismatch")
})
}
func TestBuildBedrockURL(t *testing.T) {
t.Run("stream URL with colon in model ID", func(t *testing.T) {
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-opus-4-5-20251101-v1:0", true)
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/invoke-with-response-stream", url)
})
t.Run("non-stream URL with colon in model ID", func(t *testing.T) {
url := BuildBedrockURL("eu-west-1", "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", false)
assert.Equal(t, "https://bedrock-runtime.eu-west-1.amazonaws.com/model/eu.anthropic.claude-sonnet-4-5-20250929-v1%3A0/invoke", url)
})
t.Run("model ID without colon", func(t *testing.T) {
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-sonnet-4-6", true)
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-sonnet-4-6/invoke-with-response-stream", url)
})
}

View File

@@ -1,450 +0,0 @@
package service
import (
"encoding/json"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/tidwall/gjson"
)
type claudeMaxCacheBillingOutcome struct {
Simulated bool
}
func applyClaudeMaxCacheBillingPolicyToUsage(usage *ClaudeUsage, parsed *ParsedRequest, group *Group, model string, accountID int64) claudeMaxCacheBillingOutcome {
var out claudeMaxCacheBillingOutcome
if usage == nil || !shouldApplyClaudeMaxBillingRulesForUsage(group, model, parsed) {
return out
}
resolvedModel := strings.TrimSpace(model)
if resolvedModel == "" && parsed != nil {
resolvedModel = strings.TrimSpace(parsed.Model)
}
if hasCacheCreationTokens(*usage) {
// Upstream already returned cache creation usage; keep original usage.
return out
}
if !shouldSimulateClaudeMaxUsageForUsage(*usage, parsed) {
return out
}
beforeInputTokens := usage.InputTokens
out.Simulated = safelyProjectUsageToClaudeMax1H(usage, parsed)
if out.Simulated {
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d",
resolvedModel,
accountID,
beforeInputTokens,
usage.InputTokens,
usage.CacheCreation1hTokens,
)
}
return out
}
func isClaudeFamilyModel(model string) bool {
normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model)))
if normalized == "" {
return false
}
return strings.Contains(normalized, "claude-")
}
func shouldApplyClaudeMaxBillingRules(input *RecordUsageInput) bool {
if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil {
return false
}
return shouldApplyClaudeMaxBillingRulesForUsage(input.APIKey.Group, input.Result.Model, input.ParsedRequest)
}
func shouldApplyClaudeMaxBillingRulesForUsage(group *Group, model string, parsed *ParsedRequest) bool {
if group == nil {
return false
}
if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic {
return false
}
resolvedModel := model
if resolvedModel == "" && parsed != nil {
resolvedModel = parsed.Model
}
if !isClaudeFamilyModel(resolvedModel) {
return false
}
return true
}
func hasCacheCreationTokens(usage ClaudeUsage) bool {
return usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0
}
func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool {
if input == nil || input.Result == nil {
return false
}
if !shouldApplyClaudeMaxBillingRules(input) {
return false
}
return shouldSimulateClaudeMaxUsageForUsage(input.Result.Usage, input.ParsedRequest)
}
func shouldSimulateClaudeMaxUsageForUsage(usage ClaudeUsage, parsed *ParsedRequest) bool {
if usage.InputTokens <= 0 {
return false
}
if hasCacheCreationTokens(usage) {
return false
}
if !hasClaudeCacheSignals(parsed) {
return false
}
return true
}
func safelyProjectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) (changed bool) {
defer func() {
if r := recover(); r != nil {
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage skipped: panic=%v", r)
changed = false
}
}()
return projectUsageToClaudeMax1H(usage, parsed)
}
func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool {
if usage == nil {
return false
}
totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
if totalWindowTokens <= 1 {
return false
}
simulatedInputTokens := computeClaudeMaxProjectedInputTokens(totalWindowTokens, parsed)
if simulatedInputTokens <= 0 {
simulatedInputTokens = 1
}
if simulatedInputTokens >= totalWindowTokens {
simulatedInputTokens = totalWindowTokens - 1
}
cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens
if usage.InputTokens == simulatedInputTokens &&
usage.CacheCreation5mTokens == 0 &&
usage.CacheCreation1hTokens == cacheCreation1hTokens &&
usage.CacheCreationInputTokens == cacheCreation1hTokens {
return false
}
usage.InputTokens = simulatedInputTokens
usage.CacheCreation5mTokens = 0
usage.CacheCreation1hTokens = cacheCreation1hTokens
usage.CacheCreationInputTokens = cacheCreation1hTokens
return true
}
type claudeCacheProjection struct {
HasBreakpoint bool
BreakpointCount int
TotalEstimatedTokens int
TailEstimatedTokens int
}
func computeClaudeMaxProjectedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int {
if totalWindowTokens <= 1 {
return totalWindowTokens
}
projection := analyzeClaudeCacheProjection(parsed)
if !projection.HasBreakpoint || projection.TotalEstimatedTokens <= 0 || projection.TailEstimatedTokens <= 0 {
return totalWindowTokens
}
totalEstimate := int64(projection.TotalEstimatedTokens)
tailEstimate := int64(projection.TailEstimatedTokens)
if tailEstimate > totalEstimate {
tailEstimate = totalEstimate
}
scaled := (int64(totalWindowTokens)*tailEstimate + totalEstimate/2) / totalEstimate
if scaled <= 0 {
scaled = 1
}
if scaled >= int64(totalWindowTokens) {
scaled = int64(totalWindowTokens - 1)
}
return int(scaled)
}
func hasClaudeCacheSignals(parsed *ParsedRequest) bool {
if parsed == nil {
return false
}
if hasTopLevelEphemeralCacheControl(parsed) {
return true
}
return countExplicitCacheBreakpoints(parsed) > 0
}
func hasTopLevelEphemeralCacheControl(parsed *ParsedRequest) bool {
if parsed == nil || len(parsed.Body) == 0 {
return false
}
cacheType := strings.TrimSpace(gjson.GetBytes(parsed.Body, "cache_control.type").String())
return strings.EqualFold(cacheType, "ephemeral")
}
func analyzeClaudeCacheProjection(parsed *ParsedRequest) claudeCacheProjection {
var projection claudeCacheProjection
if parsed == nil {
return projection
}
total := 0
lastBreakpointAt := -1
switch system := parsed.System.(type) {
case string:
total += claudeMaxMessageOverheadTokens + estimateClaudeTextTokens(system)
case []any:
for _, raw := range system {
block, ok := raw.(map[string]any)
if !ok {
total += claudeMaxUnknownContentTokens
continue
}
total += estimateClaudeBlockTokens(block)
if hasEphemeralCacheControl(block) {
lastBreakpointAt = total
projection.BreakpointCount++
projection.HasBreakpoint = true
}
}
}
for _, rawMsg := range parsed.Messages {
total += claudeMaxMessageOverheadTokens
msg, ok := rawMsg.(map[string]any)
if !ok {
total += claudeMaxUnknownContentTokens
continue
}
content, exists := msg["content"]
if !exists {
continue
}
msgTokens, msgLastBreak, msgBreakCount := estimateClaudeContentTokens(content)
total += msgTokens
if msgBreakCount > 0 {
lastBreakpointAt = total - msgTokens + msgLastBreak
projection.BreakpointCount += msgBreakCount
projection.HasBreakpoint = true
}
}
if total <= 0 {
total = 1
}
projection.TotalEstimatedTokens = total
if projection.HasBreakpoint && lastBreakpointAt >= 0 {
tail := total - lastBreakpointAt
if tail <= 0 {
tail = 1
}
projection.TailEstimatedTokens = tail
return projection
}
if hasTopLevelEphemeralCacheControl(parsed) {
tail := estimateLastUserMessageTokens(parsed)
if tail <= 0 {
tail = 1
}
projection.HasBreakpoint = true
projection.BreakpointCount = 1
projection.TailEstimatedTokens = tail
}
return projection
}
func countExplicitCacheBreakpoints(parsed *ParsedRequest) int {
if parsed == nil {
return 0
}
total := 0
if system, ok := parsed.System.([]any); ok {
for _, raw := range system {
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
total++
}
}
}
for _, rawMsg := range parsed.Messages {
msg, ok := rawMsg.(map[string]any)
if !ok {
continue
}
content, ok := msg["content"].([]any)
if !ok {
continue
}
for _, raw := range content {
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
total++
}
}
}
return total
}
func hasEphemeralCacheControl(block map[string]any) bool {
if block == nil {
return false
}
raw, ok := block["cache_control"]
if !ok || raw == nil {
return false
}
switch cc := raw.(type) {
case map[string]any:
cacheType, _ := cc["type"].(string)
return strings.EqualFold(strings.TrimSpace(cacheType), "ephemeral")
case map[string]string:
return strings.EqualFold(strings.TrimSpace(cc["type"]), "ephemeral")
default:
return false
}
}
func estimateClaudeContentTokens(content any) (tokens int, lastBreakAt int, breakpointCount int) {
switch value := content.(type) {
case string:
return estimateClaudeTextTokens(value), -1, 0
case []any:
total := 0
lastBreak := -1
breaks := 0
for _, raw := range value {
block, ok := raw.(map[string]any)
if !ok {
total += claudeMaxUnknownContentTokens
continue
}
total += estimateClaudeBlockTokens(block)
if hasEphemeralCacheControl(block) {
lastBreak = total
breaks++
}
}
return total, lastBreak, breaks
default:
return estimateStructuredTokens(value), -1, 0
}
}
func estimateClaudeBlockTokens(block map[string]any) int {
if block == nil {
return claudeMaxUnknownContentTokens
}
tokens := claudeMaxBlockOverheadTokens
blockType, _ := block["type"].(string)
switch blockType {
case "text":
if text, ok := block["text"].(string); ok {
tokens += estimateClaudeTextTokens(text)
}
case "tool_result":
if content, ok := block["content"]; ok {
nested, _, _ := estimateClaudeContentTokens(content)
tokens += nested
}
case "tool_use":
if name, ok := block["name"].(string); ok {
tokens += estimateClaudeTextTokens(name)
}
if input, ok := block["input"]; ok {
tokens += estimateStructuredTokens(input)
}
default:
if text, ok := block["text"].(string); ok {
tokens += estimateClaudeTextTokens(text)
} else if content, ok := block["content"]; ok {
nested, _, _ := estimateClaudeContentTokens(content)
tokens += nested
}
}
if tokens <= claudeMaxBlockOverheadTokens {
tokens += claudeMaxUnknownContentTokens
}
return tokens
}
func estimateLastUserMessageTokens(parsed *ParsedRequest) int {
if parsed == nil || len(parsed.Messages) == 0 {
return 0
}
for i := len(parsed.Messages) - 1; i >= 0; i-- {
msg, ok := parsed.Messages[i].(map[string]any)
if !ok {
continue
}
role, _ := msg["role"].(string)
if !strings.EqualFold(strings.TrimSpace(role), "user") {
continue
}
tokens, _, _ := estimateClaudeContentTokens(msg["content"])
return claudeMaxMessageOverheadTokens + tokens
}
return 0
}
func estimateStructuredTokens(v any) int {
if v == nil {
return 0
}
raw, err := json.Marshal(v)
if err != nil {
return claudeMaxUnknownContentTokens
}
return estimateClaudeTextTokens(string(raw))
}
func estimateClaudeTextTokens(text string) int {
if tokens, ok := estimateTokensByThirdPartyTokenizer(text); ok {
return tokens
}
return estimateClaudeTextTokensHeuristic(text)
}
func estimateClaudeTextTokensHeuristic(text string) int {
normalized := strings.Join(strings.Fields(strings.TrimSpace(text)), " ")
if normalized == "" {
return 0
}
asciiChars := 0
nonASCIIChars := 0
for _, r := range normalized {
if r <= 127 {
asciiChars++
} else {
nonASCIIChars++
}
}
tokens := nonASCIIChars
if asciiChars > 0 {
tokens += (asciiChars + 3) / 4
}
if words := len(strings.Fields(normalized)); words > tokens {
tokens = words
}
if tokens <= 0 {
return 1
}
return tokens
}

View File

@@ -1,156 +0,0 @@
package service
import (
"strings"
"testing"
)
func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) {
usage := &ClaudeUsage{
InputTokens: 1200,
CacheCreationInputTokens: 0,
CacheCreation5mTokens: 0,
CacheCreation1hTokens: 0,
}
parsed := &ParsedRequest{
Model: "claude-sonnet-4-5",
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": strings.Repeat("cached context ", 200),
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "summarize quickly",
},
},
},
},
}
changed := projectUsageToClaudeMax1H(usage, parsed)
if !changed {
t.Fatalf("expected usage to be projected")
}
total := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
if total != 1200 {
t.Fatalf("total tokens changed: got=%d want=%d", total, 1200)
}
if usage.CacheCreation5mTokens != 0 {
t.Fatalf("cache_creation_5m should be 0, got=%d", usage.CacheCreation5mTokens)
}
if usage.InputTokens <= 0 || usage.InputTokens >= 1200 {
t.Fatalf("simulated input out of range, got=%d", usage.InputTokens)
}
if usage.InputTokens > 100 {
t.Fatalf("simulated input should stay near cache breakpoint tail, got=%d", usage.InputTokens)
}
if usage.CacheCreation1hTokens <= 0 {
t.Fatalf("cache_creation_1h should be > 0, got=%d", usage.CacheCreation1hTokens)
}
if usage.CacheCreationInputTokens != usage.CacheCreation1hTokens {
t.Fatalf("cache_creation_input_tokens mismatch: got=%d want=%d", usage.CacheCreationInputTokens, usage.CacheCreation1hTokens)
}
}
func TestComputeClaudeMaxProjectedInputTokens_Deterministic(t *testing.T) {
parsed := &ParsedRequest{
Model: "claude-opus-4-5",
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "build context",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "what is failing now",
},
},
},
},
}
got1 := computeClaudeMaxProjectedInputTokens(4096, parsed)
got2 := computeClaudeMaxProjectedInputTokens(4096, parsed)
if got1 != got2 {
t.Fatalf("non-deterministic input tokens: %d != %d", got1, got2)
}
}
func TestShouldSimulateClaudeMaxUsage(t *testing.T) {
group := &Group{
Platform: PlatformAnthropic,
SimulateClaudeMaxEnabled: true,
}
input := &RecordUsageInput{
Result: &ForwardResult{
Model: "claude-sonnet-4-5",
Usage: ClaudeUsage{
InputTokens: 3000,
CacheCreationInputTokens: 0,
CacheCreation5mTokens: 0,
CacheCreation1hTokens: 0,
},
},
ParsedRequest: &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "cached",
"cache_control": map[string]any{"type": "ephemeral"},
},
map[string]any{
"type": "text",
"text": "tail",
},
},
},
},
},
APIKey: &APIKey{Group: group},
}
if !shouldSimulateClaudeMaxUsage(input) {
t.Fatalf("expected simulate=true for claude group with cache signal")
}
input.ParsedRequest = &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "no cache signal"},
},
}
if shouldSimulateClaudeMaxUsage(input) {
t.Fatalf("expected simulate=false when request has no cache signal")
}
input.ParsedRequest = &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "text",
"text": "cached",
"cache_control": map[string]any{"type": "ephemeral"},
},
},
},
},
}
input.Result.Usage.CacheCreationInputTokens = 100
if shouldSimulateClaudeMaxUsage(input) {
t.Fatalf("expected simulate=false when cache creation already exists")
}
}

View File

@@ -1,41 +0,0 @@
package service
import (
"sync"
tiktoken "github.com/pkoukk/tiktoken-go"
tiktokenloader "github.com/pkoukk/tiktoken-go-loader"
)
var (
claudeTokenizerOnce sync.Once
claudeTokenizer *tiktoken.Tiktoken
)
func getClaudeTokenizer() *tiktoken.Tiktoken {
claudeTokenizerOnce.Do(func() {
// Use offline loader to avoid runtime dictionary download.
tiktoken.SetBpeLoader(tiktokenloader.NewOfflineLoader())
// Use a high-capacity tokenizer as the default approximation for Claude payloads.
enc, err := tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
if err != nil {
enc, err = tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
}
if err == nil {
claudeTokenizer = enc
}
})
return claudeTokenizer
}
func estimateTokensByThirdPartyTokenizer(text string) (int, bool) {
enc := getClaudeTokenizer()
if enc == nil {
return 0, false
}
tokens := len(enc.EncodeOrdinary(text))
if tokens <= 0 {
return 0, false
}
return tokens, true
}

View File

@@ -343,9 +343,8 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts.
// Uses a detached context with timeout to prevent HTTP request cancellation from
// causing the entire batch to fail (which would show all concurrency as 0).
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return map[int64]int{}, nil
@@ -357,11 +356,5 @@ func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, acc
}
return result, nil
}
// Use a detached context so that a cancelled HTTP request doesn't cause
// the Redis pipeline to fail and return all-zero concurrency counts.
redisCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
return s.cache.GetAccountConcurrencyBatch(redisCtx, accountIDs)
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
}

View File

@@ -35,6 +35,7 @@ type DashboardAggregationRepository interface {
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
}
@@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays)
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
if aggErr != nil {
@@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
if usageErr != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
}
if aggErr == nil && usageErr == nil {
dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff)
if dedupErr != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr)
}
if aggErr == nil && usageErr == nil && dedupErr == nil {
s.lastRetentionCleanup.Store(now)
}
}

View File

@@ -12,12 +12,18 @@ import (
type dashboardAggregationRepoTestStub struct {
aggregateCalls int
recomputeCalls int
cleanupUsageCalls int
cleanupDedupCalls int
ensurePartitionCalls int
lastStart time.Time
lastEnd time.Time
watermark time.Time
aggregateErr error
cleanupAggregatesErr error
cleanupUsageErr error
cleanupDedupErr error
ensurePartitionErr error
}
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
@@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
}
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
s.recomputeCalls++
return s.AggregateRange(ctx, start, end)
}
@@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context
}
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
s.cleanupUsageCalls++
return s.cleanupUsageErr
}
func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
s.cleanupDedupCalls++
return s.cleanupDedupErr
}
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
s.ensurePartitionCalls++
return s.ensurePartitionErr
}
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
@@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
require.Nil(t, svc.lastRetentionCleanup.Load())
require.Equal(t, 1, repo.cleanupUsageCalls)
require.Equal(t, 1, repo.cleanupDedupCalls)
}
func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
require.Nil(t, svc.lastRetentionCleanup.Load())
require.Equal(t, 1, repo.cleanupDedupCalls)
}
func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Enabled: true,
IntervalSeconds: 60,
LookbackSeconds: 120,
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
UsageBillingDedupDays: 2,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.runScheduledAggregation()
require.Equal(t, 1, repo.ensurePartitionCalls)
require.Equal(t, 1, repo.aggregateCalls)
}
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {

View File

@@ -327,6 +327,14 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
return trend, nil
}
func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
ranking, err := s.usageRepo.GetUserSpendingRanking(ctx, startTime, endTime, limit)
if err != nil {
return nil, fmt.Errorf("get user spending ranking: %w", err)
}
return ranking, nil
}
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
if err != nil {

View File

@@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut
return nil
}
func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
}

View File

@@ -29,10 +29,12 @@ const (
// Account type constants
const (
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号inference only scope
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号inference only scope
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock
AccountTypeBedrockAPIKey = domain.AccountTypeBedrockAPIKey // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock
)
// Redeem type constants

View File

@@ -220,7 +220,7 @@ func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
v, exists := c.Get(OpsSkipPassthroughKey)
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
boolVal, ok := v.(bool)
assert.True(t, ok, "value should be a bool")
assert.True(t, ok, "value should be bool")
assert.True(t, boolVal)
}

Some files were not shown because too many files have changed in this diff Show More