mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-06 00:10:21 +08:00
Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f2a97073e | ||
|
|
7407e3b45d | ||
|
|
e6d59216d4 | ||
|
|
91e4d95660 | ||
|
|
4588258d80 | ||
|
|
c12e48f966 | ||
|
|
ec8f50a658 | ||
|
|
99c9191784 | ||
|
|
6bb02d141f | ||
|
|
07bb2a5f3f | ||
|
|
417861a48e | ||
|
|
b7e878de64 | ||
|
|
05edb5514b | ||
|
|
e90ec847b6 | ||
|
|
7e288acc90 | ||
|
|
27ff222cfb | ||
|
|
11f7b83522 | ||
|
|
1ee984478f | ||
|
|
fd693dc526 | ||
|
|
e73531ce9b | ||
|
|
ecea13757b | ||
|
|
80d8d6c3bc | ||
|
|
d648811233 | ||
|
|
34695acb85 | ||
|
|
a63de12182 | ||
|
|
f16910d616 | ||
|
|
64b3f3cec1 | ||
|
|
6a685727d0 | ||
|
|
32d25f76fc | ||
|
|
69cafe8674 | ||
|
|
18ba8d9166 | ||
|
|
e97fd7e81c | ||
|
|
cdb64b0d33 | ||
|
|
8d4d3b03bb | ||
|
|
addefe79e1 | ||
|
|
b764d3b8f6 | ||
|
|
611fd884bd | ||
|
|
c9debc50b1 |
2
.github/workflows/backend-ci.yml
vendored
2
.github/workflows/backend-ci.yml
vendored
@@ -17,7 +17,6 @@ jobs:
|
|||||||
go-version-file: backend/go.mod
|
go-version-file: backend/go.mod
|
||||||
check-latest: false
|
check-latest: false
|
||||||
cache: true
|
cache: true
|
||||||
cache-dependency-path: backend/go.sum
|
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.26.1'
|
go version | grep -q 'go1.26.1'
|
||||||
@@ -37,7 +36,6 @@ jobs:
|
|||||||
go-version-file: backend/go.mod
|
go-version-file: backend/go.mod
|
||||||
check-latest: false
|
check-latest: false
|
||||||
cache: true
|
cache: true
|
||||||
cache-dependency-path: backend/go.sum
|
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.26.1'
|
go version | grep -q 'go1.26.1'
|
||||||
|
|||||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -78,7 +78,6 @@ Desktop.ini
|
|||||||
# ===================
|
# ===================
|
||||||
tmp/
|
tmp/
|
||||||
temp/
|
temp/
|
||||||
logs/
|
|
||||||
*.tmp
|
*.tmp
|
||||||
*.temp
|
*.temp
|
||||||
*.log
|
*.log
|
||||||
@@ -129,15 +128,8 @@ deploy/docker-compose.override.yml
|
|||||||
vite.config.js
|
vite.config.js
|
||||||
docs/*
|
docs/*
|
||||||
.serena/
|
.serena/
|
||||||
|
|
||||||
# ===================
|
|
||||||
# 压测工具
|
|
||||||
# ===================
|
|
||||||
tools/loadtest/
|
|
||||||
# Antigravity Manager
|
|
||||||
Antigravity-Manager/
|
|
||||||
antigravity_projectid_fix.patch
|
|
||||||
.codex/
|
.codex/
|
||||||
frontend/coverage/
|
frontend/coverage/
|
||||||
aicodex
|
aicodex
|
||||||
output/
|
output/
|
||||||
|
|
||||||
|
|||||||
10
README.md
10
README.md
@@ -39,6 +39,16 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
|||||||
- **Concurrency Control** - Per-user and per-account concurrency limits
|
- **Concurrency Control** - Per-user and per-account concurrency limits
|
||||||
- **Rate Limiting** - Configurable request and token rate limits
|
- **Rate Limiting** - Configurable request and token rate limits
|
||||||
- **Admin Dashboard** - Web interface for monitoring and management
|
- **Admin Dashboard** - Web interface for monitoring and management
|
||||||
|
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
|
||||||
|
|
||||||
|
## Ecosystem
|
||||||
|
|
||||||
|
Community projects that extend or integrate with Sub2API:
|
||||||
|
|
||||||
|
| Project | Description | Features |
|
||||||
|
|---------|-------------|----------|
|
||||||
|
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | Self-service payment system | Self-service top-up and subscription purchase; supports YiPay protocol, WeChat Pay, Alipay, Stripe; embeddable via iframe |
|
||||||
|
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native |
|
||||||
|
|
||||||
## Tech Stack
|
## Tech Stack
|
||||||
|
|
||||||
|
|||||||
10
README_CN.md
10
README_CN.md
@@ -39,6 +39,16 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
|||||||
- **并发控制** - 用户级和账号级并发限制
|
- **并发控制** - 用户级和账号级并发限制
|
||||||
- **速率限制** - 可配置的请求和 Token 速率限制
|
- **速率限制** - 可配置的请求和 Token 速率限制
|
||||||
- **管理后台** - Web 界面进行监控和管理
|
- **管理后台** - Web 界面进行监控和管理
|
||||||
|
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
|
||||||
|
|
||||||
|
## 生态项目
|
||||||
|
|
||||||
|
围绕 Sub2API 的社区扩展与集成项目:
|
||||||
|
|
||||||
|
| 项目 | 说明 | 功能 |
|
||||||
|
|------|------|------|
|
||||||
|
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | 自助支付系统 | 用户自助充值、自助订阅购买;兼容易支付协议、微信官方支付、支付宝官方支付、Stripe;支持 iframe 嵌入管理后台 |
|
||||||
|
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用(iOS/Android/Web),支持用户管理、账号管理、监控看板、多后端切换;基于 Expo + React Native 构建 |
|
||||||
|
|
||||||
## 技术栈
|
## 技术栈
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
0.1.96.1
|
0.1.88
|
||||||
@@ -41,6 +41,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
// Server layer ProviderSet
|
// Server layer ProviderSet
|
||||||
server.ProviderSet,
|
server.ProviderSet,
|
||||||
|
|
||||||
|
// Privacy client factory for OpenAI training opt-out
|
||||||
|
providePrivacyClientFactory,
|
||||||
|
|
||||||
// BuildInfo provider
|
// BuildInfo provider
|
||||||
provideServiceBuildInfo,
|
provideServiceBuildInfo,
|
||||||
|
|
||||||
@@ -53,6 +56,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||||
|
return repository.CreatePrivacyReqClient
|
||||||
|
}
|
||||||
|
|
||||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||||
return service.BuildInfo{
|
return service.BuildInfo{
|
||||||
Version: buildInfo.Version,
|
Version: buildInfo.Version,
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
userHandler := handler.NewUserHandler(userService)
|
userHandler := handler.NewUserHandler(userService)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||||
|
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||||
@@ -104,7 +105,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
proxyRepository := repository.NewProxyRepository(client, db)
|
proxyRepository := repository.NewProxyRepository(client, db)
|
||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
|
privacyClientFactory := providePrivacyClientFactory()
|
||||||
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||||
@@ -162,9 +164,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||||
digestSessionStore := service.NewDigestSessionStore()
|
digestSessionStore := service.NewDigestSessionStore()
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||||
@@ -226,7 +228,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||||
@@ -245,6 +247,10 @@ type Application struct {
|
|||||||
Cleanup func()
|
Cleanup func()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||||
|
return repository.CreatePrivacyReqClient
|
||||||
|
}
|
||||||
|
|
||||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||||
return service.BuildInfo{
|
return service.BuildInfo{
|
||||||
Version: buildInfo.Version,
|
Version: buildInfo.Version,
|
||||||
|
|||||||
@@ -62,28 +62,26 @@ type Group struct {
|
|||||||
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
||||||
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
|
// SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
|
||||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
|
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
|
||||||
// allow Claude Code client only
|
// 是否仅允许 Claude Code 客户端
|
||||||
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||||
// fallback group for non-Claude-Code requests
|
// 非 Claude Code 请求降级使用的分组 ID
|
||||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||||
// fallback group for invalid request
|
// 无效请求兜底使用的分组 ID
|
||||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||||
// model routing config: pattern -> account ids
|
// 模型路由配置:模型模式 -> 优先账号ID列表
|
||||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||||
// whether model routing is enabled
|
// 是否启用模型路由配置
|
||||||
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
|
ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
|
||||||
// whether MCP XML prompt injection is enabled
|
// 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)
|
||||||
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
|
McpXMLInject bool `json:"mcp_xml_inject,omitempty"`
|
||||||
// supported model scopes: claude, gemini_text, gemini_image
|
// 支持的模型系列:claude, gemini_text, gemini_image
|
||||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||||
// group display order, lower comes first
|
// 分组显示排序,数值越小越靠前
|
||||||
SortOrder int `json:"sort_order,omitempty"`
|
SortOrder int `json:"sort_order,omitempty"`
|
||||||
// 是否允许 /v1/messages 调度到此 OpenAI 分组
|
// 是否允许 /v1/messages 调度到此 OpenAI 分组
|
||||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
|
AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"`
|
||||||
// 默认映射模型 ID,当账号级映射找不到时使用此值
|
// 默认映射模型 ID,当账号级映射找不到时使用此值
|
||||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||||
// simulate claude usage as claude-max style (1h cache write)
|
|
||||||
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled,omitempty"`
|
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||||
Edges GroupEdges `json:"edges"`
|
Edges GroupEdges `json:"edges"`
|
||||||
@@ -192,7 +190,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
|||||||
switch columns[i] {
|
switch columns[i] {
|
||||||
case group.FieldModelRouting, group.FieldSupportedModelScopes:
|
case group.FieldModelRouting, group.FieldSupportedModelScopes:
|
||||||
values[i] = new([]byte)
|
values[i] = new([]byte)
|
||||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldSimulateClaudeMaxEnabled:
|
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch:
|
||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
@@ -433,12 +431,6 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.DefaultMappedModel = value.String
|
_m.DefaultMappedModel = value.String
|
||||||
}
|
}
|
||||||
case group.FieldSimulateClaudeMaxEnabled:
|
|
||||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field simulate_claude_max_enabled", values[i])
|
|
||||||
} else if value.Valid {
|
|
||||||
_m.SimulateClaudeMaxEnabled = value.Bool
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
_m.selectValues.Set(columns[i], values[i])
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
}
|
}
|
||||||
@@ -638,9 +630,6 @@ func (_m *Group) String() string {
|
|||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
builder.WriteString("default_mapped_model=")
|
builder.WriteString("default_mapped_model=")
|
||||||
builder.WriteString(_m.DefaultMappedModel)
|
builder.WriteString(_m.DefaultMappedModel)
|
||||||
builder.WriteString(", ")
|
|
||||||
builder.WriteString("simulate_claude_max_enabled=")
|
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.SimulateClaudeMaxEnabled))
|
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,8 +79,6 @@ const (
|
|||||||
FieldAllowMessagesDispatch = "allow_messages_dispatch"
|
FieldAllowMessagesDispatch = "allow_messages_dispatch"
|
||||||
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
|
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
|
||||||
FieldDefaultMappedModel = "default_mapped_model"
|
FieldDefaultMappedModel = "default_mapped_model"
|
||||||
// FieldSimulateClaudeMaxEnabled holds the string denoting the simulate_claude_max_enabled field in the database.
|
|
||||||
FieldSimulateClaudeMaxEnabled = "simulate_claude_max_enabled"
|
|
||||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||||
EdgeAPIKeys = "api_keys"
|
EdgeAPIKeys = "api_keys"
|
||||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||||
@@ -188,7 +186,6 @@ var Columns = []string{
|
|||||||
FieldSortOrder,
|
FieldSortOrder,
|
||||||
FieldAllowMessagesDispatch,
|
FieldAllowMessagesDispatch,
|
||||||
FieldDefaultMappedModel,
|
FieldDefaultMappedModel,
|
||||||
FieldSimulateClaudeMaxEnabled,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -262,8 +259,6 @@ var (
|
|||||||
DefaultDefaultMappedModel string
|
DefaultDefaultMappedModel string
|
||||||
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||||
DefaultMappedModelValidator func(string) error
|
DefaultMappedModelValidator func(string) error
|
||||||
// DefaultSimulateClaudeMaxEnabled holds the default value on creation for the "simulate_claude_max_enabled" field.
|
|
||||||
DefaultSimulateClaudeMaxEnabled bool
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// OrderOption defines the ordering options for the Group queries.
|
// OrderOption defines the ordering options for the Group queries.
|
||||||
@@ -424,11 +419,6 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
// BySimulateClaudeMaxEnabled orders the results by the simulate_claude_max_enabled field.
|
|
||||||
func BySimulateClaudeMaxEnabled(opts ...sql.OrderTermOption) OrderOption {
|
|
||||||
return sql.OrderByField(FieldSimulateClaudeMaxEnabled, opts...).ToFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ByAPIKeysCount orders the results by api_keys count.
|
// ByAPIKeysCount orders the results by api_keys count.
|
||||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return func(s *sql.Selector) {
|
return func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -205,11 +205,6 @@ func DefaultMappedModel(v string) predicate.Group {
|
|||||||
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SimulateClaudeMaxEnabled applies equality check predicate on the "simulate_claude_max_enabled" field. It's identical to SimulateClaudeMaxEnabledEQ.
|
|
||||||
func SimulateClaudeMaxEnabled(v bool) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
@@ -1560,16 +1555,6 @@ func DefaultMappedModelContainsFold(v string) predicate.Group {
|
|||||||
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
|
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SimulateClaudeMaxEnabledEQ applies the EQ predicate on the "simulate_claude_max_enabled" field.
|
|
||||||
func SimulateClaudeMaxEnabledEQ(v bool) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldEQ(FieldSimulateClaudeMaxEnabled, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SimulateClaudeMaxEnabledNEQ applies the NEQ predicate on the "simulate_claude_max_enabled" field.
|
|
||||||
func SimulateClaudeMaxEnabledNEQ(v bool) predicate.Group {
|
|
||||||
return predicate.Group(sql.FieldNEQ(FieldSimulateClaudeMaxEnabled, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||||
func HasAPIKeys() predicate.Group {
|
func HasAPIKeys() predicate.Group {
|
||||||
return predicate.Group(func(s *sql.Selector) {
|
return predicate.Group(func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -452,20 +452,6 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
|
||||||
func (_c *GroupCreate) SetSimulateClaudeMaxEnabled(v bool) *GroupCreate {
|
|
||||||
_c.mutation.SetSimulateClaudeMaxEnabled(v)
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
|
|
||||||
func (_c *GroupCreate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupCreate {
|
|
||||||
if v != nil {
|
|
||||||
_c.SetSimulateClaudeMaxEnabled(*v)
|
|
||||||
}
|
|
||||||
return _c
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||||
_c.mutation.AddAPIKeyIDs(ids...)
|
_c.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -663,10 +649,6 @@ func (_c *GroupCreate) defaults() error {
|
|||||||
v := group.DefaultDefaultMappedModel
|
v := group.DefaultDefaultMappedModel
|
||||||
_c.mutation.SetDefaultMappedModel(v)
|
_c.mutation.SetDefaultMappedModel(v)
|
||||||
}
|
}
|
||||||
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
|
|
||||||
v := group.DefaultSimulateClaudeMaxEnabled
|
|
||||||
_c.mutation.SetSimulateClaudeMaxEnabled(v)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -748,9 +730,6 @@ func (_c *GroupCreate) check() error {
|
|||||||
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if _, ok := _c.mutation.SimulateClaudeMaxEnabled(); !ok {
|
|
||||||
return &ValidationError{Name: "simulate_claude_max_enabled", err: errors.New(`ent: missing required field "Group.simulate_claude_max_enabled"`)}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -906,10 +885,6 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||||
_node.DefaultMappedModel = value
|
_node.DefaultMappedModel = value
|
||||||
}
|
}
|
||||||
if value, ok := _c.mutation.SimulateClaudeMaxEnabled(); ok {
|
|
||||||
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
|
|
||||||
_node.SimulateClaudeMaxEnabled = value
|
|
||||||
}
|
|
||||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -1624,18 +1599,6 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
|
||||||
func (u *GroupUpsert) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsert {
|
|
||||||
u.Set(group.FieldSimulateClaudeMaxEnabled, v)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsert) UpdateSimulateClaudeMaxEnabled() *GroupUpsert {
|
|
||||||
u.SetExcluded(group.FieldSimulateClaudeMaxEnabled)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
// Using this option is equivalent to using:
|
// Using this option is equivalent to using:
|
||||||
//
|
//
|
||||||
@@ -2332,20 +2295,6 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
|
||||||
func (u *GroupUpsertOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.SetSimulateClaudeMaxEnabled(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsertOne) UpdateSimulateClaudeMaxEnabled() *GroupUpsertOne {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.UpdateSimulateClaudeMaxEnabled()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||||
if len(u.create.conflict) == 0 {
|
if len(u.create.conflict) == 0 {
|
||||||
@@ -3208,20 +3157,6 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
|
||||||
func (u *GroupUpsertBulk) SetSimulateClaudeMaxEnabled(v bool) *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.SetSimulateClaudeMaxEnabled(v)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field to the value that was provided on create.
|
|
||||||
func (u *GroupUpsertBulk) UpdateSimulateClaudeMaxEnabled() *GroupUpsertBulk {
|
|
||||||
return u.Update(func(s *GroupUpsert) {
|
|
||||||
s.UpdateSimulateClaudeMaxEnabled()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||||
if u.create.err != nil {
|
if u.create.err != nil {
|
||||||
|
|||||||
@@ -653,20 +653,6 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
|
||||||
func (_u *GroupUpdate) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdate {
|
|
||||||
_u.mutation.SetSimulateClaudeMaxEnabled(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
|
|
||||||
func (_u *GroupUpdate) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdate {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSimulateClaudeMaxEnabled(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -1163,9 +1149,6 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||||
}
|
}
|
||||||
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
|
|
||||||
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -2098,20 +2081,6 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
|
||||||
func (_u *GroupUpdateOne) SetSimulateClaudeMaxEnabled(v bool) *GroupUpdateOne {
|
|
||||||
_u.mutation.SetSimulateClaudeMaxEnabled(v)
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNillableSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field if the given value is not nil.
|
|
||||||
func (_u *GroupUpdateOne) SetNillableSimulateClaudeMaxEnabled(v *bool) *GroupUpdateOne {
|
|
||||||
if v != nil {
|
|
||||||
_u.SetSimulateClaudeMaxEnabled(*v)
|
|
||||||
}
|
|
||||||
return _u
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -2638,9 +2607,6 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
|||||||
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
if value, ok := _u.mutation.DefaultMappedModel(); ok {
|
||||||
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
|
||||||
}
|
}
|
||||||
if value, ok := _u.mutation.SimulateClaudeMaxEnabled(); ok {
|
|
||||||
_spec.SetField(group.FieldSimulateClaudeMaxEnabled, field.TypeBool, value)
|
|
||||||
}
|
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
|
|||||||
@@ -410,7 +410,6 @@ var (
|
|||||||
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
||||||
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
|
{Name: "allow_messages_dispatch", Type: field.TypeBool, Default: false},
|
||||||
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
|
||||||
{Name: "simulate_claude_max_enabled", Type: field.TypeBool, Default: false},
|
|
||||||
}
|
}
|
||||||
// GroupsTable holds the schema information for the "groups" table.
|
// GroupsTable holds the schema information for the "groups" table.
|
||||||
GroupsTable = &schema.Table{
|
GroupsTable = &schema.Table{
|
||||||
|
|||||||
@@ -8252,7 +8252,6 @@ type GroupMutation struct {
|
|||||||
addsort_order *int
|
addsort_order *int
|
||||||
allow_messages_dispatch *bool
|
allow_messages_dispatch *bool
|
||||||
default_mapped_model *string
|
default_mapped_model *string
|
||||||
simulate_claude_max_enabled *bool
|
|
||||||
clearedFields map[string]struct{}
|
clearedFields map[string]struct{}
|
||||||
api_keys map[int64]struct{}
|
api_keys map[int64]struct{}
|
||||||
removedapi_keys map[int64]struct{}
|
removedapi_keys map[int64]struct{}
|
||||||
@@ -10069,42 +10068,6 @@ func (m *GroupMutation) ResetDefaultMappedModel() {
|
|||||||
m.default_mapped_model = nil
|
m.default_mapped_model = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSimulateClaudeMaxEnabled sets the "simulate_claude_max_enabled" field.
|
|
||||||
func (m *GroupMutation) SetSimulateClaudeMaxEnabled(b bool) {
|
|
||||||
m.simulate_claude_max_enabled = &b
|
|
||||||
}
|
|
||||||
|
|
||||||
// SimulateClaudeMaxEnabled returns the value of the "simulate_claude_max_enabled" field in the mutation.
|
|
||||||
func (m *GroupMutation) SimulateClaudeMaxEnabled() (r bool, exists bool) {
|
|
||||||
v := m.simulate_claude_max_enabled
|
|
||||||
if v == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return *v, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// OldSimulateClaudeMaxEnabled returns the old "simulate_claude_max_enabled" field's value of the Group entity.
|
|
||||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
|
||||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
|
||||||
func (m *GroupMutation) OldSimulateClaudeMaxEnabled(ctx context.Context) (v bool, err error) {
|
|
||||||
if !m.op.Is(OpUpdateOne) {
|
|
||||||
return v, errors.New("OldSimulateClaudeMaxEnabled is only allowed on UpdateOne operations")
|
|
||||||
}
|
|
||||||
if m.id == nil || m.oldValue == nil {
|
|
||||||
return v, errors.New("OldSimulateClaudeMaxEnabled requires an ID field in the mutation")
|
|
||||||
}
|
|
||||||
oldValue, err := m.oldValue(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return v, fmt.Errorf("querying old value for OldSimulateClaudeMaxEnabled: %w", err)
|
|
||||||
}
|
|
||||||
return oldValue.SimulateClaudeMaxEnabled, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetSimulateClaudeMaxEnabled resets all changes to the "simulate_claude_max_enabled" field.
|
|
||||||
func (m *GroupMutation) ResetSimulateClaudeMaxEnabled() {
|
|
||||||
m.simulate_claude_max_enabled = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||||
if m.api_keys == nil {
|
if m.api_keys == nil {
|
||||||
@@ -10463,7 +10426,7 @@ func (m *GroupMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *GroupMutation) Fields() []string {
|
func (m *GroupMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 33)
|
fields := make([]string, 0, 32)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, group.FieldCreatedAt)
|
fields = append(fields, group.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -10560,9 +10523,6 @@ func (m *GroupMutation) Fields() []string {
|
|||||||
if m.default_mapped_model != nil {
|
if m.default_mapped_model != nil {
|
||||||
fields = append(fields, group.FieldDefaultMappedModel)
|
fields = append(fields, group.FieldDefaultMappedModel)
|
||||||
}
|
}
|
||||||
if m.simulate_claude_max_enabled != nil {
|
|
||||||
fields = append(fields, group.FieldSimulateClaudeMaxEnabled)
|
|
||||||
}
|
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -10635,8 +10595,6 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.AllowMessagesDispatch()
|
return m.AllowMessagesDispatch()
|
||||||
case group.FieldDefaultMappedModel:
|
case group.FieldDefaultMappedModel:
|
||||||
return m.DefaultMappedModel()
|
return m.DefaultMappedModel()
|
||||||
case group.FieldSimulateClaudeMaxEnabled:
|
|
||||||
return m.SimulateClaudeMaxEnabled()
|
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -10710,8 +10668,6 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
|||||||
return m.OldAllowMessagesDispatch(ctx)
|
return m.OldAllowMessagesDispatch(ctx)
|
||||||
case group.FieldDefaultMappedModel:
|
case group.FieldDefaultMappedModel:
|
||||||
return m.OldDefaultMappedModel(ctx)
|
return m.OldDefaultMappedModel(ctx)
|
||||||
case group.FieldSimulateClaudeMaxEnabled:
|
|
||||||
return m.OldSimulateClaudeMaxEnabled(ctx)
|
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -10945,13 +10901,6 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetDefaultMappedModel(v)
|
m.SetDefaultMappedModel(v)
|
||||||
return nil
|
return nil
|
||||||
case group.FieldSimulateClaudeMaxEnabled:
|
|
||||||
v, ok := value.(bool)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
|
||||||
}
|
|
||||||
m.SetSimulateClaudeMaxEnabled(v)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group field %s", name)
|
return fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -11385,9 +11334,6 @@ func (m *GroupMutation) ResetField(name string) error {
|
|||||||
case group.FieldDefaultMappedModel:
|
case group.FieldDefaultMappedModel:
|
||||||
m.ResetDefaultMappedModel()
|
m.ResetDefaultMappedModel()
|
||||||
return nil
|
return nil
|
||||||
case group.FieldSimulateClaudeMaxEnabled:
|
|
||||||
m.ResetSimulateClaudeMaxEnabled()
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group field %s", name)
|
return fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -463,10 +463,6 @@ func init() {
|
|||||||
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
|
||||||
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
|
||||||
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
|
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
|
||||||
// groupDescSimulateClaudeMaxEnabled is the schema descriptor for simulate_claude_max_enabled field.
|
|
||||||
groupDescSimulateClaudeMaxEnabled := groupFields[29].Descriptor()
|
|
||||||
// group.DefaultSimulateClaudeMaxEnabled holds the default value on creation for the simulate_claude_max_enabled field.
|
|
||||||
group.DefaultSimulateClaudeMaxEnabled = groupDescSimulateClaudeMaxEnabled.Default.(bool)
|
|
||||||
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
|
||||||
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
|
||||||
_ = idempotencyrecordMixinFields0
|
_ = idempotencyrecordMixinFields0
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ func (Group) Mixin() []ent.Mixin {
|
|||||||
|
|
||||||
func (Group) Fields() []ent.Field {
|
func (Group) Fields() []ent.Field {
|
||||||
return []ent.Field{
|
return []ent.Field{
|
||||||
|
// 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重用
|
||||||
|
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
|
||||||
field.String("name").
|
field.String("name").
|
||||||
MaxLen(100).
|
MaxLen(100).
|
||||||
NotEmpty(),
|
NotEmpty(),
|
||||||
@@ -49,6 +51,7 @@ func (Group) Fields() []ent.Field {
|
|||||||
MaxLen(20).
|
MaxLen(20).
|
||||||
Default(domain.StatusActive),
|
Default(domain.StatusActive),
|
||||||
|
|
||||||
|
// Subscription-related fields (added by migration 003)
|
||||||
field.String("platform").
|
field.String("platform").
|
||||||
MaxLen(50).
|
MaxLen(50).
|
||||||
Default(domain.PlatformAnthropic),
|
Default(domain.PlatformAnthropic),
|
||||||
@@ -70,6 +73,7 @@ func (Group) Fields() []ent.Field {
|
|||||||
field.Int("default_validity_days").
|
field.Int("default_validity_days").
|
||||||
Default(30),
|
Default(30),
|
||||||
|
|
||||||
|
// 图片生成计费配置(antigravity 和 gemini 平台使用)
|
||||||
field.Float("image_price_1k").
|
field.Float("image_price_1k").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
@@ -83,6 +87,7 @@ func (Group) Fields() []ent.Field {
|
|||||||
Nillable().
|
Nillable().
|
||||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||||
|
|
||||||
|
// Sora 按次计费配置(阶段 1)
|
||||||
field.Float("sora_image_price_360").
|
field.Float("sora_image_price_360").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
@@ -104,38 +109,45 @@ func (Group) Fields() []ent.Field {
|
|||||||
field.Int64("sora_storage_quota_bytes").
|
field.Int64("sora_storage_quota_bytes").
|
||||||
Default(0),
|
Default(0),
|
||||||
|
|
||||||
|
// Claude Code 客户端限制 (added by migration 029)
|
||||||
field.Bool("claude_code_only").
|
field.Bool("claude_code_only").
|
||||||
Default(false).
|
Default(false).
|
||||||
Comment("allow Claude Code client only"),
|
Comment("是否仅允许 Claude Code 客户端"),
|
||||||
field.Int64("fallback_group_id").
|
field.Int64("fallback_group_id").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
Comment("fallback group for non-Claude-Code requests"),
|
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
||||||
field.Int64("fallback_group_id_on_invalid_request").
|
field.Int64("fallback_group_id_on_invalid_request").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
Comment("fallback group for invalid request"),
|
Comment("无效请求兜底使用的分组 ID"),
|
||||||
|
|
||||||
|
// 模型路由配置 (added by migration 040)
|
||||||
field.JSON("model_routing", map[string][]int64{}).
|
field.JSON("model_routing", map[string][]int64{}).
|
||||||
Optional().
|
Optional().
|
||||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||||
Comment("model routing config: pattern -> account ids"),
|
Comment("模型路由配置:模型模式 -> 优先账号ID列表"),
|
||||||
|
|
||||||
|
// 模型路由开关 (added by migration 041)
|
||||||
field.Bool("model_routing_enabled").
|
field.Bool("model_routing_enabled").
|
||||||
Default(false).
|
Default(false).
|
||||||
Comment("whether model routing is enabled"),
|
Comment("是否启用模型路由配置"),
|
||||||
|
|
||||||
|
// MCP XML 协议注入开关 (added by migration 042)
|
||||||
field.Bool("mcp_xml_inject").
|
field.Bool("mcp_xml_inject").
|
||||||
Default(true).
|
Default(true).
|
||||||
Comment("whether MCP XML prompt injection is enabled"),
|
Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"),
|
||||||
|
|
||||||
|
// 支持的模型系列 (added by migration 046)
|
||||||
field.JSON("supported_model_scopes", []string{}).
|
field.JSON("supported_model_scopes", []string{}).
|
||||||
Default([]string{"claude", "gemini_text", "gemini_image"}).
|
Default([]string{"claude", "gemini_text", "gemini_image"}).
|
||||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
|
||||||
Comment("supported model scopes: claude, gemini_text, gemini_image"),
|
Comment("支持的模型系列:claude, gemini_text, gemini_image"),
|
||||||
|
|
||||||
|
// 分组排序 (added by migration 052)
|
||||||
field.Int("sort_order").
|
field.Int("sort_order").
|
||||||
Default(0).
|
Default(0).
|
||||||
Comment("group display order, lower comes first"),
|
Comment("分组显示排序,数值越小越靠前"),
|
||||||
|
|
||||||
// OpenAI Messages 调度配置 (added by migration 069)
|
// OpenAI Messages 调度配置 (added by migration 069)
|
||||||
field.Bool("allow_messages_dispatch").
|
field.Bool("allow_messages_dispatch").
|
||||||
@@ -145,9 +157,6 @@ func (Group) Fields() []ent.Field {
|
|||||||
MaxLen(100).
|
MaxLen(100).
|
||||||
Default("").
|
Default("").
|
||||||
Comment("默认映射模型 ID,当账号级映射找不到时使用此值"),
|
Comment("默认映射模型 ID,当账号级映射找不到时使用此值"),
|
||||||
field.Bool("simulate_claude_max_enabled").
|
|
||||||
Default(false).
|
|
||||||
Comment("simulate claude usage as claude-max style (1h cache write)"),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,11 +172,14 @@ func (Group) Edges() []ent.Edge {
|
|||||||
edge.From("allowed_users", User.Type).
|
edge.From("allowed_users", User.Type).
|
||||||
Ref("allowed_groups").
|
Ref("allowed_groups").
|
||||||
Through("user_allowed_groups", UserAllowedGroup.Type),
|
Through("user_allowed_groups", UserAllowedGroup.Type),
|
||||||
|
// 注意:fallback_group_id 直接作为字段使用,不定义 edge
|
||||||
|
// 这样允许多个分组指向同一个降级分组(M2O 关系)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (Group) Indexes() []ent.Index {
|
func (Group) Indexes() []ent.Index {
|
||||||
return []ent.Index{
|
return []ent.Index{
|
||||||
|
// name 字段已在 Fields() 中声明 Unique(),无需重复索引
|
||||||
index.Fields("status"),
|
index.Fields("status"),
|
||||||
index.Fields("platform"),
|
index.Fields("platform"),
|
||||||
index.Fields("subscription_type"),
|
index.Fields("subscription_type"),
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ require (
|
|||||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||||
github.com/DouDOU-start/go-sora2api v1.1.0
|
github.com/DouDOU-start/go-sora2api v1.1.0
|
||||||
github.com/alitto/pond/v2 v2.6.2
|
github.com/alitto/pond/v2 v2.6.2
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.2
|
github.com/aws/aws-sdk-go-v2 v1.41.3
|
||||||
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||||
@@ -66,7 +66,7 @@ require (
|
|||||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
||||||
github.com/aws/smithy-go v1.24.1 // indirect
|
github.com/aws/smithy-go v1.24.2 // indirect
|
||||||
github.com/bdandy/go-errors v1.2.2 // indirect
|
github.com/bdandy/go-errors v1.2.2 // indirect
|
||||||
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
||||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||||
@@ -87,7 +87,6 @@ require (
|
|||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/distribution/reference v0.6.0 // indirect
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
|
||||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||||
github.com/docker/go-connections v0.6.0 // indirect
|
github.com/docker/go-connections v0.6.0 // indirect
|
||||||
github.com/docker/go-units v0.5.0 // indirect
|
github.com/docker/go-units v0.5.0 // indirect
|
||||||
@@ -138,8 +137,6 @@ require (
|
|||||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
|
|
||||||
github.com/pkoukk/tiktoken-go-loader v0.0.2 // indirect
|
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||||
github.com/quic-go/qpack v0.6.0 // indirect
|
github.com/quic-go/qpack v0.6.0 // indirect
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
|
|||||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
||||||
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
|
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
|
||||||
@@ -60,6 +62,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8
|
|||||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
||||||
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
||||||
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
||||||
|
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||||
|
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||||
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
||||||
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
||||||
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
||||||
@@ -124,8 +128,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
|||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
|
||||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
|
||||||
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||||
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||||
@@ -182,7 +184,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
|||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||||
@@ -202,8 +203,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
|||||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
|
||||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
@@ -286,10 +285,6 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
|
|||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
|
|
||||||
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
|
||||||
github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
|
|
||||||
github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
@@ -342,8 +337,6 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS
|
|||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
@@ -352,6 +345,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
|||||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||||
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
|
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
|
||||||
@@ -441,11 +436,11 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||||
|
|||||||
@@ -934,9 +934,10 @@ type DashboardAggregationConfig struct {
|
|||||||
|
|
||||||
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
||||||
type DashboardAggregationRetentionConfig struct {
|
type DashboardAggregationRetentionConfig struct {
|
||||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||||
HourlyDays int `mapstructure:"hourly_days"`
|
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
|
||||||
DailyDays int `mapstructure:"daily_days"`
|
HourlyDays int `mapstructure:"hourly_days"`
|
||||||
|
DailyDays int `mapstructure:"daily_days"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsageCleanupConfig 使用记录清理任务配置
|
// UsageCleanupConfig 使用记录清理任务配置
|
||||||
@@ -1301,6 +1302,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
||||||
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
||||||
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
||||||
|
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
|
||||||
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
||||||
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||||
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||||
@@ -1758,6 +1760,12 @@ func (c *Config) Validate() error {
|
|||||||
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
||||||
}
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||||
|
}
|
||||||
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
||||||
}
|
}
|
||||||
@@ -1780,6 +1788,14 @@ func (c *Config) Validate() error {
|
|||||||
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
||||||
}
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
|
||||||
|
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
|
||||||
|
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||||
|
}
|
||||||
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
|||||||
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
||||||
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
||||||
}
|
}
|
||||||
|
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
|
||||||
|
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
|
||||||
|
}
|
||||||
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
||||||
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
||||||
}
|
}
|
||||||
@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
|
|||||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
||||||
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "dashboard aggregation dedup retention",
|
||||||
|
mutate: func(c *Config) {
|
||||||
|
c.DashboardAgg.Enabled = true
|
||||||
|
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
|
||||||
|
},
|
||||||
|
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dashboard aggregation dedup retention smaller than usage logs",
|
||||||
|
mutate: func(c *Config) {
|
||||||
|
c.DashboardAgg.Enabled = true
|
||||||
|
c.DashboardAgg.Retention.UsageLogsDays = 30
|
||||||
|
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
|
||||||
|
},
|
||||||
|
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "dashboard aggregation disabled interval",
|
name: "dashboard aggregation disabled interval",
|
||||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
||||||
|
|||||||
@@ -27,10 +27,12 @@ const (
|
|||||||
|
|
||||||
// Account type constants
|
// Account type constants
|
||||||
const (
|
const (
|
||||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||||
|
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock)
|
||||||
|
AccountTypeBedrockAPIKey = "bedrock-apikey" // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Redeem type constants
|
// Redeem type constants
|
||||||
@@ -113,3 +115,27 @@ var DefaultAntigravityModelMapping = map[string]string{
|
|||||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
|
||||||
|
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
|
||||||
|
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
|
||||||
|
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
|
||||||
|
var DefaultBedrockModelMapping = map[string]string{
|
||||||
|
// Claude Opus
|
||||||
|
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
|
||||||
|
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
|
||||||
|
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||||
|
"claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||||
|
"claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0",
|
||||||
|
"claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0",
|
||||||
|
// Claude Sonnet
|
||||||
|
"claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6",
|
||||||
|
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
|
||||||
|
"claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||||
|
"claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||||
|
"claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||||
|
"claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||||
|
// Claude Haiku
|
||||||
|
"claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||||
|
"claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||||
|
}
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
|
|||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
Notes *string `json:"notes"`
|
Notes *string `json:"notes"`
|
||||||
Platform string `json:"platform" binding:"required"`
|
Platform string `json:"platform" binding:"required"`
|
||||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
|
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
|
||||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||||
Extra map[string]any `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
|
|||||||
type UpdateAccountRequest struct {
|
type UpdateAccountRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Notes *string `json:"notes"`
|
Notes *string `json:"notes"`
|
||||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
|
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
|
||||||
Credentials map[string]any `json:"credentials"`
|
Credentials map[string]any `json:"credentials"`
|
||||||
Extra map[string]any `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
@@ -865,6 +865,9 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenAI OAuth: 刷新成功后检查并设置 privacy_mode
|
||||||
|
h.adminService.EnsureOpenAIPrivacy(ctx, updatedAccount)
|
||||||
|
|
||||||
return updatedAccount, "", nil
|
return updatedAccount, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1338,12 +1341,6 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
|||||||
c.JSON(409, gin.H{
|
c.JSON(409, gin.H{
|
||||||
"error": "mixed_channel_warning",
|
"error": "mixed_channel_warning",
|
||||||
"message": mixedErr.Error(),
|
"message": mixedErr.Error(),
|
||||||
"details": gin.H{
|
|
||||||
"group_id": mixedErr.GroupID,
|
|
||||||
"group_name": mixedErr.GroupName,
|
|
||||||
"current_platform": mixedErr.CurrentPlatform,
|
|
||||||
"other_platform": mixedErr.OtherPlatform,
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T
|
|||||||
var resp map[string]any
|
var resp map[string]any
|
||||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||||
require.Contains(t, resp["message"], "claude-max")
|
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||||
_, hasDetails := resp["details"]
|
_, hasDetails := resp["details"]
|
||||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||||
require.False(t, hasDetails)
|
require.False(t, hasDetails)
|
||||||
@@ -140,7 +140,7 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T
|
|||||||
var resp map[string]any
|
var resp map[string]any
|
||||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||||
require.Contains(t, resp["message"], "claude-max")
|
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||||
_, hasDetails := resp["details"]
|
_, hasDetails := resp["details"]
|
||||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||||
require.False(t, hasDetails)
|
require.False(t, hasDetails)
|
||||||
|
|||||||
@@ -179,6 +179,14 @@ func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) (
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) ClearGroupRateMultipliers(_ context.Context, _ int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int64, _ []service.GroupRateMultiplierInput) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||||
return s.accounts, int64(len(s.accounts)), nil
|
return s.accounts, int64(len(s.accounts)), nil
|
||||||
}
|
}
|
||||||
@@ -433,5 +441,9 @@ func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) EnsureOpenAIPrivacy(ctx context.Context, account *service.Account) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure stub implements interface.
|
// Ensure stub implements interface.
|
||||||
var _ service.AdminService = (*stubAdminService)(nil)
|
var _ service.AdminService = (*stubAdminService)(nil)
|
||||||
|
|||||||
@@ -466,9 +466,60 @@ type BatchUsersUsageRequest struct {
|
|||||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||||
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
||||||
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
||||||
|
|
||||||
|
func parseRankingLimit(raw string) int {
|
||||||
|
limit, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||||
|
if err != nil || limit <= 0 {
|
||||||
|
return 12
|
||||||
|
}
|
||||||
|
if limit > 50 {
|
||||||
|
return 50
|
||||||
|
}
|
||||||
|
return limit
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserSpendingRanking handles getting user spending ranking data.
|
||||||
|
// GET /api/v1/admin/dashboard/users-ranking
|
||||||
|
func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
|
||||||
|
startTime, endTime := parseTimeRange(c)
|
||||||
|
limit := parseRankingLimit(c.DefaultQuery("limit", "12"))
|
||||||
|
|
||||||
|
keyRaw, _ := json.Marshal(struct {
|
||||||
|
Start string `json:"start"`
|
||||||
|
End string `json:"end"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
}{
|
||||||
|
Start: startTime.UTC().Format(time.RFC3339),
|
||||||
|
End: endTime.UTC().Format(time.RFC3339),
|
||||||
|
Limit: limit,
|
||||||
|
})
|
||||||
|
cacheKey := string(keyRaw)
|
||||||
|
if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok {
|
||||||
|
c.Header("X-Snapshot-Cache", "hit")
|
||||||
|
response.Success(c, cached.Payload)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get user spending ranking")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := gin.H{
|
||||||
|
"ranking": ranking.Ranking,
|
||||||
|
"total_actual_cost": ranking.TotalActualCost,
|
||||||
|
"start_date": startTime.Format("2006-01-02"),
|
||||||
|
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||||
|
}
|
||||||
|
dashboardUsersRankingCache.Set(cacheKey, payload)
|
||||||
|
c.Header("X-Snapshot-Cache", "miss")
|
||||||
|
response.Success(c, payload)
|
||||||
|
}
|
||||||
|
|
||||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||||
// POST /api/v1/admin/dashboard/users-usage
|
// POST /api/v1/admin/dashboard/users-usage
|
||||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct {
|
|||||||
trendStream *bool
|
trendStream *bool
|
||||||
modelRequestType *int16
|
modelRequestType *int16
|
||||||
modelStream *bool
|
modelStream *bool
|
||||||
|
rankingLimit int
|
||||||
|
ranking []usagestats.UserSpendingRankingItem
|
||||||
|
rankingTotal float64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
||||||
@@ -49,6 +52,18 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
|
|||||||
return []usagestats.ModelStat{}, nil
|
return []usagestats.ModelStat{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
limit int,
|
||||||
|
) (*usagestats.UserSpendingRankingResponse, error) {
|
||||||
|
s.rankingLimit = limit
|
||||||
|
return &usagestats.UserSpendingRankingResponse{
|
||||||
|
Ranking: s.ranking,
|
||||||
|
TotalActualCost: s.rankingTotal,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||||
@@ -56,6 +71,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng
|
|||||||
router := gin.New()
|
router := gin.New()
|
||||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||||
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
||||||
|
router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking)
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,3 +146,30 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
|
|||||||
|
|
||||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
|
||||||
|
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||||
|
repo := &dashboardUsageRepoCapture{
|
||||||
|
ranking: []usagestats.UserSpendingRankingItem{
|
||||||
|
{UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300},
|
||||||
|
},
|
||||||
|
rankingTotal: 88.8,
|
||||||
|
}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.Equal(t, 50, repo.rankingLimit)
|
||||||
|
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
|
||||||
|
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
|
||||||
|
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec2, req2)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec2.Code)
|
||||||
|
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||||
|
}
|
||||||
|
|||||||
@@ -46,10 +46,9 @@ type CreateGroupRequest struct {
|
|||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||||
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||||
// Sora 存储配额
|
// Sora 存储配额
|
||||||
@@ -85,10 +84,9 @@ type UpdateGroupRequest struct {
|
|||||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||||
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||||
// Sora 存储配额
|
// Sora 存储配额
|
||||||
@@ -209,7 +207,6 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
ModelRouting: req.ModelRouting,
|
ModelRouting: req.ModelRouting,
|
||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
MCPXMLInject: req.MCPXMLInject,
|
MCPXMLInject: req.MCPXMLInject,
|
||||||
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
|
|
||||||
SupportedModelScopes: req.SupportedModelScopes,
|
SupportedModelScopes: req.SupportedModelScopes,
|
||||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||||
@@ -263,7 +260,6 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
ModelRouting: req.ModelRouting,
|
ModelRouting: req.ModelRouting,
|
||||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||||
MCPXMLInject: req.MCPXMLInject,
|
MCPXMLInject: req.MCPXMLInject,
|
||||||
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
|
|
||||||
SupportedModelScopes: req.SupportedModelScopes,
|
SupportedModelScopes: req.SupportedModelScopes,
|
||||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||||
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
AllowMessagesDispatch: req.AllowMessagesDispatch,
|
||||||
@@ -360,6 +356,51 @@ func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
|
|||||||
response.Success(c, entries)
|
response.Success(c, entries)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClearGroupRateMultipliers handles clearing all rate multipliers for a group
|
||||||
|
// DELETE /api/v1/admin/groups/:id/rate-multipliers
|
||||||
|
func (h *GroupHandler) ClearGroupRateMultipliers(c *gin.Context) {
|
||||||
|
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid group ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.adminService.ClearGroupRateMultipliers(c.Request.Context(), groupID); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "Rate multipliers cleared successfully"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSetGroupRateMultipliersRequest represents batch set rate multipliers request
|
||||||
|
type BatchSetGroupRateMultipliersRequest struct {
|
||||||
|
Entries []service.GroupRateMultiplierInput `json:"entries" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSetGroupRateMultipliers handles batch setting rate multipliers for a group
|
||||||
|
// PUT /api/v1/admin/groups/:id/rate-multipliers
|
||||||
|
func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
|
||||||
|
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid group ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req BatchSetGroupRateMultipliersRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.adminService.BatchSetGroupRateMultipliers(c.Request.Context(), groupID, req.Entries); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||||
type UpdateSortOrderRequest struct {
|
type UpdateSortOrderRequest struct {
|
||||||
Updates []struct {
|
Updates []struct {
|
||||||
|
|||||||
@@ -289,6 +289,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
|||||||
Platform: platform,
|
Platform: platform,
|
||||||
Type: "oauth",
|
Type: "oauth",
|
||||||
Credentials: credentials,
|
Credentials: credentials,
|
||||||
|
Extra: nil,
|
||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
Priority: req.Priority,
|
Priority: req.Priority,
|
||||||
|
|||||||
@@ -41,12 +41,15 @@ type GenerateRedeemCodesRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
||||||
|
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。
|
||||||
type CreateAndRedeemCodeRequest struct {
|
type CreateAndRedeemCodeRequest struct {
|
||||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
|
||||||
Value float64 `json:"value" binding:"required,gt=0"`
|
Value float64 `json:"value" binding:"required,gt=0"`
|
||||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||||
Notes string `json:"notes"`
|
GroupID *int64 `json:"group_id"` // subscription 类型必填
|
||||||
|
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0
|
||||||
|
Notes string `json:"notes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// List handles listing all redeem codes with pagination
|
// List handles listing all redeem codes with pagination
|
||||||
@@ -136,6 +139,22 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Code = strings.TrimSpace(req.Code)
|
req.Code = strings.TrimSpace(req.Code)
|
||||||
|
// 向后兼容:旧版调用方(如 Sub2ApiPay)不传 type 字段,默认当作 balance 充值处理。
|
||||||
|
// 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。
|
||||||
|
if req.Type == "" {
|
||||||
|
req.Type = "balance"
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Type == "subscription" {
|
||||||
|
if req.GroupID == nil {
|
||||||
|
response.BadRequest(c, "group_id is required for subscription type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.ValidityDays <= 0 {
|
||||||
|
response.BadRequest(c, "validity_days must be greater than 0 for subscription type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||||
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
||||||
@@ -147,11 +166,13 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
|
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
|
||||||
Code: req.Code,
|
Code: req.Code,
|
||||||
Type: req.Type,
|
Type: req.Type,
|
||||||
Value: req.Value,
|
Value: req.Value,
|
||||||
Status: service.StatusUnused,
|
Status: service.StatusUnused,
|
||||||
Notes: req.Notes,
|
Notes: req.Notes,
|
||||||
|
GroupID: req.GroupID,
|
||||||
|
ValidityDays: req.ValidityDays,
|
||||||
})
|
})
|
||||||
if createErr != nil {
|
if createErr != nil {
|
||||||
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
||||||
|
|||||||
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal)
|
||||||
|
// RedeemService so that CreateAndRedeem's nil guard passes and we can test the
|
||||||
|
// parameter-validation layer that runs before any service call.
|
||||||
|
func newCreateAndRedeemHandler() *RedeemHandler {
|
||||||
|
return &RedeemHandler{
|
||||||
|
adminService: newStubAdminService(),
|
||||||
|
redeemService: &service.RedeemService{}, // non-nil to pass nil guard
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response
|
||||||
|
// status code. For cases that pass validation and proceed into the service layer,
|
||||||
|
// a panic may occur (because RedeemService internals are nil); this is expected
|
||||||
|
// and treated as "validation passed" (returns 0 to indicate panic).
|
||||||
|
func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) {
|
||||||
|
t.Helper()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
jsonBytes, err := json.Marshal(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
// Panic means we passed validation and entered service layer (expected for minimal stub).
|
||||||
|
code = 0
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
handler.CreateAndRedeem(c)
|
||||||
|
return w.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) {
|
||||||
|
// 不传 type 字段时应默认 balance,不触发 subscription 校验。
|
||||||
|
// 验证通过后进入 service 层会 panic(返回 0),说明默认值生效。
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-balance-default",
|
||||||
|
"value": 10.0,
|
||||||
|
"user_id": 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||||
|
"omitting type should default to balance and pass validation")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-sub-no-group",
|
||||||
|
"type": "subscription",
|
||||||
|
"value": 29.9,
|
||||||
|
"user_id": 1,
|
||||||
|
"validity_days": 30,
|
||||||
|
// group_id 缺失
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) {
|
||||||
|
groupID := int64(5)
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
validityDays int
|
||||||
|
}{
|
||||||
|
{"zero", 0},
|
||||||
|
{"negative", -1},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-sub-bad-days-" + tc.name,
|
||||||
|
"type": "subscription",
|
||||||
|
"value": 29.9,
|
||||||
|
"user_id": 1,
|
||||||
|
"group_id": groupID,
|
||||||
|
"validity_days": tc.validityDays,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
|
||||||
|
groupID := int64(5)
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-sub-valid",
|
||||||
|
"type": "subscription",
|
||||||
|
"value": 29.9,
|
||||||
|
"user_id": 1,
|
||||||
|
"group_id": groupID,
|
||||||
|
"validity_days": 31,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||||
|
"valid subscription params should pass validation")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
// balance 类型不传 group_id 和 validity_days,不应报 400
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-balance-no-extras",
|
||||||
|
"type": "balance",
|
||||||
|
"value": 50.0,
|
||||||
|
"user_id": 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||||
|
"balance type should not require group_id or validity_days")
|
||||||
|
}
|
||||||
@@ -218,11 +218,12 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
|||||||
|
|
||||||
// ResetSubscriptionQuotaRequest represents the reset quota request
|
// ResetSubscriptionQuotaRequest represents the reset quota request
|
||||||
type ResetSubscriptionQuotaRequest struct {
|
type ResetSubscriptionQuotaRequest struct {
|
||||||
Daily bool `json:"daily"`
|
Daily bool `json:"daily"`
|
||||||
Weekly bool `json:"weekly"`
|
Weekly bool `json:"weekly"`
|
||||||
|
Monthly bool `json:"monthly"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResetQuota resets daily and/or weekly usage for a subscription.
|
// ResetQuota resets daily, weekly, and/or monthly usage for a subscription.
|
||||||
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
||||||
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
@@ -235,11 +236,11 @@ func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !req.Daily && !req.Weekly {
|
if !req.Daily && !req.Weekly && !req.Monthly {
|
||||||
response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true")
|
response.BadRequest(c, "At least one of 'daily', 'weekly', or 'monthly' must be true")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly)
|
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly, req.Monthly)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -135,15 +135,14 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
out := &AdminGroup{
|
out := &AdminGroup{
|
||||||
Group: groupFromServiceBase(g),
|
Group: groupFromServiceBase(g),
|
||||||
ModelRouting: g.ModelRouting,
|
ModelRouting: g.ModelRouting,
|
||||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
MCPXMLInject: g.MCPXMLInject,
|
MCPXMLInject: g.MCPXMLInject,
|
||||||
DefaultMappedModel: g.DefaultMappedModel,
|
DefaultMappedModel: g.DefaultMappedModel,
|
||||||
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
|
SupportedModelScopes: g.SupportedModelScopes,
|
||||||
SupportedModelScopes: g.SupportedModelScopes,
|
AccountCount: g.AccountCount,
|
||||||
AccountCount: g.AccountCount,
|
SortOrder: g.SortOrder,
|
||||||
SortOrder: g.SortOrder,
|
|
||||||
}
|
}
|
||||||
if len(g.AccountGroups) > 0 {
|
if len(g.AccountGroups) > 0 {
|
||||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||||
|
|||||||
@@ -117,8 +117,6 @@ type AdminGroup struct {
|
|||||||
|
|
||||||
// MCP XML 协议注入(仅 antigravity 平台使用)
|
// MCP XML 协议注入(仅 antigravity 平台使用)
|
||||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||||
// Claude usage 模拟开关(仅管理员可见)
|
|
||||||
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
|
|
||||||
|
|
||||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||||
DefaultMappedModel string `json:"default_mapped_model"`
|
DefaultMappedModel string `json:"default_mapped_model"`
|
||||||
|
|||||||
@@ -434,20 +434,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
ParsedRequest: parsedReq,
|
APIKey: apiKey,
|
||||||
APIKey: apiKey,
|
User: apiKey.User,
|
||||||
User: apiKey.User,
|
Account: account,
|
||||||
Account: account,
|
Subscription: subscription,
|
||||||
Subscription: subscription,
|
UserAgent: userAgent,
|
||||||
UserAgent: userAgent,
|
IPAddress: clientIP,
|
||||||
IPAddress: clientIP,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.gateway.messages"),
|
zap.String("component", "handler.gateway.messages"),
|
||||||
@@ -631,7 +632,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// ===== 用户消息串行队列 END =====
|
// ===== 用户消息串行队列 END =====
|
||||||
|
|
||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
c.Set("parsed_request", parsedReq)
|
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
requestCtx := c.Request.Context()
|
requestCtx := c.Request.Context()
|
||||||
if fs.SwitchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
@@ -738,20 +738,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
ParsedRequest: parsedReq,
|
APIKey: currentAPIKey,
|
||||||
APIKey: currentAPIKey,
|
User: currentAPIKey.User,
|
||||||
User: currentAPIKey.User,
|
Account: account,
|
||||||
Account: account,
|
Subscription: currentSubscription,
|
||||||
Subscription: currentSubscription,
|
UserAgent: userAgent,
|
||||||
UserAgent: userAgent,
|
IPAddress: clientIP,
|
||||||
IPAddress: clientIP,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.gateway.messages"),
|
zap.String("component", "handler.gateway.messages"),
|
||||||
|
|||||||
@@ -139,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
|||||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||||
&fakeGroupRepo{group: group},
|
&fakeGroupRepo{group: group},
|
||||||
nil, // usageLogRepo
|
nil, // usageLogRepo
|
||||||
|
nil, // usageBillingRepo
|
||||||
nil, // userRepo
|
nil, // userRepo
|
||||||
nil, // userSubRepo
|
nil, // userSubRepo
|
||||||
nil, // userGroupRateRepo
|
nil, // userGroupRateRepo
|
||||||
|
|||||||
@@ -503,6 +503,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
@@ -512,6 +513,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
|
|||||||
@@ -352,18 +352,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
APIKeyService: h.apiKeyService,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.openai_gateway.responses"),
|
zap.String("component", "handler.openai_gateway.responses"),
|
||||||
@@ -732,17 +734,19 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
APIKeyService: h.apiKeyService,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.openai_gateway.messages"),
|
zap.String("component", "handler.openai_gateway.messages"),
|
||||||
@@ -1231,14 +1235,15 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
APIKeyService: h.apiKeyService,
|
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
reqLog.Error("openai.websocket_record_usage_failed",
|
reqLog.Error("openai.websocket_record_usage_failed",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
|
|||||||
@@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
|||||||
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
||||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||||
return service.NewGatewayService(
|
return service.NewGatewayService(
|
||||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -399,17 +399,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
|
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||||
|
|||||||
@@ -343,6 +343,9 @@ func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, e
|
|||||||
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -431,6 +434,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
testutil.StubGatewayCache{},
|
testutil.StubGatewayCache{},
|
||||||
cfg,
|
cfg,
|
||||||
nil,
|
nil,
|
||||||
|
|||||||
@@ -189,6 +189,5 @@ var DefaultStopSequences = []string{
|
|||||||
"<|user|>",
|
"<|user|>",
|
||||||
"<|endoftext|>",
|
"<|endoftext|>",
|
||||||
"<|end_of_turn|>",
|
"<|end_of_turn|>",
|
||||||
"[DONE]",
|
|
||||||
"\n\nHuman:",
|
"\n\nHuman:",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,9 +18,6 @@ const (
|
|||||||
BlockTypeFunction
|
BlockTypeFunction
|
||||||
)
|
)
|
||||||
|
|
||||||
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
|
|
||||||
type UsageMapHook func(usageMap map[string]any)
|
|
||||||
|
|
||||||
// StreamingProcessor 流式响应处理器
|
// StreamingProcessor 流式响应处理器
|
||||||
type StreamingProcessor struct {
|
type StreamingProcessor struct {
|
||||||
blockType BlockType
|
blockType BlockType
|
||||||
@@ -33,7 +30,6 @@ type StreamingProcessor struct {
|
|||||||
originalModel string
|
originalModel string
|
||||||
webSearchQueries []string
|
webSearchQueries []string
|
||||||
groundingChunks []GeminiGroundingChunk
|
groundingChunks []GeminiGroundingChunk
|
||||||
usageMapHook UsageMapHook
|
|
||||||
|
|
||||||
// 累计 usage
|
// 累计 usage
|
||||||
inputTokens int
|
inputTokens int
|
||||||
@@ -49,25 +45,6 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
|
|
||||||
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
|
|
||||||
p.usageMapHook = fn
|
|
||||||
}
|
|
||||||
|
|
||||||
func usageToMap(u ClaudeUsage) map[string]any {
|
|
||||||
m := map[string]any{
|
|
||||||
"input_tokens": u.InputTokens,
|
|
||||||
"output_tokens": u.OutputTokens,
|
|
||||||
}
|
|
||||||
if u.CacheCreationInputTokens > 0 {
|
|
||||||
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
|
|
||||||
}
|
|
||||||
if u.CacheReadInputTokens > 0 {
|
|
||||||
m["cache_read_input_tokens"] = u.CacheReadInputTokens
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
||||||
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
@@ -191,13 +168,6 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
|||||||
responseID = "msg_" + generateRandomID()
|
responseID = "msg_" + generateRandomID()
|
||||||
}
|
}
|
||||||
|
|
||||||
var usageValue any = usage
|
|
||||||
if p.usageMapHook != nil {
|
|
||||||
usageMap := usageToMap(usage)
|
|
||||||
p.usageMapHook(usageMap)
|
|
||||||
usageValue = usageMap
|
|
||||||
}
|
|
||||||
|
|
||||||
message := map[string]any{
|
message := map[string]any{
|
||||||
"id": responseID,
|
"id": responseID,
|
||||||
"type": "message",
|
"type": "message",
|
||||||
@@ -206,7 +176,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
|||||||
"model": p.originalModel,
|
"model": p.originalModel,
|
||||||
"stop_reason": nil,
|
"stop_reason": nil,
|
||||||
"stop_sequence": nil,
|
"stop_sequence": nil,
|
||||||
"usage": usageValue,
|
"usage": usage,
|
||||||
}
|
}
|
||||||
|
|
||||||
event := map[string]any{
|
event := map[string]any{
|
||||||
@@ -517,20 +487,13 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
|||||||
CacheReadInputTokens: p.cacheReadTokens,
|
CacheReadInputTokens: p.cacheReadTokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
var usageValue any = usage
|
|
||||||
if p.usageMapHook != nil {
|
|
||||||
usageMap := usageToMap(usage)
|
|
||||||
p.usageMapHook(usageMap)
|
|
||||||
usageValue = usageMap
|
|
||||||
}
|
|
||||||
|
|
||||||
deltaEvent := map[string]any{
|
deltaEvent := map[string]any{
|
||||||
"type": "message_delta",
|
"type": "message_delta",
|
||||||
"delta": map[string]any{
|
"delta": map[string]any{
|
||||||
"stop_reason": stopReason,
|
"stop_reason": stopReason,
|
||||||
"stop_sequence": nil,
|
"stop_sequence": nil,
|
||||||
},
|
},
|
||||||
"usage": usageValue,
|
"usage": usage,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
||||||
|
|||||||
@@ -96,12 +96,28 @@ type UserUsageTrendPoint struct {
|
|||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
UserID int64 `json:"user_id"`
|
UserID int64 `json:"user_id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
|
Username string `json:"username"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Tokens int64 `json:"tokens"`
|
Tokens int64 `json:"tokens"`
|
||||||
Cost float64 `json:"cost"` // 标准计费
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserSpendingRankingItem represents a user spending ranking row.
|
||||||
|
type UserSpendingRankingItem struct {
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserSpendingRankingResponse represents ranking rows plus total spend for the time range.
|
||||||
|
type UserSpendingRankingResponse struct {
|
||||||
|
Ranking []UserSpendingRankingItem `json:"ranking"`
|
||||||
|
TotalActualCost float64 `json:"total_actual_cost"`
|
||||||
|
}
|
||||||
|
|
||||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||||
type APIKeyUsageTrendPoint struct {
|
type APIKeyUsageTrendPoint struct {
|
||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
|
|||||||
@@ -164,7 +164,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
|||||||
group.FieldModelRoutingEnabled,
|
group.FieldModelRoutingEnabled,
|
||||||
group.FieldModelRouting,
|
group.FieldModelRouting,
|
||||||
group.FieldMcpXMLInject,
|
group.FieldMcpXMLInject,
|
||||||
group.FieldSimulateClaudeMaxEnabled,
|
|
||||||
group.FieldSupportedModelScopes,
|
group.FieldSupportedModelScopes,
|
||||||
group.FieldAllowMessagesDispatch,
|
group.FieldAllowMessagesDispatch,
|
||||||
group.FieldDefaultMappedModel,
|
group.FieldDefaultMappedModel,
|
||||||
@@ -646,7 +645,6 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
|||||||
ModelRouting: g.ModelRouting,
|
ModelRouting: g.ModelRouting,
|
||||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||||
MCPXMLInject: g.McpXMLInject,
|
MCPXMLInject: g.McpXMLInject,
|
||||||
SimulateClaudeMaxEnabled: g.SimulateClaudeMaxEnabled,
|
|
||||||
SupportedModelScopes: g.SupportedModelScopes,
|
SupportedModelScopes: g.SupportedModelScopes,
|
||||||
SortOrder: g.SortOrder,
|
SortOrder: g.SortOrder,
|
||||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ type dashboardAggregationRepository struct {
|
|||||||
sql sqlExecutor
|
sql sqlExecutor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const usageLogsCleanupBatchSize = 10000
|
||||||
|
const usageBillingDedupCleanupBatchSize = 10000
|
||||||
|
|
||||||
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
||||||
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
||||||
if sqlDB == nil {
|
if sqlDB == nil {
|
||||||
@@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
if r == nil || r.sql == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
loc := timezone.Location()
|
loc := timezone.Location()
|
||||||
startLocal := start.In(loc)
|
startLocal := start.In(loc)
|
||||||
endLocal := end.In(loc)
|
endLocal := end.In(loc)
|
||||||
@@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
|
|||||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if db, ok := r.sql.(*sql.DB); ok {
|
||||||
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
|
||||||
|
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
|
||||||
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
||||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c
|
|||||||
if isPartitioned {
|
if isPartitioned {
|
||||||
return r.dropUsageLogsPartitions(ctx, cutoff)
|
return r.dropUsageLogsPartitions(ctx, cutoff)
|
||||||
}
|
}
|
||||||
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
|
for {
|
||||||
return err
|
res, err := r.sql.ExecContext(ctx, `
|
||||||
|
WITH victims AS (
|
||||||
|
SELECT ctid
|
||||||
|
FROM usage_logs
|
||||||
|
WHERE created_at < $1
|
||||||
|
LIMIT $2
|
||||||
|
)
|
||||||
|
DELETE FROM usage_logs
|
||||||
|
WHERE ctid IN (SELECT ctid FROM victims)
|
||||||
|
`, cutoff.UTC(), usageLogsCleanupBatchSize)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected < usageLogsCleanupBatchSize {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||||
|
for {
|
||||||
|
res, err := r.sql.ExecContext(ctx, `
|
||||||
|
WITH victims AS (
|
||||||
|
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
|
||||||
|
FROM usage_billing_dedup
|
||||||
|
WHERE created_at < $1
|
||||||
|
LIMIT $2
|
||||||
|
), archived AS (
|
||||||
|
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
|
||||||
|
SELECT request_id, api_key_id, request_fingerprint, created_at
|
||||||
|
FROM victims
|
||||||
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
|
)
|
||||||
|
DELETE FROM usage_billing_dedup
|
||||||
|
WHERE ctid IN (SELECT ctid FROM victims)
|
||||||
|
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected < usageBillingDedupCleanupBatchSize {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||||
|
|||||||
@@ -262,6 +262,42 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se
|
|||||||
SetKey(k.Key).
|
SetKey(k.Key).
|
||||||
SetName(k.Name).
|
SetName(k.Name).
|
||||||
SetStatus(k.Status)
|
SetStatus(k.Status)
|
||||||
|
if k.Quota != 0 {
|
||||||
|
create.SetQuota(k.Quota)
|
||||||
|
}
|
||||||
|
if k.QuotaUsed != 0 {
|
||||||
|
create.SetQuotaUsed(k.QuotaUsed)
|
||||||
|
}
|
||||||
|
if k.RateLimit5h != 0 {
|
||||||
|
create.SetRateLimit5h(k.RateLimit5h)
|
||||||
|
}
|
||||||
|
if k.RateLimit1d != 0 {
|
||||||
|
create.SetRateLimit1d(k.RateLimit1d)
|
||||||
|
}
|
||||||
|
if k.RateLimit7d != 0 {
|
||||||
|
create.SetRateLimit7d(k.RateLimit7d)
|
||||||
|
}
|
||||||
|
if k.Usage5h != 0 {
|
||||||
|
create.SetUsage5h(k.Usage5h)
|
||||||
|
}
|
||||||
|
if k.Usage1d != 0 {
|
||||||
|
create.SetUsage1d(k.Usage1d)
|
||||||
|
}
|
||||||
|
if k.Usage7d != 0 {
|
||||||
|
create.SetUsage7d(k.Usage7d)
|
||||||
|
}
|
||||||
|
if k.Window5hStart != nil {
|
||||||
|
create.SetWindow5hStart(*k.Window5hStart)
|
||||||
|
}
|
||||||
|
if k.Window1dStart != nil {
|
||||||
|
create.SetWindow1dStart(*k.Window1dStart)
|
||||||
|
}
|
||||||
|
if k.Window7dStart != nil {
|
||||||
|
create.SetWindow7dStart(*k.Window7dStart)
|
||||||
|
}
|
||||||
|
if k.ExpiresAt != nil {
|
||||||
|
create.SetExpiresAt(*k.ExpiresAt)
|
||||||
|
}
|
||||||
if k.GroupID != nil {
|
if k.GroupID != nil {
|
||||||
create.SetGroupID(*k.GroupID)
|
create.SetGroupID(*k.GroupID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,8 +61,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||||
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
|
|
||||||
|
|
||||||
// 设置模型路由配置
|
// 设置模型路由配置
|
||||||
if groupIn.ModelRouting != nil {
|
if groupIn.ModelRouting != nil {
|
||||||
@@ -130,8 +129,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||||
SetDefaultMappedModel(groupIn.DefaultMappedModel).
|
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||||
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
|
|
||||||
|
|
||||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||||
if groupIn.DailyLimitUSD != nil {
|
if groupIn.DailyLimitUSD != nil {
|
||||||
|
|||||||
@@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
|||||||
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
|
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
|
||||||
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
|
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
|
||||||
|
|
||||||
|
// usage_billing_dedup: billing idempotency narrow table
|
||||||
|
var usageBillingDedupRegclass sql.NullString
|
||||||
|
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
|
||||||
|
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
|
||||||
|
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
|
||||||
|
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
|
||||||
|
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
|
||||||
|
|
||||||
|
var usageBillingDedupArchiveRegclass sql.NullString
|
||||||
|
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
|
||||||
|
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
|
||||||
|
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
|
||||||
|
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
|
||||||
|
|
||||||
// settings table should exist
|
// settings table should exist
|
||||||
var settingsRegclass sql.NullString
|
var settingsRegclass sql.NullString
|
||||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
||||||
@@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
|||||||
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
|
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var exists bool
|
||||||
|
err := tx.QueryRowContext(context.Background(), `
|
||||||
|
SELECT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM pg_indexes
|
||||||
|
WHERE schemaname = 'public'
|
||||||
|
AND tablename = $1
|
||||||
|
AND indexname = $2
|
||||||
|
)
|
||||||
|
`, table, index).Scan(&exists)
|
||||||
|
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
|
||||||
|
require.True(t, exists, "expected index %s on %s", index, table)
|
||||||
|
}
|
||||||
|
|
||||||
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@@ -73,3 +73,14 @@ func buildReqClientKey(opts reqClientOptions) string {
|
|||||||
opts.ForceHTTP2,
|
opts.ForceHTTP2,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreatePrivacyReqClient creates an HTTP client for OpenAI privacy settings API
|
||||||
|
// This is exported for use by OpenAIPrivacyService
|
||||||
|
// Uses Chrome TLS fingerprint impersonation to bypass Cloudflare checks
|
||||||
|
func CreatePrivacyReqClient(proxyURL string) (*req.Client, error) {
|
||||||
|
return getSharedReqClient(reqClientOptions{
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
Impersonate: true, // Enable Chrome TLS fingerprint impersonation
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
308
backend/internal/repository/usage_billing_repo.go
Normal file
308
backend/internal/repository/usage_billing_repo.go
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type usageBillingRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
|
||||||
|
return &usageBillingRepository{db: sqlDB}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
|
||||||
|
if cmd == nil {
|
||||||
|
return &service.UsageBillingApplyResult{}, nil
|
||||||
|
}
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, errors.New("usage billing repository db is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Normalize()
|
||||||
|
if cmd.RequestID == "" {
|
||||||
|
return nil, service.ErrUsageBillingRequestIDRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := r.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if tx != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !applied {
|
||||||
|
return &service.UsageBillingApplyResult{Applied: false}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &service.UsageBillingApplyResult{Applied: true}
|
||||||
|
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tx = nil
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
|
||||||
|
var id int64
|
||||||
|
err := tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
|
RETURNING id
|
||||||
|
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
var existingFingerprint string
|
||||||
|
if err := tx.QueryRowContext(ctx, `
|
||||||
|
SELECT request_fingerprint
|
||||||
|
FROM usage_billing_dedup
|
||||||
|
WHERE request_id = $1 AND api_key_id = $2
|
||||||
|
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||||
|
return false, service.ErrUsageBillingRequestConflict
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
var archivedFingerprint string
|
||||||
|
err = tx.QueryRowContext(ctx, `
|
||||||
|
SELECT request_fingerprint
|
||||||
|
FROM usage_billing_dedup_archive
|
||||||
|
WHERE request_id = $1 AND api_key_id = $2
|
||||||
|
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
|
||||||
|
if err == nil {
|
||||||
|
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||||
|
return false, service.ErrUsageBillingRequestConflict
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
|
||||||
|
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
|
||||||
|
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.BalanceCost > 0 {
|
||||||
|
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.APIKeyQuotaCost > 0 {
|
||||||
|
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result.APIKeyQuotaExhausted = exhausted
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.APIKeyRateLimitCost > 0 {
|
||||||
|
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
|
||||||
|
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
|
||||||
|
const updateSQL = `
|
||||||
|
UPDATE user_subscriptions us
|
||||||
|
SET
|
||||||
|
daily_usage_usd = us.daily_usage_usd + $1,
|
||||||
|
weekly_usage_usd = us.weekly_usage_usd + $1,
|
||||||
|
monthly_usage_usd = us.monthly_usage_usd + $1,
|
||||||
|
updated_at = NOW()
|
||||||
|
FROM groups g
|
||||||
|
WHERE us.id = $2
|
||||||
|
AND us.deleted_at IS NULL
|
||||||
|
AND us.group_id = g.id
|
||||||
|
AND g.deleted_at IS NULL
|
||||||
|
`
|
||||||
|
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return service.ErrSubscriptionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
|
||||||
|
res, err := tx.ExecContext(ctx, `
|
||||||
|
UPDATE users
|
||||||
|
SET balance = balance - $1,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
|
`, amount, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return service.ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
|
||||||
|
var exhausted bool
|
||||||
|
err := tx.QueryRowContext(ctx, `
|
||||||
|
UPDATE api_keys
|
||||||
|
SET quota_used = quota_used + $1,
|
||||||
|
status = CASE
|
||||||
|
WHEN quota > 0
|
||||||
|
AND status = $3
|
||||||
|
AND quota_used < quota
|
||||||
|
AND quota_used + $1 >= quota
|
||||||
|
THEN $4
|
||||||
|
ELSE status
|
||||||
|
END,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
|
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
|
||||||
|
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return false, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return exhausted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
|
||||||
|
res, err := tx.ExecContext(ctx, `
|
||||||
|
UPDATE api_keys SET
|
||||||
|
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
|
||||||
|
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||||
|
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||||
|
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||||
|
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||||
|
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
|
`, cost, apiKeyID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected == 0 {
|
||||||
|
return service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
|
||||||
|
rows, err := tx.QueryContext(ctx,
|
||||||
|
`UPDATE accounts SET extra = (
|
||||||
|
COALESCE(extra, '{}'::jsonb)
|
||||||
|
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||||
|
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||||
|
jsonb_build_object(
|
||||||
|
'quota_daily_used',
|
||||||
|
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '24 hours'::interval <= NOW()
|
||||||
|
THEN $1
|
||||||
|
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||||
|
'quota_daily_start',
|
||||||
|
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '24 hours'::interval <= NOW()
|
||||||
|
THEN `+nowUTC+`
|
||||||
|
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||||
|
)
|
||||||
|
ELSE '{}'::jsonb END
|
||||||
|
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||||
|
jsonb_build_object(
|
||||||
|
'quota_weekly_used',
|
||||||
|
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '168 hours'::interval <= NOW()
|
||||||
|
THEN $1
|
||||||
|
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||||
|
'quota_weekly_start',
|
||||||
|
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '168 hours'::interval <= NOW()
|
||||||
|
THEN `+nowUTC+`
|
||||||
|
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||||
|
)
|
||||||
|
ELSE '{}'::jsonb END
|
||||||
|
), updated_at = NOW()
|
||||||
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
|
RETURNING
|
||||||
|
COALESCE((extra->>'quota_used')::numeric, 0),
|
||||||
|
COALESCE((extra->>'quota_limit')::numeric, 0)`,
|
||||||
|
amount, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var newUsed, limit float64
|
||||||
|
if rows.Next() {
|
||||||
|
if err := rows.Scan(&newUsed, &limit); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return service.ErrAccountNotFound
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
||||||
|
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,279 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Balance: 100,
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-usage-billing-" + uuid.NewString(),
|
||||||
|
Name: "billing",
|
||||||
|
Quota: 1,
|
||||||
|
})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{
|
||||||
|
Name: "usage-billing-account-" + uuid.NewString(),
|
||||||
|
Type: service.AccountTypeAPIKey,
|
||||||
|
})
|
||||||
|
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
cmd := &service.UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountType: service.AccountTypeAPIKey,
|
||||||
|
BalanceCost: 1.25,
|
||||||
|
APIKeyQuotaCost: 1.25,
|
||||||
|
APIKeyRateLimitCost: 1.25,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result1)
|
||||||
|
require.True(t, result1.Applied)
|
||||||
|
require.True(t, result1.APIKeyQuotaExhausted)
|
||||||
|
|
||||||
|
result2, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result2)
|
||||||
|
require.False(t, result2.Applied)
|
||||||
|
|
||||||
|
var balance float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||||
|
require.InDelta(t, 98.75, balance, 0.000001)
|
||||||
|
|
||||||
|
var quotaUsed float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan("aUsed))
|
||||||
|
require.InDelta(t, 1.25, quotaUsed, 0.000001)
|
||||||
|
|
||||||
|
var usage5h float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h))
|
||||||
|
require.InDelta(t, 1.25, usage5h, 0.000001)
|
||||||
|
|
||||||
|
var status string
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status))
|
||||||
|
require.Equal(t, service.StatusAPIKeyQuotaExhausted, status)
|
||||||
|
|
||||||
|
var dedupCount int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount))
|
||||||
|
require.Equal(t, 1, dedupCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
})
|
||||||
|
group := mustCreateGroup(t, client, &service.Group{
|
||||||
|
Name: "usage-billing-group-" + uuid.NewString(),
|
||||||
|
Platform: service.PlatformAnthropic,
|
||||||
|
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: &group.ID,
|
||||||
|
Key: "sk-usage-billing-sub-" + uuid.NewString(),
|
||||||
|
Name: "billing-sub",
|
||||||
|
})
|
||||||
|
subscription := mustCreateSubscription(t, client, &service.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
})
|
||||||
|
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
cmd := &service.UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
AccountID: 0,
|
||||||
|
SubscriptionID: &subscription.ID,
|
||||||
|
SubscriptionCost: 2.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result1.Applied)
|
||||||
|
|
||||||
|
result2, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, result2.Applied)
|
||||||
|
|
||||||
|
var dailyUsage float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage))
|
||||||
|
require.InDelta(t, 2.5, dailyUsage, 0.000001)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Balance: 100,
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-usage-billing-conflict-" + uuid.NewString(),
|
||||||
|
Name: "billing-conflict",
|
||||||
|
})
|
||||||
|
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
BalanceCost: 1.25,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
BalanceCost: 2.50,
|
||||||
|
})
|
||||||
|
require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-usage-billing-account-" + uuid.NewString(),
|
||||||
|
Name: "billing-account",
|
||||||
|
})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{
|
||||||
|
Name: "usage-billing-account-quota-" + uuid.NewString(),
|
||||||
|
Type: service.AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"quota_limit": 100.0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountType: service.AccountTypeAPIKey,
|
||||||
|
AccountQuotaCost: 3.5,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var quotaUsed float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan("aUsed))
|
||||||
|
require.InDelta(t, 3.5, quotaUsed, 0.000001)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||||
|
|
||||||
|
oldRequestID := "dedup-old-" + uuid.NewString()
|
||||||
|
newRequestID := "dedup-new-" + uuid.NewString()
|
||||||
|
oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400)
|
||||||
|
newCreatedAt := time.Now().UTC().Add(-time.Hour)
|
||||||
|
|
||||||
|
_, err := integrationDB.ExecContext(ctx, `
|
||||||
|
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at)
|
||||||
|
VALUES ($1, 1, $2, $3), ($4, 1, $5, $6)
|
||||||
|
`,
|
||||||
|
oldRequestID, strings.Repeat("a", 64), oldCreatedAt,
|
||||||
|
newRequestID, strings.Repeat("b", 64), newCreatedAt,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||||
|
|
||||||
|
var oldCount int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount))
|
||||||
|
require.Equal(t, 0, oldCount)
|
||||||
|
|
||||||
|
var newCount int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount))
|
||||||
|
require.Equal(t, 1, newCount)
|
||||||
|
|
||||||
|
var archivedCount int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount))
|
||||||
|
require.Equal(t, 1, archivedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Balance: 100,
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-usage-billing-archive-" + uuid.NewString(),
|
||||||
|
Name: "billing-archive",
|
||||||
|
})
|
||||||
|
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
cmd := &service.UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
BalanceCost: 1.25,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result1.Applied)
|
||||||
|
|
||||||
|
_, err = integrationDB.ExecContext(ctx, `
|
||||||
|
UPDATE usage_billing_dedup
|
||||||
|
SET created_at = $1
|
||||||
|
WHERE request_id = $2 AND api_key_id = $3
|
||||||
|
`, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||||
|
|
||||||
|
result2, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, result2.Applied)
|
||||||
|
|
||||||
|
var balance float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||||
|
require.InDelta(t, 98.75, balance, 0.000001)
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,8 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -14,6 +16,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -84,6 +87,367 @@ func (s *UsageLogRepoSuite) TestCreate() {
|
|||||||
s.Require().NotZero(log.ID)
|
s.Require().NotZero(log.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()})
|
||||||
|
|
||||||
|
const total = 16
|
||||||
|
results := make([]bool, total)
|
||||||
|
errs := make([]error, total)
|
||||||
|
logs := make([]*service.UsageLog, total)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(total)
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
i := i
|
||||||
|
logs[i] = &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10 + i,
|
||||||
|
OutputTokens: 20 + i,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
results[i], errs[i] = repo.Create(ctx, logs[i])
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
require.NoError(t, errs[i])
|
||||||
|
require.True(t, results[i])
|
||||||
|
require.NotZero(t, logs[i].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count))
|
||||||
|
require.Equal(t, total, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()})
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
|
||||||
|
log1 := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: requestID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
log2 := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: requestID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
|
||||||
|
inserted1, err1 := repo.Create(ctx, log1)
|
||||||
|
inserted2, err2 := repo.Create(ctx, log2)
|
||||||
|
require.NoError(t, err1)
|
||||||
|
require.NoError(t, err2)
|
||||||
|
require.True(t, inserted1)
|
||||||
|
require.False(t, inserted2)
|
||||||
|
require.Equal(t, log1.ID, log2.ID)
|
||||||
|
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()})
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
|
||||||
|
const total = 8
|
||||||
|
batch := make([]usageLogCreateRequest, 0, total)
|
||||||
|
logs := make([]*service.UsageLog, 0, total)
|
||||||
|
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
log := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: requestID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10 + i,
|
||||||
|
OutputTokens: 20 + i,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
logs = append(logs, log)
|
||||||
|
batch = append(batch, usageLogCreateRequest{
|
||||||
|
log: log,
|
||||||
|
prepared: prepareUsageLogInsert(log),
|
||||||
|
resultCh: make(chan usageLogCreateResult, 1),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
repo.flushCreateBatch(integrationDB, batch)
|
||||||
|
|
||||||
|
insertedCount := 0
|
||||||
|
var firstID int64
|
||||||
|
for idx, req := range batch {
|
||||||
|
res := <-req.resultCh
|
||||||
|
require.NoError(t, res.err)
|
||||||
|
if res.inserted {
|
||||||
|
insertedCount++
|
||||||
|
}
|
||||||
|
require.NotZero(t, logs[idx].ID)
|
||||||
|
if idx == 0 {
|
||||||
|
firstID = logs[idx].ID
|
||||||
|
} else {
|
||||||
|
require.Equal(t, firstID, logs[idx].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 1, insertedCount)
|
||||||
|
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()})
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
|
||||||
|
log1 := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: requestID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
log2 := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: requestID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, repo.CreateBestEffort(ctx, log1))
|
||||||
|
require.NoError(t, repo.CreateBestEffort(ctx, log2))
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
var count int
|
||||||
|
err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)
|
||||||
|
return err == nil && count == 1
|
||||||
|
}, 3*time.Second, 20*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1)
|
||||||
|
repo.bestEffortBatchCh <- usageLogBestEffortRequest{}
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()})
|
||||||
|
|
||||||
|
err := repo.CreateBestEffort(ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, service.IsUsageLogCreateDropped(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
})
|
||||||
|
|
||||||
|
require.False(t, inserted)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||||
|
repo.createBatchCh <- usageLogCreateRequest{}
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()})
|
||||||
|
|
||||||
|
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
})
|
||||||
|
|
||||||
|
require.False(t, inserted)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := repo.createBatched(ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
})
|
||||||
|
errCh <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
req := <-repo.createBatchCh
|
||||||
|
require.NotNil(t, req.shared)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := <-errCh
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||||
|
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) {
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()})
|
||||||
|
|
||||||
|
log := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
req := usageLogCreateRequest{
|
||||||
|
log: log,
|
||||||
|
prepared: prepareUsageLogInsert(log),
|
||||||
|
shared: &usageLogCreateShared{},
|
||||||
|
resultCh: make(chan usageLogCreateResult, 1),
|
||||||
|
}
|
||||||
|
req.shared.state.Store(usageLogCreateStateCanceled)
|
||||||
|
|
||||||
|
repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req})
|
||||||
|
|
||||||
|
res := <-req.resultCh
|
||||||
|
require.False(t, res.inserted)
|
||||||
|
require.Error(t, res.err)
|
||||||
|
require.True(t, service.IsUsageLogCreateNotPersisted(res.err))
|
||||||
|
}
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
||||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||||
|
|||||||
@@ -248,6 +248,35 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
|
|||||||
require.NoError(t, mock.ExpectationsWereMet())
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
|
||||||
|
db, mock := newSQLMock(t)
|
||||||
|
repo := &usageLogRepository{sql: db}
|
||||||
|
|
||||||
|
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
end := start.Add(24 * time.Hour)
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}).
|
||||||
|
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0).
|
||||||
|
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0).
|
||||||
|
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0)
|
||||||
|
|
||||||
|
mock.ExpectQuery("WITH user_spend AS \\(").
|
||||||
|
WithArgs(start, end, 12).
|
||||||
|
WillReturnRows(rows)
|
||||||
|
|
||||||
|
got, err := repo.GetUserSpendingRanking(context.Background(), start, end, 12)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, &usagestats.UserSpendingRankingResponse{
|
||||||
|
Ranking: []usagestats.UserSpendingRankingItem{
|
||||||
|
{UserID: 2, Email: "beta@example.com", ActualCost: 12.5, Requests: 9, Tokens: 900},
|
||||||
|
{UserID: 1, Email: "alpha@example.com", ActualCost: 12.5, Requests: 8, Tokens: 800},
|
||||||
|
{UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300},
|
||||||
|
},
|
||||||
|
TotalActualCost: 40.0,
|
||||||
|
}, got)
|
||||||
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
|
func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -3,8 +3,11 @@
|
|||||||
package repository
|
package repository
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) {
|
||||||
|
log := &service.UsageLog{
|
||||||
|
UserID: 1,
|
||||||
|
APIKeyID: 2,
|
||||||
|
AccountID: 3,
|
||||||
|
RequestID: "req-batch-no-update",
|
||||||
|
Model: "gpt-5",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalCost: 1.2,
|
||||||
|
ActualCost: 1.2,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
prepared := prepareUsageLogInsert(log)
|
||||||
|
|
||||||
|
query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{
|
||||||
|
usageLogBatchKey(log.RequestID, log.APIKeyID): prepared,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING")
|
||||||
|
require.NotContains(t, strings.ToUpper(query), "DO UPDATE")
|
||||||
|
}
|
||||||
|
|||||||
@@ -98,9 +98,9 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
|
|||||||
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
||||||
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
|
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT ugr.user_id, u.email, ugr.rate_multiplier
|
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
|
||||||
FROM user_group_rate_multipliers ugr
|
FROM user_group_rate_multipliers ugr
|
||||||
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
|
JOIN users u ON u.id = ugr.user_id
|
||||||
WHERE ugr.group_id = $1
|
WHERE ugr.group_id = $1
|
||||||
ORDER BY ugr.user_id
|
ORDER BY ugr.user_id
|
||||||
`
|
`
|
||||||
@@ -113,7 +113,7 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
|
|||||||
var result []service.UserGroupRateEntry
|
var result []service.UserGroupRateEntry
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var entry service.UserGroupRateEntry
|
var entry service.UserGroupRateEntry
|
||||||
if err := rows.Scan(&entry.UserID, &entry.UserEmail, &entry.RateMultiplier); err != nil {
|
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
result = append(result, entry)
|
result = append(result, entry)
|
||||||
@@ -193,6 +193,31 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
|
||||||
|
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
|
||||||
|
if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
userIDs := make([]int64, len(entries))
|
||||||
|
rates := make([]float64, len(entries))
|
||||||
|
for i, e := range entries {
|
||||||
|
userIDs[i] = e.UserID
|
||||||
|
rates[i] = e.RateMultiplier
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
_, err := r.sql.ExecContext(ctx, `
|
||||||
|
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
|
||||||
|
SELECT data.user_id, $1::bigint, data.rate_multiplier, $2::timestamptz, $2::timestamptz
|
||||||
|
FROM unnest($3::bigint[], $4::double precision[]) AS data(user_id, rate_multiplier)
|
||||||
|
ON CONFLICT (user_id, group_id)
|
||||||
|
DO UPDATE SET rate_multiplier = EXCLUDED.rate_multiplier, updated_at = EXCLUDED.updated_at
|
||||||
|
`, groupID, now, pq.Array(userIDs), pq.Array(rates))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率
|
// DeleteByGroupID 删除指定分组的所有用户专属倍率
|
||||||
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
||||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewAnnouncementRepository,
|
NewAnnouncementRepository,
|
||||||
NewAnnouncementReadRepository,
|
NewAnnouncementReadRepository,
|
||||||
NewUsageLogRepository,
|
NewUsageLogRepository,
|
||||||
|
NewUsageBillingRepository,
|
||||||
NewIdempotencyRepository,
|
NewIdempotencyRepository,
|
||||||
NewUsageCleanupRepository,
|
NewUsageCleanupRepository,
|
||||||
NewDashboardAggregationRepository,
|
NewDashboardAggregationRepository,
|
||||||
|
|||||||
@@ -645,7 +645,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
settingRepo := newStubSettingRepo()
|
settingRepo := newStubSettingRepo()
|
||||||
settingService := service.NewSettingService(settingRepo, cfg)
|
settingService := service.NewSettingService(settingRepo, cfg)
|
||||||
|
|
||||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
@@ -1635,6 +1635,10 @@ func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, end
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||||
logs := r.userLogs[userID]
|
logs := r.userLogs[userID]
|
||||||
if len(logs) == 0 {
|
if len(logs) == 0 {
|
||||||
|
|||||||
@@ -192,6 +192,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats)
|
dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats)
|
||||||
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
|
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
|
||||||
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
||||||
|
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
|
||||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
||||||
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
|
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
|
||||||
@@ -229,6 +230,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
groups.DELETE("/:id", h.Admin.Group.Delete)
|
groups.DELETE("/:id", h.Admin.Group.Delete)
|
||||||
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
||||||
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
|
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
|
||||||
|
groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers)
|
||||||
|
groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers)
|
||||||
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -412,6 +412,7 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri
|
|||||||
if a.Platform == domain.PlatformAntigravity {
|
if a.Platform == domain.PlatformAntigravity {
|
||||||
return domain.DefaultAntigravityModelMapping
|
return domain.DefaultAntigravityModelMapping
|
||||||
}
|
}
|
||||||
|
// Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if len(rawMapping) == 0 {
|
if len(rawMapping) == 0 {
|
||||||
@@ -764,6 +765,14 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Account) IsBedrock() bool {
|
||||||
|
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeBedrock || a.Type == AccountTypeBedrockAPIKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) IsBedrockAPIKey() bool {
|
||||||
|
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrockAPIKey
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Account) IsOpenAI() bool {
|
func (a *Account) IsOpenAI() bool {
|
||||||
return a.Platform == PlatformOpenAI
|
return a.Platform == PlatformOpenAI
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -207,14 +207,14 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
|||||||
testModelID = claude.DefaultTestModel
|
testModelID = claude.DefaultTestModel
|
||||||
}
|
}
|
||||||
|
|
||||||
// For API Key accounts with model mapping, map the model
|
// API Key 账号测试连接时也需要应用通配符模型映射。
|
||||||
if account.Type == "apikey" {
|
if account.Type == "apikey" {
|
||||||
mapping := account.GetModelMapping()
|
testModelID = account.GetMappedModel(testModelID)
|
||||||
if len(mapping) > 0 {
|
}
|
||||||
if mappedModel, exists := mapping[testModelID]; exists {
|
|
||||||
testModelID = mappedModel
|
// Bedrock accounts use a separate test path
|
||||||
}
|
if account.IsBedrock() {
|
||||||
}
|
return s.testBedrockAccountConnection(c, ctx, account, testModelID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine authentication method and API URL
|
// Determine authentication method and API URL
|
||||||
@@ -312,6 +312,109 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
|||||||
return s.processClaudeStream(c, resp.Body)
|
return s.processClaudeStream(c, resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
|
||||||
|
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
|
||||||
|
region := bedrockRuntimeRegion(account)
|
||||||
|
resolvedModelID, ok := ResolveBedrockModelID(account, testModelID)
|
||||||
|
if !ok {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Bedrock model: %s", testModelID))
|
||||||
|
}
|
||||||
|
testModelID = resolvedModelID
|
||||||
|
|
||||||
|
// Set SSE headers (test UI expects SSE)
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Writer.Flush()
|
||||||
|
|
||||||
|
// Create a minimal Bedrock-compatible payload (no stream, no cache_control)
|
||||||
|
bedrockPayload := map[string]any{
|
||||||
|
"anthropic_version": "bedrock-2023-05-31",
|
||||||
|
"messages": []map[string]any{
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": []map[string]any{
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "hi",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"max_tokens": 256,
|
||||||
|
"temperature": 1,
|
||||||
|
}
|
||||||
|
bedrockBody, _ := json.Marshal(bedrockPayload)
|
||||||
|
|
||||||
|
// Use non-streaming endpoint (response is standard Claude JSON)
|
||||||
|
apiURL := BuildBedrockURL(region, testModelID, false)
|
||||||
|
|
||||||
|
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bedrockBody))
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Sign or set auth based on account type
|
||||||
|
if account.IsBedrockAPIKey() {
|
||||||
|
apiKey := account.GetCredential("api_key")
|
||||||
|
if apiKey == "" {
|
||||||
|
return s.sendErrorAndEnd(c, "No API key available")
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
} else {
|
||||||
|
signer, err := NewBedrockSignerFromAccount(account)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Bedrock signer: %s", err.Error()))
|
||||||
|
}
|
||||||
|
if err := signer.SignRequest(ctx, req, bedrockBody); err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to sign request: %s", err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL := ""
|
||||||
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, false)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bedrock non-streaming response is standard Claude JSON, extract the text
|
||||||
|
var result struct {
|
||||||
|
Content []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"content"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
text := ""
|
||||||
|
if len(result.Content) > 0 {
|
||||||
|
text = result.Content[0].Text
|
||||||
|
}
|
||||||
|
if text == "" {
|
||||||
|
text = "(empty response)"
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||||
|
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// testOpenAIAccountConnection tests an OpenAI account's connection
|
// testOpenAIAccountConnection tests an OpenAI account's connection
|
||||||
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ type UsageLogRepository interface {
|
|||||||
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
|
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
|
||||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||||
|
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
|
||||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
|
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,8 @@ type AdminService interface {
|
|||||||
DeleteGroup(ctx context.Context, id int64) error
|
DeleteGroup(ctx context.Context, id int64) error
|
||||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
|
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||||
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
||||||
|
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
|
||||||
|
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
|
||||||
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
||||||
|
|
||||||
// API Key management (admin)
|
// API Key management (admin)
|
||||||
@@ -58,6 +60,8 @@ type AdminService interface {
|
|||||||
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
|
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
|
||||||
ClearAccountError(ctx context.Context, id int64) (*Account, error)
|
ClearAccountError(ctx context.Context, id int64) (*Account, error)
|
||||||
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
||||||
|
// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号 privacy_mode,未设置则尝试关闭训练数据共享并持久化。
|
||||||
|
EnsureOpenAIPrivacy(ctx context.Context, account *Account) string
|
||||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
||||||
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||||
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
|
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
|
||||||
@@ -139,10 +143,9 @@ type CreateGroupInput struct {
|
|||||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||||
FallbackGroupIDOnInvalidRequest *int64
|
FallbackGroupIDOnInvalidRequest *int64
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64
|
ModelRouting map[string][]int64
|
||||||
ModelRoutingEnabled bool // 是否启用模型路由
|
ModelRoutingEnabled bool // 是否启用模型路由
|
||||||
MCPXMLInject *bool
|
MCPXMLInject *bool
|
||||||
SimulateClaudeMaxEnabled *bool
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes []string
|
SupportedModelScopes []string
|
||||||
// Sora 存储配额
|
// Sora 存储配额
|
||||||
@@ -179,10 +182,9 @@ type UpdateGroupInput struct {
|
|||||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||||
FallbackGroupIDOnInvalidRequest *int64
|
FallbackGroupIDOnInvalidRequest *int64
|
||||||
// 模型路由配置(仅 anthropic 平台使用)
|
// 模型路由配置(仅 anthropic 平台使用)
|
||||||
ModelRouting map[string][]int64
|
ModelRouting map[string][]int64
|
||||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||||
MCPXMLInject *bool
|
MCPXMLInject *bool
|
||||||
SimulateClaudeMaxEnabled *bool
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes *[]string
|
SupportedModelScopes *[]string
|
||||||
// Sora 存储配额
|
// Sora 存储配额
|
||||||
@@ -366,10 +368,6 @@ type ProxyExitInfoProber interface {
|
|||||||
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type groupExistenceBatchReader interface {
|
|
||||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type proxyQualityTarget struct {
|
type proxyQualityTarget struct {
|
||||||
Target string
|
Target string
|
||||||
URL string
|
URL string
|
||||||
@@ -440,12 +438,17 @@ type adminServiceImpl struct {
|
|||||||
settingService *SettingService
|
settingService *SettingService
|
||||||
defaultSubAssigner DefaultSubscriptionAssigner
|
defaultSubAssigner DefaultSubscriptionAssigner
|
||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
|
privacyClientFactory PrivacyClientFactory
|
||||||
}
|
}
|
||||||
|
|
||||||
type userGroupRateBatchReader interface {
|
type userGroupRateBatchReader interface {
|
||||||
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
|
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type groupExistenceBatchReader interface {
|
||||||
|
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
// NewAdminService creates a new AdminService
|
// NewAdminService creates a new AdminService
|
||||||
func NewAdminService(
|
func NewAdminService(
|
||||||
userRepo UserRepository,
|
userRepo UserRepository,
|
||||||
@@ -464,6 +467,7 @@ func NewAdminService(
|
|||||||
settingService *SettingService,
|
settingService *SettingService,
|
||||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||||
userSubRepo UserSubscriptionRepository,
|
userSubRepo UserSubscriptionRepository,
|
||||||
|
privacyClientFactory PrivacyClientFactory,
|
||||||
) AdminService {
|
) AdminService {
|
||||||
return &adminServiceImpl{
|
return &adminServiceImpl{
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
@@ -482,6 +486,7 @@ func NewAdminService(
|
|||||||
settingService: settingService,
|
settingService: settingService,
|
||||||
defaultSubAssigner: defaultSubAssigner,
|
defaultSubAssigner: defaultSubAssigner,
|
||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
|
privacyClientFactory: privacyClientFactory,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -863,13 +868,6 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
if input.MCPXMLInject != nil {
|
if input.MCPXMLInject != nil {
|
||||||
mcpXMLInject = *input.MCPXMLInject
|
mcpXMLInject = *input.MCPXMLInject
|
||||||
}
|
}
|
||||||
simulateClaudeMaxEnabled := false
|
|
||||||
if input.SimulateClaudeMaxEnabled != nil {
|
|
||||||
if platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled {
|
|
||||||
return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups")
|
|
||||||
}
|
|
||||||
simulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果指定了复制账号的源分组,先获取账号 ID 列表
|
// 如果指定了复制账号的源分组,先获取账号 ID 列表
|
||||||
var accountIDsToCopy []int64
|
var accountIDsToCopy []int64
|
||||||
@@ -926,7 +924,6 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
|
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
|
||||||
ModelRouting: input.ModelRouting,
|
ModelRouting: input.ModelRouting,
|
||||||
MCPXMLInject: mcpXMLInject,
|
MCPXMLInject: mcpXMLInject,
|
||||||
SimulateClaudeMaxEnabled: simulateClaudeMaxEnabled,
|
|
||||||
SupportedModelScopes: input.SupportedModelScopes,
|
SupportedModelScopes: input.SupportedModelScopes,
|
||||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
||||||
AllowMessagesDispatch: input.AllowMessagesDispatch,
|
AllowMessagesDispatch: input.AllowMessagesDispatch,
|
||||||
@@ -1138,15 +1135,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
if input.MCPXMLInject != nil {
|
if input.MCPXMLInject != nil {
|
||||||
group.MCPXMLInject = *input.MCPXMLInject
|
group.MCPXMLInject = *input.MCPXMLInject
|
||||||
}
|
}
|
||||||
if input.SimulateClaudeMaxEnabled != nil {
|
|
||||||
if group.Platform != PlatformAnthropic && *input.SimulateClaudeMaxEnabled {
|
|
||||||
return nil, fmt.Errorf("simulate_claude_max_enabled only supported for anthropic groups")
|
|
||||||
}
|
|
||||||
group.SimulateClaudeMaxEnabled = *input.SimulateClaudeMaxEnabled
|
|
||||||
}
|
|
||||||
if group.Platform != PlatformAnthropic {
|
|
||||||
group.SimulateClaudeMaxEnabled = false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
if input.SupportedModelScopes != nil {
|
if input.SupportedModelScopes != nil {
|
||||||
@@ -1271,6 +1259,20 @@ func (s *adminServiceImpl) GetGroupRateMultipliers(ctx context.Context, groupID
|
|||||||
return s.userGroupRateRepo.GetByGroupID(ctx, groupID)
|
return s.userGroupRateRepo.GetByGroupID(ctx, groupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) ClearGroupRateMultipliers(ctx context.Context, groupID int64) error {
|
||||||
|
if s.userGroupRateRepo == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.userGroupRateRepo.DeleteByGroupID(ctx, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
|
||||||
|
if s.userGroupRateRepo == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
||||||
return s.groupRepo.UpdateSortOrders(ctx, updates)
|
return s.groupRepo.UpdateSortOrders(ctx, updates)
|
||||||
}
|
}
|
||||||
@@ -2529,3 +2531,39 @@ func (e *MixedChannelError) Error() string {
|
|||||||
func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error {
|
func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error {
|
||||||
return s.accountRepo.ResetQuotaUsed(ctx, id)
|
return s.accountRepo.ResetQuotaUsed(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号是否已设置 privacy_mode,
|
||||||
|
// 未设置则调用 disableOpenAITraining 并持久化到 Extra,返回设置的 mode 值。
|
||||||
|
func (s *adminServiceImpl) EnsureOpenAIPrivacy(ctx context.Context, account *Account) string {
|
||||||
|
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if s.privacyClientFactory == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if account.Extra != nil {
|
||||||
|
if _, ok := account.Extra["privacy_mode"]; ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
token, _ := account.Credentials["access_token"].(string)
|
||||||
|
if token == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var proxyURL string
|
||||||
|
if account.ProxyID != nil {
|
||||||
|
if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil {
|
||||||
|
proxyURL = p.URL()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mode := disableOpenAITraining(ctx, s.privacyClientFactory, token, proxyURL)
|
||||||
|
if mode == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode})
|
||||||
|
return mode
|
||||||
|
}
|
||||||
|
|||||||
@@ -43,16 +43,6 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
|
|
||||||
if err, ok := s.listByGroupErr[groupID]; ok {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if rows, ok := s.listByGroupData[groupID]; ok {
|
|
||||||
return rows, nil
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
|
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
|
||||||
s.getByIDsCalled = true
|
s.getByIDsCalled = true
|
||||||
s.getByIDsIDs = append([]int64{}, ids...)
|
s.getByIDsIDs = append([]int64{}, ids...)
|
||||||
@@ -73,6 +63,16 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac
|
|||||||
return nil, errors.New("account not found")
|
return nil, errors.New("account not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
|
||||||
|
if err, ok := s.listByGroupErr[groupID]; ok {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if rows, ok := s.listByGroupData[groupID]; ok {
|
||||||
|
return rows, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||||
repo := &accountRepoStubForBulkUpdate{}
|
repo := &accountRepoStubForBulkUpdate{}
|
||||||
|
|||||||
176
backend/internal/service/admin_service_group_rate_test.go
Normal file
176
backend/internal/service/admin_service_group_rate_test.go
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// userGroupRateRepoStubForGroupRate implements UserGroupRateRepository for group rate tests.
|
||||||
|
type userGroupRateRepoStubForGroupRate struct {
|
||||||
|
getByGroupIDData map[int64][]UserGroupRateEntry
|
||||||
|
getByGroupIDErr error
|
||||||
|
|
||||||
|
deletedGroupIDs []int64
|
||||||
|
deleteByGroupErr error
|
||||||
|
|
||||||
|
syncedGroupID int64
|
||||||
|
syncedEntries []GroupRateMultiplierInput
|
||||||
|
syncGroupErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
|
||||||
|
panic("unexpected GetByUserID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context, _, _ int64) (*float64, error) {
|
||||||
|
panic("unexpected GetByUserAndGroup call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
||||||
|
if s.getByGroupIDErr != nil {
|
||||||
|
return nil, s.getByGroupIDErr
|
||||||
|
}
|
||||||
|
return s.getByGroupIDData[groupID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) SyncUserGroupRates(_ context.Context, _ int64, _ map[int64]*float64) error {
|
||||||
|
panic("unexpected SyncUserGroupRates call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
|
||||||
|
s.syncedGroupID = groupID
|
||||||
|
s.syncedEntries = entries
|
||||||
|
return s.syncGroupErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
|
||||||
|
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
|
||||||
|
return s.deleteByGroupErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) DeleteByUserID(_ context.Context, _ int64) error {
|
||||||
|
panic("unexpected DeleteByUserID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
|
||||||
|
t.Run("returns entries for group", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{
|
||||||
|
getByGroupIDData: map[int64][]UserGroupRateEntry{
|
||||||
|
10: {
|
||||||
|
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
|
||||||
|
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
|
||||||
|
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, entries, 2)
|
||||||
|
require.Equal(t, int64(1), entries[0].UserID)
|
||||||
|
require.Equal(t, "alice", entries[0].UserName)
|
||||||
|
require.Equal(t, 1.5, entries[0].RateMultiplier)
|
||||||
|
require.Equal(t, int64(2), entries[1].UserID)
|
||||||
|
require.Equal(t, 0.8, entries[1].RateMultiplier)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: nil}
|
||||||
|
|
||||||
|
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, entries)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns empty slice for group with no entries", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{
|
||||||
|
getByGroupIDData: map[int64][]UserGroupRateEntry{},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
|
||||||
|
entries, err := svc.GetGroupRateMultipliers(context.Background(), 99)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, entries)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("propagates repo error", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{
|
||||||
|
getByGroupIDErr: errors.New("db error"),
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.GetGroupRateMultipliers(context.Background(), 10)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "db error")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_ClearGroupRateMultipliers(t *testing.T) {
|
||||||
|
t.Run("deletes by group ID", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
|
||||||
|
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []int64{42}, repo.deletedGroupIDs)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: nil}
|
||||||
|
|
||||||
|
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("propagates repo error", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{
|
||||||
|
deleteByGroupErr: errors.New("delete failed"),
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
|
||||||
|
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "delete failed")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
|
||||||
|
t.Run("syncs entries to repo", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
|
||||||
|
entries := []GroupRateMultiplierInput{
|
||||||
|
{UserID: 1, RateMultiplier: 1.5},
|
||||||
|
{UserID: 2, RateMultiplier: 0.8},
|
||||||
|
}
|
||||||
|
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, entries)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(10), repo.syncedGroupID)
|
||||||
|
require.Equal(t, entries, repo.syncedEntries)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: nil}
|
||||||
|
|
||||||
|
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("propagates repo error", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{
|
||||||
|
syncGroupErr: errors.New("sync failed"),
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
|
||||||
|
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, []GroupRateMultiplierInput{
|
||||||
|
{UserID: 1, RateMultiplier: 1.0},
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "sync failed")
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -785,57 +785,3 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
|
|||||||
require.NotNil(t, repo.updated)
|
require.NotNil(t, repo.updated)
|
||||||
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAdminService_CreateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) {
|
|
||||||
repo := &groupRepoStubForAdmin{}
|
|
||||||
svc := &adminServiceImpl{groupRepo: repo}
|
|
||||||
|
|
||||||
enabled := true
|
|
||||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
|
||||||
Name: "openai-group",
|
|
||||||
Platform: PlatformOpenAI,
|
|
||||||
SimulateClaudeMaxEnabled: &enabled,
|
|
||||||
})
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups")
|
|
||||||
require.Nil(t, repo.created)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAdminService_UpdateGroup_SimulateClaudeMaxRequiresAnthropic(t *testing.T) {
|
|
||||||
existingGroup := &Group{
|
|
||||||
ID: 1,
|
|
||||||
Name: "openai-group",
|
|
||||||
Platform: PlatformOpenAI,
|
|
||||||
Status: StatusActive,
|
|
||||||
}
|
|
||||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
|
||||||
svc := &adminServiceImpl{groupRepo: repo}
|
|
||||||
|
|
||||||
enabled := true
|
|
||||||
_, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
|
|
||||||
SimulateClaudeMaxEnabled: &enabled,
|
|
||||||
})
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "simulate_claude_max_enabled only supported for anthropic groups")
|
|
||||||
require.Nil(t, repo.updated)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAdminService_UpdateGroup_ClearsSimulateClaudeMaxWhenPlatformChanges(t *testing.T) {
|
|
||||||
existingGroup := &Group{
|
|
||||||
ID: 1,
|
|
||||||
Name: "anthropic-group",
|
|
||||||
Platform: PlatformAnthropic,
|
|
||||||
Status: StatusActive,
|
|
||||||
SimulateClaudeMaxEnabled: true,
|
|
||||||
}
|
|
||||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
|
||||||
svc := &adminServiceImpl{groupRepo: repo}
|
|
||||||
|
|
||||||
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
|
|
||||||
Platform: PlatformOpenAI,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, group)
|
|
||||||
require.NotNil(t, repo.updated)
|
|
||||||
require.False(t, repo.updated.SimulateClaudeMaxEnabled)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -68,11 +68,15 @@ func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context
|
|||||||
panic("unexpected SyncUserGroupRates call")
|
panic("unexpected SyncUserGroupRates call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, _ int64) ([]UserGroupRateEntry, error) {
|
||||||
panic("unexpected GetByGroupID call")
|
panic("unexpected GetByGroupID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
|
func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.Context, _ int64, _ []GroupRateMultiplierInput) error {
|
||||||
|
panic("unexpected SyncGroupRateMultipliers call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
|
||||||
panic("unexpected DeleteByGroupID call")
|
panic("unexpected DeleteByGroupID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1673,7 +1673,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
var clientDisconnect bool
|
var clientDisconnect bool
|
||||||
if claudeReq.Stream {
|
if claudeReq.Stream {
|
||||||
// 客户端要求流式,直接透传转换
|
// 客户端要求流式,直接透传转换
|
||||||
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel, account.ID)
|
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
|
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1683,7 +1683,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
clientDisconnect = streamRes.clientDisconnect
|
clientDisconnect = streamRes.clientDisconnect
|
||||||
} else {
|
} else {
|
||||||
// 客户端要求非流式,收集流式响应后转换返回
|
// 客户端要求非流式,收集流式响应后转换返回
|
||||||
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel, account.ID)
|
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
|
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1692,9 +1692,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
firstTokenMs = streamRes.firstTokenMs
|
firstTokenMs = streamRes.firstTokenMs
|
||||||
}
|
}
|
||||||
|
|
||||||
// Claude Max cache billing: 同步 ForwardResult.Usage 与客户端响应体一致
|
|
||||||
applyClaudeMaxCacheBillingPolicyToUsage(usage, parsedRequestFromGinContext(c), claudeMaxGroupFromGinContext(c), originalModel, account.ID)
|
|
||||||
|
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
@@ -3598,7 +3595,7 @@ func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int,
|
|||||||
|
|
||||||
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
|
// handleClaudeStreamToNonStreaming 收集上游流式响应,转换为 Claude 非流式格式返回
|
||||||
// 用于处理客户端非流式请求但上游只支持流式的情况
|
// 用于处理客户端非流式请求但上游只支持流式的情况
|
||||||
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
|
func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
maxLineSize := defaultMaxLineSize
|
maxLineSize := defaultMaxLineSize
|
||||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||||
@@ -3756,9 +3753,6 @@ returnResponse:
|
|||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Claude Max cache billing simulation (non-streaming)
|
|
||||||
claudeResp = applyClaudeMaxNonStreamingRewrite(c, claudeResp, agUsage, originalModel, accountID)
|
|
||||||
|
|
||||||
c.Data(http.StatusOK, "application/json", claudeResp)
|
c.Data(http.StatusOK, "application/json", claudeResp)
|
||||||
|
|
||||||
// 转换为 service.ClaudeUsage
|
// 转换为 service.ClaudeUsage
|
||||||
@@ -3773,7 +3767,7 @@ returnResponse:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
|
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
|
||||||
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string, accountID int64) (*antigravityStreamResult, error) {
|
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
c.Header("Cache-Control", "no-cache")
|
c.Header("Cache-Control", "no-cache")
|
||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
@@ -3786,8 +3780,6 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
|||||||
}
|
}
|
||||||
|
|
||||||
processor := antigravity.NewStreamingProcessor(originalModel)
|
processor := antigravity.NewStreamingProcessor(originalModel)
|
||||||
setupClaudeMaxStreamingHook(c, processor, originalModel, accountID)
|
|
||||||
|
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
// 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
|||||||
@@ -922,7 +922,7 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
|
|||||||
fmt.Fprintln(pw, "")
|
fmt.Fprintln(pw, "")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||||
_ = pr.Close()
|
_ = pr.Close()
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -999,7 +999,7 @@ func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) {
|
|||||||
fmt.Fprintln(pw, "")
|
fmt.Fprintln(pw, "")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro", 0)
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro")
|
||||||
_ = pr.Close()
|
_ = pr.Close()
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1202,7 +1202,7 @@ func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) {
|
|||||||
fmt.Fprintln(pw, "")
|
fmt.Fprintln(pw, "")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||||
_ = pr.Close()
|
_ = pr.Close()
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1234,7 +1234,7 @@ func TestHandleClaudeStreamingResponse_EmptyStream(t *testing.T) {
|
|||||||
fmt.Fprintln(pw, "")
|
fmt.Fprintln(pw, "")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
_, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||||
_ = pr.Close()
|
_ = pr.Close()
|
||||||
|
|
||||||
// 应当返回 UpstreamFailoverError 而非 nil,以便上层触发 failover
|
// 应当返回 UpstreamFailoverError 而非 nil,以便上层触发 failover
|
||||||
@@ -1266,7 +1266,7 @@ func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) {
|
|||||||
|
|
||||||
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}}
|
||||||
|
|
||||||
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5", 0)
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5")
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
|
|||||||
@@ -59,10 +59,9 @@ type APIKeyAuthGroupSnapshot struct {
|
|||||||
|
|
||||||
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
||||||
// Only anthropic groups use these fields; others may leave them empty.
|
// Only anthropic groups use these fields; others may leave them empty.
|
||||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||||
SimulateClaudeMaxEnabled bool `json:"simulate_claude_max_enabled"`
|
|
||||||
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||||
|
|||||||
@@ -244,7 +244,6 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
|||||||
ModelRouting: apiKey.Group.ModelRouting,
|
ModelRouting: apiKey.Group.ModelRouting,
|
||||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||||
MCPXMLInject: apiKey.Group.MCPXMLInject,
|
MCPXMLInject: apiKey.Group.MCPXMLInject,
|
||||||
SimulateClaudeMaxEnabled: apiKey.Group.SimulateClaudeMaxEnabled,
|
|
||||||
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
|
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
|
||||||
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
|
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
|
||||||
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
|
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
|
||||||
@@ -304,7 +303,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
|||||||
ModelRouting: snapshot.Group.ModelRouting,
|
ModelRouting: snapshot.Group.ModelRouting,
|
||||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||||
MCPXMLInject: snapshot.Group.MCPXMLInject,
|
MCPXMLInject: snapshot.Group.MCPXMLInject,
|
||||||
SimulateClaudeMaxEnabled: snapshot.Group.SimulateClaudeMaxEnabled,
|
|
||||||
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
|
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
|
||||||
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
|
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
|
||||||
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
|
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
|
||||||
|
|||||||
607
backend/internal/service/bedrock_request.go
Normal file
607
backend/internal/service/bedrock_request.go
Normal file
@@ -0,0 +1,607 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultBedrockRegion = "us-east-1"
|
||||||
|
|
||||||
|
var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."}
|
||||||
|
|
||||||
|
// BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀
|
||||||
|
// 参考: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
|
||||||
|
func BedrockCrossRegionPrefix(region string) string {
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(region, "us-gov"):
|
||||||
|
return "us-gov" // GovCloud 使用独立的 us-gov 前缀
|
||||||
|
case strings.HasPrefix(region, "us-"):
|
||||||
|
return "us"
|
||||||
|
case strings.HasPrefix(region, "eu-"):
|
||||||
|
return "eu"
|
||||||
|
case region == "ap-northeast-1":
|
||||||
|
return "jp" // 日本区域使用独立的 jp 前缀(AWS 官方定义)
|
||||||
|
case region == "ap-southeast-2":
|
||||||
|
return "au" // 澳大利亚区域使用独立的 au 前缀(AWS 官方定义)
|
||||||
|
case strings.HasPrefix(region, "ap-"):
|
||||||
|
return "apac" // 其余亚太区域使用通用 apac 前缀
|
||||||
|
case strings.HasPrefix(region, "ca-"):
|
||||||
|
return "us" // 加拿大区域使用 us 前缀的跨区域推理
|
||||||
|
case strings.HasPrefix(region, "sa-"):
|
||||||
|
return "us" // 南美区域使用 us 前缀的跨区域推理
|
||||||
|
default:
|
||||||
|
return "us"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdjustBedrockModelRegionPrefix 将模型 ID 的区域前缀替换为与当前 AWS Region 匹配的前缀
|
||||||
|
// 例如 region=eu-west-1 时,"us.anthropic.claude-opus-4-6-v1" → "eu.anthropic.claude-opus-4-6-v1"
|
||||||
|
// 特殊值 region="global" 强制使用 global. 前缀
|
||||||
|
func AdjustBedrockModelRegionPrefix(modelID, region string) string {
|
||||||
|
var targetPrefix string
|
||||||
|
if region == "global" {
|
||||||
|
targetPrefix = "global"
|
||||||
|
} else {
|
||||||
|
targetPrefix = BedrockCrossRegionPrefix(region)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range bedrockCrossRegionPrefixes {
|
||||||
|
if strings.HasPrefix(modelID, p) {
|
||||||
|
if p == targetPrefix+"." {
|
||||||
|
return modelID // 前缀已匹配,无需替换
|
||||||
|
}
|
||||||
|
return targetPrefix + "." + modelID[len(p):]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 模型 ID 没有已知区域前缀(如 "anthropic.claude-..."),不做修改
|
||||||
|
return modelID
|
||||||
|
}
|
||||||
|
|
||||||
|
func bedrockRuntimeRegion(account *Account) string {
|
||||||
|
if account == nil {
|
||||||
|
return defaultBedrockRegion
|
||||||
|
}
|
||||||
|
if region := account.GetCredential("aws_region"); region != "" {
|
||||||
|
return region
|
||||||
|
}
|
||||||
|
return defaultBedrockRegion
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldForceBedrockGlobal(account *Account) bool {
|
||||||
|
return account != nil && account.GetCredential("aws_force_global") == "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRegionalBedrockModelID(modelID string) bool {
|
||||||
|
for _, prefix := range bedrockCrossRegionPrefixes {
|
||||||
|
if strings.HasPrefix(modelID, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLikelyBedrockModelID(modelID string) bool {
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(modelID))
|
||||||
|
if lower == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(lower, "arn:") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, prefix := range []string{
|
||||||
|
"anthropic.",
|
||||||
|
"amazon.",
|
||||||
|
"meta.",
|
||||||
|
"mistral.",
|
||||||
|
"cohere.",
|
||||||
|
"ai21.",
|
||||||
|
"deepseek.",
|
||||||
|
"stability.",
|
||||||
|
"writer.",
|
||||||
|
"nova.",
|
||||||
|
} {
|
||||||
|
if strings.HasPrefix(lower, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return isRegionalBedrockModelID(lower)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeBedrockModelID(modelID string) (normalized string, shouldAdjustRegion bool, ok bool) {
|
||||||
|
modelID = strings.TrimSpace(modelID)
|
||||||
|
if modelID == "" {
|
||||||
|
return "", false, false
|
||||||
|
}
|
||||||
|
if mapped, exists := domain.DefaultBedrockModelMapping[modelID]; exists {
|
||||||
|
return mapped, true, true
|
||||||
|
}
|
||||||
|
if isRegionalBedrockModelID(modelID) {
|
||||||
|
return modelID, true, true
|
||||||
|
}
|
||||||
|
if isLikelyBedrockModelID(modelID) {
|
||||||
|
return modelID, false, true
|
||||||
|
}
|
||||||
|
return "", false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveBedrockModelID resolves a requested Claude model into a Bedrock model ID.
|
||||||
|
// It applies account model_mapping first, then default Bedrock aliases, and finally
|
||||||
|
// adjusts Anthropic cross-region prefixes to match the account region.
|
||||||
|
func ResolveBedrockModelID(account *Account, requestedModel string) (string, bool) {
|
||||||
|
if account == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
mappedModel := account.GetMappedModel(requestedModel)
|
||||||
|
modelID, shouldAdjustRegion, ok := normalizeBedrockModelID(mappedModel)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
if shouldAdjustRegion {
|
||||||
|
targetRegion := bedrockRuntimeRegion(account)
|
||||||
|
if shouldForceBedrockGlobal(account) {
|
||||||
|
targetRegion = "global"
|
||||||
|
}
|
||||||
|
modelID = AdjustBedrockModelRegionPrefix(modelID, targetRegion)
|
||||||
|
}
|
||||||
|
return modelID, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildBedrockURL 构建 Bedrock InvokeModel 的 URL
|
||||||
|
// stream=true 时使用 invoke-with-response-stream 端点
|
||||||
|
// modelID 中的特殊字符会被 URL 编码(与 litellm 的 urllib.parse.quote(safe="") 对齐)
|
||||||
|
func BuildBedrockURL(region, modelID string, stream bool) string {
|
||||||
|
if region == "" {
|
||||||
|
region = defaultBedrockRegion
|
||||||
|
}
|
||||||
|
encodedModelID := url.PathEscape(modelID)
|
||||||
|
// url.PathEscape 不编码冒号(RFC 允许 path 中出现 ":"),
|
||||||
|
// 但 AWS Bedrock 期望模型 ID 中的冒号被编码为 %3A
|
||||||
|
encodedModelID = strings.ReplaceAll(encodedModelID, ":", "%3A")
|
||||||
|
if stream {
|
||||||
|
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", region, encodedModelID)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke", region, encodedModelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrepareBedrockRequestBody 处理请求体以适配 Bedrock API
|
||||||
|
// 1. 注入 anthropic_version
|
||||||
|
// 2. 注入 anthropic_beta(从客户端 anthropic-beta 头解析)
|
||||||
|
// 3. 移除 Bedrock 不支持的字段(model, stream, output_format, output_config)
|
||||||
|
// 4. 移除工具定义中的 custom 字段(Claude Code 会发送 custom: {defer_loading: true})
|
||||||
|
// 5. 清理 cache_control 中 Bedrock 不支持的字段(scope, ttl)
|
||||||
|
func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) {
|
||||||
|
betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID)
|
||||||
|
return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens.
|
||||||
|
func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// 注入 anthropic_version(Bedrock 要求)
|
||||||
|
body, err = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("inject anthropic_version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 注入 anthropic_beta(Bedrock Invoke 通过请求体传递 beta 头,而非 HTTP 头)
|
||||||
|
// 1. 从客户端 anthropic-beta header 解析
|
||||||
|
// 2. 根据请求体内容自动补齐必要的 beta token
|
||||||
|
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() + _get_tool_search_beta_header_for_bedrock()
|
||||||
|
if len(betaTokens) > 0 {
|
||||||
|
body, err = sjson.SetBytes(body, "anthropic_beta", betaTokens)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("inject anthropic_beta: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除 model 字段(Bedrock 通过 URL 指定模型)
|
||||||
|
body, err = sjson.DeleteBytes(body, "model")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("remove model field: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除 stream 字段(Bedrock 通过不同端点控制流式,不接受请求体中的 stream 字段)
|
||||||
|
body, err = sjson.DeleteBytes(body, "stream")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("remove stream field: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换 output_format(Bedrock Invoke 不支持此字段,但可将 schema 内联到最后一条 user message)
|
||||||
|
// 参考 litellm: _convert_output_format_to_inline_schema()
|
||||||
|
body = convertOutputFormatToInlineSchema(body)
|
||||||
|
|
||||||
|
// 移除 output_config 字段(Bedrock Invoke 不支持)
|
||||||
|
body, err = sjson.DeleteBytes(body, "output_config")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("remove output_config field: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除工具定义中的 custom 字段
|
||||||
|
// Claude Code (v2.1.69+) 在 tool 定义中发送 custom: {defer_loading: true},
|
||||||
|
// Anthropic API 接受但 Bedrock 会拒绝并报 "Extra inputs are not permitted"
|
||||||
|
body = removeCustomFieldFromTools(body)
|
||||||
|
|
||||||
|
// 清理 cache_control 中 Bedrock 不支持的字段
|
||||||
|
body = sanitizeBedrockCacheControl(body, modelID)
|
||||||
|
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveBedrockBetaTokens computes the final Bedrock beta token list before policy filtering.
|
||||||
|
func ResolveBedrockBetaTokens(betaHeader string, body []byte, modelID string) []string {
|
||||||
|
betaTokens := parseAnthropicBetaHeader(betaHeader)
|
||||||
|
betaTokens = autoInjectBedrockBetaTokens(betaTokens, body, modelID)
|
||||||
|
return filterBedrockBetaTokens(betaTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertOutputFormatToInlineSchema 将 output_format 中的 JSON schema 内联到最后一条 user message
|
||||||
|
// Bedrock Invoke 不支持 output_format 参数,litellm 的做法是将 schema 追加到用户消息中
|
||||||
|
// 参考: litellm AmazonAnthropicClaudeMessagesConfig._convert_output_format_to_inline_schema()
|
||||||
|
func convertOutputFormatToInlineSchema(body []byte) []byte {
|
||||||
|
outputFormat := gjson.GetBytes(body, "output_format")
|
||||||
|
if !outputFormat.Exists() || !outputFormat.IsObject() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// 先从请求体中移除 output_format
|
||||||
|
body, _ = sjson.DeleteBytes(body, "output_format")
|
||||||
|
|
||||||
|
schema := outputFormat.Get("schema")
|
||||||
|
if !schema.Exists() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// 找到最后一条 user message
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
msgArr := messages.Array()
|
||||||
|
lastUserIdx := -1
|
||||||
|
for i := len(msgArr) - 1; i >= 0; i-- {
|
||||||
|
if msgArr[i].Get("role").String() == "user" {
|
||||||
|
lastUserIdx = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lastUserIdx < 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将 schema 序列化为 JSON 文本追加到该 message 的 content 数组
|
||||||
|
schemaJSON, err := json.Marshal(json.RawMessage(schema.Raw))
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
content := msgArr[lastUserIdx].Get("content")
|
||||||
|
basePath := fmt.Sprintf("messages.%d.content", lastUserIdx)
|
||||||
|
|
||||||
|
if content.IsArray() {
|
||||||
|
// 追加一个 text block 到 content 数组末尾
|
||||||
|
idx := len(content.Array())
|
||||||
|
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.type", basePath, idx), "text")
|
||||||
|
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.text", basePath, idx), string(schemaJSON))
|
||||||
|
} else if content.Type == gjson.String {
|
||||||
|
// content 是纯字符串,转换为数组格式
|
||||||
|
originalText := content.String()
|
||||||
|
body, _ = sjson.SetBytes(body, basePath, []map[string]string{
|
||||||
|
{"type": "text", "text": originalText},
|
||||||
|
{"type": "text", "text": string(schemaJSON)},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeCustomFieldFromTools 移除 tools 数组中每个工具定义的 custom 字段
|
||||||
|
func removeCustomFieldFromTools(body []byte) []byte {
|
||||||
|
tools := gjson.GetBytes(body, "tools")
|
||||||
|
if !tools.Exists() || !tools.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
for i := range tools.Array() {
|
||||||
|
body, err = sjson.DeleteBytes(body, fmt.Sprintf("tools.%d.custom", i))
|
||||||
|
if err != nil {
|
||||||
|
// 删除失败不影响整体流程,跳过
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// claudeVersionRe 匹配 Claude 模型 ID 中的版本号部分
|
||||||
|
// 支持 claude-{tier}-{major}-{minor} 和 claude-{tier}-{major}.{minor} 格式
|
||||||
|
var claudeVersionRe = regexp.MustCompile(`claude-(?:haiku|sonnet|opus)-(\d+)[-.](\d+)`)
|
||||||
|
|
||||||
|
// isBedrockClaude45OrNewer 判断 Bedrock 模型 ID 是否为 Claude 4.5 或更新版本
|
||||||
|
// Claude 4.5+ 支持 cache_control 中的 ttl 字段("5m" 和 "1h")
|
||||||
|
func isBedrockClaude45OrNewer(modelID string) bool {
|
||||||
|
lower := strings.ToLower(modelID)
|
||||||
|
matches := claudeVersionRe.FindStringSubmatch(lower)
|
||||||
|
if matches == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
major, _ := strconv.Atoi(matches[1])
|
||||||
|
minor, _ := strconv.Atoi(matches[2])
|
||||||
|
return major > 4 || (major == 4 && minor >= 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeBedrockCacheControl 清理 system 和 messages 中 cache_control 里
|
||||||
|
// Bedrock 不支持的字段:
|
||||||
|
// - scope:Bedrock 不支持(如 "global" 跨请求缓存)
|
||||||
|
// - ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",旧模型需要移除
|
||||||
|
func sanitizeBedrockCacheControl(body []byte, modelID string) []byte {
|
||||||
|
isClaude45 := isBedrockClaude45OrNewer(modelID)
|
||||||
|
|
||||||
|
// 清理 system 数组中的 cache_control
|
||||||
|
systemArr := gjson.GetBytes(body, "system")
|
||||||
|
if systemArr.Exists() && systemArr.IsArray() {
|
||||||
|
for i, item := range systemArr.Array() {
|
||||||
|
if !item.IsObject() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cc := item.Get("cache_control")
|
||||||
|
if !cc.Exists() || !cc.IsObject() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("system.%d.cache_control", i), cc, isClaude45)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理 messages 中的 cache_control
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
for mi, msg := range messages.Array() {
|
||||||
|
if !msg.IsObject() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content := msg.Get("content")
|
||||||
|
if !content.Exists() || !content.IsArray() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for ci, block := range content.Array() {
|
||||||
|
if !block.IsObject() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cc := block.Get("cache_control")
|
||||||
|
if !cc.Exists() || !cc.IsObject() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("messages.%d.content.%d.cache_control", mi, ci), cc, isClaude45)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteCacheControlUnsupportedFields 删除给定 cache_control 路径下 Bedrock 不支持的字段
|
||||||
|
func deleteCacheControlUnsupportedFields(body []byte, basePath string, cc gjson.Result, isClaude45 bool) []byte {
|
||||||
|
// Bedrock 不支持 scope(如 "global")
|
||||||
|
if cc.Get("scope").Exists() {
|
||||||
|
body, _ = sjson.DeleteBytes(body, basePath+".scope")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",其余情况移除
|
||||||
|
ttl := cc.Get("ttl")
|
||||||
|
if ttl.Exists() {
|
||||||
|
shouldRemove := true
|
||||||
|
if isClaude45 {
|
||||||
|
v := ttl.String()
|
||||||
|
if v == "5m" || v == "1h" {
|
||||||
|
shouldRemove = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if shouldRemove {
|
||||||
|
body, _ = sjson.DeleteBytes(body, basePath+".ttl")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAnthropicBetaHeader 解析 anthropic-beta 头的逗号分隔字符串为 token 列表
|
||||||
|
func parseAnthropicBetaHeader(header string) []string {
|
||||||
|
header = strings.TrimSpace(header)
|
||||||
|
if header == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(header, "[") && strings.HasSuffix(header, "]") {
|
||||||
|
var parsed []any
|
||||||
|
if err := json.Unmarshal([]byte(header), &parsed); err == nil {
|
||||||
|
tokens := make([]string, 0, len(parsed))
|
||||||
|
for _, item := range parsed {
|
||||||
|
token := strings.TrimSpace(fmt.Sprint(item))
|
||||||
|
if token != "" {
|
||||||
|
tokens = append(tokens, token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var tokens []string
|
||||||
|
for _, part := range strings.Split(header, ",") {
|
||||||
|
t := strings.TrimSpace(part)
|
||||||
|
if t != "" {
|
||||||
|
tokens = append(tokens, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单
|
||||||
|
// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json)
|
||||||
|
// 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单
|
||||||
|
var bedrockSupportedBetaTokens = map[string]bool{
|
||||||
|
"computer-use-2025-01-24": true,
|
||||||
|
"computer-use-2025-11-24": true,
|
||||||
|
"context-1m-2025-08-07": true,
|
||||||
|
"context-management-2025-06-27": true,
|
||||||
|
"compact-2026-01-12": true,
|
||||||
|
"interleaved-thinking-2025-05-14": true,
|
||||||
|
"tool-search-tool-2025-10-19": true,
|
||||||
|
"tool-examples-2025-10-29": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则
|
||||||
|
// Anthropic 直接 API 使用通用头,Bedrock Invoke 需要特定的替代头
|
||||||
|
var bedrockBetaTokenTransforms = map[string]string{
|
||||||
|
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
|
||||||
|
}
|
||||||
|
|
||||||
|
// autoInjectBedrockBetaTokens 根据请求体内容自动补齐必要的 beta token
|
||||||
|
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() 和
|
||||||
|
// AmazonAnthropicClaudeMessagesConfig._get_tool_search_beta_header_for_bedrock()
|
||||||
|
//
|
||||||
|
// 客户端(特别是非 Claude Code 客户端)可能只在 body 中启用了功能而不在 header 中带对应 beta token,
|
||||||
|
// 这里通过检测请求体特征自动补齐,确保 Bedrock Invoke 不会因缺少必要 beta 头而 400。
|
||||||
|
func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) []string {
|
||||||
|
seen := make(map[string]bool, len(tokens))
|
||||||
|
for _, t := range tokens {
|
||||||
|
seen[t] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
inject := func(token string) {
|
||||||
|
if !seen[token] {
|
||||||
|
tokens = append(tokens, token)
|
||||||
|
seen[token] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检测 thinking / interleaved thinking
|
||||||
|
// 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta
|
||||||
|
if gjson.GetBytes(body, "thinking").Exists() {
|
||||||
|
inject("interleaved-thinking-2025-05-14")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检测 computer_use 工具
|
||||||
|
// tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta
|
||||||
|
tools := gjson.GetBytes(body, "tools")
|
||||||
|
if tools.Exists() && tools.IsArray() {
|
||||||
|
toolSearchUsed := false
|
||||||
|
programmaticToolCallingUsed := false
|
||||||
|
inputExamplesUsed := false
|
||||||
|
for _, tool := range tools.Array() {
|
||||||
|
toolType := tool.Get("type").String()
|
||||||
|
if strings.HasPrefix(toolType, "computer_20") {
|
||||||
|
inject("computer-use-2025-11-24")
|
||||||
|
}
|
||||||
|
if isBedrockToolSearchType(toolType) {
|
||||||
|
toolSearchUsed = true
|
||||||
|
}
|
||||||
|
if hasCodeExecutionAllowedCallers(tool) {
|
||||||
|
programmaticToolCallingUsed = true
|
||||||
|
}
|
||||||
|
if hasInputExamples(tool) {
|
||||||
|
inputExamplesUsed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if programmaticToolCallingUsed || inputExamplesUsed {
|
||||||
|
// programmatic tool calling 和 input examples 需要 advanced-tool-use,
|
||||||
|
// 后续 filterBedrockBetaTokens 会将其转换为 Bedrock 特定的 tool-search-tool
|
||||||
|
inject("advanced-tool-use-2025-11-20")
|
||||||
|
}
|
||||||
|
if toolSearchUsed && bedrockModelSupportsToolSearch(modelID) {
|
||||||
|
// 纯 tool search(无 programmatic/inputExamples)时直接注入 Bedrock 特定头,
|
||||||
|
// 跳过 advanced-tool-use → tool-search-tool 的转换步骤(与 litellm 对齐)
|
||||||
|
if !programmaticToolCallingUsed && !inputExamplesUsed {
|
||||||
|
inject("tool-search-tool-2025-10-19")
|
||||||
|
} else {
|
||||||
|
inject("advanced-tool-use-2025-11-20")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBedrockToolSearchType(toolType string) bool {
|
||||||
|
return toolType == "tool_search_tool_regex_20251119" || toolType == "tool_search_tool_bm25_20251119"
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasCodeExecutionAllowedCallers(tool gjson.Result) bool {
|
||||||
|
allowedCallers := tool.Get("allowed_callers")
|
||||||
|
if containsStringInJSONArray(allowedCallers, "code_execution_20250825") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return containsStringInJSONArray(tool.Get("function.allowed_callers"), "code_execution_20250825")
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasInputExamples(tool gjson.Result) bool {
|
||||||
|
if arr := tool.Get("input_examples"); arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
arr := tool.Get("function.input_examples")
|
||||||
|
return arr.Exists() && arr.IsArray() && len(arr.Array()) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsStringInJSONArray(result gjson.Result, target string) bool {
|
||||||
|
if !result.Exists() || !result.IsArray() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range result.Array() {
|
||||||
|
if item.String() == target {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// bedrockModelSupportsToolSearch 判断 Bedrock 模型是否支持 tool search
|
||||||
|
// 目前仅 Claude Opus/Sonnet 4.5+ 支持,Haiku 不支持
|
||||||
|
func bedrockModelSupportsToolSearch(modelID string) bool {
|
||||||
|
lower := strings.ToLower(modelID)
|
||||||
|
matches := claudeVersionRe.FindStringSubmatch(lower)
|
||||||
|
if matches == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Haiku 不支持 tool search
|
||||||
|
if strings.Contains(lower, "haiku") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
major, _ := strconv.Atoi(matches[1])
|
||||||
|
minor, _ := strconv.Atoi(matches[2])
|
||||||
|
return major > 4 || (major == 4 && minor >= 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterBedrockBetaTokens 过滤并转换 beta token 列表,仅保留 Bedrock Invoke 支持的 token
|
||||||
|
// 1. 应用转换规则(如 advanced-tool-use → tool-search-tool)
|
||||||
|
// 2. 过滤掉 Bedrock 不支持的 token(如 output-128k, files-api, structured-outputs 等)
|
||||||
|
// 3. 自动关联 tool-examples(当 tool-search-tool 存在时)
|
||||||
|
func filterBedrockBetaTokens(tokens []string) []string {
|
||||||
|
seen := make(map[string]bool, len(tokens))
|
||||||
|
var result []string
|
||||||
|
|
||||||
|
for _, t := range tokens {
|
||||||
|
// 应用转换规则
|
||||||
|
if replacement, ok := bedrockBetaTokenTransforms[t]; ok {
|
||||||
|
t = replacement
|
||||||
|
}
|
||||||
|
// 只保留白名单中的 token,且去重
|
||||||
|
if bedrockSupportedBetaTokens[t] && !seen[t] {
|
||||||
|
result = append(result, t)
|
||||||
|
seen[t] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 自动关联: tool-search-tool 存在时,确保 tool-examples 也存在
|
||||||
|
if seen["tool-search-tool-2025-10-19"] && !seen["tool-examples-2025-10-29"] {
|
||||||
|
result = append(result, "tool-examples-2025-10-29")
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
659
backend/internal/service/bedrock_request_test.go
Normal file
659
backend/internal/service/bedrock_request_test.go
Normal file
@@ -0,0 +1,659 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPrepareBedrockRequestBody_BasicFields(t *testing.T) {
|
||||||
|
input := `{"model":"claude-opus-4-6","stream":true,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// anthropic_version 应被注入
|
||||||
|
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
|
||||||
|
// model 和 stream 应被移除
|
||||||
|
assert.False(t, gjson.GetBytes(result, "model").Exists())
|
||||||
|
assert.False(t, gjson.GetBytes(result, "stream").Exists())
|
||||||
|
// max_tokens 应保留
|
||||||
|
assert.Equal(t, int64(1024), gjson.GetBytes(result, "max_tokens").Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareBedrockRequestBody_OutputFormatInlineSchema(t *testing.T) {
|
||||||
|
t.Run("schema inlined into last user message array content", func(t *testing.T) {
|
||||||
|
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}},"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||||
|
// schema 应内联到最后一条 user message 的 content 数组末尾
|
||||||
|
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
|
||||||
|
require.Len(t, contentArr, 2)
|
||||||
|
assert.Equal(t, "text", contentArr[1].Get("type").String())
|
||||||
|
assert.Contains(t, contentArr[1].Get("text").String(), `"name":"string"`)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("schema inlined into string content", func(t *testing.T) {
|
||||||
|
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"result":"number"}},"messages":[{"role":"user","content":"compute this"}]}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||||
|
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
|
||||||
|
require.Len(t, contentArr, 2)
|
||||||
|
assert.Equal(t, "compute this", contentArr[0].Get("text").String())
|
||||||
|
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"number"`)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no schema field just removes output_format", func(t *testing.T) {
|
||||||
|
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json"},"messages":[{"role":"user","content":"hi"}]}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no messages just removes output_format", func(t *testing.T) {
|
||||||
|
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}}}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareBedrockRequestBody_RemoveOutputConfig(t *testing.T) {
|
||||||
|
input := `{"model":"claude-sonnet-4-5","output_config":{"max_tokens":100},"messages":[]}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoveCustomFieldFromTools(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"tools": [
|
||||||
|
{"name":"tool1","custom":{"defer_loading":true},"description":"desc1"},
|
||||||
|
{"name":"tool2","description":"desc2"},
|
||||||
|
{"name":"tool3","custom":{"defer_loading":true,"other":123},"description":"desc3"}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
result := removeCustomFieldFromTools([]byte(input))
|
||||||
|
|
||||||
|
tools := gjson.GetBytes(result, "tools").Array()
|
||||||
|
require.Len(t, tools, 3)
|
||||||
|
// custom 应被移除
|
||||||
|
assert.False(t, tools[0].Get("custom").Exists())
|
||||||
|
// name/description 应保留
|
||||||
|
assert.Equal(t, "tool1", tools[0].Get("name").String())
|
||||||
|
assert.Equal(t, "desc1", tools[0].Get("description").String())
|
||||||
|
// 没有 custom 的工具不受影响
|
||||||
|
assert.Equal(t, "tool2", tools[1].Get("name").String())
|
||||||
|
// 第三个工具的 custom 也应被移除
|
||||||
|
assert.False(t, tools[2].Get("custom").Exists())
|
||||||
|
assert.Equal(t, "tool3", tools[2].Get("name").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoveCustomFieldFromTools_NoTools(t *testing.T) {
|
||||||
|
input := `{"messages":[{"role":"user","content":"hi"}]}`
|
||||||
|
result := removeCustomFieldFromTools([]byte(input))
|
||||||
|
// 无 tools 时不改变原始数据
|
||||||
|
assert.JSONEq(t, input, string(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeBedrockCacheControl_RemoveScope(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","scope":"global"}}],
|
||||||
|
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"global"}}]}]
|
||||||
|
}`
|
||||||
|
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
|
||||||
|
// scope 应被移除
|
||||||
|
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
|
||||||
|
assert.False(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.scope").Exists())
|
||||||
|
// type 应保留
|
||||||
|
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
|
||||||
|
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "messages.0.content.0.cache_control.type").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeBedrockCacheControl_TTL_OldModel(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
|
||||||
|
}`
|
||||||
|
// 旧模型(Claude 3.5)不支持 ttl
|
||||||
|
result := sanitizeBedrockCacheControl([]byte(input), "anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||||
|
|
||||||
|
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
|
||||||
|
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeBedrockCacheControl_TTL_Claude45_Supported(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
|
||||||
|
}`
|
||||||
|
// Claude 4.5+ 支持 "5m" 和 "1h"
|
||||||
|
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
|
||||||
|
|
||||||
|
assert.True(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
|
||||||
|
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeBedrockCacheControl_TTL_Claude45_UnsupportedValue(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"10m"}}]
|
||||||
|
}`
|
||||||
|
// Claude 4.5 不支持 "10m"
|
||||||
|
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
|
||||||
|
|
||||||
|
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeBedrockCacheControl_TTL_Claude46(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
|
||||||
|
}`
|
||||||
|
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
|
||||||
|
assert.True(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").Exists())
|
||||||
|
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeBedrockCacheControl_NoCacheControl(t *testing.T) {
|
||||||
|
input := `{"system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`
|
||||||
|
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
// 无 cache_control 时不改变原始数据
|
||||||
|
assert.JSONEq(t, input, string(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsBedrockClaude45OrNewer(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
modelID string
|
||||||
|
expect bool
|
||||||
|
}{
|
||||||
|
{"us.anthropic.claude-opus-4-6-v1", true},
|
||||||
|
{"us.anthropic.claude-sonnet-4-6", true},
|
||||||
|
{"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true},
|
||||||
|
{"us.anthropic.claude-opus-4-5-20251101-v1:0", true},
|
||||||
|
{"us.anthropic.claude-haiku-4-5-20251001-v1:0", true},
|
||||||
|
{"anthropic.claude-3-5-sonnet-20241022-v2:0", false},
|
||||||
|
{"anthropic.claude-3-opus-20240229-v1:0", false},
|
||||||
|
{"anthropic.claude-3-haiku-20240307-v1:0", false},
|
||||||
|
// 未来版本应自动支持
|
||||||
|
{"us.anthropic.claude-sonnet-5-0-v1", true},
|
||||||
|
{"us.anthropic.claude-opus-4-7-v1", true},
|
||||||
|
// 旧版本
|
||||||
|
{"anthropic.claude-opus-4-1-v1", false},
|
||||||
|
{"anthropic.claude-sonnet-4-0-v1", false},
|
||||||
|
// 非 Claude 模型
|
||||||
|
{"amazon.nova-pro-v1", false},
|
||||||
|
{"meta.llama3-70b", false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.modelID, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expect, isBedrockClaude45OrNewer(tt.modelID))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) {
|
||||||
|
// 模拟一个完整的 Claude Code 请求
|
||||||
|
input := `{
|
||||||
|
"model": "claude-opus-4-6",
|
||||||
|
"stream": true,
|
||||||
|
"max_tokens": 16384,
|
||||||
|
"output_format": {"type": "json", "schema": {"result": "string"}},
|
||||||
|
"output_config": {"max_tokens": 100},
|
||||||
|
"system": [{"type": "text", "text": "You are helpful", "cache_control": {"type": "ephemeral", "scope": "global", "ttl": "5m"}}],
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "hello", "cache_control": {"type": "ephemeral", "ttl": "1h"}}]}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{"name": "bash", "description": "Run bash", "custom": {"defer_loading": true}, "input_schema": {"type": "object"}},
|
||||||
|
{"name": "read", "description": "Read file", "input_schema": {"type": "object"}}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12"
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 基本字段
|
||||||
|
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
|
||||||
|
assert.False(t, gjson.GetBytes(result, "model").Exists())
|
||||||
|
assert.False(t, gjson.GetBytes(result, "stream").Exists())
|
||||||
|
assert.Equal(t, int64(16384), gjson.GetBytes(result, "max_tokens").Int())
|
||||||
|
|
||||||
|
// anthropic_beta 应包含所有 beta tokens
|
||||||
|
betaArr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
require.Len(t, betaArr, 3)
|
||||||
|
assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String())
|
||||||
|
assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String())
|
||||||
|
assert.Equal(t, "compact-2026-01-12", betaArr[2].String())
|
||||||
|
|
||||||
|
// output_format 应被移除,schema 内联到最后一条 user message
|
||||||
|
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||||
|
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
|
||||||
|
// content 数组:原始 text block + 内联 schema block
|
||||||
|
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
|
||||||
|
require.Len(t, contentArr, 2)
|
||||||
|
assert.Equal(t, "hello", contentArr[0].Get("text").String())
|
||||||
|
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"string"`)
|
||||||
|
|
||||||
|
// tools 中的 custom 应被移除
|
||||||
|
assert.False(t, gjson.GetBytes(result, "tools.0.custom").Exists())
|
||||||
|
assert.Equal(t, "bash", gjson.GetBytes(result, "tools.0.name").String())
|
||||||
|
assert.Equal(t, "read", gjson.GetBytes(result, "tools.1.name").String())
|
||||||
|
|
||||||
|
// cache_control: scope 应被移除,ttl 在 Claude 4.6 上保留合法值
|
||||||
|
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
|
||||||
|
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
|
||||||
|
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
|
||||||
|
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) {
|
||||||
|
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
|
||||||
|
|
||||||
|
t.Run("empty beta header", func(t *testing.T) {
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("single beta token", func(t *testing.T) {
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14")
|
||||||
|
require.NoError(t, err)
|
||||||
|
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
require.Len(t, arr, 1)
|
||||||
|
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple beta tokens with spaces", func(t *testing.T) {
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ")
|
||||||
|
require.NoError(t, err)
|
||||||
|
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
require.Len(t, arr, 2)
|
||||||
|
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||||
|
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("json array beta header", func(t *testing.T) {
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
require.Len(t, arr, 2)
|
||||||
|
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||||
|
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseAnthropicBetaHeader(t *testing.T) {
|
||||||
|
assert.Nil(t, parseAnthropicBetaHeader(""))
|
||||||
|
assert.Equal(t, []string{"a"}, parseAnthropicBetaHeader("a"))
|
||||||
|
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a,b"))
|
||||||
|
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a , b "))
|
||||||
|
assert.Equal(t, []string{"a", "b", "c"}, parseAnthropicBetaHeader("a,b,c"))
|
||||||
|
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader(`["a","b"]`))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterBedrockBetaTokens(t *testing.T) {
|
||||||
|
t.Run("supported tokens pass through", func(t *testing.T) {
|
||||||
|
tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"}
|
||||||
|
result := filterBedrockBetaTokens(tokens)
|
||||||
|
assert.Equal(t, tokens, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported tokens are filtered out", func(t *testing.T) {
|
||||||
|
tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"}
|
||||||
|
result := filterBedrockBetaTokens(tokens)
|
||||||
|
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) {
|
||||||
|
tokens := []string{"advanced-tool-use-2025-11-20"}
|
||||||
|
result := filterBedrockBetaTokens(tokens)
|
||||||
|
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||||
|
// tool-examples 自动关联
|
||||||
|
assert.Contains(t, result, "tool-examples-2025-10-29")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tool-search-tool auto-associates tool-examples", func(t *testing.T) {
|
||||||
|
tokens := []string{"tool-search-tool-2025-10-19"}
|
||||||
|
result := filterBedrockBetaTokens(tokens)
|
||||||
|
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||||
|
assert.Contains(t, result, "tool-examples-2025-10-29")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no duplication when tool-examples already present", func(t *testing.T) {
|
||||||
|
tokens := []string{"tool-search-tool-2025-10-19", "tool-examples-2025-10-29"}
|
||||||
|
result := filterBedrockBetaTokens(tokens)
|
||||||
|
count := 0
|
||||||
|
for _, t := range result {
|
||||||
|
if t == "tool-examples-2025-10-29" {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, 1, count)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty input returns nil", func(t *testing.T) {
|
||||||
|
result := filterBedrockBetaTokens(nil)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("all unsupported returns nil", func(t *testing.T) {
|
||||||
|
result := filterBedrockBetaTokens([]string{"output-128k-2025-02-19", "effort-2025-11-24"})
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("duplicate tokens are deduplicated", func(t *testing.T) {
|
||||||
|
tokens := []string{"context-1m-2025-08-07", "context-1m-2025-08-07"}
|
||||||
|
result := filterBedrockBetaTokens(tokens)
|
||||||
|
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) {
|
||||||
|
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
|
||||||
|
|
||||||
|
t.Run("unsupported beta tokens are filtered", func(t *testing.T) {
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
|
||||||
|
"interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14")
|
||||||
|
require.NoError(t, err)
|
||||||
|
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
require.Len(t, arr, 1)
|
||||||
|
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) {
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
|
||||||
|
"advanced-tool-use-2025-11-20")
|
||||||
|
require.NoError(t, err)
|
||||||
|
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
require.Len(t, arr, 2)
|
||||||
|
assert.Equal(t, "tool-search-tool-2025-10-19", arr[0].String())
|
||||||
|
assert.Equal(t, "tool-examples-2025-10-29", arr[1].String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBedrockCrossRegionPrefix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
region string
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
// US regions
|
||||||
|
{"us-east-1", "us"},
|
||||||
|
{"us-east-2", "us"},
|
||||||
|
{"us-west-1", "us"},
|
||||||
|
{"us-west-2", "us"},
|
||||||
|
// GovCloud
|
||||||
|
{"us-gov-east-1", "us-gov"},
|
||||||
|
{"us-gov-west-1", "us-gov"},
|
||||||
|
// EU regions
|
||||||
|
{"eu-west-1", "eu"},
|
||||||
|
{"eu-west-2", "eu"},
|
||||||
|
{"eu-west-3", "eu"},
|
||||||
|
{"eu-central-1", "eu"},
|
||||||
|
{"eu-central-2", "eu"},
|
||||||
|
{"eu-north-1", "eu"},
|
||||||
|
{"eu-south-1", "eu"},
|
||||||
|
// APAC regions
|
||||||
|
{"ap-northeast-1", "jp"},
|
||||||
|
{"ap-northeast-2", "apac"},
|
||||||
|
{"ap-southeast-1", "apac"},
|
||||||
|
{"ap-southeast-2", "au"},
|
||||||
|
{"ap-south-1", "apac"},
|
||||||
|
// Canada / South America fallback to us
|
||||||
|
{"ca-central-1", "us"},
|
||||||
|
{"sa-east-1", "us"},
|
||||||
|
// Unknown defaults to us
|
||||||
|
{"me-south-1", "us"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.region, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expect, BedrockCrossRegionPrefix(tt.region))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveBedrockModelID(t *testing.T) {
|
||||||
|
t.Run("default alias resolves and adjusts region", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeBedrock,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"aws_region": "eu-west-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-5")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", modelID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("custom alias mapping reuses default bedrock mapping", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeBedrock,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"aws_region": "ap-southeast-2",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"claude-*": "claude-opus-4-6",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
modelID, ok := ResolveBedrockModelID(account, "claude-opus-4-6-thinking")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("force global rewrites anthropic regional model id", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeBedrock,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"aws_region": "us-east-1",
|
||||||
|
"aws_force_global": "true",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-6")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "global.anthropic.claude-sonnet-4-6", modelID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("direct bedrock model id passes through", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeBedrock,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"aws_region": "us-east-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
modelID, ok := ResolveBedrockModelID(account, "anthropic.claude-haiku-4-5-20251001-v1:0")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "anthropic.claude-haiku-4-5-20251001-v1:0", modelID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported alias returns false", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeBedrock,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"aws_region": "us-east-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := ResolveBedrockModelID(account, "claude-3-5-sonnet-20241022")
|
||||||
|
assert.False(t, ok)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAutoInjectBedrockBetaTokens(t *testing.T) {
|
||||||
|
t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) {
|
||||||
|
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no duplicate when already present", func(t *testing.T) {
|
||||||
|
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
count := 0
|
||||||
|
for _, t := range result {
|
||||||
|
if t == "interleaved-thinking-2025-05-14" {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, 1, count)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("inject computer-use when computer tool present", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"type":"computer_20250124","name":"computer","display_width_px":1024}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Contains(t, result, "computer-use-2025-11-24")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("inject advanced-tool-use for programmatic tool calling", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("inject advanced-tool-use for input examples", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"name":"bash","input_examples":[{"cmd":"ls"}]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("inject tool-search-tool directly for pure tool search (no programmatic/inputExamples)", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
|
||||||
|
// 纯 tool search 场景直接注入 Bedrock 特定头,不走 advanced-tool-use 转换
|
||||||
|
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||||
|
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("inject advanced-tool-use when tool search combined with programmatic calling", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"},{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
|
||||||
|
// 混合场景使用 advanced-tool-use(后续由 filter 转换为 tool-search-tool)
|
||||||
|
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("do not inject tool-search beta for unsupported models", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||||
|
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
|
||||||
|
assert.NotContains(t, result, "tool-search-tool-2025-10-19")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no injection for regular tools", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"name":"bash","description":"run bash","input_schema":{"type":"object"}}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Empty(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no injection when no features detected", func(t *testing.T) {
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Empty(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves existing tokens", func(t *testing.T) {
|
||||||
|
body := []byte(`{"thinking":{"type":"enabled"},"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
existing := []string{"context-1m-2025-08-07", "compact-2026-01-12"}
|
||||||
|
result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Contains(t, result, "context-1m-2025-08-07")
|
||||||
|
assert.Contains(t, result, "compact-2026-01-12")
|
||||||
|
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveBedrockBetaTokens(t *testing.T) {
|
||||||
|
t.Run("body-only tool features resolve to final bedrock tokens", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := ResolveBedrockBetaTokens("", body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||||
|
assert.Contains(t, result, "tool-examples-2025-10-29")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) {
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) {
|
||||||
|
t.Run("thinking in body auto-injects beta without header", func(t *testing.T) {
|
||||||
|
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
found := false
|
||||||
|
for _, v := range arr {
|
||||||
|
if v.String() == "interleaved-thinking-2025-05-14" {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, found, "interleaved-thinking should be auto-injected")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) {
|
||||||
|
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07")
|
||||||
|
require.NoError(t, err)
|
||||||
|
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
names := make([]string, len(arr))
|
||||||
|
for i, v := range arr {
|
||||||
|
names[i] = v.String()
|
||||||
|
}
|
||||||
|
assert.Contains(t, names, "context-1m-2025-08-07")
|
||||||
|
assert.Contains(t, names, "interleaved-thinking-2025-05-14")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdjustBedrockModelRegionPrefix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
modelID string
|
||||||
|
region string
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
// US region — no change needed
|
||||||
|
{"us region keeps us prefix", "us.anthropic.claude-opus-4-6-v1", "us-east-1", "us.anthropic.claude-opus-4-6-v1"},
|
||||||
|
// EU region — replace us → eu
|
||||||
|
{"eu region replaces prefix", "us.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
|
||||||
|
{"eu region sonnet", "us.anthropic.claude-sonnet-4-6", "eu-central-1", "eu.anthropic.claude-sonnet-4-6"},
|
||||||
|
// APAC region — jp and au have dedicated prefixes per AWS docs
|
||||||
|
{"jp region (ap-northeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-northeast-1", "jp.anthropic.claude-sonnet-4-5-20250929-v1:0"},
|
||||||
|
{"au region (ap-southeast-2)", "us.anthropic.claude-haiku-4-5-20251001-v1:0", "ap-southeast-2", "au.anthropic.claude-haiku-4-5-20251001-v1:0"},
|
||||||
|
{"apac region (ap-southeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-southeast-1", "apac.anthropic.claude-sonnet-4-5-20250929-v1:0"},
|
||||||
|
// eu → us (user manually set eu prefix, moved to us region)
|
||||||
|
{"eu to us", "eu.anthropic.claude-opus-4-6-v1", "us-west-2", "us.anthropic.claude-opus-4-6-v1"},
|
||||||
|
// global prefix — replace to match region
|
||||||
|
{"global to eu", "global.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
|
||||||
|
// No known prefix — leave unchanged
|
||||||
|
{"no prefix unchanged", "anthropic.claude-3-5-sonnet-20241022-v2:0", "eu-west-1", "anthropic.claude-3-5-sonnet-20241022-v2:0"},
|
||||||
|
// GovCloud — uses independent us-gov prefix
|
||||||
|
{"govcloud from us", "us.anthropic.claude-opus-4-6-v1", "us-gov-east-1", "us-gov.anthropic.claude-opus-4-6-v1"},
|
||||||
|
{"govcloud already correct", "us-gov.anthropic.claude-opus-4-6-v1", "us-gov-west-1", "us-gov.anthropic.claude-opus-4-6-v1"},
|
||||||
|
// Force global (special region value)
|
||||||
|
{"force global from us", "us.anthropic.claude-opus-4-6-v1", "global", "global.anthropic.claude-opus-4-6-v1"},
|
||||||
|
{"force global from eu", "eu.anthropic.claude-sonnet-4-6", "global", "global.anthropic.claude-sonnet-4-6"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expect, AdjustBedrockModelRegionPrefix(tt.modelID, tt.region))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
67
backend/internal/service/bedrock_signer.go
Normal file
67
backend/internal/service/bedrock_signer.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
|
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BedrockSigner 使用 AWS SigV4 对 Bedrock 请求签名
|
||||||
|
type BedrockSigner struct {
|
||||||
|
credentials aws.Credentials
|
||||||
|
region string
|
||||||
|
signer *v4.Signer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBedrockSigner 创建 BedrockSigner
|
||||||
|
func NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region string) *BedrockSigner {
|
||||||
|
return &BedrockSigner{
|
||||||
|
credentials: aws.Credentials{
|
||||||
|
AccessKeyID: accessKeyID,
|
||||||
|
SecretAccessKey: secretAccessKey,
|
||||||
|
SessionToken: sessionToken,
|
||||||
|
},
|
||||||
|
region: region,
|
||||||
|
signer: v4.NewSigner(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBedrockSignerFromAccount 从 Account 凭证创建 BedrockSigner
|
||||||
|
func NewBedrockSignerFromAccount(account *Account) (*BedrockSigner, error) {
|
||||||
|
accessKeyID := account.GetCredential("aws_access_key_id")
|
||||||
|
if accessKeyID == "" {
|
||||||
|
return nil, fmt.Errorf("aws_access_key_id not found in credentials")
|
||||||
|
}
|
||||||
|
secretAccessKey := account.GetCredential("aws_secret_access_key")
|
||||||
|
if secretAccessKey == "" {
|
||||||
|
return nil, fmt.Errorf("aws_secret_access_key not found in credentials")
|
||||||
|
}
|
||||||
|
region := account.GetCredential("aws_region")
|
||||||
|
if region == "" {
|
||||||
|
region = defaultBedrockRegion
|
||||||
|
}
|
||||||
|
sessionToken := account.GetCredential("aws_session_token") // 可选
|
||||||
|
|
||||||
|
return NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignRequest 对 HTTP 请求进行 SigV4 签名
|
||||||
|
// 重要约束:调用此方法前,req 应只包含 AWS 相关的 header(如 Content-Type、Accept)。
|
||||||
|
// 非 AWS header(如 anthropic-beta)会参与签名计算,如果 Bedrock 服务端不识别这些 header,
|
||||||
|
// 签名验证可能失败。litellm 通过 _filter_headers_for_aws_signature 实现头过滤,
|
||||||
|
// 当前实现中 buildUpstreamRequestBedrock 仅设置了 Content-Type 和 Accept,因此是安全的。
|
||||||
|
func (s *BedrockSigner) SignRequest(ctx context.Context, req *http.Request, body []byte) error {
|
||||||
|
payloadHash := sha256Hash(body)
|
||||||
|
return s.signer.SignHTTP(ctx, s.credentials, req, payloadHash, "bedrock", s.region, time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
func sha256Hash(data []byte) string {
|
||||||
|
h := sha256.Sum256(data)
|
||||||
|
return hex.EncodeToString(h[:])
|
||||||
|
}
|
||||||
35
backend/internal/service/bedrock_signer_test.go
Normal file
35
backend/internal/service/bedrock_signer_test.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewBedrockSignerFromAccount_DefaultRegion(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeBedrock,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"aws_access_key_id": "test-akid",
|
||||||
|
"aws_secret_access_key": "test-secret",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
signer, err := NewBedrockSignerFromAccount(account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, signer)
|
||||||
|
assert.Equal(t, defaultBedrockRegion, signer.region)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterBetaTokens(t *testing.T) {
|
||||||
|
tokens := []string{"interleaved-thinking-2025-05-14", "tool-search-tool-2025-10-19"}
|
||||||
|
filterSet := map[string]struct{}{
|
||||||
|
"tool-search-tool-2025-10-19": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, filterBetaTokens(tokens, filterSet))
|
||||||
|
assert.Equal(t, tokens, filterBetaTokens(tokens, nil))
|
||||||
|
assert.Nil(t, filterBetaTokens(nil, filterSet))
|
||||||
|
}
|
||||||
414
backend/internal/service/bedrock_stream.go
Normal file
414
backend/internal/service/bedrock_stream.go
Normal file
@@ -0,0 +1,414 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"hash/crc32"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handleBedrockStreamingResponse 处理 Bedrock InvokeModelWithResponseStream 的 EventStream 响应
|
||||||
|
// Bedrock 返回 AWS EventStream 二进制格式,每个事件的 payload 中 chunk.bytes 是 base64 编码的
|
||||||
|
// Claude SSE 事件 JSON。本方法解码后转换为标准 SSE 格式写入客户端。
|
||||||
|
func (s *GatewayService) handleBedrockStreamingResponse(
|
||||||
|
ctx context.Context,
|
||||||
|
resp *http.Response,
|
||||||
|
c *gin.Context,
|
||||||
|
account *Account,
|
||||||
|
startTime time.Time,
|
||||||
|
model string,
|
||||||
|
) (*streamingResult, error) {
|
||||||
|
w := c.Writer
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("streaming not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
if v := resp.Header.Get("x-amzn-requestid"); v != "" {
|
||||||
|
c.Header("x-request-id", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &ClaudeUsage{}
|
||||||
|
var firstTokenMs *int
|
||||||
|
clientDisconnected := false
|
||||||
|
|
||||||
|
// Bedrock EventStream 使用 application/vnd.amazon.eventstream 二进制格式。
|
||||||
|
// 每个帧结构:total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4)
|
||||||
|
// 但更实用的方式是使用行扫描找 JSON chunks,因为 Bedrock 的响应在二进制帧中。
|
||||||
|
// 我们使用 EventStream decoder 来正确解析。
|
||||||
|
decoder := newBedrockEventStreamDecoder(resp.Body)
|
||||||
|
|
||||||
|
type decodeEvent struct {
|
||||||
|
payload []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
events := make(chan decodeEvent, 16)
|
||||||
|
done := make(chan struct{})
|
||||||
|
sendEvent := func(ev decodeEvent) bool {
|
||||||
|
select {
|
||||||
|
case events <- ev:
|
||||||
|
return true
|
||||||
|
case <-done:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var lastReadAt atomic.Int64
|
||||||
|
lastReadAt.Store(time.Now().UnixNano())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(events)
|
||||||
|
for {
|
||||||
|
payload, err := decoder.Decode()
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = sendEvent(decodeEvent{err: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lastReadAt.Store(time.Now().UnixNano())
|
||||||
|
if !sendEvent(decodeEvent{payload: payload}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
streamInterval := time.Duration(0)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||||
|
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||||
|
}
|
||||||
|
var intervalTicker *time.Ticker
|
||||||
|
if streamInterval > 0 {
|
||||||
|
intervalTicker = time.NewTicker(streamInterval)
|
||||||
|
defer intervalTicker.Stop()
|
||||||
|
}
|
||||||
|
var intervalCh <-chan time.Time
|
||||||
|
if intervalTicker != nil {
|
||||||
|
intervalCh = intervalTicker.C
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev, ok := <-events:
|
||||||
|
if !ok {
|
||||||
|
if !clientDisconnected {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||||
|
}
|
||||||
|
if ev.err != nil {
|
||||||
|
if clientDisconnected {
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
|
}
|
||||||
|
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
|
}
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("bedrock stream read error: %w", ev.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// payload 是 JSON,提取 chunk.bytes(base64 编码的 Claude SSE 事件数据)
|
||||||
|
sseData := extractBedrockChunkData(ev.payload)
|
||||||
|
if sseData == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if firstTokenMs == nil {
|
||||||
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
|
firstTokenMs = &ms
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式
|
||||||
|
// 同时移除该字段避免透传给客户端
|
||||||
|
sseData = transformBedrockInvocationMetrics(sseData)
|
||||||
|
|
||||||
|
// 解析 SSE 事件数据提取 usage
|
||||||
|
s.parseSSEUsagePassthrough(string(sseData), usage)
|
||||||
|
|
||||||
|
// 确定 SSE event type
|
||||||
|
eventType := gjson.GetBytes(sseData, "type").String()
|
||||||
|
|
||||||
|
// 写入标准 SSE 格式
|
||||||
|
if !clientDisconnected {
|
||||||
|
var writeErr error
|
||||||
|
if eventType != "" {
|
||||||
|
_, writeErr = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, sseData)
|
||||||
|
} else {
|
||||||
|
_, writeErr = fmt.Fprintf(w, "data: %s\n\n", sseData)
|
||||||
|
}
|
||||||
|
if writeErr != nil {
|
||||||
|
clientDisconnected = true
|
||||||
|
logger.LegacyPrintf("service.gateway", "[Bedrock] Client disconnected during streaming, continue draining for usage: account=%d", account.ID)
|
||||||
|
} else {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-intervalCh:
|
||||||
|
lastRead := time.Unix(0, lastReadAt.Load())
|
||||||
|
if time.Since(lastRead) < streamInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if clientDisconnected {
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
|
}
|
||||||
|
logger.LegacyPrintf("service.gateway", "[Bedrock] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
|
||||||
|
}
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractBedrockChunkData 从 Bedrock EventStream payload 中提取 Claude SSE 事件数据
|
||||||
|
// Bedrock payload 格式:{"bytes":"<base64-encoded-json>"}
|
||||||
|
func extractBedrockChunkData(payload []byte) []byte {
|
||||||
|
b64 := gjson.GetBytes(payload, "bytes").String()
|
||||||
|
if b64 == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(b64)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return decoded
|
||||||
|
}
|
||||||
|
|
||||||
|
// transformBedrockInvocationMetrics 将 Bedrock 特有的 amazon-bedrock-invocationMetrics
|
||||||
|
// 转换为标准 Anthropic usage 格式,并从 SSE 数据中移除该字段。
|
||||||
|
//
|
||||||
|
// Bedrock Invoke 返回的 message_delta 事件可能包含:
|
||||||
|
//
|
||||||
|
// {"type":"message_delta","delta":{...},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}
|
||||||
|
//
|
||||||
|
// 转换为:
|
||||||
|
//
|
||||||
|
// {"type":"message_delta","delta":{...},"usage":{"input_tokens":150,"output_tokens":42}}
|
||||||
|
func transformBedrockInvocationMetrics(data []byte) []byte {
|
||||||
|
metrics := gjson.GetBytes(data, "amazon-bedrock-invocationMetrics")
|
||||||
|
if !metrics.Exists() || !metrics.IsObject() {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除 Bedrock 特有字段
|
||||||
|
data, _ = sjson.DeleteBytes(data, "amazon-bedrock-invocationMetrics")
|
||||||
|
|
||||||
|
// 如果已有标准 usage 字段,不覆盖
|
||||||
|
if gjson.GetBytes(data, "usage").Exists() {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换 camelCase → snake_case 写入 usage
|
||||||
|
inputTokens := metrics.Get("inputTokenCount")
|
||||||
|
outputTokens := metrics.Get("outputTokenCount")
|
||||||
|
if inputTokens.Exists() {
|
||||||
|
data, _ = sjson.SetBytes(data, "usage.input_tokens", inputTokens.Int())
|
||||||
|
}
|
||||||
|
if outputTokens.Exists() {
|
||||||
|
data, _ = sjson.SetBytes(data, "usage.output_tokens", outputTokens.Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// bedrockEventStreamDecoder 解码 AWS EventStream 二进制帧
|
||||||
|
// EventStream 帧格式:
|
||||||
|
//
|
||||||
|
// [total_byte_length: 4 bytes]
|
||||||
|
// [headers_byte_length: 4 bytes]
|
||||||
|
// [prelude_crc: 4 bytes]
|
||||||
|
// [headers: variable]
|
||||||
|
// [payload: variable]
|
||||||
|
// [message_crc: 4 bytes]
|
||||||
|
type bedrockEventStreamDecoder struct {
|
||||||
|
reader *bufio.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBedrockEventStreamDecoder(r io.Reader) *bedrockEventStreamDecoder {
|
||||||
|
return &bedrockEventStreamDecoder{
|
||||||
|
reader: bufio.NewReaderSize(r, 64*1024),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode 读取下一个 EventStream 帧并返回 chunk 类型事件的 payload
|
||||||
|
func (d *bedrockEventStreamDecoder) Decode() ([]byte, error) {
|
||||||
|
for {
|
||||||
|
// 读取 prelude: total_length(4) + headers_length(4) + prelude_crc(4) = 12 bytes
|
||||||
|
prelude := make([]byte, 12)
|
||||||
|
if _, err := io.ReadFull(d.reader, prelude); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 prelude CRC(AWS EventStream 使用标准 CRC32 / IEEE)
|
||||||
|
preludeCRC := bedrockReadUint32(prelude[8:12])
|
||||||
|
if crc32.Checksum(prelude[0:8], crc32IEEETable) != preludeCRC {
|
||||||
|
return nil, fmt.Errorf("eventstream prelude CRC mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
totalLength := bedrockReadUint32(prelude[0:4])
|
||||||
|
headersLength := bedrockReadUint32(prelude[4:8])
|
||||||
|
|
||||||
|
if totalLength < 16 { // minimum: 12 prelude + 4 message_crc
|
||||||
|
return nil, fmt.Errorf("invalid eventstream frame: total_length=%d", totalLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 读取 headers + payload + message_crc
|
||||||
|
remaining := int(totalLength) - 12
|
||||||
|
if remaining <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := make([]byte, remaining)
|
||||||
|
if _, err := io.ReadFull(d.reader, data); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 message CRC(覆盖 prelude + headers + payload)
|
||||||
|
messageCRC := bedrockReadUint32(data[len(data)-4:])
|
||||||
|
h := crc32.New(crc32IEEETable)
|
||||||
|
_, _ = h.Write(prelude)
|
||||||
|
_, _ = h.Write(data[:len(data)-4])
|
||||||
|
if h.Sum32() != messageCRC {
|
||||||
|
return nil, fmt.Errorf("eventstream message CRC mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 headers
|
||||||
|
headers := data[:headersLength]
|
||||||
|
payload := data[headersLength : len(data)-4] // 去掉 message_crc
|
||||||
|
|
||||||
|
// 从 headers 中提取 :event-type
|
||||||
|
eventType := extractEventStreamHeaderValue(headers, ":event-type")
|
||||||
|
|
||||||
|
// 只处理 chunk 事件
|
||||||
|
if eventType == "chunk" {
|
||||||
|
// payload 是完整的 JSON,包含 bytes 字段
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查异常事件
|
||||||
|
exceptionType := extractEventStreamHeaderValue(headers, ":exception-type")
|
||||||
|
if exceptionType != "" {
|
||||||
|
return nil, fmt.Errorf("bedrock exception: %s: %s", exceptionType, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
messageType := extractEventStreamHeaderValue(headers, ":message-type")
|
||||||
|
if messageType == "exception" || messageType == "error" {
|
||||||
|
return nil, fmt.Errorf("bedrock error: %s", string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 跳过其他事件类型(如 initial-response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractEventStreamHeaderValue 从 EventStream headers 二进制数据中提取指定 header 的字符串值
|
||||||
|
// EventStream header 格式:
|
||||||
|
//
|
||||||
|
// [name_length: 1 byte][name: variable][value_type: 1 byte][value: variable]
|
||||||
|
//
|
||||||
|
// value_type = 7 表示 string 类型,前 2 bytes 为长度
|
||||||
|
func extractEventStreamHeaderValue(headers []byte, targetName string) string {
|
||||||
|
pos := 0
|
||||||
|
for pos < len(headers) {
|
||||||
|
if pos >= len(headers) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
nameLen := int(headers[pos])
|
||||||
|
pos++
|
||||||
|
if pos+nameLen > len(headers) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
name := string(headers[pos : pos+nameLen])
|
||||||
|
pos += nameLen
|
||||||
|
|
||||||
|
if pos >= len(headers) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
valueType := headers[pos]
|
||||||
|
pos++
|
||||||
|
|
||||||
|
switch valueType {
|
||||||
|
case 7: // string
|
||||||
|
if pos+2 > len(headers) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
|
||||||
|
pos += 2
|
||||||
|
if pos+valueLen > len(headers) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
value := string(headers[pos : pos+valueLen])
|
||||||
|
pos += valueLen
|
||||||
|
if name == targetName {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
case 0: // bool true
|
||||||
|
if name == targetName {
|
||||||
|
return "true"
|
||||||
|
}
|
||||||
|
case 1: // bool false
|
||||||
|
if name == targetName {
|
||||||
|
return "false"
|
||||||
|
}
|
||||||
|
case 2: // byte
|
||||||
|
pos++
|
||||||
|
if name == targetName {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
case 3: // short
|
||||||
|
pos += 2
|
||||||
|
if name == targetName {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
case 4: // int
|
||||||
|
pos += 4
|
||||||
|
if name == targetName {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
case 5: // long
|
||||||
|
pos += 8
|
||||||
|
if name == targetName {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
case 6: // bytes
|
||||||
|
if pos+2 > len(headers) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
|
||||||
|
pos += 2 + valueLen
|
||||||
|
case 8: // timestamp
|
||||||
|
pos += 8
|
||||||
|
case 9: // uuid
|
||||||
|
pos += 16
|
||||||
|
default:
|
||||||
|
return "" // 未知类型,无法继续解析
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// crc32IEEETable is the CRC32 / IEEE table used by AWS EventStream.
|
||||||
|
var crc32IEEETable = crc32.MakeTable(crc32.IEEE)
|
||||||
|
|
||||||
|
func bedrockReadUint32(b []byte) uint32 {
|
||||||
|
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||||
|
}
|
||||||
|
|
||||||
|
func bedrockReadUint16(b []byte) uint16 {
|
||||||
|
return uint16(b[0])<<8 | uint16(b[1])
|
||||||
|
}
|
||||||
261
backend/internal/service/bedrock_stream_test.go
Normal file
261
backend/internal/service/bedrock_stream_test.go
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/binary"
|
||||||
|
"hash/crc32"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractBedrockChunkData(t *testing.T) {
|
||||||
|
t.Run("valid base64 payload", func(t *testing.T) {
|
||||||
|
original := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`
|
||||||
|
b64 := base64.StdEncoding.EncodeToString([]byte(original))
|
||||||
|
payload := []byte(`{"bytes":"` + b64 + `"}`)
|
||||||
|
|
||||||
|
result := extractBedrockChunkData(payload)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.JSONEq(t, original, string(result))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty bytes field", func(t *testing.T) {
|
||||||
|
result := extractBedrockChunkData([]byte(`{"bytes":""}`))
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no bytes field", func(t *testing.T) {
|
||||||
|
result := extractBedrockChunkData([]byte(`{"other":"value"}`))
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid base64", func(t *testing.T) {
|
||||||
|
result := extractBedrockChunkData([]byte(`{"bytes":"not-valid-base64!!!"}`))
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransformBedrockInvocationMetrics(t *testing.T) {
|
||||||
|
t.Run("converts metrics to usage", func(t *testing.T) {
|
||||||
|
input := `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
|
||||||
|
result := transformBedrockInvocationMetrics([]byte(input))
|
||||||
|
|
||||||
|
// amazon-bedrock-invocationMetrics should be removed
|
||||||
|
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
|
||||||
|
// usage should be set
|
||||||
|
assert.Equal(t, int64(150), gjson.GetBytes(result, "usage.input_tokens").Int())
|
||||||
|
assert.Equal(t, int64(42), gjson.GetBytes(result, "usage.output_tokens").Int())
|
||||||
|
// original fields preserved
|
||||||
|
assert.Equal(t, "message_delta", gjson.GetBytes(result, "type").String())
|
||||||
|
assert.Equal(t, "end_turn", gjson.GetBytes(result, "delta.stop_reason").String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no metrics present", func(t *testing.T) {
|
||||||
|
input := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}`
|
||||||
|
result := transformBedrockInvocationMetrics([]byte(input))
|
||||||
|
assert.JSONEq(t, input, string(result))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not overwrite existing usage", func(t *testing.T) {
|
||||||
|
input := `{"type":"message_delta","usage":{"output_tokens":100},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
|
||||||
|
result := transformBedrockInvocationMetrics([]byte(input))
|
||||||
|
|
||||||
|
// metrics removed but existing usage preserved
|
||||||
|
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
|
||||||
|
assert.Equal(t, int64(100), gjson.GetBytes(result, "usage.output_tokens").Int())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractEventStreamHeaderValue(t *testing.T) {
|
||||||
|
// Build a header with :event-type = "chunk" (string type = 7)
|
||||||
|
buildStringHeader := func(name, value string) []byte {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
// name length (1 byte)
|
||||||
|
_ = buf.WriteByte(byte(len(name)))
|
||||||
|
// name
|
||||||
|
_, _ = buf.WriteString(name)
|
||||||
|
// value type (7 = string)
|
||||||
|
_ = buf.WriteByte(7)
|
||||||
|
// value length (2 bytes, big-endian)
|
||||||
|
_ = binary.Write(&buf, binary.BigEndian, uint16(len(value)))
|
||||||
|
// value
|
||||||
|
_, _ = buf.WriteString(value)
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("find string header", func(t *testing.T) {
|
||||||
|
headers := buildStringHeader(":event-type", "chunk")
|
||||||
|
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("header not found", func(t *testing.T) {
|
||||||
|
headers := buildStringHeader(":event-type", "chunk")
|
||||||
|
assert.Equal(t, "", extractEventStreamHeaderValue(headers, ":message-type"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple headers", func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
_, _ = buf.Write(buildStringHeader(":content-type", "application/json"))
|
||||||
|
_, _ = buf.Write(buildStringHeader(":event-type", "chunk"))
|
||||||
|
_, _ = buf.Write(buildStringHeader(":message-type", "event"))
|
||||||
|
|
||||||
|
headers := buf.Bytes()
|
||||||
|
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
|
||||||
|
assert.Equal(t, "application/json", extractEventStreamHeaderValue(headers, ":content-type"))
|
||||||
|
assert.Equal(t, "event", extractEventStreamHeaderValue(headers, ":message-type"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty headers", func(t *testing.T) {
|
||||||
|
assert.Equal(t, "", extractEventStreamHeaderValue([]byte{}, ":event-type"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBedrockEventStreamDecoder(t *testing.T) {
|
||||||
|
crc32IeeeTab := crc32.MakeTable(crc32.IEEE)
|
||||||
|
|
||||||
|
// Build a valid EventStream frame with correct CRC32/IEEE checksums.
|
||||||
|
buildFrame := func(eventType string, payload []byte) []byte {
|
||||||
|
// Build headers
|
||||||
|
var headersBuf bytes.Buffer
|
||||||
|
// :event-type header
|
||||||
|
_ = headersBuf.WriteByte(byte(len(":event-type")))
|
||||||
|
_, _ = headersBuf.WriteString(":event-type")
|
||||||
|
_ = headersBuf.WriteByte(7) // string type
|
||||||
|
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len(eventType)))
|
||||||
|
_, _ = headersBuf.WriteString(eventType)
|
||||||
|
// :message-type header
|
||||||
|
_ = headersBuf.WriteByte(byte(len(":message-type")))
|
||||||
|
_, _ = headersBuf.WriteString(":message-type")
|
||||||
|
_ = headersBuf.WriteByte(7)
|
||||||
|
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("event")))
|
||||||
|
_, _ = headersBuf.WriteString("event")
|
||||||
|
|
||||||
|
headers := headersBuf.Bytes()
|
||||||
|
headersLen := uint32(len(headers))
|
||||||
|
// total = 12 (prelude) + headers + payload + 4 (message_crc)
|
||||||
|
totalLen := uint32(12 + len(headers) + len(payload) + 4)
|
||||||
|
|
||||||
|
// Prelude: total_length(4) + headers_length(4)
|
||||||
|
var preludeBuf bytes.Buffer
|
||||||
|
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
|
||||||
|
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
|
||||||
|
preludeBytes := preludeBuf.Bytes()
|
||||||
|
preludeCRC := crc32.Checksum(preludeBytes, crc32IeeeTab)
|
||||||
|
|
||||||
|
// Build frame: prelude + prelude_crc + headers + payload
|
||||||
|
var frame bytes.Buffer
|
||||||
|
_, _ = frame.Write(preludeBytes)
|
||||||
|
_ = binary.Write(&frame, binary.BigEndian, preludeCRC)
|
||||||
|
_, _ = frame.Write(headers)
|
||||||
|
_, _ = frame.Write(payload)
|
||||||
|
|
||||||
|
// Message CRC covers everything before itself
|
||||||
|
messageCRC := crc32.Checksum(frame.Bytes(), crc32IeeeTab)
|
||||||
|
_ = binary.Write(&frame, binary.BigEndian, messageCRC)
|
||||||
|
return frame.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("decode chunk event", func(t *testing.T) {
|
||||||
|
payload := []byte(`{"bytes":"dGVzdA=="}`) // base64("test")
|
||||||
|
frame := buildFrame("chunk", payload)
|
||||||
|
|
||||||
|
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
|
||||||
|
result, err := decoder.Decode()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, payload, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("skip non-chunk events", func(t *testing.T) {
|
||||||
|
// Write initial-response followed by chunk
|
||||||
|
var buf bytes.Buffer
|
||||||
|
_, _ = buf.Write(buildFrame("initial-response", []byte(`{}`)))
|
||||||
|
chunkPayload := []byte(`{"bytes":"aGVsbG8="}`)
|
||||||
|
_, _ = buf.Write(buildFrame("chunk", chunkPayload))
|
||||||
|
|
||||||
|
decoder := newBedrockEventStreamDecoder(&buf)
|
||||||
|
result, err := decoder.Decode()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, chunkPayload, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EOF on empty input", func(t *testing.T) {
|
||||||
|
decoder := newBedrockEventStreamDecoder(bytes.NewReader(nil))
|
||||||
|
_, err := decoder.Decode()
|
||||||
|
assert.Equal(t, io.EOF, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("corrupted prelude CRC", func(t *testing.T) {
|
||||||
|
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
|
||||||
|
// Corrupt the prelude CRC (bytes 8-11)
|
||||||
|
frame[8] ^= 0xFF
|
||||||
|
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
|
||||||
|
_, err := decoder.Decode()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "prelude CRC mismatch")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("corrupted message CRC", func(t *testing.T) {
|
||||||
|
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
|
||||||
|
// Corrupt the message CRC (last 4 bytes)
|
||||||
|
frame[len(frame)-1] ^= 0xFF
|
||||||
|
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
|
||||||
|
_, err := decoder.Decode()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "message CRC mismatch")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("castagnoli encoded frame is rejected", func(t *testing.T) {
|
||||||
|
castagnoliTab := crc32.MakeTable(crc32.Castagnoli)
|
||||||
|
payload := []byte(`{"bytes":"dGVzdA=="}`)
|
||||||
|
|
||||||
|
var headersBuf bytes.Buffer
|
||||||
|
_ = headersBuf.WriteByte(byte(len(":event-type")))
|
||||||
|
_, _ = headersBuf.WriteString(":event-type")
|
||||||
|
_ = headersBuf.WriteByte(7)
|
||||||
|
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("chunk")))
|
||||||
|
_, _ = headersBuf.WriteString("chunk")
|
||||||
|
|
||||||
|
headers := headersBuf.Bytes()
|
||||||
|
headersLen := uint32(len(headers))
|
||||||
|
totalLen := uint32(12 + len(headers) + len(payload) + 4)
|
||||||
|
|
||||||
|
var preludeBuf bytes.Buffer
|
||||||
|
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
|
||||||
|
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
|
||||||
|
preludeBytes := preludeBuf.Bytes()
|
||||||
|
|
||||||
|
var frame bytes.Buffer
|
||||||
|
_, _ = frame.Write(preludeBytes)
|
||||||
|
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(preludeBytes, castagnoliTab))
|
||||||
|
_, _ = frame.Write(headers)
|
||||||
|
_, _ = frame.Write(payload)
|
||||||
|
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(frame.Bytes(), castagnoliTab))
|
||||||
|
|
||||||
|
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame.Bytes()))
|
||||||
|
_, err := decoder.Decode()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "prelude CRC mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildBedrockURL(t *testing.T) {
|
||||||
|
t.Run("stream URL with colon in model ID", func(t *testing.T) {
|
||||||
|
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-opus-4-5-20251101-v1:0", true)
|
||||||
|
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/invoke-with-response-stream", url)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-stream URL with colon in model ID", func(t *testing.T) {
|
||||||
|
url := BuildBedrockURL("eu-west-1", "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", false)
|
||||||
|
assert.Equal(t, "https://bedrock-runtime.eu-west-1.amazonaws.com/model/eu.anthropic.claude-sonnet-4-5-20250929-v1%3A0/invoke", url)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("model ID without colon", func(t *testing.T) {
|
||||||
|
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-sonnet-4-6", true)
|
||||||
|
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-sonnet-4-6/invoke-with-response-stream", url)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,450 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
)
|
|
||||||
|
|
||||||
type claudeMaxCacheBillingOutcome struct {
|
|
||||||
Simulated bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyClaudeMaxCacheBillingPolicyToUsage(usage *ClaudeUsage, parsed *ParsedRequest, group *Group, model string, accountID int64) claudeMaxCacheBillingOutcome {
|
|
||||||
var out claudeMaxCacheBillingOutcome
|
|
||||||
if usage == nil || !shouldApplyClaudeMaxBillingRulesForUsage(group, model, parsed) {
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
resolvedModel := strings.TrimSpace(model)
|
|
||||||
if resolvedModel == "" && parsed != nil {
|
|
||||||
resolvedModel = strings.TrimSpace(parsed.Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasCacheCreationTokens(*usage) {
|
|
||||||
// Upstream already returned cache creation usage; keep original usage.
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
if !shouldSimulateClaudeMaxUsageForUsage(*usage, parsed) {
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
beforeInputTokens := usage.InputTokens
|
|
||||||
out.Simulated = safelyProjectUsageToClaudeMax1H(usage, parsed)
|
|
||||||
if out.Simulated {
|
|
||||||
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage: model=%s account=%d input_tokens:%d->%d cache_creation_1h=%d",
|
|
||||||
resolvedModel,
|
|
||||||
accountID,
|
|
||||||
beforeInputTokens,
|
|
||||||
usage.InputTokens,
|
|
||||||
usage.CacheCreation1hTokens,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func isClaudeFamilyModel(model string) bool {
|
|
||||||
normalized := strings.ToLower(strings.TrimSpace(claude.NormalizeModelID(model)))
|
|
||||||
if normalized == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return strings.Contains(normalized, "claude-")
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldApplyClaudeMaxBillingRules(input *RecordUsageInput) bool {
|
|
||||||
if input == nil || input.Result == nil || input.APIKey == nil || input.APIKey.Group == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return shouldApplyClaudeMaxBillingRulesForUsage(input.APIKey.Group, input.Result.Model, input.ParsedRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldApplyClaudeMaxBillingRulesForUsage(group *Group, model string, parsed *ParsedRequest) bool {
|
|
||||||
if group == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if !group.SimulateClaudeMaxEnabled || group.Platform != PlatformAnthropic {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
resolvedModel := model
|
|
||||||
if resolvedModel == "" && parsed != nil {
|
|
||||||
resolvedModel = parsed.Model
|
|
||||||
}
|
|
||||||
if !isClaudeFamilyModel(resolvedModel) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasCacheCreationTokens(usage ClaudeUsage) bool {
|
|
||||||
return usage.CacheCreationInputTokens > 0 || usage.CacheCreation5mTokens > 0 || usage.CacheCreation1hTokens > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldSimulateClaudeMaxUsage(input *RecordUsageInput) bool {
|
|
||||||
if input == nil || input.Result == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if !shouldApplyClaudeMaxBillingRules(input) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return shouldSimulateClaudeMaxUsageForUsage(input.Result.Usage, input.ParsedRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldSimulateClaudeMaxUsageForUsage(usage ClaudeUsage, parsed *ParsedRequest) bool {
|
|
||||||
if usage.InputTokens <= 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if hasCacheCreationTokens(usage) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if !hasClaudeCacheSignals(parsed) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func safelyProjectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) (changed bool) {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
logger.LegacyPrintf("service.gateway", "simulate_claude_max_usage skipped: panic=%v", r)
|
|
||||||
changed = false
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return projectUsageToClaudeMax1H(usage, parsed)
|
|
||||||
}
|
|
||||||
|
|
||||||
func projectUsageToClaudeMax1H(usage *ClaudeUsage, parsed *ParsedRequest) bool {
|
|
||||||
if usage == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
totalWindowTokens := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
|
|
||||||
if totalWindowTokens <= 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
simulatedInputTokens := computeClaudeMaxProjectedInputTokens(totalWindowTokens, parsed)
|
|
||||||
if simulatedInputTokens <= 0 {
|
|
||||||
simulatedInputTokens = 1
|
|
||||||
}
|
|
||||||
if simulatedInputTokens >= totalWindowTokens {
|
|
||||||
simulatedInputTokens = totalWindowTokens - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheCreation1hTokens := totalWindowTokens - simulatedInputTokens
|
|
||||||
if usage.InputTokens == simulatedInputTokens &&
|
|
||||||
usage.CacheCreation5mTokens == 0 &&
|
|
||||||
usage.CacheCreation1hTokens == cacheCreation1hTokens &&
|
|
||||||
usage.CacheCreationInputTokens == cacheCreation1hTokens {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
usage.InputTokens = simulatedInputTokens
|
|
||||||
usage.CacheCreation5mTokens = 0
|
|
||||||
usage.CacheCreation1hTokens = cacheCreation1hTokens
|
|
||||||
usage.CacheCreationInputTokens = cacheCreation1hTokens
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
type claudeCacheProjection struct {
|
|
||||||
HasBreakpoint bool
|
|
||||||
BreakpointCount int
|
|
||||||
TotalEstimatedTokens int
|
|
||||||
TailEstimatedTokens int
|
|
||||||
}
|
|
||||||
|
|
||||||
func computeClaudeMaxProjectedInputTokens(totalWindowTokens int, parsed *ParsedRequest) int {
|
|
||||||
if totalWindowTokens <= 1 {
|
|
||||||
return totalWindowTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
projection := analyzeClaudeCacheProjection(parsed)
|
|
||||||
if !projection.HasBreakpoint || projection.TotalEstimatedTokens <= 0 || projection.TailEstimatedTokens <= 0 {
|
|
||||||
return totalWindowTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
totalEstimate := int64(projection.TotalEstimatedTokens)
|
|
||||||
tailEstimate := int64(projection.TailEstimatedTokens)
|
|
||||||
if tailEstimate > totalEstimate {
|
|
||||||
tailEstimate = totalEstimate
|
|
||||||
}
|
|
||||||
|
|
||||||
scaled := (int64(totalWindowTokens)*tailEstimate + totalEstimate/2) / totalEstimate
|
|
||||||
if scaled <= 0 {
|
|
||||||
scaled = 1
|
|
||||||
}
|
|
||||||
if scaled >= int64(totalWindowTokens) {
|
|
||||||
scaled = int64(totalWindowTokens - 1)
|
|
||||||
}
|
|
||||||
return int(scaled)
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasClaudeCacheSignals(parsed *ParsedRequest) bool {
|
|
||||||
if parsed == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if hasTopLevelEphemeralCacheControl(parsed) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return countExplicitCacheBreakpoints(parsed) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasTopLevelEphemeralCacheControl(parsed *ParsedRequest) bool {
|
|
||||||
if parsed == nil || len(parsed.Body) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
cacheType := strings.TrimSpace(gjson.GetBytes(parsed.Body, "cache_control.type").String())
|
|
||||||
return strings.EqualFold(cacheType, "ephemeral")
|
|
||||||
}
|
|
||||||
|
|
||||||
func analyzeClaudeCacheProjection(parsed *ParsedRequest) claudeCacheProjection {
|
|
||||||
var projection claudeCacheProjection
|
|
||||||
if parsed == nil {
|
|
||||||
return projection
|
|
||||||
}
|
|
||||||
|
|
||||||
total := 0
|
|
||||||
lastBreakpointAt := -1
|
|
||||||
|
|
||||||
switch system := parsed.System.(type) {
|
|
||||||
case string:
|
|
||||||
total += claudeMaxMessageOverheadTokens + estimateClaudeTextTokens(system)
|
|
||||||
case []any:
|
|
||||||
for _, raw := range system {
|
|
||||||
block, ok := raw.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
total += claudeMaxUnknownContentTokens
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
total += estimateClaudeBlockTokens(block)
|
|
||||||
if hasEphemeralCacheControl(block) {
|
|
||||||
lastBreakpointAt = total
|
|
||||||
projection.BreakpointCount++
|
|
||||||
projection.HasBreakpoint = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rawMsg := range parsed.Messages {
|
|
||||||
total += claudeMaxMessageOverheadTokens
|
|
||||||
msg, ok := rawMsg.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
total += claudeMaxUnknownContentTokens
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
content, exists := msg["content"]
|
|
||||||
if !exists {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
msgTokens, msgLastBreak, msgBreakCount := estimateClaudeContentTokens(content)
|
|
||||||
total += msgTokens
|
|
||||||
if msgBreakCount > 0 {
|
|
||||||
lastBreakpointAt = total - msgTokens + msgLastBreak
|
|
||||||
projection.BreakpointCount += msgBreakCount
|
|
||||||
projection.HasBreakpoint = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if total <= 0 {
|
|
||||||
total = 1
|
|
||||||
}
|
|
||||||
projection.TotalEstimatedTokens = total
|
|
||||||
|
|
||||||
if projection.HasBreakpoint && lastBreakpointAt >= 0 {
|
|
||||||
tail := total - lastBreakpointAt
|
|
||||||
if tail <= 0 {
|
|
||||||
tail = 1
|
|
||||||
}
|
|
||||||
projection.TailEstimatedTokens = tail
|
|
||||||
return projection
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasTopLevelEphemeralCacheControl(parsed) {
|
|
||||||
tail := estimateLastUserMessageTokens(parsed)
|
|
||||||
if tail <= 0 {
|
|
||||||
tail = 1
|
|
||||||
}
|
|
||||||
projection.HasBreakpoint = true
|
|
||||||
projection.BreakpointCount = 1
|
|
||||||
projection.TailEstimatedTokens = tail
|
|
||||||
}
|
|
||||||
return projection
|
|
||||||
}
|
|
||||||
|
|
||||||
func countExplicitCacheBreakpoints(parsed *ParsedRequest) int {
|
|
||||||
if parsed == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
total := 0
|
|
||||||
if system, ok := parsed.System.([]any); ok {
|
|
||||||
for _, raw := range system {
|
|
||||||
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
|
|
||||||
total++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, rawMsg := range parsed.Messages {
|
|
||||||
msg, ok := rawMsg.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
content, ok := msg["content"].([]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, raw := range content {
|
|
||||||
if block, ok := raw.(map[string]any); ok && hasEphemeralCacheControl(block) {
|
|
||||||
total++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return total
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasEphemeralCacheControl(block map[string]any) bool {
|
|
||||||
if block == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
raw, ok := block["cache_control"]
|
|
||||||
if !ok || raw == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
switch cc := raw.(type) {
|
|
||||||
case map[string]any:
|
|
||||||
cacheType, _ := cc["type"].(string)
|
|
||||||
return strings.EqualFold(strings.TrimSpace(cacheType), "ephemeral")
|
|
||||||
case map[string]string:
|
|
||||||
return strings.EqualFold(strings.TrimSpace(cc["type"]), "ephemeral")
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func estimateClaudeContentTokens(content any) (tokens int, lastBreakAt int, breakpointCount int) {
|
|
||||||
switch value := content.(type) {
|
|
||||||
case string:
|
|
||||||
return estimateClaudeTextTokens(value), -1, 0
|
|
||||||
case []any:
|
|
||||||
total := 0
|
|
||||||
lastBreak := -1
|
|
||||||
breaks := 0
|
|
||||||
for _, raw := range value {
|
|
||||||
block, ok := raw.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
total += claudeMaxUnknownContentTokens
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
total += estimateClaudeBlockTokens(block)
|
|
||||||
if hasEphemeralCacheControl(block) {
|
|
||||||
lastBreak = total
|
|
||||||
breaks++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return total, lastBreak, breaks
|
|
||||||
default:
|
|
||||||
return estimateStructuredTokens(value), -1, 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func estimateClaudeBlockTokens(block map[string]any) int {
|
|
||||||
if block == nil {
|
|
||||||
return claudeMaxUnknownContentTokens
|
|
||||||
}
|
|
||||||
tokens := claudeMaxBlockOverheadTokens
|
|
||||||
blockType, _ := block["type"].(string)
|
|
||||||
switch blockType {
|
|
||||||
case "text":
|
|
||||||
if text, ok := block["text"].(string); ok {
|
|
||||||
tokens += estimateClaudeTextTokens(text)
|
|
||||||
}
|
|
||||||
case "tool_result":
|
|
||||||
if content, ok := block["content"]; ok {
|
|
||||||
nested, _, _ := estimateClaudeContentTokens(content)
|
|
||||||
tokens += nested
|
|
||||||
}
|
|
||||||
case "tool_use":
|
|
||||||
if name, ok := block["name"].(string); ok {
|
|
||||||
tokens += estimateClaudeTextTokens(name)
|
|
||||||
}
|
|
||||||
if input, ok := block["input"]; ok {
|
|
||||||
tokens += estimateStructuredTokens(input)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if text, ok := block["text"].(string); ok {
|
|
||||||
tokens += estimateClaudeTextTokens(text)
|
|
||||||
} else if content, ok := block["content"]; ok {
|
|
||||||
nested, _, _ := estimateClaudeContentTokens(content)
|
|
||||||
tokens += nested
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if tokens <= claudeMaxBlockOverheadTokens {
|
|
||||||
tokens += claudeMaxUnknownContentTokens
|
|
||||||
}
|
|
||||||
return tokens
|
|
||||||
}
|
|
||||||
|
|
||||||
func estimateLastUserMessageTokens(parsed *ParsedRequest) int {
|
|
||||||
if parsed == nil || len(parsed.Messages) == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
for i := len(parsed.Messages) - 1; i >= 0; i-- {
|
|
||||||
msg, ok := parsed.Messages[i].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
role, _ := msg["role"].(string)
|
|
||||||
if !strings.EqualFold(strings.TrimSpace(role), "user") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
tokens, _, _ := estimateClaudeContentTokens(msg["content"])
|
|
||||||
return claudeMaxMessageOverheadTokens + tokens
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func estimateStructuredTokens(v any) int {
|
|
||||||
if v == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
raw, err := json.Marshal(v)
|
|
||||||
if err != nil {
|
|
||||||
return claudeMaxUnknownContentTokens
|
|
||||||
}
|
|
||||||
return estimateClaudeTextTokens(string(raw))
|
|
||||||
}
|
|
||||||
|
|
||||||
func estimateClaudeTextTokens(text string) int {
|
|
||||||
if tokens, ok := estimateTokensByThirdPartyTokenizer(text); ok {
|
|
||||||
return tokens
|
|
||||||
}
|
|
||||||
return estimateClaudeTextTokensHeuristic(text)
|
|
||||||
}
|
|
||||||
|
|
||||||
func estimateClaudeTextTokensHeuristic(text string) int {
|
|
||||||
normalized := strings.Join(strings.Fields(strings.TrimSpace(text)), " ")
|
|
||||||
if normalized == "" {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
asciiChars := 0
|
|
||||||
nonASCIIChars := 0
|
|
||||||
for _, r := range normalized {
|
|
||||||
if r <= 127 {
|
|
||||||
asciiChars++
|
|
||||||
} else {
|
|
||||||
nonASCIIChars++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tokens := nonASCIIChars
|
|
||||||
if asciiChars > 0 {
|
|
||||||
tokens += (asciiChars + 3) / 4
|
|
||||||
}
|
|
||||||
if words := len(strings.Fields(normalized)); words > tokens {
|
|
||||||
tokens = words
|
|
||||||
}
|
|
||||||
if tokens <= 0 {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return tokens
|
|
||||||
}
|
|
||||||
@@ -1,156 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestProjectUsageToClaudeMax1H_Conservation(t *testing.T) {
|
|
||||||
usage := &ClaudeUsage{
|
|
||||||
InputTokens: 1200,
|
|
||||||
CacheCreationInputTokens: 0,
|
|
||||||
CacheCreation5mTokens: 0,
|
|
||||||
CacheCreation1hTokens: 0,
|
|
||||||
}
|
|
||||||
parsed := &ParsedRequest{
|
|
||||||
Model: "claude-sonnet-4-5",
|
|
||||||
Messages: []any{
|
|
||||||
map[string]any{
|
|
||||||
"role": "user",
|
|
||||||
"content": []any{
|
|
||||||
map[string]any{
|
|
||||||
"type": "text",
|
|
||||||
"text": strings.Repeat("cached context ", 200),
|
|
||||||
"cache_control": map[string]any{"type": "ephemeral"},
|
|
||||||
},
|
|
||||||
map[string]any{
|
|
||||||
"type": "text",
|
|
||||||
"text": "summarize quickly",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
changed := projectUsageToClaudeMax1H(usage, parsed)
|
|
||||||
if !changed {
|
|
||||||
t.Fatalf("expected usage to be projected")
|
|
||||||
}
|
|
||||||
|
|
||||||
total := usage.InputTokens + usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
|
|
||||||
if total != 1200 {
|
|
||||||
t.Fatalf("total tokens changed: got=%d want=%d", total, 1200)
|
|
||||||
}
|
|
||||||
if usage.CacheCreation5mTokens != 0 {
|
|
||||||
t.Fatalf("cache_creation_5m should be 0, got=%d", usage.CacheCreation5mTokens)
|
|
||||||
}
|
|
||||||
if usage.InputTokens <= 0 || usage.InputTokens >= 1200 {
|
|
||||||
t.Fatalf("simulated input out of range, got=%d", usage.InputTokens)
|
|
||||||
}
|
|
||||||
if usage.InputTokens > 100 {
|
|
||||||
t.Fatalf("simulated input should stay near cache breakpoint tail, got=%d", usage.InputTokens)
|
|
||||||
}
|
|
||||||
if usage.CacheCreation1hTokens <= 0 {
|
|
||||||
t.Fatalf("cache_creation_1h should be > 0, got=%d", usage.CacheCreation1hTokens)
|
|
||||||
}
|
|
||||||
if usage.CacheCreationInputTokens != usage.CacheCreation1hTokens {
|
|
||||||
t.Fatalf("cache_creation_input_tokens mismatch: got=%d want=%d", usage.CacheCreationInputTokens, usage.CacheCreation1hTokens)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestComputeClaudeMaxProjectedInputTokens_Deterministic(t *testing.T) {
|
|
||||||
parsed := &ParsedRequest{
|
|
||||||
Model: "claude-opus-4-5",
|
|
||||||
Messages: []any{
|
|
||||||
map[string]any{
|
|
||||||
"role": "user",
|
|
||||||
"content": []any{
|
|
||||||
map[string]any{
|
|
||||||
"type": "text",
|
|
||||||
"text": "build context",
|
|
||||||
"cache_control": map[string]any{"type": "ephemeral"},
|
|
||||||
},
|
|
||||||
map[string]any{
|
|
||||||
"type": "text",
|
|
||||||
"text": "what is failing now",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
got1 := computeClaudeMaxProjectedInputTokens(4096, parsed)
|
|
||||||
got2 := computeClaudeMaxProjectedInputTokens(4096, parsed)
|
|
||||||
if got1 != got2 {
|
|
||||||
t.Fatalf("non-deterministic input tokens: %d != %d", got1, got2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestShouldSimulateClaudeMaxUsage(t *testing.T) {
|
|
||||||
group := &Group{
|
|
||||||
Platform: PlatformAnthropic,
|
|
||||||
SimulateClaudeMaxEnabled: true,
|
|
||||||
}
|
|
||||||
input := &RecordUsageInput{
|
|
||||||
Result: &ForwardResult{
|
|
||||||
Model: "claude-sonnet-4-5",
|
|
||||||
Usage: ClaudeUsage{
|
|
||||||
InputTokens: 3000,
|
|
||||||
CacheCreationInputTokens: 0,
|
|
||||||
CacheCreation5mTokens: 0,
|
|
||||||
CacheCreation1hTokens: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
ParsedRequest: &ParsedRequest{
|
|
||||||
Messages: []any{
|
|
||||||
map[string]any{
|
|
||||||
"role": "user",
|
|
||||||
"content": []any{
|
|
||||||
map[string]any{
|
|
||||||
"type": "text",
|
|
||||||
"text": "cached",
|
|
||||||
"cache_control": map[string]any{"type": "ephemeral"},
|
|
||||||
},
|
|
||||||
map[string]any{
|
|
||||||
"type": "text",
|
|
||||||
"text": "tail",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
APIKey: &APIKey{Group: group},
|
|
||||||
}
|
|
||||||
|
|
||||||
if !shouldSimulateClaudeMaxUsage(input) {
|
|
||||||
t.Fatalf("expected simulate=true for claude group with cache signal")
|
|
||||||
}
|
|
||||||
|
|
||||||
input.ParsedRequest = &ParsedRequest{
|
|
||||||
Messages: []any{
|
|
||||||
map[string]any{"role": "user", "content": "no cache signal"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if shouldSimulateClaudeMaxUsage(input) {
|
|
||||||
t.Fatalf("expected simulate=false when request has no cache signal")
|
|
||||||
}
|
|
||||||
|
|
||||||
input.ParsedRequest = &ParsedRequest{
|
|
||||||
Messages: []any{
|
|
||||||
map[string]any{
|
|
||||||
"role": "user",
|
|
||||||
"content": []any{
|
|
||||||
map[string]any{
|
|
||||||
"type": "text",
|
|
||||||
"text": "cached",
|
|
||||||
"cache_control": map[string]any{"type": "ephemeral"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
input.Result.Usage.CacheCreationInputTokens = 100
|
|
||||||
if shouldSimulateClaudeMaxUsage(input) {
|
|
||||||
t.Fatalf("expected simulate=false when cache creation already exists")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
tiktoken "github.com/pkoukk/tiktoken-go"
|
|
||||||
tiktokenloader "github.com/pkoukk/tiktoken-go-loader"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
claudeTokenizerOnce sync.Once
|
|
||||||
claudeTokenizer *tiktoken.Tiktoken
|
|
||||||
)
|
|
||||||
|
|
||||||
func getClaudeTokenizer() *tiktoken.Tiktoken {
|
|
||||||
claudeTokenizerOnce.Do(func() {
|
|
||||||
// Use offline loader to avoid runtime dictionary download.
|
|
||||||
tiktoken.SetBpeLoader(tiktokenloader.NewOfflineLoader())
|
|
||||||
// Use a high-capacity tokenizer as the default approximation for Claude payloads.
|
|
||||||
enc, err := tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
|
|
||||||
if err != nil {
|
|
||||||
enc, err = tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
|
|
||||||
}
|
|
||||||
if err == nil {
|
|
||||||
claudeTokenizer = enc
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return claudeTokenizer
|
|
||||||
}
|
|
||||||
|
|
||||||
func estimateTokensByThirdPartyTokenizer(text string) (int, bool) {
|
|
||||||
enc := getClaudeTokenizer()
|
|
||||||
if enc == nil {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
tokens := len(enc.EncodeOrdinary(text))
|
|
||||||
if tokens <= 0 {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
return tokens, true
|
|
||||||
}
|
|
||||||
@@ -343,9 +343,8 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts.
|
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
|
||||||
// Uses a detached context with timeout to prevent HTTP request cancellation from
|
// Returns a map of accountID -> current concurrency count
|
||||||
// causing the entire batch to fail (which would show all concurrency as 0).
|
|
||||||
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||||
if len(accountIDs) == 0 {
|
if len(accountIDs) == 0 {
|
||||||
return map[int64]int{}, nil
|
return map[int64]int{}, nil
|
||||||
@@ -357,11 +356,5 @@ func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, acc
|
|||||||
}
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
|
||||||
// Use a detached context so that a cancelled HTTP request doesn't cause
|
|
||||||
// the Redis pipeline to fail and return all-zero concurrency counts.
|
|
||||||
redisCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
return s.cache.GetAccountConcurrencyBatch(redisCtx, accountIDs)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ type DashboardAggregationRepository interface {
|
|||||||
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
||||||
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
||||||
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
||||||
|
CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error
|
||||||
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
|
|||||||
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
||||||
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
||||||
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
||||||
|
dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays)
|
||||||
|
|
||||||
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
||||||
if aggErr != nil {
|
if aggErr != nil {
|
||||||
@@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
|
|||||||
if usageErr != nil {
|
if usageErr != nil {
|
||||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
||||||
}
|
}
|
||||||
if aggErr == nil && usageErr == nil {
|
dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff)
|
||||||
|
if dedupErr != nil {
|
||||||
|
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr)
|
||||||
|
}
|
||||||
|
if aggErr == nil && usageErr == nil && dedupErr == nil {
|
||||||
s.lastRetentionCleanup.Store(now)
|
s.lastRetentionCleanup.Store(now)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,12 +12,18 @@ import (
|
|||||||
|
|
||||||
type dashboardAggregationRepoTestStub struct {
|
type dashboardAggregationRepoTestStub struct {
|
||||||
aggregateCalls int
|
aggregateCalls int
|
||||||
|
recomputeCalls int
|
||||||
|
cleanupUsageCalls int
|
||||||
|
cleanupDedupCalls int
|
||||||
|
ensurePartitionCalls int
|
||||||
lastStart time.Time
|
lastStart time.Time
|
||||||
lastEnd time.Time
|
lastEnd time.Time
|
||||||
watermark time.Time
|
watermark time.Time
|
||||||
aggregateErr error
|
aggregateErr error
|
||||||
cleanupAggregatesErr error
|
cleanupAggregatesErr error
|
||||||
cleanupUsageErr error
|
cleanupUsageErr error
|
||||||
|
cleanupDedupErr error
|
||||||
|
ensurePartitionErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
@@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
s.recomputeCalls++
|
||||||
return s.AggregateRange(ctx, start, end)
|
return s.AggregateRange(ctx, start, end)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||||
|
s.cleanupUsageCalls++
|
||||||
return s.cleanupUsageErr
|
return s.cleanupUsageErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||||
|
s.cleanupDedupCalls++
|
||||||
|
return s.cleanupDedupErr
|
||||||
|
}
|
||||||
|
|
||||||
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||||
return nil
|
s.ensurePartitionCalls++
|
||||||
|
return s.ensurePartitionErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
||||||
@@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te
|
|||||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||||
|
|
||||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||||
|
require.Equal(t, 1, repo.cleanupUsageCalls)
|
||||||
|
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) {
|
||||||
|
repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")}
|
||||||
|
svc := &DashboardAggregationService{
|
||||||
|
repo: repo,
|
||||||
|
cfg: config.DashboardAggregationConfig{
|
||||||
|
Retention: config.DashboardAggregationRetentionConfig{
|
||||||
|
UsageLogsDays: 1,
|
||||||
|
HourlyDays: 1,
|
||||||
|
DailyDays: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||||
|
|
||||||
|
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||||
|
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) {
|
||||||
|
repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")}
|
||||||
|
svc := &DashboardAggregationService{
|
||||||
|
repo: repo,
|
||||||
|
cfg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
IntervalSeconds: 60,
|
||||||
|
LookbackSeconds: 120,
|
||||||
|
Retention: config.DashboardAggregationRetentionConfig{
|
||||||
|
UsageLogsDays: 1,
|
||||||
|
UsageBillingDedupDays: 2,
|
||||||
|
HourlyDays: 1,
|
||||||
|
DailyDays: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc.runScheduledAggregation()
|
||||||
|
|
||||||
|
require.Equal(t, 1, repo.ensurePartitionCalls)
|
||||||
|
require.Equal(t, 1, repo.aggregateCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
||||||
|
|||||||
@@ -327,6 +327,14 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
|
|||||||
return trend, nil
|
return trend, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||||
|
ranking, err := s.usageRepo.GetUserSpendingRanking(ctx, startTime, endTime, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get user spending ranking: %w", err)
|
||||||
|
}
|
||||||
|
return ranking, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||||
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
|
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,10 +29,12 @@ const (
|
|||||||
|
|
||||||
// Account type constants
|
// Account type constants
|
||||||
const (
|
const (
|
||||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||||
|
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock)
|
||||||
|
AccountTypeBedrockAPIKey = domain.AccountTypeBedrockAPIKey // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Redeem type constants
|
// Redeem type constants
|
||||||
|
|||||||
@@ -220,7 +220,7 @@ func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
|
|||||||
v, exists := c.Get(OpsSkipPassthroughKey)
|
v, exists := c.Get(OpsSkipPassthroughKey)
|
||||||
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
|
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
|
||||||
boolVal, ok := v.(bool)
|
boolVal, ok := v.(bool)
|
||||||
assert.True(t, ok, "value should be a bool")
|
assert.True(t, ok, "value should be bool")
|
||||||
assert.True(t, boolVal)
|
assert.True(t, boolVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -136,16 +136,18 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
svc := &GatewayService{
|
cfg := &config.Config{
|
||||||
cfg: &config.Config{
|
Gateway: config.GatewayConfig{
|
||||||
Gateway: config.GatewayConfig{
|
MaxLineSize: defaultMaxLineSize,
|
||||||
MaxLineSize: defaultMaxLineSize,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
httpUpstream: upstream,
|
}
|
||||||
rateLimitService: &RateLimitService{},
|
svc := &GatewayService{
|
||||||
deferredService: &DeferredService{},
|
cfg: cfg,
|
||||||
billingCacheService: nil,
|
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
deferredService: &DeferredService{},
|
||||||
|
billingCacheService: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
account := &Account{
|
account := &Account{
|
||||||
@@ -221,14 +223,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
svc := &GatewayService{
|
cfg := &config.Config{
|
||||||
cfg: &config.Config{
|
Gateway: config.GatewayConfig{
|
||||||
Gateway: config.GatewayConfig{
|
MaxLineSize: defaultMaxLineSize,
|
||||||
MaxLineSize: defaultMaxLineSize,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
httpUpstream: upstream,
|
}
|
||||||
rateLimitService: &RateLimitService{},
|
svc := &GatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
}
|
}
|
||||||
|
|
||||||
account := &Account{
|
account := &Account{
|
||||||
@@ -727,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf
|
|||||||
require.Equal(t, 5, result.usage.OutputTokens)
|
require.Equal(t, 5, result.usage.OutputTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||||
|
`data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`,
|
||||||
|
"",
|
||||||
|
`data: {"type":"message_delta","usage":{"output_tokens":5}}`,
|
||||||
|
"",
|
||||||
|
}, "\n"))),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "missing terminal event")
|
||||||
|
require.NotNil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) {
|
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
@@ -1074,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi
|
|||||||
_ = pr.Close()
|
_ = pr.Close()
|
||||||
<-done
|
<-done
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "stream usage incomplete after timeout")
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.True(t, result.clientDisconnect)
|
require.True(t, result.clientDisconnect)
|
||||||
require.Equal(t, 9, result.usage.InputTokens)
|
require.Equal(t, 9, result.usage.InputTokens)
|
||||||
@@ -1103,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t
|
|||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219")
|
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||||
require.NoError(t, err)
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "stream usage incomplete")
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.True(t, result.clientDisconnect)
|
require.True(t, result.clientDisconnect)
|
||||||
}
|
}
|
||||||
@@ -1133,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft
|
|||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219")
|
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||||
require.NoError(t, err)
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "stream usage incomplete after disconnect")
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.True(t, result.clientDisconnect)
|
require.True(t, result.clientDisconnect)
|
||||||
require.Equal(t, 8, result.usage.InputTokens)
|
require.Equal(t, 8, result.usage.InputTokens)
|
||||||
|
|||||||
@@ -1,196 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/tidwall/sjson"
|
|
||||||
)
|
|
||||||
|
|
||||||
type claudeMaxResponseRewriteContext struct {
|
|
||||||
Parsed *ParsedRequest
|
|
||||||
Group *Group
|
|
||||||
}
|
|
||||||
|
|
||||||
type claudeMaxResponseRewriteContextKeyType struct{}
|
|
||||||
|
|
||||||
var claudeMaxResponseRewriteContextKey = claudeMaxResponseRewriteContextKeyType{}
|
|
||||||
|
|
||||||
func withClaudeMaxResponseRewriteContext(ctx context.Context, c *gin.Context, parsed *ParsedRequest) context.Context {
|
|
||||||
if ctx == nil {
|
|
||||||
ctx = context.Background()
|
|
||||||
}
|
|
||||||
value := claudeMaxResponseRewriteContext{
|
|
||||||
Parsed: parsed,
|
|
||||||
Group: claudeMaxGroupFromGinContext(c),
|
|
||||||
}
|
|
||||||
return context.WithValue(ctx, claudeMaxResponseRewriteContextKey, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
func claudeMaxResponseRewriteContextFromContext(ctx context.Context) claudeMaxResponseRewriteContext {
|
|
||||||
if ctx == nil {
|
|
||||||
return claudeMaxResponseRewriteContext{}
|
|
||||||
}
|
|
||||||
value, _ := ctx.Value(claudeMaxResponseRewriteContextKey).(claudeMaxResponseRewriteContext)
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
func claudeMaxGroupFromGinContext(c *gin.Context) *Group {
|
|
||||||
if c == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
raw, exists := c.Get("api_key")
|
|
||||||
if !exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
apiKey, ok := raw.(*APIKey)
|
|
||||||
if !ok || apiKey == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return apiKey.Group
|
|
||||||
}
|
|
||||||
|
|
||||||
func parsedRequestFromGinContext(c *gin.Context) *ParsedRequest {
|
|
||||||
if c == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
raw, exists := c.Get("parsed_request")
|
|
||||||
if !exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
parsed, _ := raw.(*ParsedRequest)
|
|
||||||
return parsed
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyClaudeMaxSimulationToUsage(ctx context.Context, usage *ClaudeUsage, model string, accountID int64) claudeMaxCacheBillingOutcome {
|
|
||||||
var out claudeMaxCacheBillingOutcome
|
|
||||||
if usage == nil {
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
rewriteCtx := claudeMaxResponseRewriteContextFromContext(ctx)
|
|
||||||
return applyClaudeMaxCacheBillingPolicyToUsage(usage, rewriteCtx.Parsed, rewriteCtx.Group, model, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyClaudeMaxSimulationToUsageJSONMap(ctx context.Context, usageObj map[string]any, model string, accountID int64) claudeMaxCacheBillingOutcome {
|
|
||||||
var out claudeMaxCacheBillingOutcome
|
|
||||||
if usageObj == nil {
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
usage := claudeUsageFromJSONMap(usageObj)
|
|
||||||
out = applyClaudeMaxSimulationToUsage(ctx, &usage, model, accountID)
|
|
||||||
if out.Simulated {
|
|
||||||
rewriteClaudeUsageJSONMap(usageObj, usage)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func rewriteClaudeUsageJSONBytes(body []byte, usage ClaudeUsage) []byte {
|
|
||||||
updated := body
|
|
||||||
var err error
|
|
||||||
|
|
||||||
updated, err = sjson.SetBytes(updated, "usage.input_tokens", usage.InputTokens)
|
|
||||||
if err != nil {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
updated, err = sjson.SetBytes(updated, "usage.cache_creation_input_tokens", usage.CacheCreationInputTokens)
|
|
||||||
if err != nil {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_5m_input_tokens", usage.CacheCreation5mTokens)
|
|
||||||
if err != nil {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
updated, err = sjson.SetBytes(updated, "usage.cache_creation.ephemeral_1h_input_tokens", usage.CacheCreation1hTokens)
|
|
||||||
if err != nil {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
return updated
|
|
||||||
}
|
|
||||||
|
|
||||||
func claudeUsageFromJSONMap(usageObj map[string]any) ClaudeUsage {
|
|
||||||
var usage ClaudeUsage
|
|
||||||
if usageObj == nil {
|
|
||||||
return usage
|
|
||||||
}
|
|
||||||
|
|
||||||
usage.InputTokens = usageIntFromAny(usageObj["input_tokens"])
|
|
||||||
usage.OutputTokens = usageIntFromAny(usageObj["output_tokens"])
|
|
||||||
usage.CacheCreationInputTokens = usageIntFromAny(usageObj["cache_creation_input_tokens"])
|
|
||||||
usage.CacheReadInputTokens = usageIntFromAny(usageObj["cache_read_input_tokens"])
|
|
||||||
|
|
||||||
if ccObj, ok := usageObj["cache_creation"].(map[string]any); ok {
|
|
||||||
usage.CacheCreation5mTokens = usageIntFromAny(ccObj["ephemeral_5m_input_tokens"])
|
|
||||||
usage.CacheCreation1hTokens = usageIntFromAny(ccObj["ephemeral_1h_input_tokens"])
|
|
||||||
}
|
|
||||||
return usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func rewriteClaudeUsageJSONMap(usageObj map[string]any, usage ClaudeUsage) {
|
|
||||||
if usageObj == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
usageObj["input_tokens"] = usage.InputTokens
|
|
||||||
usageObj["cache_creation_input_tokens"] = usage.CacheCreationInputTokens
|
|
||||||
|
|
||||||
ccObj, _ := usageObj["cache_creation"].(map[string]any)
|
|
||||||
if ccObj == nil {
|
|
||||||
ccObj = make(map[string]any, 2)
|
|
||||||
usageObj["cache_creation"] = ccObj
|
|
||||||
}
|
|
||||||
ccObj["ephemeral_5m_input_tokens"] = usage.CacheCreation5mTokens
|
|
||||||
ccObj["ephemeral_1h_input_tokens"] = usage.CacheCreation1hTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
func usageIntFromAny(v any) int {
|
|
||||||
switch value := v.(type) {
|
|
||||||
case int:
|
|
||||||
return value
|
|
||||||
case int64:
|
|
||||||
return int(value)
|
|
||||||
case float64:
|
|
||||||
return int(value)
|
|
||||||
case json.Number:
|
|
||||||
if n, err := value.Int64(); err == nil {
|
|
||||||
return int(n)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupClaudeMaxStreamingHook 为 Antigravity 流式路径设置 SSE usage 改写 hook。
|
|
||||||
func setupClaudeMaxStreamingHook(c *gin.Context, processor *antigravity.StreamingProcessor, originalModel string, accountID int64) {
|
|
||||||
group := claudeMaxGroupFromGinContext(c)
|
|
||||||
parsed := parsedRequestFromGinContext(c)
|
|
||||||
if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
processor.SetUsageMapHook(func(usageMap map[string]any) {
|
|
||||||
svcUsage := claudeUsageFromJSONMap(usageMap)
|
|
||||||
outcome := applyClaudeMaxCacheBillingPolicyToUsage(&svcUsage, parsed, group, originalModel, accountID)
|
|
||||||
if outcome.Simulated {
|
|
||||||
rewriteClaudeUsageJSONMap(usageMap, svcUsage)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyClaudeMaxNonStreamingRewrite 为 Antigravity 非流式路径改写响应体中的 usage。
|
|
||||||
func applyClaudeMaxNonStreamingRewrite(c *gin.Context, claudeResp []byte, agUsage *antigravity.ClaudeUsage, originalModel string, accountID int64) []byte {
|
|
||||||
group := claudeMaxGroupFromGinContext(c)
|
|
||||||
parsed := parsedRequestFromGinContext(c)
|
|
||||||
if !shouldApplyClaudeMaxBillingRulesForUsage(group, originalModel, parsed) {
|
|
||||||
return claudeResp
|
|
||||||
}
|
|
||||||
svcUsage := &ClaudeUsage{
|
|
||||||
InputTokens: agUsage.InputTokens,
|
|
||||||
OutputTokens: agUsage.OutputTokens,
|
|
||||||
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
|
|
||||||
CacheReadInputTokens: agUsage.CacheReadInputTokens,
|
|
||||||
}
|
|
||||||
outcome := applyClaudeMaxCacheBillingPolicyToUsage(svcUsage, parsed, group, originalModel, accountID)
|
|
||||||
if outcome.Simulated {
|
|
||||||
return rewriteClaudeUsageJSONBytes(claudeResp, *svcUsage)
|
|
||||||
}
|
|
||||||
return claudeResp
|
|
||||||
}
|
|
||||||
@@ -1,199 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
type usageLogRepoRecordUsageStub struct {
|
|
||||||
UsageLogRepository
|
|
||||||
|
|
||||||
last *UsageLog
|
|
||||||
inserted bool
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *usageLogRepoRecordUsageStub) Create(_ context.Context, log *UsageLog) (bool, error) {
|
|
||||||
copied := *log
|
|
||||||
s.last = &copied
|
|
||||||
return s.inserted, s.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func newGatewayServiceForRecordUsageTest(repo UsageLogRepository) *GatewayService {
|
|
||||||
return &GatewayService{
|
|
||||||
usageLogRepo: repo,
|
|
||||||
billingService: NewBillingService(&config.Config{}, nil),
|
|
||||||
cfg: &config.Config{RunMode: config.RunModeSimple},
|
|
||||||
deferredService: &DeferredService{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRecordUsage_SimulateClaudeMaxEnabled_ProjectsUsageAndSkipsTTLOverride(t *testing.T) {
|
|
||||||
repo := &usageLogRepoRecordUsageStub{inserted: true}
|
|
||||||
svc := newGatewayServiceForRecordUsageTest(repo)
|
|
||||||
|
|
||||||
groupID := int64(11)
|
|
||||||
input := &RecordUsageInput{
|
|
||||||
Result: &ForwardResult{
|
|
||||||
RequestID: "req-sim-1",
|
|
||||||
Model: "claude-sonnet-4",
|
|
||||||
Duration: time.Second,
|
|
||||||
Usage: ClaudeUsage{
|
|
||||||
InputTokens: 160,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
ParsedRequest: &ParsedRequest{
|
|
||||||
Model: "claude-sonnet-4",
|
|
||||||
Messages: []any{
|
|
||||||
map[string]any{
|
|
||||||
"role": "user",
|
|
||||||
"content": []any{
|
|
||||||
map[string]any{
|
|
||||||
"type": "text",
|
|
||||||
"text": "long cached context for prior turns",
|
|
||||||
"cache_control": map[string]any{"type": "ephemeral"},
|
|
||||||
},
|
|
||||||
map[string]any{
|
|
||||||
"type": "text",
|
|
||||||
"text": "please summarize the logs and provide root cause analysis",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
APIKey: &APIKey{
|
|
||||||
ID: 1,
|
|
||||||
GroupID: &groupID,
|
|
||||||
Group: &Group{
|
|
||||||
ID: groupID,
|
|
||||||
Platform: PlatformAnthropic,
|
|
||||||
RateMultiplier: 1,
|
|
||||||
SimulateClaudeMaxEnabled: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
User: &User{ID: 2},
|
|
||||||
Account: &Account{
|
|
||||||
ID: 3,
|
|
||||||
Platform: PlatformAnthropic,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Extra: map[string]any{
|
|
||||||
"cache_ttl_override_enabled": true,
|
|
||||||
"cache_ttl_override_target": "5m",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err := svc.RecordUsage(context.Background(), input)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, repo.last)
|
|
||||||
|
|
||||||
log := repo.last
|
|
||||||
require.Equal(t, 80, log.InputTokens)
|
|
||||||
require.Equal(t, 80, log.CacheCreationTokens)
|
|
||||||
require.Equal(t, 0, log.CacheCreation5mTokens)
|
|
||||||
require.Equal(t, 80, log.CacheCreation1hTokens)
|
|
||||||
require.False(t, log.CacheTTLOverridden, "simulate outcome should skip account ttl override")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRecordUsage_SimulateClaudeMaxDisabled_AppliesTTLOverride(t *testing.T) {
|
|
||||||
repo := &usageLogRepoRecordUsageStub{inserted: true}
|
|
||||||
svc := newGatewayServiceForRecordUsageTest(repo)
|
|
||||||
|
|
||||||
groupID := int64(12)
|
|
||||||
input := &RecordUsageInput{
|
|
||||||
Result: &ForwardResult{
|
|
||||||
RequestID: "req-sim-2",
|
|
||||||
Model: "claude-sonnet-4",
|
|
||||||
Duration: time.Second,
|
|
||||||
Usage: ClaudeUsage{
|
|
||||||
InputTokens: 40,
|
|
||||||
CacheCreationInputTokens: 120,
|
|
||||||
CacheCreation1hTokens: 120,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
APIKey: &APIKey{
|
|
||||||
ID: 2,
|
|
||||||
GroupID: &groupID,
|
|
||||||
Group: &Group{
|
|
||||||
ID: groupID,
|
|
||||||
Platform: PlatformAnthropic,
|
|
||||||
RateMultiplier: 1,
|
|
||||||
SimulateClaudeMaxEnabled: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
User: &User{ID: 3},
|
|
||||||
Account: &Account{
|
|
||||||
ID: 4,
|
|
||||||
Platform: PlatformAnthropic,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Extra: map[string]any{
|
|
||||||
"cache_ttl_override_enabled": true,
|
|
||||||
"cache_ttl_override_target": "5m",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err := svc.RecordUsage(context.Background(), input)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, repo.last)
|
|
||||||
|
|
||||||
log := repo.last
|
|
||||||
require.Equal(t, 120, log.CacheCreationTokens)
|
|
||||||
require.Equal(t, 120, log.CacheCreation5mTokens)
|
|
||||||
require.Equal(t, 0, log.CacheCreation1hTokens)
|
|
||||||
require.True(t, log.CacheTTLOverridden)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRecordUsage_SimulateClaudeMaxEnabled_ExistingCacheCreationBypassesSimulation(t *testing.T) {
|
|
||||||
repo := &usageLogRepoRecordUsageStub{inserted: true}
|
|
||||||
svc := newGatewayServiceForRecordUsageTest(repo)
|
|
||||||
|
|
||||||
groupID := int64(13)
|
|
||||||
input := &RecordUsageInput{
|
|
||||||
Result: &ForwardResult{
|
|
||||||
RequestID: "req-sim-3",
|
|
||||||
Model: "claude-sonnet-4",
|
|
||||||
Duration: time.Second,
|
|
||||||
Usage: ClaudeUsage{
|
|
||||||
InputTokens: 20,
|
|
||||||
CacheCreationInputTokens: 120,
|
|
||||||
CacheCreation5mTokens: 120,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
APIKey: &APIKey{
|
|
||||||
ID: 3,
|
|
||||||
GroupID: &groupID,
|
|
||||||
Group: &Group{
|
|
||||||
ID: groupID,
|
|
||||||
Platform: PlatformAnthropic,
|
|
||||||
RateMultiplier: 1,
|
|
||||||
SimulateClaudeMaxEnabled: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
User: &User{ID: 4},
|
|
||||||
Account: &Account{
|
|
||||||
ID: 5,
|
|
||||||
Platform: PlatformAnthropic,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Extra: map[string]any{
|
|
||||||
"cache_ttl_override_enabled": true,
|
|
||||||
"cache_ttl_override_target": "5m",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err := svc.RecordUsage(context.Background(), input)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, repo.last)
|
|
||||||
|
|
||||||
log := repo.last
|
|
||||||
require.Equal(t, 20, log.InputTokens)
|
|
||||||
require.Equal(t, 120, log.CacheCreation5mTokens)
|
|
||||||
require.Equal(t, 0, log.CacheCreation1hTokens)
|
|
||||||
require.Equal(t, 120, log.CacheCreationTokens)
|
|
||||||
require.False(t, log.CacheTTLOverridden, "existing cache_creation with SimulateClaudeMax enabled should skip account ttl override")
|
|
||||||
}
|
|
||||||
371
backend/internal/service/gateway_record_usage_test.go
Normal file
371
backend/internal/service/gateway_record_usage_test.go
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Default.RateMultiplier = 1.1
|
||||||
|
return NewGatewayService(
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
usageRepo,
|
||||||
|
nil,
|
||||||
|
userRepo,
|
||||||
|
subRepo,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
cfg,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
NewBillingService(cfg, nil),
|
||||||
|
nil,
|
||||||
|
&BillingCacheService{},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
&DeferredService{},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
|
||||||
|
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||||
|
svc.usageBillingRepo = billingRepo
|
||||||
|
return svc
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIRecordUsageBestEffortLogRepoStub struct {
|
||||||
|
UsageLogRepository
|
||||||
|
|
||||||
|
bestEffortErr error
|
||||||
|
createErr error
|
||||||
|
bestEffortCalls int
|
||||||
|
createCalls int
|
||||||
|
lastLog *UsageLog
|
||||||
|
lastCtxErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIRecordUsageBestEffortLogRepoStub) CreateBestEffort(ctx context.Context, log *UsageLog) error {
|
||||||
|
s.bestEffortCalls++
|
||||||
|
s.lastLog = log
|
||||||
|
s.lastCtxErr = ctx.Err()
|
||||||
|
return s.bestEffortErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIRecordUsageBestEffortLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
|
||||||
|
s.createCalls++
|
||||||
|
s.lastLog = log
|
||||||
|
s.lastCtxErr = ctx.Err()
|
||||||
|
return false, s.createErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||||
|
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||||
|
|
||||||
|
reqCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := svc.RecordUsage(reqCtx, &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "gateway_detached_ctx",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 501,
|
||||||
|
Quota: 100,
|
||||||
|
},
|
||||||
|
User: &User{ID: 601},
|
||||||
|
Account: &Account{ID: 701},
|
||||||
|
APIKeyService: quotaSvc,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
|
require.Equal(t, 1, userRepo.deductCalls)
|
||||||
|
require.NoError(t, userRepo.lastCtxErr)
|
||||||
|
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||||
|
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||||
|
|
||||||
|
payloadHash := HashUsageRequestPayload([]byte(`{"messages":[{"role":"user","content":"hello"}]}`))
|
||||||
|
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "gateway_payload_hash",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||||
|
User: &User{ID: 601},
|
||||||
|
Account: &Account{ID: 701},
|
||||||
|
RequestPayloadHash: payloadHash,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, billingRepo.lastCmd)
|
||||||
|
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-123")
|
||||||
|
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "gateway_payload_fallback",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||||
|
User: &User{ID: 601},
|
||||||
|
Account: &Account{ID: 701},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, billingRepo.lastCmd)
|
||||||
|
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||||
|
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "gateway_not_persisted",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 503,
|
||||||
|
Quota: 100,
|
||||||
|
},
|
||||||
|
User: &User{ID: 603},
|
||||||
|
Account: &Account{ID: 703},
|
||||||
|
APIKeyService: quotaSvc,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
|
require.Equal(t, 1, userRepo.deductCalls)
|
||||||
|
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||||
|
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||||
|
|
||||||
|
reqCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := svc.RecordUsageWithLongContext(reqCtx, &RecordUsageLongContextInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "gateway_long_context_detached_ctx",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 12,
|
||||||
|
OutputTokens: 8,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 502,
|
||||||
|
Quota: 100,
|
||||||
|
},
|
||||||
|
User: &User{ID: 602},
|
||||||
|
Account: &Account{ID: 702},
|
||||||
|
LongContextThreshold: 200000,
|
||||||
|
LongContextMultiplier: 2,
|
||||||
|
APIKeyService: quotaSvc,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, usageRepo.calls)
|
||||||
|
require.Equal(t, 1, userRepo.deductCalls)
|
||||||
|
require.NoError(t, userRepo.lastCtxErr)
|
||||||
|
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||||
|
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "gateway-local-fallback")
|
||||||
|
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 504},
|
||||||
|
User: &User{ID: 604},
|
||||||
|
Account: &Account{ID: 704},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-stable-123")
|
||||||
|
ctx = context.WithValue(ctx, ctxkey.RequestID, "req-local-ignored")
|
||||||
|
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "upstream-volatile-456",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 506},
|
||||||
|
User: &User{ID: 606},
|
||||||
|
Account: &Account{ID: 706},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, billingRepo.lastCmd)
|
||||||
|
require.Equal(t, "client:client-stable-123", billingRepo.lastCmd.RequestID)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, "client:client-stable-123", usageRepo.lastLog.RequestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 507},
|
||||||
|
User: &User{ID: 607},
|
||||||
|
Account: &Account{ID: 707},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, billingRepo.lastCmd)
|
||||||
|
require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:"))
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_DroppedUsageLogDoesNotSyncFallback(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{
|
||||||
|
bestEffortErr: MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")),
|
||||||
|
}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "gateway_drop_usage_log",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 508},
|
||||||
|
User: &User{ID: 608},
|
||||||
|
Account: &Account{ID: 708},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, usageRepo.bestEffortCalls)
|
||||||
|
require.Equal(t, 0, usageRepo.createCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "gateway_billing_fail",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 505},
|
||||||
|
User: &User{ID: 605},
|
||||||
|
Account: &Account{ID: 705},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, 1, billingRepo.calls)
|
||||||
|
require.Equal(t, 0, usageRepo.calls)
|
||||||
|
}
|
||||||
@@ -1,170 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestHandleNonStreamingResponse_UsageAlignedWithClaudeMaxSimulation(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
svc := &GatewayService{
|
|
||||||
cfg: &config.Config{},
|
|
||||||
rateLimitService: &RateLimitService{},
|
|
||||||
}
|
|
||||||
|
|
||||||
account := &Account{
|
|
||||||
ID: 11,
|
|
||||||
Platform: PlatformAnthropic,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Extra: map[string]any{
|
|
||||||
"cache_ttl_override_enabled": true,
|
|
||||||
"cache_ttl_override_target": "5m",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
group := &Group{
|
|
||||||
ID: 99,
|
|
||||||
Platform: PlatformAnthropic,
|
|
||||||
SimulateClaudeMaxEnabled: true,
|
|
||||||
}
|
|
||||||
parsed := &ParsedRequest{
|
|
||||||
Model: "claude-sonnet-4",
|
|
||||||
Messages: []any{
|
|
||||||
map[string]any{
|
|
||||||
"role": "user",
|
|
||||||
"content": []any{
|
|
||||||
map[string]any{
|
|
||||||
"type": "text",
|
|
||||||
"text": "long cached context",
|
|
||||||
"cache_control": map[string]any{"type": "ephemeral"},
|
|
||||||
},
|
|
||||||
map[string]any{
|
|
||||||
"type": "text",
|
|
||||||
"text": "new user question",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
upstreamBody := []byte(`{"id":"msg_1","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`)
|
|
||||||
resp := &http.Response{
|
|
||||||
StatusCode: http.StatusOK,
|
|
||||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
||||||
Body: ioNopCloserBytes(upstreamBody),
|
|
||||||
}
|
|
||||||
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(rec)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil))
|
|
||||||
c.Set("api_key", &APIKey{Group: group})
|
|
||||||
requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed)
|
|
||||||
|
|
||||||
usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, usage)
|
|
||||||
|
|
||||||
var rendered struct {
|
|
||||||
Usage ClaudeUsage `json:"usage"`
|
|
||||||
}
|
|
||||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &rendered))
|
|
||||||
rendered.Usage.CacheCreation5mTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_5m_input_tokens").Int())
|
|
||||||
rendered.Usage.CacheCreation1hTokens = int(gjson.GetBytes(rec.Body.Bytes(), "usage.cache_creation.ephemeral_1h_input_tokens").Int())
|
|
||||||
|
|
||||||
require.Equal(t, rendered.Usage.InputTokens, usage.InputTokens)
|
|
||||||
require.Equal(t, rendered.Usage.OutputTokens, usage.OutputTokens)
|
|
||||||
require.Equal(t, rendered.Usage.CacheCreationInputTokens, usage.CacheCreationInputTokens)
|
|
||||||
require.Equal(t, rendered.Usage.CacheCreation5mTokens, usage.CacheCreation5mTokens)
|
|
||||||
require.Equal(t, rendered.Usage.CacheCreation1hTokens, usage.CacheCreation1hTokens)
|
|
||||||
require.Equal(t, rendered.Usage.CacheReadInputTokens, usage.CacheReadInputTokens)
|
|
||||||
|
|
||||||
require.Greater(t, usage.CacheCreation1hTokens, 0)
|
|
||||||
require.Equal(t, 0, usage.CacheCreation5mTokens)
|
|
||||||
require.Less(t, usage.InputTokens, 120)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandleNonStreamingResponse_ClaudeMaxDisabled_NoSimulationIntercept(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
svc := &GatewayService{
|
|
||||||
cfg: &config.Config{},
|
|
||||||
rateLimitService: &RateLimitService{},
|
|
||||||
}
|
|
||||||
|
|
||||||
account := &Account{
|
|
||||||
ID: 12,
|
|
||||||
Platform: PlatformAnthropic,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Extra: map[string]any{
|
|
||||||
"cache_ttl_override_enabled": true,
|
|
||||||
"cache_ttl_override_target": "5m",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
group := &Group{
|
|
||||||
ID: 100,
|
|
||||||
Platform: PlatformAnthropic,
|
|
||||||
SimulateClaudeMaxEnabled: false,
|
|
||||||
}
|
|
||||||
parsed := &ParsedRequest{
|
|
||||||
Model: "claude-sonnet-4",
|
|
||||||
Messages: []any{
|
|
||||||
map[string]any{
|
|
||||||
"role": "user",
|
|
||||||
"content": []any{
|
|
||||||
map[string]any{
|
|
||||||
"type": "text",
|
|
||||||
"text": "long cached context",
|
|
||||||
"cache_control": map[string]any{"type": "ephemeral"},
|
|
||||||
},
|
|
||||||
map[string]any{
|
|
||||||
"type": "text",
|
|
||||||
"text": "new user question",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
upstreamBody := []byte(`{"id":"msg_2","model":"claude-sonnet-4","usage":{"input_tokens":120,"output_tokens":8}}`)
|
|
||||||
resp := &http.Response{
|
|
||||||
StatusCode: http.StatusOK,
|
|
||||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
|
||||||
Body: ioNopCloserBytes(upstreamBody),
|
|
||||||
}
|
|
||||||
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(rec)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(nil))
|
|
||||||
c.Set("api_key", &APIKey{Group: group})
|
|
||||||
requestCtx := withClaudeMaxResponseRewriteContext(context.Background(), c, parsed)
|
|
||||||
|
|
||||||
usage, err := svc.handleNonStreamingResponse(requestCtx, resp, c, account, "claude-sonnet-4", "claude-sonnet-4")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, usage)
|
|
||||||
|
|
||||||
require.Equal(t, 120, usage.InputTokens)
|
|
||||||
require.Equal(t, 0, usage.CacheCreationInputTokens)
|
|
||||||
require.Equal(t, 0, usage.CacheCreation5mTokens)
|
|
||||||
require.Equal(t, 0, usage.CacheCreation1hTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ioNopCloserBytes(b []byte) *readCloserFromBytes {
|
|
||||||
return &readCloserFromBytes{Reader: bytes.NewReader(b)}
|
|
||||||
}
|
|
||||||
|
|
||||||
type readCloserFromBytes struct {
|
|
||||||
*bytes.Reader
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *readCloserFromBytes) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
267
backend/internal/service/gateway_service_bedrock_beta_test.go
Normal file
267
backend/internal/service/gateway_service_bedrock_beta_test.go
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type betaPolicySettingRepoStub struct {
|
||||||
|
values map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *betaPolicySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||||
|
panic("unexpected Get call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *betaPolicySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||||
|
if v, ok := s.values[key]; ok {
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
return "", ErrSettingNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *betaPolicySettingRepoStub) Set(ctx context.Context, key, value string) error {
|
||||||
|
panic("unexpected Set call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *betaPolicySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||||
|
panic("unexpected GetMultiple call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *betaPolicySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||||
|
panic("unexpected SetMultiple call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *betaPolicySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||||
|
panic("unexpected GetAll call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *betaPolicySettingRepoStub) Delete(ctx context.Context, key string) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveBedrockBetaTokensForRequest_BlocksOnOriginalAnthropicToken(t *testing.T) {
|
||||||
|
settings := &BetaPolicySettings{
|
||||||
|
Rules: []BetaPolicyRule{
|
||||||
|
{
|
||||||
|
BetaToken: "advanced-tool-use-2025-11-20",
|
||||||
|
Action: BetaPolicyActionBlock,
|
||||||
|
Scope: BetaPolicyScopeAll,
|
||||||
|
ErrorMessage: "advanced tool use is blocked",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
raw, err := json.Marshal(settings)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal settings: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
settingService: NewSettingService(
|
||||||
|
&betaPolicySettingRepoStub{values: map[string]string{
|
||||||
|
SettingKeyBetaPolicySettings: string(raw),
|
||||||
|
}},
|
||||||
|
&config.Config{},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
|
||||||
|
|
||||||
|
_, err = svc.resolveBedrockBetaTokensForRequest(
|
||||||
|
context.Background(),
|
||||||
|
account,
|
||||||
|
"advanced-tool-use-2025-11-20",
|
||||||
|
[]byte(`{"messages":[{"role":"user","content":"hi"}]}`),
|
||||||
|
"us.anthropic.claude-opus-4-6-v1",
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected raw advanced-tool-use token to be blocked before Bedrock transform")
|
||||||
|
}
|
||||||
|
if err.Error() != "advanced tool use is blocked" {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveBedrockBetaTokensForRequest_FiltersAfterBedrockTransform(t *testing.T) {
|
||||||
|
settings := &BetaPolicySettings{
|
||||||
|
Rules: []BetaPolicyRule{
|
||||||
|
{
|
||||||
|
BetaToken: "tool-search-tool-2025-10-19",
|
||||||
|
Action: BetaPolicyActionFilter,
|
||||||
|
Scope: BetaPolicyScopeAll,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
raw, err := json.Marshal(settings)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal settings: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
settingService: NewSettingService(
|
||||||
|
&betaPolicySettingRepoStub{values: map[string]string{
|
||||||
|
SettingKeyBetaPolicySettings: string(raw),
|
||||||
|
}},
|
||||||
|
&config.Config{},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
|
||||||
|
|
||||||
|
betaTokens, err := svc.resolveBedrockBetaTokensForRequest(
|
||||||
|
context.Background(),
|
||||||
|
account,
|
||||||
|
"advanced-tool-use-2025-11-20",
|
||||||
|
[]byte(`{"messages":[{"role":"user","content":"hi"}]}`),
|
||||||
|
"us.anthropic.claude-opus-4-6-v1",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
for _, token := range betaTokens {
|
||||||
|
if token == "tool-search-tool-2025-10-19" {
|
||||||
|
t.Fatalf("expected transformed Bedrock token to be filtered")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking 验证:
|
||||||
|
// 管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token,
|
||||||
|
// 但请求体包含 thinking 字段 → 自动注入后应被 block。
|
||||||
|
func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking(t *testing.T) {
|
||||||
|
settings := &BetaPolicySettings{
|
||||||
|
Rules: []BetaPolicyRule{
|
||||||
|
{
|
||||||
|
BetaToken: "interleaved-thinking-2025-05-14",
|
||||||
|
Action: BetaPolicyActionBlock,
|
||||||
|
Scope: BetaPolicyScopeAll,
|
||||||
|
ErrorMessage: "thinking is blocked",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
raw, err := json.Marshal(settings)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal settings: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
settingService: NewSettingService(
|
||||||
|
&betaPolicySettingRepoStub{values: map[string]string{
|
||||||
|
SettingKeyBetaPolicySettings: string(raw),
|
||||||
|
}},
|
||||||
|
&config.Config{},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
|
||||||
|
|
||||||
|
// header 中不带 beta token,但 body 中有 thinking 字段
|
||||||
|
_, err = svc.resolveBedrockBetaTokensForRequest(
|
||||||
|
context.Background(),
|
||||||
|
account,
|
||||||
|
"", // 空 header
|
||||||
|
[]byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`),
|
||||||
|
"us.anthropic.claude-opus-4-6-v1",
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected body-injected interleaved-thinking to be blocked")
|
||||||
|
}
|
||||||
|
if err.Error() != "thinking is blocked" {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch 验证:
|
||||||
|
// 管理员 block 了 tool-search-tool,客户端不在 header 中带 beta token,
|
||||||
|
// 但请求体包含 tool search 工具 → 自动注入后应被 block。
|
||||||
|
func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch(t *testing.T) {
|
||||||
|
settings := &BetaPolicySettings{
|
||||||
|
Rules: []BetaPolicyRule{
|
||||||
|
{
|
||||||
|
BetaToken: "tool-search-tool-2025-10-19",
|
||||||
|
Action: BetaPolicyActionBlock,
|
||||||
|
Scope: BetaPolicyScopeAll,
|
||||||
|
ErrorMessage: "tool search is blocked",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
raw, err := json.Marshal(settings)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal settings: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
settingService: NewSettingService(
|
||||||
|
&betaPolicySettingRepoStub{values: map[string]string{
|
||||||
|
SettingKeyBetaPolicySettings: string(raw),
|
||||||
|
}},
|
||||||
|
&config.Config{},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
|
||||||
|
|
||||||
|
// header 中不带 beta token,但 body 中有 tool_search_tool 工具
|
||||||
|
_, err = svc.resolveBedrockBetaTokensForRequest(
|
||||||
|
context.Background(),
|
||||||
|
account,
|
||||||
|
"",
|
||||||
|
[]byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`),
|
||||||
|
"us.anthropic.claude-sonnet-4-6",
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected body-injected tool-search-tool to be blocked")
|
||||||
|
}
|
||||||
|
if err.Error() != "tool search is blocked" {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches 验证:
|
||||||
|
// body 自动注入的 token 如果没有对应的 block 规则,应正常通过。
|
||||||
|
func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *testing.T) {
|
||||||
|
settings := &BetaPolicySettings{
|
||||||
|
Rules: []BetaPolicyRule{
|
||||||
|
{
|
||||||
|
BetaToken: "computer-use-2025-11-24",
|
||||||
|
Action: BetaPolicyActionBlock,
|
||||||
|
Scope: BetaPolicyScopeAll,
|
||||||
|
ErrorMessage: "computer use is blocked",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
raw, err := json.Marshal(settings)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal settings: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
settingService: NewSettingService(
|
||||||
|
&betaPolicySettingRepoStub{values: map[string]string{
|
||||||
|
SettingKeyBetaPolicySettings: string(raw),
|
||||||
|
}},
|
||||||
|
&config.Config{},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
|
||||||
|
|
||||||
|
// body 中有 thinking(会注入 interleaved-thinking),但 block 规则只针对 computer-use
|
||||||
|
tokens, err := svc.resolveBedrockBetaTokensForRequest(
|
||||||
|
context.Background(),
|
||||||
|
account,
|
||||||
|
"",
|
||||||
|
[]byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`),
|
||||||
|
"us.anthropic.claude-opus-4-6-v1",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for _, token := range tokens {
|
||||||
|
if token == "interleaved-thinking-2025-05-14" {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatal("expected interleaved-thinking token to be present")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestGatewayServiceIsModelSupportedByAccount_BedrockDefaultMappingRestrictsModels(t *testing.T) {
|
||||||
|
svc := &GatewayService{}
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeBedrock,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"aws_region": "us-east-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !svc.isModelSupportedByAccount(account, "claude-sonnet-4-5") {
|
||||||
|
t.Fatalf("expected default Bedrock alias to be supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
if svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022") {
|
||||||
|
t.Fatalf("expected unsupported alias to be rejected for Bedrock account")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceIsModelSupportedByAccount_BedrockCustomMappingStillActsAsAllowlist(t *testing.T) {
|
||||||
|
svc := &GatewayService{}
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeBedrock,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"aws_region": "eu-west-1",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"claude-sonnet-*": "claude-sonnet-4-6",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !svc.isModelSupportedByAccount(account, "claude-sonnet-4-6") {
|
||||||
|
t.Fatalf("expected matched custom mapping to be supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !svc.isModelSupportedByAccount(account, "claude-opus-4-6") {
|
||||||
|
t.Fatalf("expected default Bedrock alias fallback to remain supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
if svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022") {
|
||||||
|
t.Fatalf("expected unsupported model to still be rejected")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -181,7 +181,8 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) {
|
|||||||
|
|
||||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||||
_ = pr.Close()
|
_ = pr.Close()
|
||||||
require.NoError(t, err)
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "missing terminal event")
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -50,9 +50,6 @@ type Group struct {
|
|||||||
// MCP XML 协议注入开关(仅 antigravity 平台使用)
|
// MCP XML 协议注入开关(仅 antigravity 平台使用)
|
||||||
MCPXMLInject bool
|
MCPXMLInject bool
|
||||||
|
|
||||||
// Claude usage 模拟开关:将无写缓存 usage 模拟为 claude-max 风格
|
|
||||||
SimulateClaudeMaxEnabled bool
|
|
||||||
|
|
||||||
// 支持的模型系列(仅 antigravity 平台使用)
|
// 支持的模型系列(仅 antigravity 平台使用)
|
||||||
// 可选值: claude, gemini_text, gemini_image
|
// 可选值: claude, gemini_text, gemini_image
|
||||||
SupportedModelScopes []string
|
SupportedModelScopes []string
|
||||||
|
|||||||
@@ -129,6 +129,41 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 兼容遗留的 functions 和 function_call,转换为 tools 和 tool_choice
|
||||||
|
if functionsRaw, ok := reqBody["functions"]; ok {
|
||||||
|
if functions, k := functionsRaw.([]any); k {
|
||||||
|
tools := make([]any, 0, len(functions))
|
||||||
|
for _, f := range functions {
|
||||||
|
tools = append(tools, map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": f,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
reqBody["tools"] = tools
|
||||||
|
}
|
||||||
|
delete(reqBody, "functions")
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if fcRaw, ok := reqBody["function_call"]; ok {
|
||||||
|
if fcStr, ok := fcRaw.(string); ok {
|
||||||
|
// e.g. "auto", "none"
|
||||||
|
reqBody["tool_choice"] = fcStr
|
||||||
|
} else if fcObj, ok := fcRaw.(map[string]any); ok {
|
||||||
|
// e.g. {"name": "my_func"}
|
||||||
|
if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" {
|
||||||
|
reqBody["tool_choice"] = map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": name,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(reqBody, "function_call")
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
|
||||||
if normalizeCodexTools(reqBody) {
|
if normalizeCodexTools(reqBody) {
|
||||||
result.Modified = true
|
result.Modified = true
|
||||||
}
|
}
|
||||||
@@ -303,6 +338,18 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
typ, _ := m["type"].(string)
|
typ, _ := m["type"].(string)
|
||||||
|
|
||||||
|
// 修复 OpenAI 上游的最新校验:"Expected an ID that begins with 'fc'"
|
||||||
|
fixIDPrefix := func(id string) string {
|
||||||
|
if id == "" || strings.HasPrefix(id, "fc") {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(id, "call_") {
|
||||||
|
return "fc" + strings.TrimPrefix(id, "call_")
|
||||||
|
}
|
||||||
|
return "fc_" + id
|
||||||
|
}
|
||||||
|
|
||||||
if typ == "item_reference" {
|
if typ == "item_reference" {
|
||||||
if !preserveReferences {
|
if !preserveReferences {
|
||||||
continue
|
continue
|
||||||
@@ -311,6 +358,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
|||||||
for key, value := range m {
|
for key, value := range m {
|
||||||
newItem[key] = value
|
newItem[key] = value
|
||||||
}
|
}
|
||||||
|
if id, ok := newItem["id"].(string); ok && id != "" {
|
||||||
|
newItem["id"] = fixIDPrefix(id)
|
||||||
|
}
|
||||||
filtered = append(filtered, newItem)
|
filtered = append(filtered, newItem)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -330,10 +380,20 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if isCodexToolCallItemType(typ) {
|
if isCodexToolCallItemType(typ) {
|
||||||
if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" {
|
callID, ok := m["call_id"].(string)
|
||||||
|
if !ok || strings.TrimSpace(callID) == "" {
|
||||||
if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" {
|
if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" {
|
||||||
|
callID = id
|
||||||
ensureCopy()
|
ensureCopy()
|
||||||
newItem["call_id"] = id
|
newItem["call_id"] = callID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if callID != "" {
|
||||||
|
fixedCallID := fixIDPrefix(callID)
|
||||||
|
if fixedCallID != callID {
|
||||||
|
ensureCopy()
|
||||||
|
newItem["call_id"] = fixedCallID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -344,6 +404,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
|||||||
if !isCodexToolCallItemType(typ) {
|
if !isCodexToolCallItemType(typ) {
|
||||||
delete(newItem, "call_id")
|
delete(newItem, "call_id")
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
if id, ok := newItem["id"].(string); ok && id != "" {
|
||||||
|
fixedID := fixIDPrefix(id)
|
||||||
|
if fixedID != id {
|
||||||
|
ensureCopy()
|
||||||
|
newItem["id"] = fixedID
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
filtered = append(filtered, newItem)
|
filtered = append(filtered, newItem)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user