mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 07:22:13 +08:00
Compare commits
80 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6da5fa01b9 | ||
|
|
616930f9d3 | ||
|
|
b9c31fa7c4 | ||
|
|
17b339972c | ||
|
|
39f8bd91b9 | ||
|
|
aa4e37d085 | ||
|
|
f59b66b7d4 | ||
|
|
8f0ea7a02d | ||
|
|
a1dc00890e | ||
|
|
dfbcc363d1 | ||
|
|
1047f973d5 | ||
|
|
e32977dd73 | ||
|
|
b5f78ec1e8 | ||
|
|
e0f290fdc8 | ||
|
|
fc00a4e3b2 | ||
|
|
db1f6ded88 | ||
|
|
4644af2ccc | ||
|
|
2e3e8687e1 | ||
|
|
ca42a45802 | ||
|
|
9350ecb62b | ||
|
|
a4a026e8da | ||
|
|
342fd03e72 | ||
|
|
e3f1fd9b63 | ||
|
|
e4a4dfd038 | ||
|
|
a377e99088 | ||
|
|
1d3d7a3033 | ||
|
|
e7086cb3a3 | ||
|
|
4f2a97073e | ||
|
|
7407e3b45d | ||
|
|
01ef7340aa | ||
|
|
1c960d22c1 | ||
|
|
ece0606fed | ||
|
|
2666422b99 | ||
|
|
e6d59216d4 | ||
|
|
4e8615f276 | ||
|
|
91e4d95660 | ||
|
|
45456fa24c | ||
|
|
4588258d80 | ||
|
|
c12e48f966 | ||
|
|
ec8f50a658 | ||
|
|
99c9191784 | ||
|
|
6bb02d141f | ||
|
|
07bb2a5f3f | ||
|
|
417861a48e | ||
|
|
b7e878de64 | ||
|
|
05edb5514b | ||
|
|
e90ec847b6 | ||
|
|
6344fa2a86 | ||
|
|
7e288acc90 | ||
|
|
29b0e4a8a5 | ||
|
|
27ff222cfb | ||
|
|
11f7b83522 | ||
|
|
f7177be3b6 | ||
|
|
875b417fde | ||
|
|
2573107b32 | ||
|
|
5b85005945 | ||
|
|
1ee984478f | ||
|
|
fd693dc526 | ||
|
|
e73531ce9b | ||
|
|
53ad1645cf | ||
|
|
ecea13757b | ||
|
|
af9c4a7dd0 | ||
|
|
80d8d6c3bc | ||
|
|
d648811233 | ||
|
|
34695acb85 | ||
|
|
a63de12182 | ||
|
|
f16910d616 | ||
|
|
64b3f3cec1 | ||
|
|
6a685727d0 | ||
|
|
32d25f76fc | ||
|
|
69cafe8674 | ||
|
|
18ba8d9166 | ||
|
|
e97fd7e81c | ||
|
|
cdb64b0d33 | ||
|
|
8d4d3b03bb | ||
|
|
addefe79e1 | ||
|
|
b764d3b8f6 | ||
|
|
611fd884bd | ||
|
|
6826149a8f | ||
|
|
c9debc50b1 |
20
Dockerfile
20
Dockerfile
@@ -9,6 +9,7 @@
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
|
||||
@@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \
|
||||
./cmd/server
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 3: Final Runtime Image
|
||||
# Stage 3: PostgreSQL Client (version-matched with docker-compose)
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 4: Final Runtime Image
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${ALPINE_IMAGE}
|
||||
|
||||
@@ -86,8 +92,20 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||
RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Copy pg_dump and psql from the same postgres image used in docker-compose
|
||||
# This ensures version consistency between backup tools and the database server
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
|
||||
10
README.md
10
README.md
@@ -39,6 +39,16 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
- **Concurrency Control** - Per-user and per-account concurrency limits
|
||||
- **Rate Limiting** - Configurable request and token rate limits
|
||||
- **Admin Dashboard** - Web interface for monitoring and management
|
||||
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
|
||||
|
||||
## Ecosystem
|
||||
|
||||
Community projects that extend or integrate with Sub2API:
|
||||
|
||||
| Project | Description | Features |
|
||||
|---------|-------------|----------|
|
||||
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | Self-service payment system | Self-service top-up and subscription purchase; supports YiPay protocol, WeChat Pay, Alipay, Stripe; embeddable via iframe |
|
||||
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native |
|
||||
|
||||
## Tech Stack
|
||||
|
||||
|
||||
10
README_CN.md
10
README_CN.md
@@ -39,6 +39,16 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
||||
- **并发控制** - 用户级和账号级并发限制
|
||||
- **速率限制** - 可配置的请求和 Token 速率限制
|
||||
- **管理后台** - Web 界面进行监控和管理
|
||||
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
|
||||
|
||||
## 生态项目
|
||||
|
||||
围绕 Sub2API 的社区扩展与集成项目:
|
||||
|
||||
| 项目 | 说明 | 功能 |
|
||||
|------|------|------|
|
||||
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | 自助支付系统 | 用户自助充值、自助订阅购买;兼容易支付协议、微信官方支付、支付宝官方支付、Stripe;支持 iframe 嵌入管理后台 |
|
||||
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用(iOS/Android/Web),支持用户管理、账号管理、监控看板、多后端切换;基于 Expo + React Native 构建 |
|
||||
|
||||
## 技术栈
|
||||
|
||||
|
||||
@@ -41,6 +41,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
// Server layer ProviderSet
|
||||
server.ProviderSet,
|
||||
|
||||
// Privacy client factory for OpenAI training opt-out
|
||||
providePrivacyClientFactory,
|
||||
|
||||
// BuildInfo provider
|
||||
provideServiceBuildInfo,
|
||||
|
||||
@@ -53,6 +56,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||
return repository.CreatePrivacyReqClient
|
||||
}
|
||||
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
@@ -87,6 +94,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -223,6 +231,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||
@@ -104,7 +105,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
|
||||
privacyClientFactory := providePrivacyClientFactory()
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
@@ -144,6 +146,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||
dataManagementService := service.NewDataManagementService()
|
||||
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
||||
backupObjectStoreFactory := repository.NewS3BackupStoreFactory()
|
||||
dbDumper := repository.NewPgDumper(configConfig)
|
||||
backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper)
|
||||
backupHandler := admin.NewBackupHandler(backupService, userService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||
@@ -162,9 +168,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
@@ -199,7 +205,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
@@ -226,11 +232,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -245,6 +251,10 @@ type Application struct {
|
||||
Cleanup func()
|
||||
}
|
||||
|
||||
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||
return repository.CreatePrivacyReqClient
|
||||
}
|
||||
|
||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||
return service.BuildInfo{
|
||||
Version: buildInfo.Version,
|
||||
@@ -279,6 +289,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -414,6 +425,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -75,6 +75,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
antigravityOAuthSvc,
|
||||
nil, // openAIGateway
|
||||
nil, // scheduledTestRunner
|
||||
nil, // backupSvc
|
||||
)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
|
||||
@@ -7,7 +7,7 @@ require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/DouDOU-start/go-sora2api v1.1.0
|
||||
github.com/alitto/pond/v2 v2.6.2
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||
@@ -66,7 +66,7 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
||||
github.com/aws/smithy-go v1.24.1 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/bdandy/go-errors v1.2.2 // indirect
|
||||
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||
|
||||
@@ -24,6 +24,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
|
||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
|
||||
@@ -60,6 +62,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
||||
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
||||
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
||||
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
||||
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
||||
|
||||
@@ -934,9 +934,10 @@ type DashboardAggregationConfig struct {
|
||||
|
||||
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
||||
type DashboardAggregationRetentionConfig struct {
|
||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||
HourlyDays int `mapstructure:"hourly_days"`
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
|
||||
HourlyDays int `mapstructure:"hourly_days"`
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
}
|
||||
|
||||
// UsageCleanupConfig 使用记录清理任务配置
|
||||
@@ -1301,6 +1302,7 @@ func setDefaults() {
|
||||
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
||||
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
||||
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
||||
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
|
||||
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
||||
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||
@@ -1758,6 +1760,12 @@ func (c *Config) Validate() error {
|
||||
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||
}
|
||||
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
||||
}
|
||||
@@ -1780,6 +1788,14 @@ func (c *Config) Validate() error {
|
||||
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
|
||||
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||
}
|
||||
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
||||
}
|
||||
|
||||
@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
||||
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
||||
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
|
||||
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
||||
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
||||
}
|
||||
@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
||||
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation dedup retention",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
|
||||
},
|
||||
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation dedup retention smaller than usage logs",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.Retention.UsageLogsDays = 30
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
|
||||
},
|
||||
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation disabled interval",
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
||||
|
||||
@@ -31,6 +31,7 @@ const (
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@@ -113,3 +114,27 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||
}
|
||||
|
||||
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
|
||||
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
|
||||
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
|
||||
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
|
||||
var DefaultBedrockModelMapping = map[string]string{
|
||||
// Claude Opus
|
||||
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
|
||||
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
|
||||
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0",
|
||||
// Claude Sonnet
|
||||
"claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6",
|
||||
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
|
||||
"claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
// Claude Haiku
|
||||
"claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
}
|
||||
|
||||
@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -865,6 +865,9 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI OAuth: 刷新成功后检查并设置 privacy_mode
|
||||
h.adminService.EnsureOpenAIPrivacy(ctx, updatedAccount)
|
||||
|
||||
return updatedAccount, "", nil
|
||||
}
|
||||
|
||||
@@ -1715,13 +1718,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
// Handle OpenAI accounts
|
||||
if account.IsOpenAI() {
|
||||
// For OAuth accounts: return default OpenAI models
|
||||
if account.IsOAuth() {
|
||||
// OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。
|
||||
if account.IsOpenAIPassthroughEnabled() {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// For API Key accounts: check model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type availableModelsAdminService struct {
|
||||
*stubAdminService
|
||||
account service.Account
|
||||
}
|
||||
|
||||
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
|
||||
if s.account.ID == id {
|
||||
acc := s.account
|
||||
return &acc, nil
|
||||
}
|
||||
return s.stubAdminService.GetAccount(context.Background(), id)
|
||||
}
|
||||
|
||||
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 42,
|
||||
Name: "openai-oauth",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 1)
|
||||
require.Equal(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 43,
|
||||
Name: "openai-oauth-passthrough",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_passthrough": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
@@ -175,6 +175,18 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) ([]service.UserGroupRateEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ClearGroupRateMultipliers(_ context.Context, _ int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int64, _ []service.GroupRateMultiplierInput) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||
return s.accounts, int64(len(s.accounts)), nil
|
||||
}
|
||||
@@ -429,5 +441,9 @@ func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) EnsureOpenAIPrivacy(ctx context.Context, account *service.Account) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Ensure stub implements interface.
|
||||
var _ service.AdminService = (*stubAdminService)(nil)
|
||||
|
||||
204
backend/internal/handler/admin/backup_handler.go
Normal file
204
backend/internal/handler/admin/backup_handler.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type BackupHandler struct {
|
||||
backupService *service.BackupService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
|
||||
return &BackupHandler{
|
||||
backupService: backupService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// ─── S3 配置 ───
|
||||
|
||||
func (h *BackupHandler) GetS3Config(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetS3Config(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
err := h.backupService.TestS3Connection(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.Success(c, gin.H{"ok": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
|
||||
}
|
||||
|
||||
// ─── 定时备份 ───
|
||||
|
||||
func (h *BackupHandler) GetSchedule(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetSchedule(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
|
||||
var req service.BackupScheduleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
// ─── 备份操作 ───
|
||||
|
||||
type CreateBackupRequest struct {
|
||||
ExpireDays *int `json:"expire_days"` // nil=使用默认值14,0=永不过期
|
||||
}
|
||||
|
||||
func (h *BackupHandler) CreateBackup(c *gin.Context) {
|
||||
var req CreateBackupRequest
|
||||
_ = c.ShouldBindJSON(&req) // 允许空 body
|
||||
|
||||
expireDays := 14 // 默认14天过期
|
||||
if req.ExpireDays != nil {
|
||||
expireDays = *req.ExpireDays
|
||||
}
|
||||
|
||||
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
||||
records, err := h.backupService.ListBackups(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if records == nil {
|
||||
records = []service.BackupRecord{}
|
||||
}
|
||||
response.Success(c, gin.H{"items": records})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"url": url})
|
||||
}
|
||||
|
||||
// ─── 恢复操作(需要重新输入管理员密码) ───
|
||||
|
||||
type RestoreBackupRequest struct {
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req RestoreBackupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "password is required for restore operation")
|
||||
return
|
||||
}
|
||||
|
||||
// 从上下文获取当前管理员用户 ID
|
||||
sub, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取管理员用户并验证密码
|
||||
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if !user.CheckPassword(req.Password) {
|
||||
response.BadRequest(c, "incorrect admin password")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"restored": true})
|
||||
}
|
||||
@@ -466,9 +466,60 @@ type BatchUsersUsageRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
}
|
||||
|
||||
var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
||||
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
func parseRankingLimit(raw string) int {
|
||||
limit, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||
if err != nil || limit <= 0 {
|
||||
return 12
|
||||
}
|
||||
if limit > 50 {
|
||||
return 50
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
// GetUserSpendingRanking handles getting user spending ranking data.
|
||||
// GET /api/v1/admin/dashboard/users-ranking
|
||||
func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
limit := parseRankingLimit(c.DefaultQuery("limit", "12"))
|
||||
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
Start string `json:"start"`
|
||||
End string `json:"end"`
|
||||
Limit int `json:"limit"`
|
||||
}{
|
||||
Start: startTime.UTC().Format(time.RFC3339),
|
||||
End: endTime.UTC().Format(time.RFC3339),
|
||||
Limit: limit,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok {
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user spending ranking")
|
||||
return
|
||||
}
|
||||
|
||||
payload := gin.H{
|
||||
"ranking": ranking.Ranking,
|
||||
"total_actual_cost": ranking.TotalActualCost,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
}
|
||||
dashboardUsersRankingCache.Set(cacheKey, payload)
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||
// POST /api/v1/admin/dashboard/users-usage
|
||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
|
||||
@@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct {
|
||||
trendStream *bool
|
||||
modelRequestType *int16
|
||||
modelStream *bool
|
||||
rankingLimit int
|
||||
ranking []usagestats.UserSpendingRankingItem
|
||||
rankingTotal float64
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
||||
@@ -49,6 +52,18 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
|
||||
return []usagestats.ModelStat{}, nil
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
limit int,
|
||||
) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
s.rankingLimit = limit
|
||||
return &usagestats.UserSpendingRankingResponse{
|
||||
Ranking: s.ranking,
|
||||
TotalActualCost: s.rankingTotal,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
@@ -56,6 +71,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
||||
router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking)
|
||||
return router
|
||||
}
|
||||
|
||||
@@ -130,3 +146,30 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
|
||||
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||
repo := &dashboardUsageRepoCapture{
|
||||
ranking: []usagestats.UserSpendingRankingItem{
|
||||
{UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300},
|
||||
},
|
||||
rankingTotal: 88.8,
|
||||
}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 50, repo.rankingLimit)
|
||||
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
|
||||
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||
}
|
||||
|
||||
@@ -335,6 +335,72 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
response.Paginated(c, outKeys, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetGroupRateMultipliers handles getting rate multipliers for users in a group
|
||||
// GET /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := h.adminService.GetGroupRateMultipliers(c.Request.Context(), groupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if entries == nil {
|
||||
entries = []service.UserGroupRateEntry{}
|
||||
}
|
||||
response.Success(c, entries)
|
||||
}
|
||||
|
||||
// ClearGroupRateMultipliers handles clearing all rate multipliers for a group
|
||||
// DELETE /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) ClearGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.ClearGroupRateMultipliers(c.Request.Context(), groupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rate multipliers cleared successfully"})
|
||||
}
|
||||
|
||||
// BatchSetGroupRateMultipliersRequest represents batch set rate multipliers request
|
||||
type BatchSetGroupRateMultipliersRequest struct {
|
||||
Entries []service.GroupRateMultiplierInput `json:"entries" binding:"required"`
|
||||
}
|
||||
|
||||
// BatchSetGroupRateMultipliers handles batch setting rate multipliers for a group
|
||||
// PUT /api/v1/admin/groups/:id/rate-multipliers
|
||||
func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req BatchSetGroupRateMultipliersRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.BatchSetGroupRateMultipliers(c.Request.Context(), groupID, req.Entries); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
|
||||
}
|
||||
|
||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||
type UpdateSortOrderRequest struct {
|
||||
Updates []struct {
|
||||
|
||||
@@ -289,6 +289,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
Platform: platform,
|
||||
Type: "oauth",
|
||||
Credentials: credentials,
|
||||
Extra: nil,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
|
||||
@@ -41,12 +41,15 @@ type GenerateRedeemCodesRequest struct {
|
||||
}
|
||||
|
||||
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
||||
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。
|
||||
type CreateAndRedeemCodeRequest struct {
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
||||
Value float64 `json:"value" binding:"required,gt=0"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
Notes string `json:"notes"`
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
|
||||
Value float64 `json:"value" binding:"required,gt=0"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
GroupID *int64 `json:"group_id"` // subscription 类型必填
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all redeem codes with pagination
|
||||
@@ -136,6 +139,22 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
req.Code = strings.TrimSpace(req.Code)
|
||||
// 向后兼容:旧版调用方(如 Sub2ApiPay)不传 type 字段,默认当作 balance 充值处理。
|
||||
// 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。
|
||||
if req.Type == "" {
|
||||
req.Type = "balance"
|
||||
}
|
||||
|
||||
if req.Type == "subscription" {
|
||||
if req.GroupID == nil {
|
||||
response.BadRequest(c, "group_id is required for subscription type")
|
||||
return
|
||||
}
|
||||
if req.ValidityDays <= 0 {
|
||||
response.BadRequest(c, "validity_days must be greater than 0 for subscription type")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
||||
@@ -147,11 +166,13 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
}
|
||||
|
||||
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
|
||||
Code: req.Code,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
Status: service.StatusUnused,
|
||||
Notes: req.Notes,
|
||||
Code: req.Code,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
Status: service.StatusUnused,
|
||||
Notes: req.Notes,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if createErr != nil {
|
||||
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
||||
|
||||
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal)
|
||||
// RedeemService so that CreateAndRedeem's nil guard passes and we can test the
|
||||
// parameter-validation layer that runs before any service call.
|
||||
func newCreateAndRedeemHandler() *RedeemHandler {
|
||||
return &RedeemHandler{
|
||||
adminService: newStubAdminService(),
|
||||
redeemService: &service.RedeemService{}, // non-nil to pass nil guard
|
||||
}
|
||||
}
|
||||
|
||||
// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response
|
||||
// status code. For cases that pass validation and proceed into the service layer,
|
||||
// a panic may occur (because RedeemService internals are nil); this is expected
|
||||
// and treated as "validation passed" (returns 0 to indicate panic).
|
||||
func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
jsonBytes, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Panic means we passed validation and entered service layer (expected for minimal stub).
|
||||
code = 0
|
||||
}
|
||||
}()
|
||||
handler.CreateAndRedeem(c)
|
||||
return w.Code
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) {
|
||||
// 不传 type 字段时应默认 balance,不触发 subscription 校验。
|
||||
// 验证通过后进入 service 层会 panic(返回 0),说明默认值生效。
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-balance-default",
|
||||
"value": 10.0,
|
||||
"user_id": 1,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"omitting type should default to balance and pass validation")
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-no-group",
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"validity_days": 30,
|
||||
// group_id 缺失
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, code)
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) {
|
||||
groupID := int64(5)
|
||||
h := newCreateAndRedeemHandler()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
validityDays int
|
||||
}{
|
||||
{"zero", 0},
|
||||
{"negative", -1},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-bad-days-" + tc.name,
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"group_id": groupID,
|
||||
"validity_days": tc.validityDays,
|
||||
})
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
|
||||
groupID := int64(5)
|
||||
h := newCreateAndRedeemHandler()
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-sub-valid",
|
||||
"type": "subscription",
|
||||
"value": 29.9,
|
||||
"user_id": 1,
|
||||
"group_id": groupID,
|
||||
"validity_days": 31,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"valid subscription params should pass validation")
|
||||
}
|
||||
|
||||
func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
|
||||
h := newCreateAndRedeemHandler()
|
||||
// balance 类型不传 group_id 和 validity_days,不应报 400
|
||||
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||
"code": "test-balance-no-extras",
|
||||
"type": "balance",
|
||||
"value": 50.0,
|
||||
"user_id": 1,
|
||||
})
|
||||
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"balance type should not require group_id or validity_days")
|
||||
}
|
||||
@@ -125,6 +125,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -199,6 +200,9 @@ type UpdateSettingsRequest struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -473,6 +477,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: req.BackendModeEnabled,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -571,6 +576,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: updatedSettings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -725,6 +731,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
||||
changed = append(changed, "allow_ungrouped_key_scheduling")
|
||||
}
|
||||
if before.BackendModeEnabled != after.BackendModeEnabled {
|
||||
changed = append(changed, "backend_mode_enabled")
|
||||
}
|
||||
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
||||
changed = append(changed, "purchase_subscription_enabled")
|
||||
}
|
||||
|
||||
@@ -218,11 +218,12 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
|
||||
// ResetSubscriptionQuotaRequest represents the reset quota request
|
||||
type ResetSubscriptionQuotaRequest struct {
|
||||
Daily bool `json:"daily"`
|
||||
Weekly bool `json:"weekly"`
|
||||
Daily bool `json:"daily"`
|
||||
Weekly bool `json:"weekly"`
|
||||
Monthly bool `json:"monthly"`
|
||||
}
|
||||
|
||||
// ResetQuota resets daily and/or weekly usage for a subscription.
|
||||
// ResetQuota resets daily, weekly, and/or monthly usage for a subscription.
|
||||
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
||||
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
@@ -235,11 +236,11 @@ func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if !req.Daily && !req.Weekly {
|
||||
response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true")
|
||||
if !req.Daily && !req.Weekly && !req.Monthly {
|
||||
response.BadRequest(c, "At least one of 'daily', 'weekly', or 'monthly' must be true")
|
||||
return
|
||||
}
|
||||
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly)
|
||||
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly, req.Monthly)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
// Get the user
|
||||
// Get the user (before session deletion so we can check backend mode)
|
||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login (check BEFORE deleting session)
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session (only after all checks pass)
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: block non-admin token refresh
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, RefreshTokenResponse{
|
||||
AccessToken: tokenPair.AccessToken,
|
||||
RefreshToken: tokenPair.RefreshToken,
|
||||
ExpiresIn: tokenPair.ExpiresIn,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresIn: result.ExpiresIn,
|
||||
TokenType: "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -264,8 +264,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
// 提取账号配额限制(apikey / bedrock 类型有效)
|
||||
if a.IsAPIKeyOrBedrock() {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
used := a.GetQuotaUsed()
|
||||
@@ -281,6 +281,31 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
used := a.GetQuotaWeeklyUsed()
|
||||
out.QuotaWeeklyUsed = &used
|
||||
}
|
||||
// 固定时间重置配置
|
||||
if mode := a.GetQuotaDailyResetMode(); mode == "fixed" {
|
||||
out.QuotaDailyResetMode = &mode
|
||||
hour := a.GetQuotaDailyResetHour()
|
||||
out.QuotaDailyResetHour = &hour
|
||||
}
|
||||
if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" {
|
||||
out.QuotaWeeklyResetMode = &mode
|
||||
day := a.GetQuotaWeeklyResetDay()
|
||||
out.QuotaWeeklyResetDay = &day
|
||||
hour := a.GetQuotaWeeklyResetHour()
|
||||
out.QuotaWeeklyResetHour = &hour
|
||||
}
|
||||
if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||
tz := a.GetQuotaResetTimezone()
|
||||
out.QuotaResetTimezone = &tz
|
||||
}
|
||||
if a.Extra != nil {
|
||||
if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaDailyResetAt = &v
|
||||
}
|
||||
if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaWeeklyResetAt = &v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
@@ -81,6 +81,9 @@ type SystemSettings struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -111,6 +114,7 @@ type PublicSettings struct {
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
|
||||
@@ -203,6 +203,16 @@ type Account struct {
|
||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||
|
||||
// 配额固定时间重置配置
|
||||
QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"`
|
||||
QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"`
|
||||
QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"`
|
||||
QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"`
|
||||
QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"`
|
||||
QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"`
|
||||
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
|
||||
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
|
||||
@@ -434,19 +434,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
@@ -736,19 +738,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
|
||||
@@ -139,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||
&fakeGroupRepo{group: group},
|
||||
nil, // usageLogRepo
|
||||
nil, // usageBillingRepo
|
||||
nil, // userRepo
|
||||
nil, // userSubRepo
|
||||
nil, // userGroupRateRepo
|
||||
|
||||
@@ -503,6 +503,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
@@ -512,6 +513,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
|
||||
@@ -12,6 +12,7 @@ type AdminHandlers struct {
|
||||
Account *admin.AccountHandler
|
||||
Announcement *admin.AnnouncementHandler
|
||||
DataManagement *admin.DataManagementHandler
|
||||
Backup *admin.BackupHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
|
||||
@@ -181,13 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
|
||||
@@ -352,18 +352,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
@@ -653,14 +655,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
// 如果使用了降级模型调度,强制使用降级模型
|
||||
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时,
|
||||
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。
|
||||
defaultMappedModel := c.GetString("openai_messages_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
@@ -732,17 +729,19 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.messages"),
|
||||
@@ -1231,14 +1230,15 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
|
||||
@@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
||||
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||
return service.NewGatewayService(
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -399,17 +399,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||
|
||||
@@ -343,6 +343,9 @@ func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, e
|
||||
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -431,6 +434,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
testutil.StubGatewayCache{},
|
||||
cfg,
|
||||
nil,
|
||||
|
||||
@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
|
||||
accountHandler *admin.AccountHandler,
|
||||
announcementHandler *admin.AnnouncementHandler,
|
||||
dataManagementHandler *admin.DataManagementHandler,
|
||||
backupHandler *admin.BackupHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
|
||||
Account: accountHandler,
|
||||
Announcement: announcementHandler,
|
||||
DataManagement: dataManagementHandler,
|
||||
Backup: backupHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
@@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewAccountHandler,
|
||||
admin.NewAnnouncementHandler,
|
||||
admin.NewDataManagementHandler,
|
||||
admin.NewBackupHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
|
||||
@@ -19,6 +19,16 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
)
|
||||
|
||||
// ForbiddenError 表示上游返回 403 Forbidden
|
||||
type ForbiddenError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *ForbiddenError) Error() string {
|
||||
return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||
@@ -514,7 +524,20 @@ type ModelQuotaInfo struct {
|
||||
|
||||
// ModelInfo 模型信息
|
||||
type ModelInfo struct {
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
SupportsImages *bool `json:"supportsImages,omitempty"`
|
||||
SupportsThinking *bool `json:"supportsThinking,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
Recommended *bool `json:"recommended,omitempty"`
|
||||
MaxTokens *int `json:"maxTokens,omitempty"`
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"`
|
||||
}
|
||||
|
||||
// DeprecatedModelInfo 废弃模型转发信息
|
||||
type DeprecatedModelInfo struct {
|
||||
NewModelID string `json:"newModelId"`
|
||||
}
|
||||
|
||||
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
||||
@@ -524,7 +547,8 @@ type FetchAvailableModelsRequest struct {
|
||||
|
||||
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
||||
type FetchAvailableModelsResponse struct {
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"`
|
||||
}
|
||||
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
@@ -573,6 +597,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
return nil, nil, &ForbiddenError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBodyBytes),
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
@@ -189,6 +189,5 @@ var DefaultStopSequences = []string{
|
||||
"<|user|>",
|
||||
"<|endoftext|>",
|
||||
"<|end_of_turn|>",
|
||||
"[DONE]",
|
||||
"\n\nHuman:",
|
||||
}
|
||||
|
||||
@@ -105,6 +105,7 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Equal(t, "fc_call_1", items[2].CallID)
|
||||
assert.Empty(t, items[2].ID)
|
||||
assert.Equal(t, "function_call_output", items[3].Type)
|
||||
assert.Equal(t, "fc_call_1", items[3].CallID)
|
||||
assert.Equal(t, "Sunny, 72°F", items[3].Output)
|
||||
|
||||
@@ -277,7 +277,6 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
|
||||
CallID: fcID,
|
||||
Name: b.Name,
|
||||
Arguments: args,
|
||||
ID: fcID,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -99,6 +99,7 @@ func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
|
||||
// Check function_call item
|
||||
assert.Equal(t, "function_call", items[1].Type)
|
||||
assert.Equal(t, "call_1", items[1].CallID)
|
||||
assert.Empty(t, items[1].ID)
|
||||
assert.Equal(t, "ping", items[1].Name)
|
||||
|
||||
// Check function_call_output item
|
||||
@@ -252,6 +253,55 @@ func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Empty(t, items[2].ID)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantArrayContentPreserved(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"A"},{"type":"text","text":"B"}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Equal(t, "AB", parts[0].Text)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"internal plan"},{"type":"text","text":"final answer"}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Contains(t, parts[0].Text, "<thinking>internal plan</thinking>")
|
||||
assert.Contains(t, parts[0].Text, "final answer")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -344,8 +394,8 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
// Reasoning summary is prepended to text
|
||||
assert.Equal(t, "I thought about it.The answer is 42.", content)
|
||||
assert.Equal(t, "The answer is 42.", content)
|
||||
assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
||||
@@ -582,8 +632,35 @@ func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
|
||||
Delta: "Thinking...",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.done",
|
||||
}, state)
|
||||
require.Len(t, chunks, 0)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ReasoningThenTextAutoCloseTag(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SentRole = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.delta",
|
||||
Delta: "plan",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
assert.Equal(t, "plan", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: "answer",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "answer", *chunks[0].Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
func TestFinalizeResponsesChatStream(t *testing.T) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package apicompat
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ChatCompletionsToResponses converts a Chat Completions request into a
|
||||
@@ -174,8 +175,11 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
|
||||
// Emit assistant message with output_text if content is non-empty.
|
||||
if len(m.Content) > 0 {
|
||||
var s string
|
||||
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
|
||||
s, err := parseAssistantContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s != "" {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
@@ -196,13 +200,82 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: args,
|
||||
ID: tc.ID,
|
||||
})
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// parseAssistantContent returns assistant content as plain text.
|
||||
//
|
||||
// Supported formats:
|
||||
// - JSON string
|
||||
// - JSON array of typed parts (e.g. [{"type":"text","text":"..."}])
|
||||
//
|
||||
// For structured thinking/reasoning parts, it preserves semantics by wrapping
|
||||
// the text in explicit tags so downstream can still distinguish it from normal text.
|
||||
func parseAssistantContent(raw json.RawMessage) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var parts []map[string]any
|
||||
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||
// Keep compatibility with prior behavior: unsupported assistant content
|
||||
// formats are ignored instead of failing the whole request conversion.
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
write := func(v string) error {
|
||||
_, err := b.WriteString(v)
|
||||
return err
|
||||
}
|
||||
for _, p := range parts {
|
||||
typ, _ := p["type"].(string)
|
||||
text, _ := p["text"].(string)
|
||||
thinking, _ := p["thinking"].(string)
|
||||
|
||||
switch typ {
|
||||
case "thinking", "reasoning":
|
||||
if thinking != "" {
|
||||
if err := write("<thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write(thinking); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write("</thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else if text != "" {
|
||||
if err := write("<thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write(text); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write("</thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
default:
|
||||
if text != "" {
|
||||
if err := write(text); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
// chatToolToResponses converts a tool result message (role=tool) into a
|
||||
// function_call_output item.
|
||||
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
|
||||
@@ -29,6 +29,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
||||
}
|
||||
|
||||
var contentText string
|
||||
var reasoningText string
|
||||
var toolCalls []ChatToolCall
|
||||
|
||||
for _, item := range resp.Output {
|
||||
@@ -51,7 +52,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
||||
case "reasoning":
|
||||
for _, s := range item.Summary {
|
||||
if s.Type == "summary_text" && s.Text != "" {
|
||||
contentText += s.Text
|
||||
reasoningText += s.Text
|
||||
}
|
||||
}
|
||||
case "web_search_call":
|
||||
@@ -67,6 +68,9 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
||||
raw, _ := json.Marshal(contentText)
|
||||
msg.Content = raw
|
||||
}
|
||||
if reasoningText != "" {
|
||||
msg.ReasoningContent = reasoningText
|
||||
}
|
||||
|
||||
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
||||
|
||||
@@ -153,6 +157,8 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
|
||||
return resToChatHandleFuncArgsDelta(evt, state)
|
||||
case "response.reasoning_summary_text.delta":
|
||||
return resToChatHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return nil
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
return resToChatHandleCompleted(evt, state)
|
||||
default:
|
||||
@@ -276,8 +282,8 @@ func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEv
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
content := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||
reasoning := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ReasoningContent: &reasoning})}
|
||||
}
|
||||
|
||||
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
|
||||
@@ -361,11 +361,12 @@ type ChatStreamOptions struct {
|
||||
|
||||
// ChatMessage is a single message in the Chat Completions conversation.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
|
||||
// Legacy function calling
|
||||
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
|
||||
@@ -466,9 +467,10 @@ type ChatChunkChoice struct {
|
||||
|
||||
// ChatDelta carries incremental content in a streaming chunk.
|
||||
type ChatDelta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -96,12 +96,28 @@ type UserUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// UserSpendingRankingItem represents a user spending ranking row.
|
||||
type UserSpendingRankingItem struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
}
|
||||
|
||||
// UserSpendingRankingResponse represents ranking rows plus total spend for the time range.
|
||||
type UserSpendingRankingResponse struct {
|
||||
Ranking []UserSpendingRankingItem `json:"ranking"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||
type APIKeyUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
|
||||
@@ -397,9 +397,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
}
|
||||
// 普通账号编辑(如 model_mapping / credentials)也需要立即刷新单账号快照,
|
||||
// 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1727,8 +1727,96 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
||||
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
||||
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
||||
|
||||
// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired.
|
||||
// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes.
|
||||
const dailyExpiredExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||
THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||
ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
END
|
||||
)`
|
||||
|
||||
// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired.
|
||||
const weeklyExpiredExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||
THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||
ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
END
|
||||
)`
|
||||
|
||||
// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs.
|
||||
// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour.
|
||||
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||
const nextDailyResetAtExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||
THEN to_char((
|
||||
-- Compute today's reset point in the configured timezone, then pick next future one
|
||||
CASE WHEN NOW() >= (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- NOW() is at or past today's reset point → next reset is tomorrow
|
||||
THEN (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ '1 day'::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- NOW() is before today's reset point → next reset is today
|
||||
ELSE (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
END
|
||||
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||
ELSE NULL END
|
||||
)`
|
||||
|
||||
// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs.
|
||||
// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour.
|
||||
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||
const nextWeeklyResetAtExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||
THEN to_char((
|
||||
-- Compute this week's reset point in the configured timezone
|
||||
-- Step 1: get today's date at reset hour in configured tz
|
||||
-- Step 2: compute days forward to target weekday
|
||||
-- Step 3: if same day but past reset hour, advance 7 days
|
||||
CASE
|
||||
WHEN (
|
||||
-- days_forward = (target_day - current_day + 7) % 7
|
||||
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||
+ 7) % 7
|
||||
) = 0 AND NOW() >= (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- Same weekday and past reset hour → next week
|
||||
THEN (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ '7 days'::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
ELSE (
|
||||
-- Advance to target weekday this week (or next if days_forward > 0)
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ ((
|
||||
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||
+ 7) % 7
|
||||
) || ' days')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
END
|
||||
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||
ELSE NULL END
|
||||
)`
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
||||
// 日/周额度在周期过期时自动重置为 0 再递增。
|
||||
// 支持滚动窗口(rolling)和固定时间(fixed)两种重置模式。
|
||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
rows, err := r.sql.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
@@ -1739,31 +1827,35 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
CASE WHEN `+dailyExpiredExpr+`
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
CASE WHEN `+dailyExpiredExpr+`
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
-- 固定模式重置时更新下次重置时间
|
||||
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
|
||||
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
|
||||
ELSE '{}'::jsonb END
|
||||
ELSE '{}'::jsonb END
|
||||
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
CASE WHEN `+weeklyExpiredExpr+`
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
CASE WHEN `+weeklyExpiredExpr+`
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
-- 固定模式重置时更新下次重置时间
|
||||
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
|
||||
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
|
||||
ELSE '{}'::jsonb END
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
@@ -1796,12 +1888,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
}
|
||||
|
||||
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
||||
// 保留固定重置模式的配置字段(quota_daily_reset_mode 等),仅清零用量和窗口起始时间
|
||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
||||
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
|
||||
) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
|
||||
@@ -142,6 +142,35 @@ func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnCredentialsChange() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "sync-credentials-update",
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
account.Credentials = map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.2",
|
||||
},
|
||||
}
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
mapping, ok := cacheRecorder.setAccounts[0].Credentials["model_mapping"].(map[string]any)
|
||||
s.Require().True(ok)
|
||||
s.Require().Equal("gpt-5.2", mapping["gpt-5"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||
|
||||
|
||||
98
backend/internal/repository/backup_pg_dumper.go
Normal file
98
backend/internal/repository/backup_pg_dumper.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// PgDumper implements service.DBDumper using pg_dump/psql
|
||||
type PgDumper struct {
|
||||
cfg *config.DatabaseConfig
|
||||
}
|
||||
|
||||
// NewPgDumper creates a new PgDumper
|
||||
func NewPgDumper(cfg *config.Config) service.DBDumper {
|
||||
return &PgDumper{cfg: &cfg.Database}
|
||||
}
|
||||
|
||||
// Dump executes pg_dump and returns a streaming reader of the output
|
||||
func (d *PgDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--no-owner",
|
||||
"--no-acl",
|
||||
"--clean",
|
||||
"--if-exists",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_dump", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("start pg_dump: %w", err)
|
||||
}
|
||||
|
||||
// 返回一个 ReadCloser:读 stdout,关闭时等待进程退出
|
||||
return &cmdReadCloser{ReadCloser: stdout, cmd: cmd}, nil
|
||||
}
|
||||
|
||||
// Restore executes psql to restore from a streaming reader
|
||||
func (d *PgDumper) Restore(ctx context.Context, data io.Reader) error {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--single-transaction",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
cmd.Stdin = data
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%v: %s", err, string(output))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cmdReadCloser wraps a command stdout pipe and waits for the process on Close
|
||||
type cmdReadCloser struct {
|
||||
io.ReadCloser
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func (c *cmdReadCloser) Close() error {
|
||||
// Close the pipe first
|
||||
_ = c.ReadCloser.Close()
|
||||
// Wait for the process to exit
|
||||
if err := c.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("pg_dump exited with error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
116
backend/internal/repository/backup_s3_store.go
Normal file
116
backend/internal/repository/backup_s3_store.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage
|
||||
type S3BackupStore struct {
|
||||
client *s3.Client
|
||||
bucket string
|
||||
}
|
||||
|
||||
// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores
|
||||
func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
|
||||
return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) {
|
||||
region := cfg.Region
|
||||
if region == "" {
|
||||
region = "auto" // Cloudflare R2 默认 region
|
||||
}
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||||
awsconfig.WithRegion(region),
|
||||
awsconfig.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load aws config: %w", err)
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||
if cfg.Endpoint != "" {
|
||||
o.BaseEndpoint = &cfg.Endpoint
|
||||
}
|
||||
if cfg.ForcePathStyle {
|
||||
o.UsePathStyle = true
|
||||
}
|
||||
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
|
||||
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||
})
|
||||
|
||||
return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
|
||||
// 读取全部内容以获取大小(S3 PutObject 需要知道内容长度)
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
Body: bytes.NewReader(data),
|
||||
ContentType: &contentType,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S3 PutObject: %w", err)
|
||||
}
|
||||
return int64(len(data)), nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("S3 GetObject: %w", err)
|
||||
}
|
||||
return result.Body, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Delete(ctx context.Context, key string) error {
|
||||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) {
|
||||
presignClient := s3.NewPresignClient(s.client)
|
||||
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
}, s3.WithPresignExpires(expiry))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) HeadBucket(ctx context.Context) error {
|
||||
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: &s.bucket,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("S3 HeadBucket failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -17,6 +17,9 @@ type dashboardAggregationRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
const usageLogsCleanupBatchSize = 10000
|
||||
const usageBillingDedupCleanupBatchSize = 10000
|
||||
|
||||
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
||||
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
||||
if sqlDB == nil {
|
||||
@@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool {
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
if r == nil || r.sql == nil {
|
||||
return nil
|
||||
}
|
||||
loc := timezone.Location()
|
||||
startLocal := start.In(loc)
|
||||
endLocal := end.In(loc)
|
||||
@@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
|
||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
if db, ok := r.sql.(*sql.DB); ok {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
|
||||
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
|
||||
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
@@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c
|
||||
if isPartitioned {
|
||||
return r.dropUsageLogsPartitions(ctx, cutoff)
|
||||
}
|
||||
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
|
||||
return err
|
||||
for {
|
||||
res, err := r.sql.ExecContext(ctx, `
|
||||
WITH victims AS (
|
||||
SELECT ctid
|
||||
FROM usage_logs
|
||||
WHERE created_at < $1
|
||||
LIMIT $2
|
||||
)
|
||||
DELETE FROM usage_logs
|
||||
WHERE ctid IN (SELECT ctid FROM victims)
|
||||
`, cutoff.UTC(), usageLogsCleanupBatchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected < usageLogsCleanupBatchSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
for {
|
||||
res, err := r.sql.ExecContext(ctx, `
|
||||
WITH victims AS (
|
||||
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
|
||||
FROM usage_billing_dedup
|
||||
WHERE created_at < $1
|
||||
LIMIT $2
|
||||
), archived AS (
|
||||
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
|
||||
SELECT request_id, api_key_id, request_fingerprint, created_at
|
||||
FROM victims
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
)
|
||||
DELETE FROM usage_billing_dedup
|
||||
WHERE ctid IN (SELECT ctid FROM victims)
|
||||
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected < usageBillingDedupCleanupBatchSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
|
||||
@@ -262,6 +262,42 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se
|
||||
SetKey(k.Key).
|
||||
SetName(k.Name).
|
||||
SetStatus(k.Status)
|
||||
if k.Quota != 0 {
|
||||
create.SetQuota(k.Quota)
|
||||
}
|
||||
if k.QuotaUsed != 0 {
|
||||
create.SetQuotaUsed(k.QuotaUsed)
|
||||
}
|
||||
if k.RateLimit5h != 0 {
|
||||
create.SetRateLimit5h(k.RateLimit5h)
|
||||
}
|
||||
if k.RateLimit1d != 0 {
|
||||
create.SetRateLimit1d(k.RateLimit1d)
|
||||
}
|
||||
if k.RateLimit7d != 0 {
|
||||
create.SetRateLimit7d(k.RateLimit7d)
|
||||
}
|
||||
if k.Usage5h != 0 {
|
||||
create.SetUsage5h(k.Usage5h)
|
||||
}
|
||||
if k.Usage1d != 0 {
|
||||
create.SetUsage1d(k.Usage1d)
|
||||
}
|
||||
if k.Usage7d != 0 {
|
||||
create.SetUsage7d(k.Usage7d)
|
||||
}
|
||||
if k.Window5hStart != nil {
|
||||
create.SetWindow5hStart(*k.Window5hStart)
|
||||
}
|
||||
if k.Window1dStart != nil {
|
||||
create.SetWindow1dStart(*k.Window1dStart)
|
||||
}
|
||||
if k.Window7dStart != nil {
|
||||
create.SetWindow7dStart(*k.Window7dStart)
|
||||
}
|
||||
if k.ExpiresAt != nil {
|
||||
create.SetExpiresAt(*k.ExpiresAt)
|
||||
}
|
||||
if k.GroupID != nil {
|
||||
create.SetGroupID(*k.GroupID)
|
||||
}
|
||||
|
||||
@@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
|
||||
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
|
||||
|
||||
// usage_billing_dedup: billing idempotency narrow table
|
||||
var usageBillingDedupRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
|
||||
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
|
||||
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
|
||||
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
|
||||
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
|
||||
|
||||
var usageBillingDedupArchiveRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
|
||||
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
|
||||
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
|
||||
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
|
||||
|
||||
// settings table should exist
|
||||
var settingsRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
||||
@@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
|
||||
}
|
||||
|
||||
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
|
||||
t.Helper()
|
||||
|
||||
var exists bool
|
||||
err := tx.QueryRowContext(context.Background(), `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = 'public'
|
||||
AND tablename = $1
|
||||
AND indexname = $2
|
||||
)
|
||||
`, table, index).Scan(&exists)
|
||||
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
|
||||
require.True(t, exists, "expected index %s on %s", index, table)
|
||||
}
|
||||
|
||||
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -73,3 +73,14 @@ func buildReqClientKey(opts reqClientOptions) string {
|
||||
opts.ForceHTTP2,
|
||||
)
|
||||
}
|
||||
|
||||
// CreatePrivacyReqClient creates an HTTP client for OpenAI privacy settings API
|
||||
// This is exported for use by OpenAIPrivacyService
|
||||
// Uses Chrome TLS fingerprint impersonation to bypass Cloudflare checks
|
||||
func CreatePrivacyReqClient(proxyURL string) (*req.Client, error) {
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
Impersonate: true, // Enable Chrome TLS fingerprint impersonation
|
||||
})
|
||||
}
|
||||
|
||||
308
backend/internal/repository/usage_billing_repo.go
Normal file
308
backend/internal/repository/usage_billing_repo.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type usageBillingRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
|
||||
return &usageBillingRepository{db: sqlDB}
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
|
||||
if cmd == nil {
|
||||
return &service.UsageBillingApplyResult{}, nil
|
||||
}
|
||||
if r == nil || r.db == nil {
|
||||
return nil, errors.New("usage billing repository db is nil")
|
||||
}
|
||||
|
||||
cmd.Normalize()
|
||||
if cmd.RequestID == "" {
|
||||
return nil, service.ErrUsageBillingRequestIDRequired
|
||||
}
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if tx != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !applied {
|
||||
return &service.UsageBillingApplyResult{Applied: false}, nil
|
||||
}
|
||||
|
||||
result := &service.UsageBillingApplyResult{Applied: true}
|
||||
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tx = nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
|
||||
var id int64
|
||||
err := tx.QueryRowContext(ctx, `
|
||||
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id
|
||||
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
var existingFingerprint string
|
||||
if err := tx.QueryRowContext(ctx, `
|
||||
SELECT request_fingerprint
|
||||
FROM usage_billing_dedup
|
||||
WHERE request_id = $1 AND api_key_id = $2
|
||||
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||
return false, service.ErrUsageBillingRequestConflict
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var archivedFingerprint string
|
||||
err = tx.QueryRowContext(ctx, `
|
||||
SELECT request_fingerprint
|
||||
FROM usage_billing_dedup_archive
|
||||
WHERE request_id = $1 AND api_key_id = $2
|
||||
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
|
||||
if err == nil {
|
||||
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||
return false, service.ErrUsageBillingRequestConflict
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
|
||||
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
|
||||
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.BalanceCost > 0 {
|
||||
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.APIKeyQuotaCost > 0 {
|
||||
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.APIKeyQuotaExhausted = exhausted
|
||||
}
|
||||
|
||||
if cmd.APIKeyRateLimitCost > 0 {
|
||||
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
|
||||
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
|
||||
const updateSQL = `
|
||||
UPDATE user_subscriptions us
|
||||
SET
|
||||
daily_usage_usd = us.daily_usage_usd + $1,
|
||||
weekly_usage_usd = us.weekly_usage_usd + $1,
|
||||
monthly_usage_usd = us.monthly_usage_usd + $1,
|
||||
updated_at = NOW()
|
||||
FROM groups g
|
||||
WHERE us.id = $2
|
||||
AND us.deleted_at IS NULL
|
||||
AND us.group_id = g.id
|
||||
AND g.deleted_at IS NULL
|
||||
`
|
||||
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected > 0 {
|
||||
return nil
|
||||
}
|
||||
return service.ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
|
||||
res, err := tx.ExecContext(ctx, `
|
||||
UPDATE users
|
||||
SET balance = balance - $1,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
`, amount, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected > 0 {
|
||||
return nil
|
||||
}
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
|
||||
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
|
||||
var exhausted bool
|
||||
err := tx.QueryRowContext(ctx, `
|
||||
UPDATE api_keys
|
||||
SET quota_used = quota_used + $1,
|
||||
status = CASE
|
||||
WHEN quota > 0
|
||||
AND status = $3
|
||||
AND quota_used < quota
|
||||
AND quota_used + $1 >= quota
|
||||
THEN $4
|
||||
ELSE status
|
||||
END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
|
||||
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, service.ErrAPIKeyNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exhausted, nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
|
||||
res, err := tx.ExecContext(ctx, `
|
||||
UPDATE api_keys SET
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
`, cost, apiKeyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
|
||||
rows, err := tx.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING
|
||||
COALESCE((extra->>'quota_used')::numeric, 0),
|
||||
COALESCE((extra->>'quota_limit')::numeric, 0)`,
|
||||
amount, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var newUsed, limit float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&newUsed, &limit); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-" + uuid.NewString(),
|
||||
Name: "billing",
|
||||
Quota: 1,
|
||||
})
|
||||
account := mustCreateAccount(t, client, &service.Account{
|
||||
Name: "usage-billing-account-" + uuid.NewString(),
|
||||
Type: service.AccountTypeAPIKey,
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: account.ID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
BalanceCost: 1.25,
|
||||
APIKeyQuotaCost: 1.25,
|
||||
APIKeyRateLimitCost: 1.25,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result1)
|
||||
require.True(t, result1.Applied)
|
||||
require.True(t, result1.APIKeyQuotaExhausted)
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result2)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var balance float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||
require.InDelta(t, 98.75, balance, 0.000001)
|
||||
|
||||
var quotaUsed float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan("aUsed))
|
||||
require.InDelta(t, 1.25, quotaUsed, 0.000001)
|
||||
|
||||
var usage5h float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h))
|
||||
require.InDelta(t, 1.25, usage5h, 0.000001)
|
||||
|
||||
var status string
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status))
|
||||
require.Equal(t, service.StatusAPIKeyQuotaExhausted, status)
|
||||
|
||||
var dedupCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount))
|
||||
require.Equal(t, 1, dedupCount)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
})
|
||||
group := mustCreateGroup(t, client, &service.Group{
|
||||
Name: "usage-billing-group-" + uuid.NewString(),
|
||||
Platform: service.PlatformAnthropic,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
GroupID: &group.ID,
|
||||
Key: "sk-usage-billing-sub-" + uuid.NewString(),
|
||||
Name: "billing-sub",
|
||||
})
|
||||
subscription := mustCreateSubscription(t, client, &service.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: 0,
|
||||
SubscriptionID: &subscription.ID,
|
||||
SubscriptionCost: 2.5,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result1.Applied)
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var dailyUsage float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage))
|
||||
require.InDelta(t, 2.5, dailyUsage, 0.000001)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-conflict-" + uuid.NewString(),
|
||||
Name: "billing-conflict",
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 1.25,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 2.50,
|
||||
})
|
||||
require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-account-" + uuid.NewString(),
|
||||
Name: "billing-account",
|
||||
})
|
||||
account := mustCreateAccount(t, client, &service.Account{
|
||||
Name: "usage-billing-account-quota-" + uuid.NewString(),
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"quota_limit": 100.0,
|
||||
},
|
||||
})
|
||||
|
||||
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: uuid.NewString(),
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: account.ID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
AccountQuotaCost: 3.5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var quotaUsed float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan("aUsed))
|
||||
require.InDelta(t, 3.5, quotaUsed, 0.000001)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||
|
||||
oldRequestID := "dedup-old-" + uuid.NewString()
|
||||
newRequestID := "dedup-new-" + uuid.NewString()
|
||||
oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400)
|
||||
newCreatedAt := time.Now().UTC().Add(-time.Hour)
|
||||
|
||||
_, err := integrationDB.ExecContext(ctx, `
|
||||
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at)
|
||||
VALUES ($1, 1, $2, $3), ($4, 1, $5, $6)
|
||||
`,
|
||||
oldRequestID, strings.Repeat("a", 64), oldCreatedAt,
|
||||
newRequestID, strings.Repeat("b", 64), newCreatedAt,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||
|
||||
var oldCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount))
|
||||
require.Equal(t, 0, oldCount)
|
||||
|
||||
var newCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount))
|
||||
require.Equal(t, 1, newCount)
|
||||
|
||||
var archivedCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount))
|
||||
require.Equal(t, 1, archivedCount)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-archive-" + uuid.NewString(),
|
||||
Name: "billing-archive",
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 1.25,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result1.Applied)
|
||||
|
||||
_, err = integrationDB.ExecContext(ctx, `
|
||||
UPDATE usage_billing_dedup
|
||||
SET created_at = $1
|
||||
WHERE request_id = $2 AND api_key_id = $3
|
||||
`, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var balance float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||
require.InDelta(t, 98.75, balance, 0.000001)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,8 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -14,6 +16,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
@@ -84,6 +87,367 @@ func (s *UsageLogRepoSuite) TestCreate() {
|
||||
s.Require().NotZero(log.ID)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()})
|
||||
|
||||
const total = 16
|
||||
results := make([]bool, total)
|
||||
errs := make([]error, total)
|
||||
logs := make([]*service.UsageLog, total)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(total)
|
||||
for i := 0; i < total; i++ {
|
||||
i := i
|
||||
logs[i] = &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10 + i,
|
||||
OutputTokens: 20 + i,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results[i], errs[i] = repo.Create(ctx, logs[i])
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i := 0; i < total; i++ {
|
||||
require.NoError(t, errs[i])
|
||||
require.True(t, results[i])
|
||||
require.NotZero(t, logs[i].ID)
|
||||
}
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count))
|
||||
require.Equal(t, total, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
inserted1, err1 := repo.Create(ctx, log1)
|
||||
inserted2, err2 := repo.Create(ctx, log2)
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
require.True(t, inserted1)
|
||||
require.False(t, inserted2)
|
||||
require.Equal(t, log1.ID, log2.ID)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
const total = 8
|
||||
batch := make([]usageLogCreateRequest, 0, total)
|
||||
logs := make([]*service.UsageLog, 0, total)
|
||||
|
||||
for i := 0; i < total; i++ {
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10 + i,
|
||||
OutputTokens: 20 + i,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
logs = append(logs, log)
|
||||
batch = append(batch, usageLogCreateRequest{
|
||||
log: log,
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
resultCh: make(chan usageLogCreateResult, 1),
|
||||
})
|
||||
}
|
||||
|
||||
repo.flushCreateBatch(integrationDB, batch)
|
||||
|
||||
insertedCount := 0
|
||||
var firstID int64
|
||||
for idx, req := range batch {
|
||||
res := <-req.resultCh
|
||||
require.NoError(t, res.err)
|
||||
if res.inserted {
|
||||
insertedCount++
|
||||
}
|
||||
require.NotZero(t, logs[idx].ID)
|
||||
if idx == 0 {
|
||||
firstID = logs[idx].ID
|
||||
} else {
|
||||
require.Equal(t, firstID, logs[idx].ID)
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, insertedCount)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
require.NoError(t, repo.CreateBestEffort(ctx, log1))
|
||||
require.NoError(t, repo.CreateBestEffort(ctx, log2))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
var count int
|
||||
err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)
|
||||
return err == nil && count == 1
|
||||
}, 3*time.Second, 20*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1)
|
||||
repo.bestEffortBatchCh <- usageLogBestEffortRequest{}
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()})
|
||||
|
||||
err := repo.CreateBestEffort(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateDropped(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.False(t, inserted)
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||
repo.createBatchCh <- usageLogCreateRequest{}
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()})
|
||||
|
||||
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.False(t, inserted)
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
_, err := repo.createBatched(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
req := <-repo.createBatchCh
|
||||
require.NotNil(t, req.shared)
|
||||
cancel()
|
||||
|
||||
err := <-errCh
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)})
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()})
|
||||
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
req := usageLogCreateRequest{
|
||||
log: log,
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
shared: &usageLogCreateShared{},
|
||||
resultCh: make(chan usageLogCreateResult, 1),
|
||||
}
|
||||
req.shared.state.Store(usageLogCreateStateCanceled)
|
||||
|
||||
repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req})
|
||||
|
||||
res := <-req.resultCh
|
||||
require.False(t, res.inserted)
|
||||
require.Error(t, res.err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(res.err))
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
|
||||
@@ -248,6 +248,35 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
|
||||
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
|
||||
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}).
|
||||
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0).
|
||||
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0).
|
||||
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0)
|
||||
|
||||
mock.ExpectQuery("WITH user_spend AS \\(").
|
||||
WithArgs(start, end, 12).
|
||||
WillReturnRows(rows)
|
||||
|
||||
got, err := repo.GetUserSpendingRanking(context.Background(), start, end, 12)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &usagestats.UserSpendingRankingResponse{
|
||||
Ranking: []usagestats.UserSpendingRankingItem{
|
||||
{UserID: 2, Email: "beta@example.com", ActualCost: 12.5, Requests: 9, Tokens: 900},
|
||||
{UserID: 1, Email: "alpha@example.com", ActualCost: 12.5, Requests: 8, Tokens: 800},
|
||||
{UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300},
|
||||
},
|
||||
TotalActualCost: 40.0,
|
||||
}, got)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -3,8 +3,11 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) {
|
||||
log := &service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-batch-no-update",
|
||||
Model: "gpt-5",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalCost: 1.2,
|
||||
ActualCost: 1.2,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
prepared := prepareUsageLogInsert(log)
|
||||
|
||||
query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{
|
||||
usageLogBatchKey(log.RequestID, log.APIKeyID): prepared,
|
||||
})
|
||||
|
||||
require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING")
|
||||
require.NotContains(t, strings.ToUpper(query), "DO UPDATE")
|
||||
}
|
||||
|
||||
@@ -95,6 +95,35 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
||||
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
|
||||
query := `
|
||||
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
|
||||
FROM user_group_rate_multipliers ugr
|
||||
JOIN users u ON u.id = ugr.user_id
|
||||
WHERE ugr.group_id = $1
|
||||
ORDER BY ugr.user_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var result []service.UserGroupRateEntry
|
||||
for rows.Next() {
|
||||
var entry service.UserGroupRateEntry
|
||||
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, entry)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||
@@ -164,6 +193,31 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
|
||||
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
|
||||
if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
userIDs := make([]int64, len(entries))
|
||||
rates := make([]float64, len(entries))
|
||||
for i, e := range entries {
|
||||
userIDs[i] = e.UserID
|
||||
rates[i] = e.RateMultiplier
|
||||
}
|
||||
now := time.Now()
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
|
||||
SELECT data.user_id, $1::bigint, data.rate_multiplier, $2::timestamptz, $2::timestamptz
|
||||
FROM unnest($3::bigint[], $4::double precision[]) AS data(user_id, rate_multiplier)
|
||||
ON CONFLICT (user_id, group_id)
|
||||
DO UPDATE SET rate_multiplier = EXCLUDED.rate_multiplier, updated_at = EXCLUDED.updated_at
|
||||
`, groupID, now, pq.Array(userIDs), pq.Array(rates))
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率
|
||||
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
||||
|
||||
@@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAnnouncementRepository,
|
||||
NewAnnouncementReadRepository,
|
||||
NewUsageLogRepository,
|
||||
NewUsageBillingRepository,
|
||||
NewIdempotencyRepository,
|
||||
NewUsageCleanupRepository,
|
||||
NewDashboardAggregationRepository,
|
||||
@@ -99,6 +100,10 @@ var ProviderSet = wire.NewSet(
|
||||
// Encryptors
|
||||
NewAESEncryptor,
|
||||
|
||||
// Backup infrastructure
|
||||
NewPgDumper,
|
||||
NewS3BackupStoreFactory,
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
ProvidePricingRemoteClient,
|
||||
|
||||
@@ -537,6 +537,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"purchase_subscription_url": "",
|
||||
"min_claude_code_version": "",
|
||||
"allow_ungrouped_key_scheduling": false,
|
||||
"backend_mode_enabled": false,
|
||||
"custom_menu_items": []
|
||||
}
|
||||
}`,
|
||||
@@ -645,7 +646,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
@@ -1635,6 +1636,10 @@ func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, end
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
logs := r.userLogs[userID]
|
||||
if len(logs) == 0 {
|
||||
|
||||
51
backend/internal/server/middleware/backend_mode_guard.go
Normal file
51
backend/internal/server/middleware/backend_mode_guard.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// BackendModeUserGuard blocks non-admin users from accessing user routes when backend mode is enabled.
|
||||
// Must be placed AFTER JWT auth middleware so that the user role is available in context.
|
||||
func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
role, _ := GetUserRoleFromContext(c)
|
||||
if role == "admin" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
response.Forbidden(c, "Backend mode is active. User self-service is disabled.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
|
||||
// Allows: login, login/2fa, logout, refresh (admin needs these).
|
||||
// Blocks: register, forgot-password, reset-password, OAuth, etc.
|
||||
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
path := c.Request.URL.Path
|
||||
// Allow login, 2FA, logout, refresh, public settings
|
||||
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
|
||||
for _, suffix := range allowedSuffixes {
|
||||
if strings.HasSuffix(path, suffix) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
239
backend/internal/server/middleware/backend_mode_guard_test.go
Normal file
239
backend/internal/server/middleware/backend_mode_guard_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type bmSettingRepo struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Get(_ context.Context, _ string) (*service.Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||
v, ok := r.values[key]
|
||||
if !ok {
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Set(_ context.Context, _, _ string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetMultiple(_ context.Context, _ []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
if r.values == nil {
|
||||
r.values = make(map[string]string, len(settings))
|
||||
}
|
||||
for key, value := range settings {
|
||||
r.values[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Delete(_ context.Context, _ string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func newBackendModeSettingService(t *testing.T, enabled string) *service.SettingService {
|
||||
t.Helper()
|
||||
|
||||
repo := &bmSettingRepo{
|
||||
values: map[string]string{
|
||||
service.SettingKeyBackendModeEnabled: enabled,
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{})
|
||||
require.NoError(t, svc.UpdateSettings(context.Background(), &service.SystemSettings{
|
||||
BackendModeEnabled: enabled == "true",
|
||||
}))
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
func stringPtr(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestBackendModeUserGuard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nilService bool
|
||||
enabled string
|
||||
role *string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "disabled_allows_all",
|
||||
enabled: "false",
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "nil_service_allows_all",
|
||||
nilService: true,
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_admin_allowed",
|
||||
enabled: "true",
|
||||
role: stringPtr("admin"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_user_blocked",
|
||||
enabled: "true",
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_no_role_blocked",
|
||||
enabled: "true",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_empty_role_blocked",
|
||||
enabled: "true",
|
||||
role: stringPtr(""),
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
if tc.role != nil {
|
||||
role := *tc.role
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(string(ContextKeyUserRole), role)
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
|
||||
var svc *service.SettingService
|
||||
if !tc.nilService {
|
||||
svc = newBackendModeSettingService(t, tc.enabled)
|
||||
}
|
||||
|
||||
r.Use(BackendModeUserGuard(svc))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tc.wantStatus, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendModeAuthGuard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nilService bool
|
||||
enabled string
|
||||
path string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "disabled_allows_all",
|
||||
enabled: "false",
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "nil_service_allows_all",
|
||||
nilService: true,
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_login",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/login",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_login_2fa",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/login/2fa",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_logout",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/logout",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_refresh",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/refresh",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_register",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_forgot_password",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/forgot-password",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
|
||||
var svc *service.SettingService
|
||||
if !tc.nilService {
|
||||
svc = newBackendModeSettingService(t, tc.enabled)
|
||||
}
|
||||
|
||||
r.Use(BackendModeAuthGuard(svc))
|
||||
r.Any("/*path", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tc.wantStatus, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -107,9 +107,9 @@ func registerRoutes(
|
||||
v1 := r.Group("/api/v1")
|
||||
|
||||
// 注册各模块路由
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
||||
}
|
||||
|
||||
@@ -58,6 +58,9 @@ func RegisterAdminRoutes(
|
||||
// 数据管理
|
||||
registerDataManagementRoutes(admin, h)
|
||||
|
||||
// 数据库备份恢复
|
||||
registerBackupRoutes(admin, h)
|
||||
|
||||
// 运维监控(Ops)
|
||||
registerOpsRoutes(admin, h)
|
||||
|
||||
@@ -192,6 +195,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats)
|
||||
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
|
||||
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
||||
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
|
||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
||||
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
|
||||
@@ -228,6 +232,9 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
groups.PUT("/:id", h.Admin.Group.Update)
|
||||
groups.DELETE("/:id", h.Admin.Group.Delete)
|
||||
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
||||
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
|
||||
groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers)
|
||||
groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers)
|
||||
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
||||
}
|
||||
}
|
||||
@@ -436,6 +443,30 @@ func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerBackupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
backup := admin.Group("/backups")
|
||||
{
|
||||
// S3 存储配置
|
||||
backup.GET("/s3-config", h.Admin.Backup.GetS3Config)
|
||||
backup.PUT("/s3-config", h.Admin.Backup.UpdateS3Config)
|
||||
backup.POST("/s3-config/test", h.Admin.Backup.TestS3Connection)
|
||||
|
||||
// 定时备份配置
|
||||
backup.GET("/schedule", h.Admin.Backup.GetSchedule)
|
||||
backup.PUT("/schedule", h.Admin.Backup.UpdateSchedule)
|
||||
|
||||
// 备份操作
|
||||
backup.POST("", h.Admin.Backup.CreateBackup)
|
||||
backup.GET("", h.Admin.Backup.ListBackups)
|
||||
backup.GET("/:id", h.Admin.Backup.GetBackup)
|
||||
backup.DELETE("/:id", h.Admin.Backup.DeleteBackup)
|
||||
backup.GET("/:id/download-url", h.Admin.Backup.GetDownloadURL)
|
||||
|
||||
// 恢复操作
|
||||
backup.POST("/:id/restore", h.Admin.Backup.RestoreBackup)
|
||||
}
|
||||
}
|
||||
|
||||
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
system := admin.Group("/system")
|
||||
{
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -17,12 +18,14 @@ func RegisterAuthRoutes(
|
||||
h *handler.Handlers,
|
||||
jwtAuth servermiddleware.JWTAuthMiddleware,
|
||||
redisClient *redis.Client,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
// 创建速率限制器
|
||||
rateLimiter := middleware.NewRateLimiter(redisClient)
|
||||
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
auth.Use(servermiddleware.BackendModeAuthGuard(settingService))
|
||||
{
|
||||
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
|
||||
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
|
||||
@@ -78,6 +81,7 @@ func RegisterAuthRoutes(
|
||||
// 需要认证的当前用户信息
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(servermiddleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
||||
// 撤销所有会话(需要认证)
|
||||
|
||||
@@ -29,6 +29,7 @@ func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
|
||||
c.Next()
|
||||
}),
|
||||
redisClient,
|
||||
nil,
|
||||
)
|
||||
|
||||
return router
|
||||
|
||||
@@ -3,6 +3,7 @@ package routes
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -12,6 +13,7 @@ func RegisterSoraClientRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
if h.SoraClient == nil {
|
||||
return
|
||||
@@ -19,6 +21,7 @@ func RegisterSoraClientRoutes(
|
||||
|
||||
authenticated := v1.Group("/sora")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
authenticated.POST("/generate", h.SoraClient.Generate)
|
||||
authenticated.GET("/generations", h.SoraClient.ListGenerations)
|
||||
|
||||
@@ -3,6 +3,7 @@ package routes
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -12,9 +13,11 @@ func RegisterUserRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
// 用户接口
|
||||
user := authenticated.Group("/user")
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"hash/fnv"
|
||||
"reflect"
|
||||
"sort"
|
||||
@@ -412,6 +413,7 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
}
|
||||
// Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整)
|
||||
return nil
|
||||
}
|
||||
if len(rawMapping) == 0 {
|
||||
@@ -521,16 +523,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
|
||||
// 如果未配置 mapping,返回原始模型名
|
||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
mappedModel, _ := a.ResolveMappedModel(requestedModel)
|
||||
return mappedModel
|
||||
}
|
||||
|
||||
// ResolveMappedModel 获取映射后的模型名,并返回是否命中了账号级映射。
|
||||
// matched=true 表示命中了精确映射或通配符映射,即使映射结果与原模型名相同。
|
||||
func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, matched bool) {
|
||||
mapping := a.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
return requestedModel
|
||||
return requestedModel, false
|
||||
}
|
||||
// 精确匹配优先
|
||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||
return mappedModel
|
||||
return mappedModel, true
|
||||
}
|
||||
// 通配符匹配(最长优先)
|
||||
return matchWildcardMapping(mapping, requestedModel)
|
||||
return matchWildcardMappingResult(mapping, requestedModel)
|
||||
}
|
||||
|
||||
func (a *Account) GetBaseURL() string {
|
||||
@@ -604,9 +613,7 @@ func matchWildcard(pattern, str string) bool {
|
||||
return matchAntigravityWildcard(pattern, str)
|
||||
}
|
||||
|
||||
// matchWildcardMapping 通配符映射匹配(最长优先)
|
||||
// 如果没有匹配,返回原始字符串
|
||||
func matchWildcardMapping(mapping map[string]string, requestedModel string) string {
|
||||
func matchWildcardMappingResult(mapping map[string]string, requestedModel string) (string, bool) {
|
||||
// 收集所有匹配的 pattern,按长度降序排序(最长优先)
|
||||
type patternMatch struct {
|
||||
pattern string
|
||||
@@ -621,7 +628,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return requestedModel // 无匹配,返回原始模型名
|
||||
return requestedModel, false // 无匹配,返回原始模型名
|
||||
}
|
||||
|
||||
// 按 pattern 长度降序排序
|
||||
@@ -632,7 +639,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
|
||||
return matches[i].pattern < matches[j].pattern
|
||||
})
|
||||
|
||||
return matches[0].target
|
||||
return matches[0].target, true
|
||||
}
|
||||
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
@@ -650,7 +657,7 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
// IsPoolMode 检查 API Key 账号是否启用池模式。
|
||||
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
|
||||
func (a *Account) IsPoolMode() bool {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
if !a.IsAPIKeyOrBedrock() || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["pool_mode"]; ok {
|
||||
@@ -764,6 +771,19 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Account) IsBedrock() bool {
|
||||
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
|
||||
}
|
||||
|
||||
func (a *Account) IsBedrockAPIKey() bool {
|
||||
return a.IsBedrock() && a.GetCredential("auth_mode") == "apikey"
|
||||
}
|
||||
|
||||
// IsAPIKeyOrBedrock 返回账号类型是否支持配额和池模式等特性
|
||||
func (a *Account) IsAPIKeyOrBedrock() bool {
|
||||
return a.Type == AccountTypeAPIKey || a.Type == AccountTypeBedrock
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAI() bool {
|
||||
return a.Platform == PlatformOpenAI
|
||||
}
|
||||
@@ -1260,6 +1280,240 @@ func (a *Account) getExtraTime(key string) time.Time {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// getExtraString 从 Extra 中读取指定 key 的字符串值
|
||||
func (a *Account) getExtraString(key string) string {
|
||||
if a.Extra == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getExtraInt 从 Extra 中读取指定 key 的 int 值
|
||||
func (a *Account) getExtraInt(key string) int {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
return int(parseExtraFloat64(v))
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetQuotaDailyResetMode 获取日额度重置模式:"rolling"(默认)或 "fixed"
|
||||
func (a *Account) GetQuotaDailyResetMode() string {
|
||||
if m := a.getExtraString("quota_daily_reset_mode"); m == "fixed" {
|
||||
return "fixed"
|
||||
}
|
||||
return "rolling"
|
||||
}
|
||||
|
||||
// GetQuotaDailyResetHour 获取固定重置的小时(0-23),默认 0
|
||||
func (a *Account) GetQuotaDailyResetHour() int {
|
||||
return a.getExtraInt("quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyResetMode 获取周额度重置模式:"rolling"(默认)或 "fixed"
|
||||
func (a *Account) GetQuotaWeeklyResetMode() string {
|
||||
if m := a.getExtraString("quota_weekly_reset_mode"); m == "fixed" {
|
||||
return "fixed"
|
||||
}
|
||||
return "rolling"
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyResetDay 获取固定重置的星期几(0=周日, 1=周一, ..., 6=周六),默认 1(周一)
|
||||
func (a *Account) GetQuotaWeeklyResetDay() int {
|
||||
if a.Extra == nil {
|
||||
return 1
|
||||
}
|
||||
if _, ok := a.Extra["quota_weekly_reset_day"]; !ok {
|
||||
return 1
|
||||
}
|
||||
return a.getExtraInt("quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyResetHour 获取周配额固定重置的小时(0-23),默认 0
|
||||
func (a *Account) GetQuotaWeeklyResetHour() int {
|
||||
return a.getExtraInt("quota_weekly_reset_hour")
|
||||
}
|
||||
|
||||
// GetQuotaResetTimezone 获取固定重置的时区名(IANA),默认 "UTC"
|
||||
func (a *Account) GetQuotaResetTimezone() string {
|
||||
if tz := a.getExtraString("quota_reset_timezone"); tz != "" {
|
||||
return tz
|
||||
}
|
||||
return "UTC"
|
||||
}
|
||||
|
||||
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
|
||||
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
|
||||
t := after.In(tz)
|
||||
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
if !after.Before(today) {
|
||||
return today.AddDate(0, 0, 1)
|
||||
}
|
||||
return today
|
||||
}
|
||||
|
||||
// lastFixedDailyReset 计算 now 之前最近一次的每日固定重置时间点
|
||||
func lastFixedDailyReset(hour int, tz *time.Location, now time.Time) time.Time {
|
||||
t := now.In(tz)
|
||||
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
if now.Before(today) {
|
||||
return today.AddDate(0, 0, -1)
|
||||
}
|
||||
return today
|
||||
}
|
||||
|
||||
// nextFixedWeeklyReset 计算在 after 之后的下一个每周固定重置时间点
|
||||
// day: 0=Sunday, 1=Monday, ..., 6=Saturday
|
||||
func nextFixedWeeklyReset(day, hour int, tz *time.Location, after time.Time) time.Time {
|
||||
t := after.In(tz)
|
||||
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
currentDay := int(todayReset.Weekday())
|
||||
|
||||
daysForward := (day - currentDay + 7) % 7
|
||||
if daysForward == 0 && !after.Before(todayReset) {
|
||||
daysForward = 7
|
||||
}
|
||||
return todayReset.AddDate(0, 0, daysForward)
|
||||
}
|
||||
|
||||
// lastFixedWeeklyReset 计算 now 之前最近一次的每周固定重置时间点
|
||||
func lastFixedWeeklyReset(day, hour int, tz *time.Location, now time.Time) time.Time {
|
||||
t := now.In(tz)
|
||||
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
currentDay := int(todayReset.Weekday())
|
||||
|
||||
daysBack := (currentDay - day + 7) % 7
|
||||
if daysBack == 0 && now.Before(todayReset) {
|
||||
daysBack = 7
|
||||
}
|
||||
return todayReset.AddDate(0, 0, -daysBack)
|
||||
}
|
||||
|
||||
// isFixedDailyPeriodExpired 检查日配额是否在固定时间模式下已过期
|
||||
func (a *Account) isFixedDailyPeriodExpired(periodStart time.Time) bool {
|
||||
if periodStart.IsZero() {
|
||||
return true
|
||||
}
|
||||
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
|
||||
if err != nil {
|
||||
tz = time.UTC
|
||||
}
|
||||
lastReset := lastFixedDailyReset(a.GetQuotaDailyResetHour(), tz, time.Now())
|
||||
return periodStart.Before(lastReset)
|
||||
}
|
||||
|
||||
// isFixedWeeklyPeriodExpired 检查周配额是否在固定时间模式下已过期
|
||||
func (a *Account) isFixedWeeklyPeriodExpired(periodStart time.Time) bool {
|
||||
if periodStart.IsZero() {
|
||||
return true
|
||||
}
|
||||
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
|
||||
if err != nil {
|
||||
tz = time.UTC
|
||||
}
|
||||
lastReset := lastFixedWeeklyReset(a.GetQuotaWeeklyResetDay(), a.GetQuotaWeeklyResetHour(), tz, time.Now())
|
||||
return periodStart.Before(lastReset)
|
||||
}
|
||||
|
||||
// ComputeQuotaResetAt 根据当前配置计算并填充 extra 中的 quota_daily_reset_at / quota_weekly_reset_at
|
||||
// 在保存账号配置时调用
|
||||
func ComputeQuotaResetAt(extra map[string]any) {
|
||||
now := time.Now()
|
||||
tzName, _ := extra["quota_reset_timezone"].(string)
|
||||
if tzName == "" {
|
||||
tzName = "UTC"
|
||||
}
|
||||
tz, err := time.LoadLocation(tzName)
|
||||
if err != nil {
|
||||
tz = time.UTC
|
||||
}
|
||||
|
||||
// 日配额固定重置时间
|
||||
if mode, _ := extra["quota_daily_reset_mode"].(string); mode == "fixed" {
|
||||
hour := int(parseExtraFloat64(extra["quota_daily_reset_hour"]))
|
||||
if hour < 0 || hour > 23 {
|
||||
hour = 0
|
||||
}
|
||||
resetAt := nextFixedDailyReset(hour, tz, now)
|
||||
extra["quota_daily_reset_at"] = resetAt.UTC().Format(time.RFC3339)
|
||||
} else {
|
||||
delete(extra, "quota_daily_reset_at")
|
||||
}
|
||||
|
||||
// 周配额固定重置时间
|
||||
if mode, _ := extra["quota_weekly_reset_mode"].(string); mode == "fixed" {
|
||||
day := 1 // 默认周一
|
||||
if d, ok := extra["quota_weekly_reset_day"]; ok {
|
||||
day = int(parseExtraFloat64(d))
|
||||
}
|
||||
if day < 0 || day > 6 {
|
||||
day = 1
|
||||
}
|
||||
hour := int(parseExtraFloat64(extra["quota_weekly_reset_hour"]))
|
||||
if hour < 0 || hour > 23 {
|
||||
hour = 0
|
||||
}
|
||||
resetAt := nextFixedWeeklyReset(day, hour, tz, now)
|
||||
extra["quota_weekly_reset_at"] = resetAt.UTC().Format(time.RFC3339)
|
||||
} else {
|
||||
delete(extra, "quota_weekly_reset_at")
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateQuotaResetConfig 校验配额固定重置时间配置的合法性
|
||||
func ValidateQuotaResetConfig(extra map[string]any) error {
|
||||
if extra == nil {
|
||||
return nil
|
||||
}
|
||||
// 校验时区
|
||||
if tz, ok := extra["quota_reset_timezone"].(string); ok && tz != "" {
|
||||
if _, err := time.LoadLocation(tz); err != nil {
|
||||
return errors.New("invalid quota_reset_timezone: must be a valid IANA timezone name")
|
||||
}
|
||||
}
|
||||
// 日配额重置模式
|
||||
if mode, ok := extra["quota_daily_reset_mode"].(string); ok {
|
||||
if mode != "rolling" && mode != "fixed" {
|
||||
return errors.New("quota_daily_reset_mode must be 'rolling' or 'fixed'")
|
||||
}
|
||||
}
|
||||
// 日配额重置小时
|
||||
if v, ok := extra["quota_daily_reset_hour"]; ok {
|
||||
hour := int(parseExtraFloat64(v))
|
||||
if hour < 0 || hour > 23 {
|
||||
return errors.New("quota_daily_reset_hour must be between 0 and 23")
|
||||
}
|
||||
}
|
||||
// 周配额重置模式
|
||||
if mode, ok := extra["quota_weekly_reset_mode"].(string); ok {
|
||||
if mode != "rolling" && mode != "fixed" {
|
||||
return errors.New("quota_weekly_reset_mode must be 'rolling' or 'fixed'")
|
||||
}
|
||||
}
|
||||
// 周配额重置星期几
|
||||
if v, ok := extra["quota_weekly_reset_day"]; ok {
|
||||
day := int(parseExtraFloat64(v))
|
||||
if day < 0 || day > 6 {
|
||||
return errors.New("quota_weekly_reset_day must be between 0 (Sunday) and 6 (Saturday)")
|
||||
}
|
||||
}
|
||||
// 周配额重置小时
|
||||
if v, ok := extra["quota_weekly_reset_hour"]; ok {
|
||||
hour := int(parseExtraFloat64(v))
|
||||
if hour < 0 || hour > 23 {
|
||||
return errors.New("quota_weekly_reset_hour must be between 0 and 23")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
|
||||
func (a *Account) HasAnyQuotaLimit() bool {
|
||||
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
|
||||
@@ -1282,14 +1536,26 @@ func (a *Account) IsQuotaExceeded() bool {
|
||||
// 日额度(周期过期视为未超限,下次 increment 会重置)
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_daily_start")
|
||||
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit {
|
||||
var expired bool
|
||||
if a.GetQuotaDailyResetMode() == "fixed" {
|
||||
expired = a.isFixedDailyPeriodExpired(start)
|
||||
} else {
|
||||
expired = isPeriodExpired(start, 24*time.Hour)
|
||||
}
|
||||
if !expired && a.GetQuotaDailyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// 周额度
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_weekly_start")
|
||||
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit {
|
||||
var expired bool
|
||||
if a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||
expired = a.isFixedWeeklyPeriodExpired(start)
|
||||
} else {
|
||||
expired = isPeriodExpired(start, 7*24*time.Hour)
|
||||
}
|
||||
if !expired && a.GetQuotaWeeklyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
516
backend/internal/service/account_quota_reset_test.go
Normal file
516
backend/internal/service/account_quota_reset_test.go
Normal file
@@ -0,0 +1,516 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// nextFixedDailyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNextFixedDailyReset_BeforeResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 06:00 UTC, reset hour = 9
|
||||
after := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_AtResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// Exactly at reset hour → should return tomorrow
|
||||
after := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_AfterResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// After reset hour → should return tomorrow
|
||||
after := time.Date(2026, 3, 14, 15, 30, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_MidnightReset(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// Reset at hour 0 (midnight), currently 23:59
|
||||
after := time.Date(2026, 3, 14, 23, 59, 0, 0, tz)
|
||||
got := nextFixedDailyReset(0, tz, after)
|
||||
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_NonUTCTimezone(t *testing.T) {
|
||||
tz, err := time.LoadLocation("Asia/Shanghai")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2026-03-14 07:00 UTC = 2026-03-14 15:00 CST, reset hour = 9 (CST)
|
||||
after := time.Date(2026, 3, 14, 7, 0, 0, 0, time.UTC)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
// Already past 9:00 CST today → tomorrow 9:00 CST = 2026-03-15 01:00 UTC
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// lastFixedDailyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLastFixedDailyReset_BeforeResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// Before today's 9:00 → yesterday 9:00
|
||||
want := time.Date(2026, 3, 13, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedDailyReset_AtResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// At exactly 9:00 → today 9:00
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedDailyReset_AfterResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 15, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// After 9:00 → today 9:00
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// nextFixedWeeklyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayAhead(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 is Saturday (day=6), target = Monday (day=1), hour = 9
|
||||
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday = 2026-03-16
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_BeforeHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, before 9:00
|
||||
after := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Today at 9:00
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_AtHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, exactly at 9:00
|
||||
after := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday at 9:00
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_AfterHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, after 9:00
|
||||
after := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday at 9:00
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayPast(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
|
||||
after := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday = 2026-03-23
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_Sunday(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 is Saturday (day=6), target = Sunday (day=0)
|
||||
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(0, 0, tz, after)
|
||||
// Next Sunday = 2026-03-15
|
||||
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// lastFixedWeeklyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLastFixedWeeklyReset_SameDay_AfterHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, now = 15:00
|
||||
now := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Today at 9:00
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedWeeklyReset_SameDay_BeforeHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, now = 06:00
|
||||
now := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Last Monday at 9:00 = 2026-03-09
|
||||
want := time.Date(2026, 3, 9, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedWeeklyReset_DifferentDay(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
|
||||
now := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Last Monday = 2026-03-16
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isFixedDailyPeriodExpired
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_ZeroPeriodStart(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(time.Time{}))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_NotExpired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started after the most recent reset → not expired
|
||||
// (This test uses a time very close to "now", which is after the last reset)
|
||||
periodStart := time.Now().Add(-1 * time.Minute)
|
||||
assert.False(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_Expired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 3 days ago → definitely expired
|
||||
periodStart := time.Now().Add(-72 * time.Hour)
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_InvalidTimezone(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "Invalid/Timezone",
|
||||
}}
|
||||
// Invalid timezone falls back to UTC
|
||||
periodStart := time.Now().Add(-72 * time.Hour)
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isFixedWeeklyPeriodExpired
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_ZeroPeriodStart(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
assert.True(t, a.isFixedWeeklyPeriodExpired(time.Time{}))
|
||||
}
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_NotExpired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 1 minute ago → not expired
|
||||
periodStart := time.Now().Add(-1 * time.Minute)
|
||||
assert.False(t, a.isFixedWeeklyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_Expired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 10 days ago → definitely expired
|
||||
periodStart := time.Now().Add(-240 * time.Hour)
|
||||
assert.True(t, a.isFixedWeeklyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ValidateQuotaResetConfig
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestValidateQuotaResetConfig_NilExtra(t *testing.T) {
|
||||
assert.NoError(t, ValidateQuotaResetConfig(nil))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_EmptyExtra(t *testing.T) {
|
||||
assert.NoError(t, ValidateQuotaResetConfig(map[string]any{}))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_ValidFixed(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "Asia/Shanghai",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_ValidRolling(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidTimezone(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_reset_timezone": "Not/A/Timezone",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_reset_timezone")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyMode(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "invalid",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_mode")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyHour_TooHigh(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(24),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyHour_Negative(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(-1),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyMode(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_mode": "unknown",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_mode")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyDay_TooHigh(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_day": float64(7),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyDay_Negative(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_day": float64(-1),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyHour(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_hour": float64(25),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_BoundaryValues(t *testing.T) {
|
||||
// All boundary values should be valid
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(23),
|
||||
"quota_weekly_reset_day": float64(0), // Sunday
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
|
||||
extra2 := map[string]any{
|
||||
"quota_daily_reset_hour": float64(0),
|
||||
"quota_weekly_reset_day": float64(6), // Saturday
|
||||
"quota_weekly_reset_hour": float64(23),
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra2))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ComputeQuotaResetAt
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestComputeQuotaResetAt_RollingMode_NoResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
_, hasDailyResetAt := extra["quota_daily_reset_at"]
|
||||
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
|
||||
assert.False(t, hasDailyResetAt, "rolling mode should not set quota_daily_reset_at")
|
||||
assert.False(t, hasWeeklyResetAt, "rolling mode should not set quota_weekly_reset_at")
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_RollingMode_ClearsExistingResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
"quota_daily_reset_at": "2026-03-14T09:00:00Z",
|
||||
"quota_weekly_reset_at": "2026-03-16T09:00:00Z",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
_, hasDailyResetAt := extra["quota_daily_reset_at"]
|
||||
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
|
||||
assert.False(t, hasDailyResetAt, "rolling mode should remove quota_daily_reset_at")
|
||||
assert.False(t, hasWeeklyResetAt, "rolling mode should remove quota_weekly_reset_at")
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedDaily_SetsResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok, "quota_daily_reset_at should be set")
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Reset time should be in the future
|
||||
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
|
||||
// Reset hour should be 9 UTC
|
||||
assert.Equal(t, 9, resetAt.UTC().Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedWeekly_SetsResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1), // Monday
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_weekly_reset_at"].(string)
|
||||
require.True(t, ok, "quota_weekly_reset_at should be set")
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Reset time should be in the future
|
||||
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
|
||||
// Reset day should be Monday
|
||||
assert.Equal(t, time.Monday, resetAt.UTC().Weekday())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedDaily_WithTimezone(t *testing.T) {
|
||||
tz, err := time.LoadLocation("Asia/Shanghai")
|
||||
require.NoError(t, err)
|
||||
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "Asia/Shanghai",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// In Shanghai timezone, the hour should be 9
|
||||
assert.Equal(t, 9, resetAt.In(tz).Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_DefaultTimezone(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(12),
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Default timezone is UTC
|
||||
assert.Equal(t, 12, resetAt.UTC().Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_InvalidHour_ClampedToZero(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(99),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Invalid hour → clamped to 0
|
||||
assert.Equal(t, 0, resetAt.UTC().Hour())
|
||||
}
|
||||
@@ -207,14 +207,14 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
testModelID = claude.DefaultTestModel
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
// API Key 账号测试连接时也需要应用通配符模型映射。
|
||||
if account.Type == "apikey" {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
}
|
||||
testModelID = account.GetMappedModel(testModelID)
|
||||
}
|
||||
|
||||
// Bedrock accounts use a separate test path
|
||||
if account.IsBedrock() {
|
||||
return s.testBedrockAccountConnection(c, ctx, account, testModelID)
|
||||
}
|
||||
|
||||
// Determine authentication method and API URL
|
||||
@@ -312,6 +312,109 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
return s.processClaudeStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
|
||||
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
|
||||
region := bedrockRuntimeRegion(account)
|
||||
resolvedModelID, ok := ResolveBedrockModelID(account, testModelID)
|
||||
if !ok {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Bedrock model: %s", testModelID))
|
||||
}
|
||||
testModelID = resolvedModelID
|
||||
|
||||
// Set SSE headers (test UI expects SSE)
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create a minimal Bedrock-compatible payload (no stream, no cache_control)
|
||||
bedrockPayload := map[string]any{
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"messages": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": "hi",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"max_tokens": 256,
|
||||
"temperature": 1,
|
||||
}
|
||||
bedrockBody, _ := json.Marshal(bedrockPayload)
|
||||
|
||||
// Use non-streaming endpoint (response is standard Claude JSON)
|
||||
apiURL := BuildBedrockURL(region, testModelID, false)
|
||||
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bedrockBody))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Sign or set auth based on account type
|
||||
if account.IsBedrockAPIKey() {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
} else {
|
||||
signer, err := NewBedrockSignerFromAccount(account)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Bedrock signer: %s", err.Error()))
|
||||
}
|
||||
if err := signer.SignRequest(ctx, req, bedrockBody); err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to sign request: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, false)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
// Bedrock non-streaming response is standard Claude JSON, extract the text
|
||||
var result struct {
|
||||
Content []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error()))
|
||||
}
|
||||
|
||||
text := ""
|
||||
if len(result.Content) > 0 {
|
||||
text = result.Content[0].Text
|
||||
}
|
||||
if text == "" {
|
||||
text = "(empty response)"
|
||||
}
|
||||
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
// testOpenAIAccountConnection tests an OpenAI account's connection
|
||||
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -47,6 +48,7 @@ type UsageLogRepository interface {
|
||||
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
||||
|
||||
@@ -99,6 +101,7 @@ type antigravityUsageCache struct {
|
||||
const (
|
||||
apiCacheTTL = 3 * time.Minute
|
||||
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
|
||||
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误)
|
||||
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
openAIProbeCacheTTL = 10 * time.Minute
|
||||
@@ -107,11 +110,12 @@ const (
|
||||
|
||||
// UsageCache 封装账户使用量相关的缓存
|
||||
type UsageCache struct {
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存
|
||||
openAIProbeCache sync.Map // accountID -> time.Time
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic)
|
||||
antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
|
||||
openAIProbeCache sync.Map // accountID -> time.Time
|
||||
}
|
||||
|
||||
// NewUsageCache 创建 UsageCache 实例
|
||||
@@ -148,6 +152,18 @@ type AntigravityModelQuota struct {
|
||||
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
|
||||
}
|
||||
|
||||
// AntigravityModelDetail Antigravity 单个模型的详细能力信息
|
||||
type AntigravityModelDetail struct {
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
SupportsImages *bool `json:"supports_images,omitempty"`
|
||||
SupportsThinking *bool `json:"supports_thinking,omitempty"`
|
||||
ThinkingBudget *int `json:"thinking_budget,omitempty"`
|
||||
Recommended *bool `json:"recommended,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"`
|
||||
}
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
type UsageInfo struct {
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
@@ -163,6 +179,33 @@ type UsageInfo struct {
|
||||
|
||||
// Antigravity 多模型配额
|
||||
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
|
||||
|
||||
// Antigravity 账号级信息
|
||||
SubscriptionTier string `json:"subscription_tier,omitempty"` // 归一化订阅等级: FREE/PRO/ULTRA/UNKNOWN
|
||||
SubscriptionTierRaw string `json:"subscription_tier_raw,omitempty"` // 上游原始订阅等级名称
|
||||
|
||||
// Antigravity 模型详细能力信息(与 antigravity_quota 同 key)
|
||||
AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"`
|
||||
|
||||
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
|
||||
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
|
||||
|
||||
// Antigravity 账号是否被上游禁止 (HTTP 403)
|
||||
IsForbidden bool `json:"is_forbidden,omitempty"`
|
||||
ForbiddenReason string `json:"forbidden_reason,omitempty"`
|
||||
ForbiddenType string `json:"forbidden_type,omitempty"` // "validation" / "violation" / "forbidden"
|
||||
ValidationURL string `json:"validation_url,omitempty"` // 验证/申诉链接
|
||||
|
||||
// 状态标记(从 ForbiddenType / HTTP 错误码推导)
|
||||
NeedsVerify bool `json:"needs_verify,omitempty"` // 需要人工验证(forbidden_type=validation)
|
||||
IsBanned bool `json:"is_banned,omitempty"` // 账号被封(forbidden_type=violation)
|
||||
NeedsReauth bool `json:"needs_reauth,omitempty"` // token 失效需重新授权(401)
|
||||
|
||||
// 错误码(机器可读):forbidden / unauthenticated / rate_limited / network_error
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
|
||||
// 获取 usage 时的错误信息(降级返回,而非 500)
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeUsageResponse Anthropic API返回的usage结构
|
||||
@@ -647,34 +690,157 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
|
||||
return &UsageInfo{UpdatedAt: &now}, nil
|
||||
}
|
||||
|
||||
// 1. 检查缓存(10 分钟)
|
||||
// 1. 检查缓存
|
||||
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||
// 重新计算 RemainingSeconds
|
||||
usage := cache.usageInfo
|
||||
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
||||
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok {
|
||||
ttl := antigravityCacheTTL(cache.usageInfo)
|
||||
if time.Since(cache.timestamp) < ttl {
|
||||
usage := cache.usageInfo
|
||||
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
||||
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取代理 URL
|
||||
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account)
|
||||
// 2. singleflight 防止并发击穿
|
||||
flightKey := fmt.Sprintf("ag-usage:%d", account.ID)
|
||||
result, flightErr, _ := s.cache.antigravityFlight.Do(flightKey, func() (any, error) {
|
||||
// 再次检查缓存(等待期间可能已被填充)
|
||||
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok {
|
||||
ttl := antigravityCacheTTL(cache.usageInfo)
|
||||
if time.Since(cache.timestamp) < ttl {
|
||||
usage := cache.usageInfo
|
||||
// 重新计算 RemainingSeconds,避免返回过时的剩余秒数
|
||||
recalcAntigravityRemainingSeconds(usage)
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 调用 API 获取额度
|
||||
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
|
||||
}
|
||||
// 使用独立 context,避免调用方 cancel 导致所有共享 flight 的请求失败
|
||||
fetchCtx, fetchCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer fetchCancel()
|
||||
|
||||
// 4. 缓存结果
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: result.UsageInfo,
|
||||
timestamp: time.Now(),
|
||||
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(fetchCtx, account)
|
||||
fetchResult, err := s.antigravityQuotaFetcher.FetchQuota(fetchCtx, account, proxyURL)
|
||||
if err != nil {
|
||||
degraded := buildAntigravityDegradedUsage(err)
|
||||
enrichUsageWithAccountError(degraded, account)
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: degraded,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
return degraded, nil
|
||||
}
|
||||
|
||||
enrichUsageWithAccountError(fetchResult.UsageInfo, account)
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: fetchResult.UsageInfo,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
return fetchResult.UsageInfo, nil
|
||||
})
|
||||
|
||||
return result.UsageInfo, nil
|
||||
if flightErr != nil {
|
||||
return nil, flightErr
|
||||
}
|
||||
usage, ok := result.(*UsageInfo)
|
||||
if !ok || usage == nil {
|
||||
now := time.Now()
|
||||
return &UsageInfo{UpdatedAt: &now}, nil
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
|
||||
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
|
||||
func recalcAntigravityRemainingSeconds(info *UsageInfo) {
|
||||
if info == nil {
|
||||
return
|
||||
}
|
||||
if info.FiveHour != nil && info.FiveHour.ResetsAt != nil {
|
||||
remaining := int(time.Until(*info.FiveHour.ResetsAt).Seconds())
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
info.FiveHour.RemainingSeconds = remaining
|
||||
}
|
||||
}
|
||||
|
||||
// antigravityCacheTTL 根据 UsageInfo 内容决定缓存 TTL
|
||||
// 403 forbidden 状态稳定,缓存与成功相同(3 分钟);
|
||||
// 其他错误(401/网络)可能快速恢复,缓存 1 分钟。
|
||||
func antigravityCacheTTL(info *UsageInfo) time.Duration {
|
||||
if info == nil {
|
||||
return antigravityErrorTTL
|
||||
}
|
||||
if info.IsForbidden {
|
||||
return apiCacheTTL // 封号/验证状态不会很快变
|
||||
}
|
||||
if info.ErrorCode != "" || info.Error != "" {
|
||||
return antigravityErrorTTL
|
||||
}
|
||||
return apiCacheTTL
|
||||
}
|
||||
|
||||
// buildAntigravityDegradedUsage 从 FetchQuota 错误构建降级 UsageInfo
|
||||
func buildAntigravityDegradedUsage(err error) *UsageInfo {
|
||||
now := time.Now()
|
||||
errMsg := fmt.Sprintf("usage API error: %v", err)
|
||||
slog.Warn("antigravity usage fetch failed, returning degraded response", "error", err)
|
||||
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
Error: errMsg,
|
||||
}
|
||||
|
||||
// 从错误信息推断 error_code 和状态标记
|
||||
// 错误格式来自 antigravity/client.go: "fetchAvailableModels 失败 (HTTP %d): ..."
|
||||
errStr := err.Error()
|
||||
switch {
|
||||
case strings.Contains(errStr, "HTTP 401") ||
|
||||
strings.Contains(errStr, "UNAUTHENTICATED") ||
|
||||
strings.Contains(errStr, "invalid_grant"):
|
||||
info.ErrorCode = errorCodeUnauthenticated
|
||||
info.NeedsReauth = true
|
||||
case strings.Contains(errStr, "HTTP 429"):
|
||||
info.ErrorCode = errorCodeRateLimited
|
||||
default:
|
||||
info.ErrorCode = errorCodeNetworkError
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// enrichUsageWithAccountError 结合账号错误状态修正 UsageInfo
|
||||
// 场景 1(成功路径):FetchAvailableModels 正常返回,但账号已因 403 被标记为 error,
|
||||
//
|
||||
// 需要在正常 usage 数据上附加 forbidden/validation 信息。
|
||||
//
|
||||
// 场景 2(降级路径):被封号的账号 OAuth token 失效,FetchAvailableModels 返回 401,
|
||||
//
|
||||
// 降级逻辑设置了 needs_reauth,但账号实际是 403 封号/需验证,需覆盖为正确状态。
|
||||
func enrichUsageWithAccountError(info *UsageInfo, account *Account) {
|
||||
if info == nil || account == nil || account.Status != StatusError {
|
||||
return
|
||||
}
|
||||
msg := strings.ToLower(account.ErrorMessage)
|
||||
if !strings.Contains(msg, "403") && !strings.Contains(msg, "forbidden") &&
|
||||
!strings.Contains(msg, "violation") && !strings.Contains(msg, "validation") {
|
||||
return
|
||||
}
|
||||
fbType := classifyForbiddenType(account.ErrorMessage)
|
||||
info.IsForbidden = true
|
||||
info.ForbiddenType = fbType
|
||||
info.ForbiddenReason = account.ErrorMessage
|
||||
info.NeedsVerify = fbType == forbiddenTypeValidation
|
||||
info.IsBanned = fbType == forbiddenTypeViolation
|
||||
info.ValidationURL = extractValidationURL(account.ErrorMessage)
|
||||
info.ErrorCode = errorCodeForbidden
|
||||
info.NeedsReauth = false
|
||||
}
|
||||
|
||||
// addWindowStats 为 usage 数据添加窗口期统计
|
||||
|
||||
@@ -43,12 +43,13 @@ func TestMatchWildcard(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchWildcardMapping(t *testing.T) {
|
||||
func TestMatchWildcardMappingResult(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mapping map[string]string
|
||||
requestedModel string
|
||||
expected string
|
||||
matched bool
|
||||
}{
|
||||
// 精确匹配优先于通配符
|
||||
{
|
||||
@@ -59,6 +60,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5-exact",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 最长通配符优先
|
||||
@@ -71,6 +73,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-series",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 单个通配符
|
||||
@@ -81,6 +84,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "claude-opus-4-5",
|
||||
expected: "claude-mapped",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 无匹配返回原始模型
|
||||
@@ -91,6 +95,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "gemini-3-flash",
|
||||
expected: "gemini-3-flash",
|
||||
matched: false,
|
||||
},
|
||||
|
||||
// 空映射返回原始模型
|
||||
@@ -99,6 +104,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
mapping: map[string]string{},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5",
|
||||
matched: false,
|
||||
},
|
||||
|
||||
// Gemini 模型映射
|
||||
@@ -110,14 +116,15 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "gemini-3-flash-preview",
|
||||
expected: "gemini-3-pro-high",
|
||||
matched: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := matchWildcardMapping(tt.mapping, tt.requestedModel)
|
||||
if result != tt.expected {
|
||||
t.Errorf("matchWildcardMapping(%v, %q) = %q, want %q", tt.mapping, tt.requestedModel, result, tt.expected)
|
||||
result, matched := matchWildcardMappingResult(tt.mapping, tt.requestedModel)
|
||||
if result != tt.expected || matched != tt.matched {
|
||||
t.Errorf("matchWildcardMappingResult(%v, %q) = (%q, %v), want (%q, %v)", tt.mapping, tt.requestedModel, result, matched, tt.expected, tt.matched)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -268,6 +275,69 @@ func TestAccountGetMappedModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountResolveMappedModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
credentials map[string]any
|
||||
requestedModel string
|
||||
expectedModel string
|
||||
expectedMatch bool
|
||||
}{
|
||||
{
|
||||
name: "no mapping reports unmatched",
|
||||
credentials: nil,
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: false,
|
||||
},
|
||||
{
|
||||
name: "exact passthrough mapping still counts as matched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard passthrough mapping still counts as matched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-*": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "missing mapping reports unmatched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.2": "gpt-5.2",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: tt.credentials,
|
||||
}
|
||||
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
|
||||
if mappedModel != tt.expectedModel || matched != tt.expectedMatch {
|
||||
t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
|
||||
@@ -42,6 +42,9 @@ type AdminService interface {
|
||||
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
|
||||
DeleteGroup(ctx context.Context, id int64) error
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
||||
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
|
||||
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
|
||||
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
||||
|
||||
// API Key management (admin)
|
||||
@@ -57,6 +60,8 @@ type AdminService interface {
|
||||
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
|
||||
ClearAccountError(ctx context.Context, id int64) (*Account, error)
|
||||
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
||||
// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号 privacy_mode,未设置则尝试关闭训练数据共享并持久化。
|
||||
EnsureOpenAIPrivacy(ctx context.Context, account *Account) string
|
||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
||||
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
|
||||
@@ -433,6 +438,7 @@ type adminServiceImpl struct {
|
||||
settingService *SettingService
|
||||
defaultSubAssigner DefaultSubscriptionAssigner
|
||||
userSubRepo UserSubscriptionRepository
|
||||
privacyClientFactory PrivacyClientFactory
|
||||
}
|
||||
|
||||
type userGroupRateBatchReader interface {
|
||||
@@ -461,6 +467,7 @@ func NewAdminService(
|
||||
settingService *SettingService,
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
privacyClientFactory PrivacyClientFactory,
|
||||
) AdminService {
|
||||
return &adminServiceImpl{
|
||||
userRepo: userRepo,
|
||||
@@ -479,6 +486,7 @@ func NewAdminService(
|
||||
settingService: settingService,
|
||||
defaultSubAssigner: defaultSubAssigner,
|
||||
userSubRepo: userSubRepo,
|
||||
privacyClientFactory: privacyClientFactory,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1244,6 +1252,27 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
return keys, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.userGroupRateRepo.GetByGroupID(ctx, groupID)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ClearGroupRateMultipliers(ctx context.Context, groupID int64) error {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil
|
||||
}
|
||||
return s.userGroupRateRepo.DeleteByGroupID(ctx, groupID)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil
|
||||
}
|
||||
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
||||
return s.groupRepo.UpdateSortOrders(ctx, updates)
|
||||
}
|
||||
@@ -1433,6 +1462,13 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
// 预计算固定时间重置的下次重置时间
|
||||
if account.Extra != nil {
|
||||
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ComputeQuotaResetAt(account.Extra)
|
||||
}
|
||||
if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
|
||||
expiresAt := time.Unix(*input.ExpiresAt, 0)
|
||||
account.ExpiresAt = &expiresAt
|
||||
@@ -1506,6 +1542,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
}
|
||||
}
|
||||
account.Extra = input.Extra
|
||||
// 校验并预计算固定时间重置的下次重置时间
|
||||
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ComputeQuotaResetAt(account.Extra)
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
|
||||
@@ -2502,3 +2543,39 @@ func (e *MixedChannelError) Error() string {
|
||||
func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error {
|
||||
return s.accountRepo.ResetQuotaUsed(ctx, id)
|
||||
}
|
||||
|
||||
// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号是否已设置 privacy_mode,
|
||||
// 未设置则调用 disableOpenAITraining 并持久化到 Extra,返回设置的 mode 值。
|
||||
func (s *adminServiceImpl) EnsureOpenAIPrivacy(ctx context.Context, account *Account) string {
|
||||
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||
return ""
|
||||
}
|
||||
if s.privacyClientFactory == nil {
|
||||
return ""
|
||||
}
|
||||
if account.Extra != nil {
|
||||
if _, ok := account.Extra["privacy_mode"]; ok {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
token, _ := account.Credentials["access_token"].(string)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil {
|
||||
proxyURL = p.URL()
|
||||
}
|
||||
}
|
||||
|
||||
mode := disableOpenAITraining(ctx, s.privacyClientFactory, token, proxyURL)
|
||||
if mode == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
_ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode})
|
||||
return mode
|
||||
}
|
||||
|
||||
176
backend/internal/service/admin_service_group_rate_test.go
Normal file
176
backend/internal/service/admin_service_group_rate_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// userGroupRateRepoStubForGroupRate implements UserGroupRateRepository for group rate tests.
|
||||
type userGroupRateRepoStubForGroupRate struct {
|
||||
getByGroupIDData map[int64][]UserGroupRateEntry
|
||||
getByGroupIDErr error
|
||||
|
||||
deletedGroupIDs []int64
|
||||
deleteByGroupErr error
|
||||
|
||||
syncedGroupID int64
|
||||
syncedEntries []GroupRateMultiplierInput
|
||||
syncGroupErr error
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
|
||||
panic("unexpected GetByUserID call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context, _, _ int64) (*float64, error) {
|
||||
panic("unexpected GetByUserAndGroup call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
||||
if s.getByGroupIDErr != nil {
|
||||
return nil, s.getByGroupIDErr
|
||||
}
|
||||
return s.getByGroupIDData[groupID], nil
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) SyncUserGroupRates(_ context.Context, _ int64, _ map[int64]*float64) error {
|
||||
panic("unexpected SyncUserGroupRates call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
|
||||
s.syncedGroupID = groupID
|
||||
s.syncedEntries = entries
|
||||
return s.syncGroupErr
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
|
||||
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
|
||||
return s.deleteByGroupErr
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) DeleteByUserID(_ context.Context, _ int64) error {
|
||||
panic("unexpected DeleteByUserID call")
|
||||
}
|
||||
|
||||
func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
|
||||
t.Run("returns entries for group", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
getByGroupIDData: map[int64][]UserGroupRateEntry{
|
||||
10: {
|
||||
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
|
||||
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 2)
|
||||
require.Equal(t, int64(1), entries[0].UserID)
|
||||
require.Equal(t, "alice", entries[0].UserName)
|
||||
require.Equal(t, 1.5, entries[0].RateMultiplier)
|
||||
require.Equal(t, int64(2), entries[1].UserID)
|
||||
require.Equal(t, 0.8, entries[1].RateMultiplier)
|
||||
})
|
||||
|
||||
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||
svc := &adminServiceImpl{userGroupRateRepo: nil}
|
||||
|
||||
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, entries)
|
||||
})
|
||||
|
||||
t.Run("returns empty slice for group with no entries", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
getByGroupIDData: map[int64][]UserGroupRateEntry{},
|
||||
}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
entries, err := svc.GetGroupRateMultipliers(context.Background(), 99)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, entries)
|
||||
})
|
||||
|
||||
t.Run("propagates repo error", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
getByGroupIDErr: errors.New("db error"),
|
||||
}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
_, err := svc.GetGroupRateMultipliers(context.Background(), 10)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "db error")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ClearGroupRateMultipliers(t *testing.T) {
|
||||
t.Run("deletes by group ID", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{42}, repo.deletedGroupIDs)
|
||||
})
|
||||
|
||||
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||
svc := &adminServiceImpl{userGroupRateRepo: nil}
|
||||
|
||||
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("propagates repo error", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
deleteByGroupErr: errors.New("delete failed"),
|
||||
}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "delete failed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
|
||||
t.Run("syncs entries to repo", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
entries := []GroupRateMultiplierInput{
|
||||
{UserID: 1, RateMultiplier: 1.5},
|
||||
{UserID: 2, RateMultiplier: 0.8},
|
||||
}
|
||||
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, entries)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), repo.syncedGroupID)
|
||||
require.Equal(t, entries, repo.syncedEntries)
|
||||
})
|
||||
|
||||
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||
svc := &adminServiceImpl{userGroupRateRepo: nil}
|
||||
|
||||
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, nil)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("propagates repo error", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
syncGroupErr: errors.New("sync failed"),
|
||||
}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, []GroupRateMultiplierInput{
|
||||
{UserID: 1, RateMultiplier: 1.0},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "sync failed")
|
||||
})
|
||||
}
|
||||
@@ -68,7 +68,15 @@ func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context
|
||||
panic("unexpected SyncUserGroupRates call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
|
||||
func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, _ int64) ([]UserGroupRateEntry, error) {
|
||||
panic("unexpected GetByGroupID call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.Context, _ int64, _ []GroupRateMultiplierInput) error {
|
||||
panic("unexpected SyncGroupRateMultipliers call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
|
||||
panic("unexpected DeleteByGroupID call")
|
||||
}
|
||||
|
||||
|
||||
@@ -2,12 +2,29 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
const (
|
||||
forbiddenTypeValidation = "validation"
|
||||
forbiddenTypeViolation = "violation"
|
||||
forbiddenTypeForbidden = "forbidden"
|
||||
|
||||
// 机器可读的错误码
|
||||
errorCodeForbidden = "forbidden"
|
||||
errorCodeUnauthenticated = "unauthenticated"
|
||||
errorCodeRateLimited = "rate_limited"
|
||||
errorCodeNetworkError = "network_error"
|
||||
)
|
||||
|
||||
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
|
||||
type AntigravityQuotaFetcher struct {
|
||||
proxyRepo ProxyRepository
|
||||
@@ -40,11 +57,32 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
// 调用 API 获取配额
|
||||
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
||||
if err != nil {
|
||||
// 403 Forbidden: 不报错,返回 is_forbidden 标记
|
||||
var forbiddenErr *antigravity.ForbiddenError
|
||||
if errors.As(err, &forbiddenErr) {
|
||||
now := time.Now()
|
||||
fbType := classifyForbiddenType(forbiddenErr.Body)
|
||||
return &QuotaResult{
|
||||
UsageInfo: &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
IsForbidden: true,
|
||||
ForbiddenReason: forbiddenErr.Body,
|
||||
ForbiddenType: fbType,
|
||||
ValidationURL: extractValidationURL(forbiddenErr.Body),
|
||||
NeedsVerify: fbType == forbiddenTypeValidation,
|
||||
IsBanned: fbType == forbiddenTypeViolation,
|
||||
ErrorCode: errorCodeForbidden,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调用 LoadCodeAssist 获取订阅等级(非关键路径,失败不影响主流程)
|
||||
tierRaw, tierNormalized := f.fetchSubscriptionTier(ctx, client, accessToken)
|
||||
|
||||
// 转换为 UsageInfo
|
||||
usageInfo := f.buildUsageInfo(modelsResp)
|
||||
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized)
|
||||
|
||||
return &QuotaResult{
|
||||
UsageInfo: usageInfo,
|
||||
@@ -52,15 +90,52 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
||||
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串
|
||||
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string) {
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
if err != nil {
|
||||
slog.Warn("failed to fetch subscription tier", "error", err)
|
||||
return "", ""
|
||||
}
|
||||
if loadResp == nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// 遍历所有模型,填充 AntigravityQuota
|
||||
raw = loadResp.GetTier() // 已有方法:paidTier > currentTier
|
||||
normalized = normalizeTier(raw)
|
||||
return raw, normalized
|
||||
}
|
||||
|
||||
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
|
||||
func normalizeTier(raw string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
lower := strings.ToLower(raw)
|
||||
switch {
|
||||
case strings.Contains(lower, "ultra"):
|
||||
return "ULTRA"
|
||||
case strings.Contains(lower, "pro"):
|
||||
return "PRO"
|
||||
case strings.Contains(lower, "free"):
|
||||
return "FREE"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
||||
AntigravityQuotaDetails: make(map[string]*AntigravityModelDetail),
|
||||
SubscriptionTier: tierNormalized,
|
||||
SubscriptionTierRaw: tierRaw,
|
||||
}
|
||||
|
||||
// 遍历所有模型,填充 AntigravityQuota 和 AntigravityQuotaDetails
|
||||
for modelName, modelInfo := range modelsResp.Models {
|
||||
if modelInfo.QuotaInfo == nil {
|
||||
continue
|
||||
@@ -73,6 +148,27 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv
|
||||
Utilization: utilization,
|
||||
ResetTime: modelInfo.QuotaInfo.ResetTime,
|
||||
}
|
||||
|
||||
// 填充模型详细能力信息
|
||||
detail := &AntigravityModelDetail{
|
||||
DisplayName: modelInfo.DisplayName,
|
||||
SupportsImages: modelInfo.SupportsImages,
|
||||
SupportsThinking: modelInfo.SupportsThinking,
|
||||
ThinkingBudget: modelInfo.ThinkingBudget,
|
||||
Recommended: modelInfo.Recommended,
|
||||
MaxTokens: modelInfo.MaxTokens,
|
||||
MaxOutputTokens: modelInfo.MaxOutputTokens,
|
||||
SupportedMimeTypes: modelInfo.SupportedMimeTypes,
|
||||
}
|
||||
info.AntigravityQuotaDetails[modelName] = detail
|
||||
}
|
||||
|
||||
// 废弃模型转发规则
|
||||
if len(modelsResp.DeprecatedModelIDs) > 0 {
|
||||
info.ModelForwardingRules = make(map[string]string, len(modelsResp.DeprecatedModelIDs))
|
||||
for oldID, deprecated := range modelsResp.DeprecatedModelIDs {
|
||||
info.ModelForwardingRules[oldID] = deprecated.NewModelID
|
||||
}
|
||||
}
|
||||
|
||||
// 同时设置 FiveHour 用于兼容展示(取主要模型)
|
||||
@@ -108,3 +204,58 @@ func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Acco
|
||||
}
|
||||
return proxy.URL()
|
||||
}
|
||||
|
||||
// classifyForbiddenType 根据 403 响应体判断禁止类型
|
||||
func classifyForbiddenType(body string) string {
|
||||
lower := strings.ToLower(body)
|
||||
switch {
|
||||
case strings.Contains(lower, "validation_required") ||
|
||||
strings.Contains(lower, "verify your account") ||
|
||||
strings.Contains(lower, "validation_url"):
|
||||
return forbiddenTypeValidation
|
||||
case strings.Contains(lower, "terms of service") ||
|
||||
strings.Contains(lower, "violation"):
|
||||
return forbiddenTypeViolation
|
||||
default:
|
||||
return forbiddenTypeForbidden
|
||||
}
|
||||
}
|
||||
|
||||
// urlPattern 用于从 403 响应体中提取 URL(降级方案)
|
||||
var urlPattern = regexp.MustCompile(`https://[^\s"'\\]+`)
|
||||
|
||||
// extractValidationURL 从 403 响应 JSON 中提取验证/申诉链接
|
||||
func extractValidationURL(body string) string {
|
||||
// 1. 尝试结构化 JSON 提取: /error/details[*]/metadata/validation_url 或 appeal_url
|
||||
var parsed struct {
|
||||
Error struct {
|
||||
Details []struct {
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
} `json:"details"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if json.Unmarshal([]byte(body), &parsed) == nil {
|
||||
for _, detail := range parsed.Error.Details {
|
||||
if u := detail.Metadata["validation_url"]; u != "" {
|
||||
return u
|
||||
}
|
||||
if u := detail.Metadata["appeal_url"]; u != "" {
|
||||
return u
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 降级:正则匹配 URL
|
||||
lower := strings.ToLower(body)
|
||||
if !strings.Contains(lower, "validation") &&
|
||||
!strings.Contains(lower, "verify") &&
|
||||
!strings.Contains(lower, "appeal") {
|
||||
return ""
|
||||
}
|
||||
// 先解码常见转义再匹配
|
||||
normalized := strings.ReplaceAll(body, `\u0026`, "&")
|
||||
if m := urlPattern.FindString(normalized); m != "" {
|
||||
return m
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
497
backend/internal/service/antigravity_quota_fetcher_test.go
Normal file
497
backend/internal/service/antigravity_quota_fetcher_test.go
Normal file
@@ -0,0 +1,497 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// normalizeTier
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNormalizeTier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
expected string
|
||||
}{
|
||||
{name: "empty string", raw: "", expected: ""},
|
||||
{name: "free-tier", raw: "free-tier", expected: "FREE"},
|
||||
{name: "g1-pro-tier", raw: "g1-pro-tier", expected: "PRO"},
|
||||
{name: "g1-ultra-tier", raw: "g1-ultra-tier", expected: "ULTRA"},
|
||||
{name: "unknown-something", raw: "unknown-something", expected: "UNKNOWN"},
|
||||
{name: "Google AI Pro contains pro keyword", raw: "Google AI Pro", expected: "PRO"},
|
||||
{name: "case insensitive FREE", raw: "FREE-TIER", expected: "FREE"},
|
||||
{name: "case insensitive Ultra", raw: "Ultra Plan", expected: "ULTRA"},
|
||||
{name: "arbitrary unrecognized string", raw: "enterprise-custom", expected: "UNKNOWN"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeTier(tt.raw)
|
||||
require.Equal(t, tt.expected, got, "normalizeTier(%q)", tt.raw)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// buildUsageInfo
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func aqfBoolPtr(v bool) *bool { return &v }
|
||||
func aqfIntPtr(v int) *int { return &v }
|
||||
|
||||
func TestBuildUsageInfo_BasicModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.75,
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
DisplayName: "Claude Sonnet 4",
|
||||
SupportsImages: aqfBoolPtr(true),
|
||||
SupportsThinking: aqfBoolPtr(false),
|
||||
ThinkingBudget: aqfIntPtr(0),
|
||||
Recommended: aqfBoolPtr(true),
|
||||
MaxTokens: aqfIntPtr(200000),
|
||||
MaxOutputTokens: aqfIntPtr(16384),
|
||||
SupportedMimeTypes: map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
},
|
||||
},
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
ResetTime: "2026-03-08T15:00:00Z",
|
||||
},
|
||||
DisplayName: "Gemini 2.5 Pro",
|
||||
MaxTokens: aqfIntPtr(1000000),
|
||||
MaxOutputTokens: aqfIntPtr(65536),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO")
|
||||
|
||||
// 基本字段
|
||||
require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set")
|
||||
require.Equal(t, "PRO", info.SubscriptionTier)
|
||||
require.Equal(t, "g1-pro-tier", info.SubscriptionTierRaw)
|
||||
|
||||
// AntigravityQuota
|
||||
require.Len(t, info.AntigravityQuota, 2)
|
||||
|
||||
sonnetQuota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, sonnetQuota)
|
||||
require.Equal(t, 25, sonnetQuota.Utilization) // (1 - 0.75) * 100 = 25
|
||||
require.Equal(t, "2026-03-08T12:00:00Z", sonnetQuota.ResetTime)
|
||||
|
||||
geminiQuota := info.AntigravityQuota["gemini-2.5-pro"]
|
||||
require.NotNil(t, geminiQuota)
|
||||
require.Equal(t, 50, geminiQuota.Utilization) // (1 - 0.50) * 100 = 50
|
||||
require.Equal(t, "2026-03-08T15:00:00Z", geminiQuota.ResetTime)
|
||||
|
||||
// AntigravityQuotaDetails
|
||||
require.Len(t, info.AntigravityQuotaDetails, 2)
|
||||
|
||||
sonnetDetail := info.AntigravityQuotaDetails["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, sonnetDetail)
|
||||
require.Equal(t, "Claude Sonnet 4", sonnetDetail.DisplayName)
|
||||
require.Equal(t, aqfBoolPtr(true), sonnetDetail.SupportsImages)
|
||||
require.Equal(t, aqfBoolPtr(false), sonnetDetail.SupportsThinking)
|
||||
require.Equal(t, aqfIntPtr(0), sonnetDetail.ThinkingBudget)
|
||||
require.Equal(t, aqfBoolPtr(true), sonnetDetail.Recommended)
|
||||
require.Equal(t, aqfIntPtr(200000), sonnetDetail.MaxTokens)
|
||||
require.Equal(t, aqfIntPtr(16384), sonnetDetail.MaxOutputTokens)
|
||||
require.Equal(t, map[string]bool{"image/png": true, "image/jpeg": true}, sonnetDetail.SupportedMimeTypes)
|
||||
|
||||
geminiDetail := info.AntigravityQuotaDetails["gemini-2.5-pro"]
|
||||
require.NotNil(t, geminiDetail)
|
||||
require.Equal(t, "Gemini 2.5 Pro", geminiDetail.DisplayName)
|
||||
require.Nil(t, geminiDetail.SupportsImages)
|
||||
require.Nil(t, geminiDetail.SupportsThinking)
|
||||
require.Equal(t, aqfIntPtr(1000000), geminiDetail.MaxTokens)
|
||||
require.Equal(t, aqfIntPtr(65536), geminiDetail.MaxOutputTokens)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_DeprecatedModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 1.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
DeprecatedModelIDs: map[string]antigravity.DeprecatedModelInfo{
|
||||
"claude-3-sonnet-20240229": {NewModelID: "claude-sonnet-4-20250514"},
|
||||
"claude-3-haiku-20240307": {NewModelID: "claude-haiku-3.5-latest"},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.Len(t, info.ModelForwardingRules, 2)
|
||||
require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"])
|
||||
require.Equal(t, "claude-haiku-3.5-latest", info.ModelForwardingRules["claude-3-haiku-20240307"])
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"some-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{RemainingFraction: 0.9},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_EmptyModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.NotNil(t, info.AntigravityQuota)
|
||||
require.Empty(t, info.AntigravityQuota)
|
||||
require.NotNil(t, info.AntigravityQuotaDetails)
|
||||
require.Empty(t, info.AntigravityQuotaDetails)
|
||||
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"model-without-quota": {
|
||||
DisplayName: "No Quota Model",
|
||||
// QuotaInfo is nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped")
|
||||
require.Empty(t, info.AntigravityQuotaDetails, "models with nil QuotaInfo should be skipped from details too")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// priorityModels = ["claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"]
|
||||
// When the first priority model exists, it should be used for FiveHour
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.40,
|
||||
ResetTime: "2026-03-08T18:00:00Z",
|
||||
},
|
||||
},
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.80,
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists")
|
||||
// claude-sonnet-4-20250514 is first in priority list, so it should be used
|
||||
expectedUtilization := (1.0 - 0.80) * 100 // 20
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
require.NotNil(t, info.FiveHour.ResetsAt, "ResetsAt should be parsed from ResetTime")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// Only claude-sonnet-4 exists (second in priority list), not claude-sonnet-4-20250514
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.60,
|
||||
ResetTime: "2026-03-08T14:00:00Z",
|
||||
},
|
||||
},
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.30,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
expectedUtilization := (1.0 - 0.60) * 100 // 40
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// Only gemini-2.5-pro exists (third in priority list)
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.30,
|
||||
},
|
||||
},
|
||||
"other-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.90,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
expectedUtilization := (1.0 - 0.30) * 100 // 70
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// None of the priority models exist
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"some-other-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
ResetTime: "", // empty reset time
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty")
|
||||
require.Equal(t, 0, info.FiveHour.RemainingSeconds)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FullUtilization(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.0, // fully used
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, quota)
|
||||
require.Equal(t, 100, quota.Utilization)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_ZeroUtilization(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 1.0, // fully available
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, quota)
|
||||
require.Equal(t, 0, quota.Utilization)
|
||||
}
|
||||
|
||||
func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) {
|
||||
// 模拟 FetchQuota 遇到 403 时的行为:
|
||||
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true
|
||||
forbiddenErr := &antigravity.ForbiddenError{
|
||||
StatusCode: 403,
|
||||
Body: "Access denied",
|
||||
}
|
||||
|
||||
// 验证 ForbiddenError 满足 errors.As
|
||||
var target *antigravity.ForbiddenError
|
||||
require.True(t, errors.As(forbiddenErr, &target))
|
||||
require.Equal(t, 403, target.StatusCode)
|
||||
require.Equal(t, "Access denied", target.Body)
|
||||
require.Contains(t, forbiddenErr.Error(), "403")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// classifyForbiddenType
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestClassifyForbiddenType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "VALIDATION_REQUIRED keyword",
|
||||
body: `{"error":{"message":"VALIDATION_REQUIRED"}}`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "verify your account",
|
||||
body: `Please verify your account to continue`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "contains validation_url field",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://..."}}]}}`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "terms of service violation",
|
||||
body: `Your account has been suspended for Terms of Service violation`,
|
||||
expected: "violation",
|
||||
},
|
||||
{
|
||||
name: "violation keyword",
|
||||
body: `Account suspended due to policy violation`,
|
||||
expected: "violation",
|
||||
},
|
||||
{
|
||||
name: "generic 403",
|
||||
body: `Access denied`,
|
||||
expected: "forbidden",
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
expected: "forbidden",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := classifyForbiddenType(tt.body)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// extractValidationURL
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractValidationURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "structured validation_url",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://accounts.google.com/verify?token=abc"}}]}}`,
|
||||
expected: "https://accounts.google.com/verify?token=abc",
|
||||
},
|
||||
{
|
||||
name: "structured appeal_url",
|
||||
body: `{"error":{"details":[{"metadata":{"appeal_url":"https://support.google.com/appeal/123"}}]}}`,
|
||||
expected: "https://support.google.com/appeal/123",
|
||||
},
|
||||
{
|
||||
name: "validation_url takes priority over appeal_url",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://v.com","appeal_url":"https://a.com"}}]}}`,
|
||||
expected: "https://v.com",
|
||||
},
|
||||
{
|
||||
name: "fallback regex with verify keyword",
|
||||
body: `Please verify your account at https://accounts.google.com/verify`,
|
||||
expected: "https://accounts.google.com/verify",
|
||||
},
|
||||
{
|
||||
name: "no URL in generic forbidden",
|
||||
body: `Access denied`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "URL present but no validation keywords",
|
||||
body: `Error at https://example.com/something`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "unicode escaped ampersand",
|
||||
body: `validation required: https://accounts.google.com/verify?a=1\u0026b=2`,
|
||||
expected: "https://accounts.google.com/verify?a=1&b=2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractValidationURL(tt.body)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1087,6 +1087,12 @@ type TokenPair struct {
|
||||
ExpiresIn int `json:"expires_in"` // Access Token有效期(秒)
|
||||
}
|
||||
|
||||
// TokenPairWithUser extends TokenPair with user role for backend mode checks
|
||||
type TokenPairWithUser struct {
|
||||
TokenPair
|
||||
UserRole string
|
||||
}
|
||||
|
||||
// GenerateTokenPair 生成Access Token和Refresh Token对
|
||||
// familyID: 可选的Token家族ID,用于Token轮转时保持家族关系
|
||||
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
|
||||
@@ -1168,7 +1174,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
|
||||
|
||||
// RefreshTokenPair 使用Refresh Token刷新Token对
|
||||
// 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效
|
||||
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
||||
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPairWithUser, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, ErrRefreshTokenInvalid
|
||||
@@ -1233,7 +1239,14 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
|
||||
}
|
||||
|
||||
// 生成新的Token对,保持同一个家族ID
|
||||
return s.GenerateTokenPair(ctx, user, data.FamilyID)
|
||||
pair, err := s.GenerateTokenPair(ctx, user, data.FamilyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenPairWithUser{
|
||||
TokenPair: *pair,
|
||||
UserRole: user.Role,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RevokeRefreshToken 撤销单个Refresh Token
|
||||
|
||||
770
backend/internal/service/backup_service.go
Normal file
770
backend/internal/service/backup_service.go
Normal file
@@ -0,0 +1,770 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/robfig/cron/v3"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
settingKeyBackupS3Config = "backup_s3_config"
|
||||
settingKeyBackupSchedule = "backup_schedule"
|
||||
settingKeyBackupRecords = "backup_records"
|
||||
|
||||
maxBackupRecords = 100
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBackupS3NotConfigured = infraerrors.BadRequest("BACKUP_S3_NOT_CONFIGURED", "backup S3 storage is not configured")
|
||||
ErrBackupNotFound = infraerrors.NotFound("BACKUP_NOT_FOUND", "backup record not found")
|
||||
ErrBackupInProgress = infraerrors.Conflict("BACKUP_IN_PROGRESS", "a backup is already in progress")
|
||||
ErrRestoreInProgress = infraerrors.Conflict("RESTORE_IN_PROGRESS", "a restore is already in progress")
|
||||
ErrBackupRecordsCorrupt = infraerrors.InternalServer("BACKUP_RECORDS_CORRUPT", "backup records data is corrupted")
|
||||
ErrBackupS3ConfigCorrupt = infraerrors.InternalServer("BACKUP_S3_CONFIG_CORRUPT", "backup S3 config data is corrupted")
|
||||
)
|
||||
|
||||
// ─── 接口定义 ───
|
||||
|
||||
// DBDumper abstracts database dump/restore operations
|
||||
type DBDumper interface {
|
||||
Dump(ctx context.Context) (io.ReadCloser, error)
|
||||
Restore(ctx context.Context, data io.Reader) error
|
||||
}
|
||||
|
||||
// BackupObjectStore abstracts object storage for backup files
|
||||
type BackupObjectStore interface {
|
||||
Upload(ctx context.Context, key string, body io.Reader, contentType string) (sizeBytes int64, err error)
|
||||
Download(ctx context.Context, key string) (io.ReadCloser, error)
|
||||
Delete(ctx context.Context, key string) error
|
||||
PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error)
|
||||
HeadBucket(ctx context.Context) error
|
||||
}
|
||||
|
||||
// BackupObjectStoreFactory creates an object store from S3 config
|
||||
type BackupObjectStoreFactory func(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error)
|
||||
|
||||
// ─── 数据模型 ───
|
||||
|
||||
// BackupS3Config S3 兼容存储配置(支持 Cloudflare R2)
|
||||
type BackupS3Config struct {
|
||||
Endpoint string `json:"endpoint"` // e.g. https://<account_id>.r2.cloudflarestorage.com
|
||||
Region string `json:"region"` // R2 用 "auto"
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key,omitempty"` //nolint:revive // field name follows AWS convention
|
||||
Prefix string `json:"prefix"` // S3 key 前缀,如 "backups/"
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
}
|
||||
|
||||
// IsConfigured 检查必要字段是否已配置
|
||||
func (c *BackupS3Config) IsConfigured() bool {
|
||||
return c.Bucket != "" && c.AccessKeyID != "" && c.SecretAccessKey != ""
|
||||
}
|
||||
|
||||
// BackupScheduleConfig 定时备份配置
|
||||
type BackupScheduleConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
CronExpr string `json:"cron_expr"` // cron 表达式,如 "0 2 * * *" 每天凌晨2点
|
||||
RetainDays int `json:"retain_days"` // 备份文件过期天数,默认14,0=不自动清理
|
||||
RetainCount int `json:"retain_count"` // 最多保留份数,0=不限制
|
||||
}
|
||||
|
||||
// BackupRecord 备份记录
|
||||
type BackupRecord struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"` // pending, running, completed, failed
|
||||
BackupType string `json:"backup_type"` // postgres
|
||||
FileName string `json:"file_name"`
|
||||
S3Key string `json:"s3_key"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
TriggeredBy string `json:"triggered_by"` // manual, scheduled
|
||||
ErrorMsg string `json:"error_message,omitempty"`
|
||||
StartedAt string `json:"started_at"`
|
||||
FinishedAt string `json:"finished_at,omitempty"`
|
||||
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
|
||||
}
|
||||
|
||||
// BackupService 数据库备份恢复服务
|
||||
type BackupService struct {
|
||||
settingRepo SettingRepository
|
||||
dbCfg *config.DatabaseConfig
|
||||
encryptor SecretEncryptor
|
||||
storeFactory BackupObjectStoreFactory
|
||||
dumper DBDumper
|
||||
|
||||
mu sync.Mutex
|
||||
store BackupObjectStore
|
||||
s3Cfg *BackupS3Config
|
||||
backingUp bool
|
||||
restoring bool
|
||||
|
||||
recordsMu sync.Mutex // 保护 records 的 load/save 操作
|
||||
|
||||
cronMu sync.Mutex
|
||||
cronSched *cron.Cron
|
||||
cronEntryID cron.EntryID
|
||||
}
|
||||
|
||||
func NewBackupService(
|
||||
settingRepo SettingRepository,
|
||||
cfg *config.Config,
|
||||
encryptor SecretEncryptor,
|
||||
storeFactory BackupObjectStoreFactory,
|
||||
dumper DBDumper,
|
||||
) *BackupService {
|
||||
return &BackupService{
|
||||
settingRepo: settingRepo,
|
||||
dbCfg: &cfg.Database,
|
||||
encryptor: encryptor,
|
||||
storeFactory: storeFactory,
|
||||
dumper: dumper,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动定时备份调度器
|
||||
func (s *BackupService) Start() {
|
||||
s.cronSched = cron.New()
|
||||
s.cronSched.Start()
|
||||
|
||||
// 加载已有的定时配置
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
schedule, err := s.GetSchedule(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 加载定时备份配置失败: %v", err)
|
||||
return
|
||||
}
|
||||
if schedule.Enabled && schedule.CronExpr != "" {
|
||||
if err := s.applyCronSchedule(schedule); err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 应用定时备份配置失败: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止定时备份
|
||||
func (s *BackupService) Stop() {
|
||||
s.cronMu.Lock()
|
||||
defer s.cronMu.Unlock()
|
||||
if s.cronSched != nil {
|
||||
s.cronSched.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// ─── S3 配置管理 ───
|
||||
|
||||
func (s *BackupService) GetS3Config(ctx context.Context) (*BackupS3Config, error) {
|
||||
cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg == nil {
|
||||
return &BackupS3Config{}, nil
|
||||
}
|
||||
// 脱敏返回
|
||||
cfg.SecretAccessKey = ""
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) (*BackupS3Config, error) {
|
||||
// 如果没提供 secret,保留原有值
|
||||
if cfg.SecretAccessKey == "" {
|
||||
old, _ := s.loadS3Config(ctx)
|
||||
if old != nil {
|
||||
cfg.SecretAccessKey = old.SecretAccessKey
|
||||
}
|
||||
} else {
|
||||
// 加密 SecretAccessKey
|
||||
encrypted, err := s.encryptor.Encrypt(cfg.SecretAccessKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encrypt secret: %w", err)
|
||||
}
|
||||
cfg.SecretAccessKey = encrypted
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal s3 config: %w", err)
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, settingKeyBackupS3Config, string(data)); err != nil {
|
||||
return nil, fmt.Errorf("save s3 config: %w", err)
|
||||
}
|
||||
|
||||
// 清除缓存的 S3 客户端
|
||||
s.mu.Lock()
|
||||
s.store = nil
|
||||
s.s3Cfg = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
cfg.SecretAccessKey = ""
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) TestS3Connection(ctx context.Context, cfg BackupS3Config) error {
|
||||
// 如果没提供 secret,用已保存的
|
||||
if cfg.SecretAccessKey == "" {
|
||||
old, _ := s.loadS3Config(ctx)
|
||||
if old != nil {
|
||||
cfg.SecretAccessKey = old.SecretAccessKey
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
|
||||
return fmt.Errorf("incomplete S3 config: bucket, access_key_id, secret_access_key are required")
|
||||
}
|
||||
|
||||
store, err := s.storeFactory(ctx, &cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return store.HeadBucket(ctx)
|
||||
}
|
||||
|
||||
// ─── 定时备份管理 ───
|
||||
|
||||
func (s *BackupService) GetSchedule(ctx context.Context) (*BackupScheduleConfig, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupSchedule)
|
||||
if err != nil || raw == "" {
|
||||
return &BackupScheduleConfig{}, nil
|
||||
}
|
||||
var cfg BackupScheduleConfig
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||
return &BackupScheduleConfig{}, nil
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) UpdateSchedule(ctx context.Context, cfg BackupScheduleConfig) (*BackupScheduleConfig, error) {
|
||||
if cfg.Enabled && cfg.CronExpr == "" {
|
||||
return nil, infraerrors.BadRequest("INVALID_CRON", "cron expression is required when schedule is enabled")
|
||||
}
|
||||
// 验证 cron 表达式
|
||||
if cfg.CronExpr != "" {
|
||||
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||
if _, err := parser.Parse(cfg.CronExpr); err != nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("invalid cron expression: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal schedule config: %w", err)
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, settingKeyBackupSchedule, string(data)); err != nil {
|
||||
return nil, fmt.Errorf("save schedule config: %w", err)
|
||||
}
|
||||
|
||||
// 应用或停止定时任务
|
||||
if cfg.Enabled {
|
||||
if err := s.applyCronSchedule(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
s.removeCronSchedule()
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) applyCronSchedule(cfg *BackupScheduleConfig) error {
|
||||
s.cronMu.Lock()
|
||||
defer s.cronMu.Unlock()
|
||||
|
||||
if s.cronSched == nil {
|
||||
return fmt.Errorf("cron scheduler not initialized")
|
||||
}
|
||||
|
||||
// 移除旧任务
|
||||
if s.cronEntryID != 0 {
|
||||
s.cronSched.Remove(s.cronEntryID)
|
||||
s.cronEntryID = 0
|
||||
}
|
||||
|
||||
entryID, err := s.cronSched.AddFunc(cfg.CronExpr, func() {
|
||||
s.runScheduledBackup()
|
||||
})
|
||||
if err != nil {
|
||||
return infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("failed to schedule: %v", err))
|
||||
}
|
||||
s.cronEntryID = entryID
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已启用: %s", cfg.CronExpr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupService) removeCronSchedule() {
|
||||
s.cronMu.Lock()
|
||||
defer s.cronMu.Unlock()
|
||||
if s.cronSched != nil && s.cronEntryID != 0 {
|
||||
s.cronSched.Remove(s.cronEntryID)
|
||||
s.cronEntryID = 0
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已停用")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupService) runScheduledBackup() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// 读取定时备份配置中的过期天数
|
||||
schedule, _ := s.GetSchedule(ctx)
|
||||
expireDays := 14 // 默认14天过期
|
||||
if schedule != nil && schedule.RetainDays > 0 {
|
||||
expireDays = schedule.RetainDays
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
|
||||
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
|
||||
|
||||
// 清理过期备份(复用已加载的 schedule)
|
||||
if schedule == nil {
|
||||
return
|
||||
}
|
||||
if err := s.cleanupOldBackups(ctx, schedule); err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 清理过期备份失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── 备份/恢复核心 ───
|
||||
|
||||
// CreateBackup 创建全量数据库备份并上传到 S3(流式处理)
|
||||
// expireDays: 备份过期天数,0=永不过期,默认14天
|
||||
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
|
||||
s.mu.Lock()
|
||||
if s.backingUp {
|
||||
s.mu.Unlock()
|
||||
return nil, ErrBackupInProgress
|
||||
}
|
||||
s.backingUp = true
|
||||
s.mu.Unlock()
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.backingUp = false
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s3Cfg == nil || !s3Cfg.IsConfigured() {
|
||||
return nil, ErrBackupS3NotConfigured
|
||||
}
|
||||
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init object store: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
backupID := uuid.New().String()[:8]
|
||||
fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405"))
|
||||
s3Key := s.buildS3Key(s3Cfg, fileName)
|
||||
|
||||
var expiresAt string
|
||||
if expireDays > 0 {
|
||||
expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339)
|
||||
}
|
||||
|
||||
record := &BackupRecord{
|
||||
ID: backupID,
|
||||
Status: "running",
|
||||
BackupType: "postgres",
|
||||
FileName: fileName,
|
||||
S3Key: s3Key,
|
||||
TriggeredBy: triggeredBy,
|
||||
StartedAt: now.Format(time.RFC3339),
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
|
||||
// 流式执行: pg_dump -> gzip -> S3 upload
|
||||
dumpReader, err := s.dumper.Dump(ctx)
|
||||
if err != nil {
|
||||
record.Status = "failed"
|
||||
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
_ = s.saveRecord(ctx, record)
|
||||
return record, fmt.Errorf("pg_dump: %w", err)
|
||||
}
|
||||
|
||||
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
|
||||
pr, pw := io.Pipe()
|
||||
var gzipErr error
|
||||
go func() {
|
||||
gzWriter := gzip.NewWriter(pw)
|
||||
_, gzipErr = io.Copy(gzWriter, dumpReader)
|
||||
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
|
||||
gzipErr = closeErr
|
||||
}
|
||||
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
|
||||
gzipErr = closeErr
|
||||
}
|
||||
if gzipErr != nil {
|
||||
_ = pw.CloseWithError(gzipErr)
|
||||
} else {
|
||||
_ = pw.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
contentType := "application/gzip"
|
||||
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
|
||||
if err != nil {
|
||||
record.Status = "failed"
|
||||
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
|
||||
if gzipErr != nil {
|
||||
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
|
||||
}
|
||||
record.ErrorMsg = errMsg
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
_ = s.saveRecord(ctx, record)
|
||||
return record, fmt.Errorf("backup upload: %w", err)
|
||||
}
|
||||
|
||||
record.SizeBytes = sizeBytes
|
||||
record.Status = "completed"
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
if err := s.saveRecord(ctx, record); err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
|
||||
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
|
||||
s.mu.Lock()
|
||||
if s.restoring {
|
||||
s.mu.Unlock()
|
||||
return ErrRestoreInProgress
|
||||
}
|
||||
s.restoring = true
|
||||
s.mu.Unlock()
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.restoring = false
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
record, err := s.GetBackupRecord(ctx, backupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if record.Status != "completed" {
|
||||
return infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup")
|
||||
}
|
||||
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init object store: %w", err)
|
||||
}
|
||||
|
||||
// 从 S3 流式下载
|
||||
body, err := objectStore.Download(ctx, record.S3Key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("S3 download failed: %w", err)
|
||||
}
|
||||
defer func() { _ = body.Close() }()
|
||||
|
||||
// 流式解压 gzip -> psql(不将全部数据加载到内存)
|
||||
gzReader, err := gzip.NewReader(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gzip reader: %w", err)
|
||||
}
|
||||
defer func() { _ = gzReader.Close() }()
|
||||
|
||||
// 流式恢复
|
||||
if err := s.dumper.Restore(ctx, gzReader); err != nil {
|
||||
return fmt.Errorf("pg restore: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ─── 备份记录管理 ───
|
||||
|
||||
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
|
||||
records, err := s.loadRecords(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 倒序返回(最新在前)
|
||||
sort.Slice(records, func(i, j int) bool {
|
||||
return records[i].StartedAt > records[j].StartedAt
|
||||
})
|
||||
return records, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackupRecord(ctx context.Context, backupID string) (*BackupRecord, error) {
|
||||
records, err := s.loadRecords(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range records {
|
||||
if records[i].ID == backupID {
|
||||
return &records[i], nil
|
||||
}
|
||||
}
|
||||
return nil, ErrBackupNotFound
|
||||
}
|
||||
|
||||
func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, err := s.loadRecordsLocked(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var found *BackupRecord
|
||||
var remaining []BackupRecord
|
||||
for i := range records {
|
||||
if records[i].ID == backupID {
|
||||
found = &records[i]
|
||||
} else {
|
||||
remaining = append(remaining, records[i])
|
||||
}
|
||||
}
|
||||
if found == nil {
|
||||
return ErrBackupNotFound
|
||||
}
|
||||
|
||||
// 从 S3 删除
|
||||
if found.S3Key != "" && found.Status == "completed" {
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err == nil && s3Cfg != nil && s3Cfg.IsConfigured() {
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err == nil {
|
||||
_ = objectStore.Delete(ctx, found.S3Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.saveRecordsLocked(ctx, remaining)
|
||||
}
|
||||
|
||||
// GetBackupDownloadURL 获取备份文件预签名下载 URL
|
||||
func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID string) (string, error) {
|
||||
record, err := s.GetBackupRecord(ctx, backupID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if record.Status != "completed" {
|
||||
return "", infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "backup is not completed")
|
||||
}
|
||||
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
url, err := objectStore.PresignURL(ctx, record.S3Key, 1*time.Hour)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// ─── 内部方法 ───
|
||||
|
||||
func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config)
|
||||
if err != nil || raw == "" {
|
||||
return nil, nil //nolint:nilnil // no config is a valid state
|
||||
}
|
||||
var cfg BackupS3Config
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||
return nil, ErrBackupS3ConfigCorrupt
|
||||
}
|
||||
// 解密 SecretAccessKey
|
||||
if cfg.SecretAccessKey != "" {
|
||||
decrypted, err := s.encryptor.Decrypt(cfg.SecretAccessKey)
|
||||
if err != nil {
|
||||
// 兼容未加密的旧数据:如果解密失败,保持原值
|
||||
logger.LegacyPrintf("service.backup", "[Backup] S3 SecretAccessKey 解密失败(可能是旧的未加密数据): %v", err)
|
||||
} else {
|
||||
cfg.SecretAccessKey = decrypted
|
||||
}
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.store != nil && s.s3Cfg != nil {
|
||||
return s.store, nil
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
return nil, ErrBackupS3NotConfigured
|
||||
}
|
||||
|
||||
store, err := s.storeFactory(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.store = store
|
||||
s.s3Cfg = cfg
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string {
|
||||
prefix := strings.TrimRight(cfg.Prefix, "/")
|
||||
if prefix == "" {
|
||||
prefix = "backups"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s", prefix, time.Now().Format("2006/01/02"), fileName)
|
||||
}
|
||||
|
||||
// loadRecords 加载备份记录,区分"无数据"和"数据损坏"
|
||||
func (s *BackupService) loadRecords(ctx context.Context) ([]BackupRecord, error) {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
return s.loadRecordsLocked(ctx)
|
||||
}
|
||||
|
||||
// loadRecordsLocked 在已持有 recordsMu 锁的情况下加载记录
|
||||
func (s *BackupService) loadRecordsLocked(ctx context.Context) ([]BackupRecord, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupRecords)
|
||||
if err != nil || raw == "" {
|
||||
return nil, nil //nolint:nilnil // no records is a valid state
|
||||
}
|
||||
var records []BackupRecord
|
||||
if err := json.Unmarshal([]byte(raw), &records); err != nil {
|
||||
return nil, ErrBackupRecordsCorrupt
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// saveRecordsLocked 在已持有 recordsMu 锁的情况下保存记录
|
||||
func (s *BackupService) saveRecordsLocked(ctx context.Context, records []BackupRecord) error {
|
||||
data, err := json.Marshal(records)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.settingRepo.Set(ctx, settingKeyBackupRecords, string(data))
|
||||
}
|
||||
|
||||
// saveRecord 保存单条记录(带互斥锁保护)
|
||||
func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) error {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, _ := s.loadRecordsLocked(ctx)
|
||||
|
||||
// 更新已有记录或追加
|
||||
found := false
|
||||
for i := range records {
|
||||
if records[i].ID == record.ID {
|
||||
records[i] = *record
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
records = append(records, *record)
|
||||
}
|
||||
|
||||
// 限制记录数量
|
||||
if len(records) > maxBackupRecords {
|
||||
records = records[len(records)-maxBackupRecords:]
|
||||
}
|
||||
|
||||
return s.saveRecordsLocked(ctx, records)
|
||||
}
|
||||
|
||||
func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupScheduleConfig) error {
|
||||
if schedule == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, err := s.loadRecordsLocked(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 按时间倒序
|
||||
sort.Slice(records, func(i, j int) bool {
|
||||
return records[i].StartedAt > records[j].StartedAt
|
||||
})
|
||||
|
||||
var toDelete []BackupRecord
|
||||
var toKeep []BackupRecord
|
||||
|
||||
for i, r := range records {
|
||||
shouldDelete := false
|
||||
|
||||
// 按保留份数清理
|
||||
if schedule.RetainCount > 0 && i >= schedule.RetainCount {
|
||||
shouldDelete = true
|
||||
}
|
||||
|
||||
// 按保留天数清理
|
||||
if schedule.RetainDays > 0 && r.StartedAt != "" {
|
||||
startedAt, err := time.Parse(time.RFC3339, r.StartedAt)
|
||||
if err == nil && time.Since(startedAt) > time.Duration(schedule.RetainDays)*24*time.Hour {
|
||||
shouldDelete = true
|
||||
}
|
||||
}
|
||||
|
||||
if shouldDelete && r.Status == "completed" {
|
||||
toDelete = append(toDelete, r)
|
||||
} else {
|
||||
toKeep = append(toKeep, r)
|
||||
}
|
||||
}
|
||||
|
||||
// 删除 S3 上的文件
|
||||
for _, r := range toDelete {
|
||||
if r.S3Key != "" {
|
||||
_ = s.deleteS3Object(ctx, r.S3Key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toDelete) > 0 {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 自动清理了 %d 个过期备份", len(toDelete))
|
||||
return s.saveRecordsLocked(ctx, toKeep)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteS3Object(ctx context.Context, key string) error {
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil || s3Cfg == nil {
|
||||
return nil
|
||||
}
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return objectStore.Delete(ctx, key)
|
||||
}
|
||||
528
backend/internal/service/backup_service_test.go
Normal file
528
backend/internal/service/backup_service_test.go
Normal file
@@ -0,0 +1,528 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// ─── Mocks ───
|
||||
|
||||
type mockSettingRepo struct {
|
||||
mu sync.Mutex
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
func newMockSettingRepo() *mockSettingRepo {
|
||||
return &mockSettingRepo{data: make(map[string]string)}
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) Get(_ context.Context, key string) (*Setting, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
v, ok := m.data[key]
|
||||
if !ok {
|
||||
return nil, ErrSettingNotFound
|
||||
}
|
||||
return &Setting{Key: key, Value: v}, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
v, ok := m.data[key]
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) Set(_ context.Context, key, value string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.data[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
result := make(map[string]string)
|
||||
for _, k := range keys {
|
||||
if v, ok := m.data[k]; ok {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for k, v := range settings {
|
||||
m.data[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
result := make(map[string]string, len(m.data))
|
||||
for k, v := range m.data {
|
||||
result[k] = v
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) Delete(_ context.Context, key string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.data, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// plainEncryptor 仅做 base64-like 包装,用于测试
|
||||
type plainEncryptor struct{}
|
||||
|
||||
func (e *plainEncryptor) Encrypt(plaintext string) (string, error) {
|
||||
return "ENC:" + plaintext, nil
|
||||
}
|
||||
|
||||
func (e *plainEncryptor) Decrypt(ciphertext string) (string, error) {
|
||||
if strings.HasPrefix(ciphertext, "ENC:") {
|
||||
return strings.TrimPrefix(ciphertext, "ENC:"), nil
|
||||
}
|
||||
return ciphertext, fmt.Errorf("not encrypted")
|
||||
}
|
||||
|
||||
type mockDumper struct {
|
||||
dumpData []byte
|
||||
dumpErr error
|
||||
restored []byte
|
||||
restErr error
|
||||
}
|
||||
|
||||
func (m *mockDumper) Dump(_ context.Context) (io.ReadCloser, error) {
|
||||
if m.dumpErr != nil {
|
||||
return nil, m.dumpErr
|
||||
}
|
||||
return io.NopCloser(bytes.NewReader(m.dumpData)), nil
|
||||
}
|
||||
|
||||
func (m *mockDumper) Restore(_ context.Context, data io.Reader) error {
|
||||
if m.restErr != nil {
|
||||
return m.restErr
|
||||
}
|
||||
d, err := io.ReadAll(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.restored = d
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockObjectStore struct {
|
||||
objects map[string][]byte
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMockObjectStore() *mockObjectStore {
|
||||
return &mockObjectStore{objects: make(map[string][]byte)}
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) Upload(_ context.Context, key string, body io.Reader, _ string) (int64, error) {
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.objects[key] = data
|
||||
m.mu.Unlock()
|
||||
return int64(len(data)), nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) Download(_ context.Context, key string) (io.ReadCloser, error) {
|
||||
m.mu.Lock()
|
||||
data, ok := m.objects[key]
|
||||
m.mu.Unlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("not found: %s", key)
|
||||
}
|
||||
return io.NopCloser(bytes.NewReader(data)), nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) Delete(_ context.Context, key string) error {
|
||||
m.mu.Lock()
|
||||
delete(m.objects, key)
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) PresignURL(_ context.Context, key string, _ time.Duration) (string, error) {
|
||||
return "https://presigned.example.com/" + key, nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) HeadBucket(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "test",
|
||||
DBName: "testdb",
|
||||
},
|
||||
}
|
||||
factory := func(_ context.Context, _ *BackupS3Config) (BackupObjectStore, error) {
|
||||
return store, nil
|
||||
}
|
||||
return NewBackupService(repo, cfg, &plainEncryptor{}, factory, dumper)
|
||||
}
|
||||
|
||||
func seedS3Config(t *testing.T, repo *mockSettingRepo) {
|
||||
t.Helper()
|
||||
cfg := BackupS3Config{
|
||||
Bucket: "test-bucket",
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "ENC:secret123",
|
||||
Prefix: "backups",
|
||||
}
|
||||
data, _ := json.Marshal(cfg)
|
||||
require.NoError(t, repo.Set(context.Background(), settingKeyBackupS3Config, string(data)))
|
||||
}
|
||||
|
||||
// ─── Tests ───
|
||||
|
||||
func TestBackupService_S3ConfigEncryption(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
// 保存配置 -> SecretAccessKey 应被加密
|
||||
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||
Bucket: "my-bucket",
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "my-secret",
|
||||
Prefix: "backups",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 直接读取数据库中存储的值,应该是加密后的
|
||||
raw, _ := repo.GetValue(context.Background(), settingKeyBackupS3Config)
|
||||
var stored BackupS3Config
|
||||
require.NoError(t, json.Unmarshal([]byte(raw), &stored))
|
||||
require.Equal(t, "ENC:my-secret", stored.SecretAccessKey)
|
||||
|
||||
// 通过 GetS3Config 获取应该脱敏
|
||||
cfg, err := svc.GetS3Config(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, cfg.SecretAccessKey)
|
||||
require.Equal(t, "my-bucket", cfg.Bucket)
|
||||
|
||||
// loadS3Config 内部应解密
|
||||
internal, err := svc.loadS3Config(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "my-secret", internal.SecretAccessKey)
|
||||
}
|
||||
|
||||
func TestBackupService_S3ConfigKeepExistingSecret(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
// 先保存一个有 secret 的配置
|
||||
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||
Bucket: "my-bucket",
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "original-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 再更新时不提供 secret,应保留原值
|
||||
_, err = svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||
Bucket: "my-bucket",
|
||||
AccessKeyID: "AKID-NEW",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
internal, err := svc.loadS3Config(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "original-secret", internal.SecretAccessKey)
|
||||
require.Equal(t, "AKID-NEW", internal.AccessKeyID)
|
||||
}
|
||||
|
||||
func TestBackupService_SaveRecordConcurrency(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
n := 20
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
record := &BackupRecord{
|
||||
ID: fmt.Sprintf("rec-%d", idx),
|
||||
Status: "completed",
|
||||
StartedAt: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
_ = svc.saveRecord(context.Background(), record)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
records, err := svc.loadRecords(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, records, n)
|
||||
}
|
||||
|
||||
func TestBackupService_LoadRecords_Empty(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
records, err := svc.loadRecords(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, records) // 无数据时返回 nil
|
||||
}
|
||||
|
||||
func TestBackupService_LoadRecords_Corrupted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
_ = repo.Set(context.Background(), settingKeyBackupRecords, "not valid json{{{")
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
records, err := svc.loadRecords(context.Background())
|
||||
require.Error(t, err) // 损坏数据应返回错误
|
||||
require.Nil(t, records)
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_Streaming(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
|
||||
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "completed", record.Status)
|
||||
require.Greater(t, record.SizeBytes, int64(0))
|
||||
require.NotEmpty(t, record.S3Key)
|
||||
|
||||
// 验证 S3 上确实有文件
|
||||
store.mu.Lock()
|
||||
require.Len(t, store.objects, 1)
|
||||
store.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_DumpFailure(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumper := &mockDumper{dumpErr: fmt.Errorf("pg_dump failed")}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "failed", record.Status)
|
||||
require.Contains(t, record.ErrorMsg, "pg_dump")
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_NoS3Config(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.ErrorIs(t, err, ErrBackupS3NotConfigured)
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
// 使用一个慢速 dumper 来模拟正在进行的备份
|
||||
dumper := &mockDumper{dumpData: []byte("data")}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
// 手动设置 backingUp 标志
|
||||
svc.mu.Lock()
|
||||
svc.backingUp = true
|
||||
svc.mu.Unlock()
|
||||
|
||||
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.ErrorIs(t, err, ErrBackupInProgress)
|
||||
}
|
||||
|
||||
func TestBackupService_RestoreBackup_Streaming(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
|
||||
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
// 先创建一个备份
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 恢复
|
||||
err = svc.RestoreBackup(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证 psql 收到的数据是否与原始 dump 内容一致
|
||||
require.Equal(t, dumpContent, string(dumper.restored))
|
||||
}
|
||||
|
||||
func TestBackupService_RestoreBackup_NotCompleted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
// 手动插入一条 failed 记录
|
||||
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||
ID: "fail-1",
|
||||
Status: "failed",
|
||||
})
|
||||
|
||||
err := svc.RestoreBackup(context.Background(), "fail-1")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_DeleteBackup(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumpContent := "data"
|
||||
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
|
||||
// S3 中应有文件
|
||||
store.mu.Lock()
|
||||
require.Len(t, store.objects, 1)
|
||||
store.mu.Unlock()
|
||||
|
||||
// 删除
|
||||
err = svc.DeleteBackup(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// S3 中文件应被删除
|
||||
store.mu.Lock()
|
||||
require.Len(t, store.objects, 0)
|
||||
store.mu.Unlock()
|
||||
|
||||
// 记录应不存在
|
||||
_, err = svc.GetBackupRecord(context.Background(), record.ID)
|
||||
require.ErrorIs(t, err, ErrBackupNotFound)
|
||||
}
|
||||
|
||||
func TestBackupService_GetDownloadURL(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumper := &mockDumper{dumpData: []byte("data")}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
|
||||
url, err := svc.GetBackupDownloadURL(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, url, "https://presigned.example.com/")
|
||||
}
|
||||
|
||||
func TestBackupService_ListBackups_Sorted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
now := time.Now()
|
||||
for i := 0; i < 3; i++ {
|
||||
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||
ID: fmt.Sprintf("rec-%d", i),
|
||||
Status: "completed",
|
||||
StartedAt: now.Add(time.Duration(i) * time.Hour).Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
records, err := svc.ListBackups(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, records, 3)
|
||||
// 最新在前
|
||||
require.Equal(t, "rec-2", records[0].ID)
|
||||
require.Equal(t, "rec-0", records[2].ID)
|
||||
}
|
||||
|
||||
func TestBackupService_TestS3Connection(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, store)
|
||||
|
||||
err := svc.TestS3Connection(context.Background(), BackupS3Config{
|
||||
Bucket: "test",
|
||||
AccessKeyID: "ak",
|
||||
SecretAccessKey: "sk",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_TestS3Connection_Incomplete(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
err := svc.TestS3Connection(context.Background(), BackupS3Config{
|
||||
Bucket: "test",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "incomplete")
|
||||
}
|
||||
|
||||
func TestBackupService_Schedule_CronValidation(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
svc.cronSched = nil // 未初始化 cron
|
||||
|
||||
// 启用但 cron 为空
|
||||
_, err := svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
|
||||
Enabled: true,
|
||||
CronExpr: "",
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
// 无效的 cron 表达式
|
||||
_, err = svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
|
||||
Enabled: true,
|
||||
CronExpr: "invalid",
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_LoadS3Config_Corrupted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
_ = repo.Set(context.Background(), settingKeyBackupS3Config, "not json!!!!")
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
cfg, err := svc.loadS3Config(context.Background())
|
||||
require.Error(t, err)
|
||||
require.Nil(t, cfg)
|
||||
}
|
||||
607
backend/internal/service/bedrock_request.go
Normal file
607
backend/internal/service/bedrock_request.go
Normal file
@@ -0,0 +1,607 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const defaultBedrockRegion = "us-east-1"
|
||||
|
||||
var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."}
|
||||
|
||||
// BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀
|
||||
// 参考: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
|
||||
func BedrockCrossRegionPrefix(region string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(region, "us-gov"):
|
||||
return "us-gov" // GovCloud 使用独立的 us-gov 前缀
|
||||
case strings.HasPrefix(region, "us-"):
|
||||
return "us"
|
||||
case strings.HasPrefix(region, "eu-"):
|
||||
return "eu"
|
||||
case region == "ap-northeast-1":
|
||||
return "jp" // 日本区域使用独立的 jp 前缀(AWS 官方定义)
|
||||
case region == "ap-southeast-2":
|
||||
return "au" // 澳大利亚区域使用独立的 au 前缀(AWS 官方定义)
|
||||
case strings.HasPrefix(region, "ap-"):
|
||||
return "apac" // 其余亚太区域使用通用 apac 前缀
|
||||
case strings.HasPrefix(region, "ca-"):
|
||||
return "us" // 加拿大区域使用 us 前缀的跨区域推理
|
||||
case strings.HasPrefix(region, "sa-"):
|
||||
return "us" // 南美区域使用 us 前缀的跨区域推理
|
||||
default:
|
||||
return "us"
|
||||
}
|
||||
}
|
||||
|
||||
// AdjustBedrockModelRegionPrefix 将模型 ID 的区域前缀替换为与当前 AWS Region 匹配的前缀
|
||||
// 例如 region=eu-west-1 时,"us.anthropic.claude-opus-4-6-v1" → "eu.anthropic.claude-opus-4-6-v1"
|
||||
// 特殊值 region="global" 强制使用 global. 前缀
|
||||
func AdjustBedrockModelRegionPrefix(modelID, region string) string {
|
||||
var targetPrefix string
|
||||
if region == "global" {
|
||||
targetPrefix = "global"
|
||||
} else {
|
||||
targetPrefix = BedrockCrossRegionPrefix(region)
|
||||
}
|
||||
|
||||
for _, p := range bedrockCrossRegionPrefixes {
|
||||
if strings.HasPrefix(modelID, p) {
|
||||
if p == targetPrefix+"." {
|
||||
return modelID // 前缀已匹配,无需替换
|
||||
}
|
||||
return targetPrefix + "." + modelID[len(p):]
|
||||
}
|
||||
}
|
||||
|
||||
// 模型 ID 没有已知区域前缀(如 "anthropic.claude-..."),不做修改
|
||||
return modelID
|
||||
}
|
||||
|
||||
func bedrockRuntimeRegion(account *Account) string {
|
||||
if account == nil {
|
||||
return defaultBedrockRegion
|
||||
}
|
||||
if region := account.GetCredential("aws_region"); region != "" {
|
||||
return region
|
||||
}
|
||||
return defaultBedrockRegion
|
||||
}
|
||||
|
||||
func shouldForceBedrockGlobal(account *Account) bool {
|
||||
return account != nil && account.GetCredential("aws_force_global") == "true"
|
||||
}
|
||||
|
||||
func isRegionalBedrockModelID(modelID string) bool {
|
||||
for _, prefix := range bedrockCrossRegionPrefixes {
|
||||
if strings.HasPrefix(modelID, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isLikelyBedrockModelID(modelID string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(modelID))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(lower, "arn:") {
|
||||
return true
|
||||
}
|
||||
for _, prefix := range []string{
|
||||
"anthropic.",
|
||||
"amazon.",
|
||||
"meta.",
|
||||
"mistral.",
|
||||
"cohere.",
|
||||
"ai21.",
|
||||
"deepseek.",
|
||||
"stability.",
|
||||
"writer.",
|
||||
"nova.",
|
||||
} {
|
||||
if strings.HasPrefix(lower, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return isRegionalBedrockModelID(lower)
|
||||
}
|
||||
|
||||
func normalizeBedrockModelID(modelID string) (normalized string, shouldAdjustRegion bool, ok bool) {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
return "", false, false
|
||||
}
|
||||
if mapped, exists := domain.DefaultBedrockModelMapping[modelID]; exists {
|
||||
return mapped, true, true
|
||||
}
|
||||
if isRegionalBedrockModelID(modelID) {
|
||||
return modelID, true, true
|
||||
}
|
||||
if isLikelyBedrockModelID(modelID) {
|
||||
return modelID, false, true
|
||||
}
|
||||
return "", false, false
|
||||
}
|
||||
|
||||
// ResolveBedrockModelID resolves a requested Claude model into a Bedrock model ID.
|
||||
// It applies account model_mapping first, then default Bedrock aliases, and finally
|
||||
// adjusts Anthropic cross-region prefixes to match the account region.
|
||||
func ResolveBedrockModelID(account *Account, requestedModel string) (string, bool) {
|
||||
if account == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
mappedModel := account.GetMappedModel(requestedModel)
|
||||
modelID, shouldAdjustRegion, ok := normalizeBedrockModelID(mappedModel)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if shouldAdjustRegion {
|
||||
targetRegion := bedrockRuntimeRegion(account)
|
||||
if shouldForceBedrockGlobal(account) {
|
||||
targetRegion = "global"
|
||||
}
|
||||
modelID = AdjustBedrockModelRegionPrefix(modelID, targetRegion)
|
||||
}
|
||||
return modelID, true
|
||||
}
|
||||
|
||||
// BuildBedrockURL 构建 Bedrock InvokeModel 的 URL
|
||||
// stream=true 时使用 invoke-with-response-stream 端点
|
||||
// modelID 中的特殊字符会被 URL 编码(与 litellm 的 urllib.parse.quote(safe="") 对齐)
|
||||
func BuildBedrockURL(region, modelID string, stream bool) string {
|
||||
if region == "" {
|
||||
region = defaultBedrockRegion
|
||||
}
|
||||
encodedModelID := url.PathEscape(modelID)
|
||||
// url.PathEscape 不编码冒号(RFC 允许 path 中出现 ":"),
|
||||
// 但 AWS Bedrock 期望模型 ID 中的冒号被编码为 %3A
|
||||
encodedModelID = strings.ReplaceAll(encodedModelID, ":", "%3A")
|
||||
if stream {
|
||||
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", region, encodedModelID)
|
||||
}
|
||||
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke", region, encodedModelID)
|
||||
}
|
||||
|
||||
// PrepareBedrockRequestBody 处理请求体以适配 Bedrock API
|
||||
// 1. 注入 anthropic_version
|
||||
// 2. 注入 anthropic_beta(从客户端 anthropic-beta 头解析)
|
||||
// 3. 移除 Bedrock 不支持的字段(model, stream, output_format, output_config)
|
||||
// 4. 移除工具定义中的 custom 字段(Claude Code 会发送 custom: {defer_loading: true})
|
||||
// 5. 清理 cache_control 中 Bedrock 不支持的字段(scope, ttl)
|
||||
func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) {
|
||||
betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID)
|
||||
return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens)
|
||||
}
|
||||
|
||||
// PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens.
|
||||
func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) {
|
||||
var err error
|
||||
|
||||
// 注入 anthropic_version(Bedrock 要求)
|
||||
body, err = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("inject anthropic_version: %w", err)
|
||||
}
|
||||
|
||||
// 注入 anthropic_beta(Bedrock Invoke 通过请求体传递 beta 头,而非 HTTP 头)
|
||||
// 1. 从客户端 anthropic-beta header 解析
|
||||
// 2. 根据请求体内容自动补齐必要的 beta token
|
||||
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() + _get_tool_search_beta_header_for_bedrock()
|
||||
if len(betaTokens) > 0 {
|
||||
body, err = sjson.SetBytes(body, "anthropic_beta", betaTokens)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("inject anthropic_beta: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 移除 model 字段(Bedrock 通过 URL 指定模型)
|
||||
body, err = sjson.DeleteBytes(body, "model")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("remove model field: %w", err)
|
||||
}
|
||||
|
||||
// 移除 stream 字段(Bedrock 通过不同端点控制流式,不接受请求体中的 stream 字段)
|
||||
body, err = sjson.DeleteBytes(body, "stream")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("remove stream field: %w", err)
|
||||
}
|
||||
|
||||
// 转换 output_format(Bedrock Invoke 不支持此字段,但可将 schema 内联到最后一条 user message)
|
||||
// 参考 litellm: _convert_output_format_to_inline_schema()
|
||||
body = convertOutputFormatToInlineSchema(body)
|
||||
|
||||
// 移除 output_config 字段(Bedrock Invoke 不支持)
|
||||
body, err = sjson.DeleteBytes(body, "output_config")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("remove output_config field: %w", err)
|
||||
}
|
||||
|
||||
// 移除工具定义中的 custom 字段
|
||||
// Claude Code (v2.1.69+) 在 tool 定义中发送 custom: {defer_loading: true},
|
||||
// Anthropic API 接受但 Bedrock 会拒绝并报 "Extra inputs are not permitted"
|
||||
body = removeCustomFieldFromTools(body)
|
||||
|
||||
// 清理 cache_control 中 Bedrock 不支持的字段
|
||||
body = sanitizeBedrockCacheControl(body, modelID)
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// ResolveBedrockBetaTokens computes the final Bedrock beta token list before policy filtering.
|
||||
func ResolveBedrockBetaTokens(betaHeader string, body []byte, modelID string) []string {
|
||||
betaTokens := parseAnthropicBetaHeader(betaHeader)
|
||||
betaTokens = autoInjectBedrockBetaTokens(betaTokens, body, modelID)
|
||||
return filterBedrockBetaTokens(betaTokens)
|
||||
}
|
||||
|
||||
// convertOutputFormatToInlineSchema 将 output_format 中的 JSON schema 内联到最后一条 user message
|
||||
// Bedrock Invoke 不支持 output_format 参数,litellm 的做法是将 schema 追加到用户消息中
|
||||
// 参考: litellm AmazonAnthropicClaudeMessagesConfig._convert_output_format_to_inline_schema()
|
||||
func convertOutputFormatToInlineSchema(body []byte) []byte {
|
||||
outputFormat := gjson.GetBytes(body, "output_format")
|
||||
if !outputFormat.Exists() || !outputFormat.IsObject() {
|
||||
return body
|
||||
}
|
||||
|
||||
// 先从请求体中移除 output_format
|
||||
body, _ = sjson.DeleteBytes(body, "output_format")
|
||||
|
||||
schema := outputFormat.Get("schema")
|
||||
if !schema.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
// 找到最后一条 user message
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
msgArr := messages.Array()
|
||||
lastUserIdx := -1
|
||||
for i := len(msgArr) - 1; i >= 0; i-- {
|
||||
if msgArr[i].Get("role").String() == "user" {
|
||||
lastUserIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if lastUserIdx < 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
// 将 schema 序列化为 JSON 文本追加到该 message 的 content 数组
|
||||
schemaJSON, err := json.Marshal(json.RawMessage(schema.Raw))
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
content := msgArr[lastUserIdx].Get("content")
|
||||
basePath := fmt.Sprintf("messages.%d.content", lastUserIdx)
|
||||
|
||||
if content.IsArray() {
|
||||
// 追加一个 text block 到 content 数组末尾
|
||||
idx := len(content.Array())
|
||||
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.type", basePath, idx), "text")
|
||||
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.text", basePath, idx), string(schemaJSON))
|
||||
} else if content.Type == gjson.String {
|
||||
// content 是纯字符串,转换为数组格式
|
||||
originalText := content.String()
|
||||
body, _ = sjson.SetBytes(body, basePath, []map[string]string{
|
||||
{"type": "text", "text": originalText},
|
||||
{"type": "text", "text": string(schemaJSON)},
|
||||
})
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// removeCustomFieldFromTools 移除 tools 数组中每个工具定义的 custom 字段
|
||||
func removeCustomFieldFromTools(body []byte) []byte {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return body
|
||||
}
|
||||
var err error
|
||||
for i := range tools.Array() {
|
||||
body, err = sjson.DeleteBytes(body, fmt.Sprintf("tools.%d.custom", i))
|
||||
if err != nil {
|
||||
// 删除失败不影响整体流程,跳过
|
||||
continue
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// claudeVersionRe 匹配 Claude 模型 ID 中的版本号部分
|
||||
// 支持 claude-{tier}-{major}-{minor} 和 claude-{tier}-{major}.{minor} 格式
|
||||
var claudeVersionRe = regexp.MustCompile(`claude-(?:haiku|sonnet|opus)-(\d+)[-.](\d+)`)
|
||||
|
||||
// isBedrockClaude45OrNewer 判断 Bedrock 模型 ID 是否为 Claude 4.5 或更新版本
|
||||
// Claude 4.5+ 支持 cache_control 中的 ttl 字段("5m" 和 "1h")
|
||||
func isBedrockClaude45OrNewer(modelID string) bool {
|
||||
lower := strings.ToLower(modelID)
|
||||
matches := claudeVersionRe.FindStringSubmatch(lower)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
major, _ := strconv.Atoi(matches[1])
|
||||
minor, _ := strconv.Atoi(matches[2])
|
||||
return major > 4 || (major == 4 && minor >= 5)
|
||||
}
|
||||
|
||||
// sanitizeBedrockCacheControl 清理 system 和 messages 中 cache_control 里
|
||||
// Bedrock 不支持的字段:
|
||||
// - scope:Bedrock 不支持(如 "global" 跨请求缓存)
|
||||
// - ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",旧模型需要移除
|
||||
func sanitizeBedrockCacheControl(body []byte, modelID string) []byte {
|
||||
isClaude45 := isBedrockClaude45OrNewer(modelID)
|
||||
|
||||
// 清理 system 数组中的 cache_control
|
||||
systemArr := gjson.GetBytes(body, "system")
|
||||
if systemArr.Exists() && systemArr.IsArray() {
|
||||
for i, item := range systemArr.Array() {
|
||||
if !item.IsObject() {
|
||||
continue
|
||||
}
|
||||
cc := item.Get("cache_control")
|
||||
if !cc.Exists() || !cc.IsObject() {
|
||||
continue
|
||||
}
|
||||
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("system.%d.cache_control", i), cc, isClaude45)
|
||||
}
|
||||
}
|
||||
|
||||
// 清理 messages 中的 cache_control
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
for mi, msg := range messages.Array() {
|
||||
if !msg.IsObject() {
|
||||
continue
|
||||
}
|
||||
content := msg.Get("content")
|
||||
if !content.Exists() || !content.IsArray() {
|
||||
continue
|
||||
}
|
||||
for ci, block := range content.Array() {
|
||||
if !block.IsObject() {
|
||||
continue
|
||||
}
|
||||
cc := block.Get("cache_control")
|
||||
if !cc.Exists() || !cc.IsObject() {
|
||||
continue
|
||||
}
|
||||
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("messages.%d.content.%d.cache_control", mi, ci), cc, isClaude45)
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// deleteCacheControlUnsupportedFields 删除给定 cache_control 路径下 Bedrock 不支持的字段
|
||||
func deleteCacheControlUnsupportedFields(body []byte, basePath string, cc gjson.Result, isClaude45 bool) []byte {
|
||||
// Bedrock 不支持 scope(如 "global")
|
||||
if cc.Get("scope").Exists() {
|
||||
body, _ = sjson.DeleteBytes(body, basePath+".scope")
|
||||
}
|
||||
|
||||
// ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",其余情况移除
|
||||
ttl := cc.Get("ttl")
|
||||
if ttl.Exists() {
|
||||
shouldRemove := true
|
||||
if isClaude45 {
|
||||
v := ttl.String()
|
||||
if v == "5m" || v == "1h" {
|
||||
shouldRemove = false
|
||||
}
|
||||
}
|
||||
if shouldRemove {
|
||||
body, _ = sjson.DeleteBytes(body, basePath+".ttl")
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// parseAnthropicBetaHeader 解析 anthropic-beta 头的逗号分隔字符串为 token 列表
|
||||
func parseAnthropicBetaHeader(header string) []string {
|
||||
header = strings.TrimSpace(header)
|
||||
if header == "" {
|
||||
return nil
|
||||
}
|
||||
if strings.HasPrefix(header, "[") && strings.HasSuffix(header, "]") {
|
||||
var parsed []any
|
||||
if err := json.Unmarshal([]byte(header), &parsed); err == nil {
|
||||
tokens := make([]string, 0, len(parsed))
|
||||
for _, item := range parsed {
|
||||
token := strings.TrimSpace(fmt.Sprint(item))
|
||||
if token != "" {
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
}
|
||||
var tokens []string
|
||||
for _, part := range strings.Split(header, ",") {
|
||||
t := strings.TrimSpace(part)
|
||||
if t != "" {
|
||||
tokens = append(tokens, t)
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单
|
||||
// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json)
|
||||
// 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单
|
||||
var bedrockSupportedBetaTokens = map[string]bool{
|
||||
"computer-use-2025-01-24": true,
|
||||
"computer-use-2025-11-24": true,
|
||||
"context-1m-2025-08-07": true,
|
||||
"context-management-2025-06-27": true,
|
||||
"compact-2026-01-12": true,
|
||||
"interleaved-thinking-2025-05-14": true,
|
||||
"tool-search-tool-2025-10-19": true,
|
||||
"tool-examples-2025-10-29": true,
|
||||
}
|
||||
|
||||
// bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则
|
||||
// Anthropic 直接 API 使用通用头,Bedrock Invoke 需要特定的替代头
|
||||
var bedrockBetaTokenTransforms = map[string]string{
|
||||
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
|
||||
}
|
||||
|
||||
// autoInjectBedrockBetaTokens 根据请求体内容自动补齐必要的 beta token
|
||||
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() 和
|
||||
// AmazonAnthropicClaudeMessagesConfig._get_tool_search_beta_header_for_bedrock()
|
||||
//
|
||||
// 客户端(特别是非 Claude Code 客户端)可能只在 body 中启用了功能而不在 header 中带对应 beta token,
|
||||
// 这里通过检测请求体特征自动补齐,确保 Bedrock Invoke 不会因缺少必要 beta 头而 400。
|
||||
func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) []string {
|
||||
seen := make(map[string]bool, len(tokens))
|
||||
for _, t := range tokens {
|
||||
seen[t] = true
|
||||
}
|
||||
|
||||
inject := func(token string) {
|
||||
if !seen[token] {
|
||||
tokens = append(tokens, token)
|
||||
seen[token] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 检测 thinking / interleaved thinking
|
||||
// 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta
|
||||
if gjson.GetBytes(body, "thinking").Exists() {
|
||||
inject("interleaved-thinking-2025-05-14")
|
||||
}
|
||||
|
||||
// 检测 computer_use 工具
|
||||
// tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.Exists() && tools.IsArray() {
|
||||
toolSearchUsed := false
|
||||
programmaticToolCallingUsed := false
|
||||
inputExamplesUsed := false
|
||||
for _, tool := range tools.Array() {
|
||||
toolType := tool.Get("type").String()
|
||||
if strings.HasPrefix(toolType, "computer_20") {
|
||||
inject("computer-use-2025-11-24")
|
||||
}
|
||||
if isBedrockToolSearchType(toolType) {
|
||||
toolSearchUsed = true
|
||||
}
|
||||
if hasCodeExecutionAllowedCallers(tool) {
|
||||
programmaticToolCallingUsed = true
|
||||
}
|
||||
if hasInputExamples(tool) {
|
||||
inputExamplesUsed = true
|
||||
}
|
||||
}
|
||||
if programmaticToolCallingUsed || inputExamplesUsed {
|
||||
// programmatic tool calling 和 input examples 需要 advanced-tool-use,
|
||||
// 后续 filterBedrockBetaTokens 会将其转换为 Bedrock 特定的 tool-search-tool
|
||||
inject("advanced-tool-use-2025-11-20")
|
||||
}
|
||||
if toolSearchUsed && bedrockModelSupportsToolSearch(modelID) {
|
||||
// 纯 tool search(无 programmatic/inputExamples)时直接注入 Bedrock 特定头,
|
||||
// 跳过 advanced-tool-use → tool-search-tool 的转换步骤(与 litellm 对齐)
|
||||
if !programmaticToolCallingUsed && !inputExamplesUsed {
|
||||
inject("tool-search-tool-2025-10-19")
|
||||
} else {
|
||||
inject("advanced-tool-use-2025-11-20")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
func isBedrockToolSearchType(toolType string) bool {
|
||||
return toolType == "tool_search_tool_regex_20251119" || toolType == "tool_search_tool_bm25_20251119"
|
||||
}
|
||||
|
||||
func hasCodeExecutionAllowedCallers(tool gjson.Result) bool {
|
||||
allowedCallers := tool.Get("allowed_callers")
|
||||
if containsStringInJSONArray(allowedCallers, "code_execution_20250825") {
|
||||
return true
|
||||
}
|
||||
return containsStringInJSONArray(tool.Get("function.allowed_callers"), "code_execution_20250825")
|
||||
}
|
||||
|
||||
func hasInputExamples(tool gjson.Result) bool {
|
||||
if arr := tool.Get("input_examples"); arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 {
|
||||
return true
|
||||
}
|
||||
arr := tool.Get("function.input_examples")
|
||||
return arr.Exists() && arr.IsArray() && len(arr.Array()) > 0
|
||||
}
|
||||
|
||||
func containsStringInJSONArray(result gjson.Result, target string) bool {
|
||||
if !result.Exists() || !result.IsArray() {
|
||||
return false
|
||||
}
|
||||
for _, item := range result.Array() {
|
||||
if item.String() == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// bedrockModelSupportsToolSearch 判断 Bedrock 模型是否支持 tool search
|
||||
// 目前仅 Claude Opus/Sonnet 4.5+ 支持,Haiku 不支持
|
||||
func bedrockModelSupportsToolSearch(modelID string) bool {
|
||||
lower := strings.ToLower(modelID)
|
||||
matches := claudeVersionRe.FindStringSubmatch(lower)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
// Haiku 不支持 tool search
|
||||
if strings.Contains(lower, "haiku") {
|
||||
return false
|
||||
}
|
||||
major, _ := strconv.Atoi(matches[1])
|
||||
minor, _ := strconv.Atoi(matches[2])
|
||||
return major > 4 || (major == 4 && minor >= 5)
|
||||
}
|
||||
|
||||
// filterBedrockBetaTokens 过滤并转换 beta token 列表,仅保留 Bedrock Invoke 支持的 token
|
||||
// 1. 应用转换规则(如 advanced-tool-use → tool-search-tool)
|
||||
// 2. 过滤掉 Bedrock 不支持的 token(如 output-128k, files-api, structured-outputs 等)
|
||||
// 3. 自动关联 tool-examples(当 tool-search-tool 存在时)
|
||||
func filterBedrockBetaTokens(tokens []string) []string {
|
||||
seen := make(map[string]bool, len(tokens))
|
||||
var result []string
|
||||
|
||||
for _, t := range tokens {
|
||||
// 应用转换规则
|
||||
if replacement, ok := bedrockBetaTokenTransforms[t]; ok {
|
||||
t = replacement
|
||||
}
|
||||
// 只保留白名单中的 token,且去重
|
||||
if bedrockSupportedBetaTokens[t] && !seen[t] {
|
||||
result = append(result, t)
|
||||
seen[t] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 自动关联: tool-search-tool 存在时,确保 tool-examples 也存在
|
||||
if seen["tool-search-tool-2025-10-19"] && !seen["tool-examples-2025-10-29"] {
|
||||
result = append(result, "tool-examples-2025-10-29")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
659
backend/internal/service/bedrock_request_test.go
Normal file
659
backend/internal/service/bedrock_request_test.go
Normal file
@@ -0,0 +1,659 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestPrepareBedrockRequestBody_BasicFields(t *testing.T) {
|
||||
input := `{"model":"claude-opus-4-6","stream":true,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// anthropic_version 应被注入
|
||||
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
|
||||
// model 和 stream 应被移除
|
||||
assert.False(t, gjson.GetBytes(result, "model").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "stream").Exists())
|
||||
// max_tokens 应保留
|
||||
assert.Equal(t, int64(1024), gjson.GetBytes(result, "max_tokens").Int())
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_OutputFormatInlineSchema(t *testing.T) {
|
||||
t.Run("schema inlined into last user message array content", func(t *testing.T) {
|
||||
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}},"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
// schema 应内联到最后一条 user message 的 content 数组末尾
|
||||
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
|
||||
require.Len(t, contentArr, 2)
|
||||
assert.Equal(t, "text", contentArr[1].Get("type").String())
|
||||
assert.Contains(t, contentArr[1].Get("text").String(), `"name":"string"`)
|
||||
})
|
||||
|
||||
t.Run("schema inlined into string content", func(t *testing.T) {
|
||||
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"result":"number"}},"messages":[{"role":"user","content":"compute this"}]}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
|
||||
require.Len(t, contentArr, 2)
|
||||
assert.Equal(t, "compute this", contentArr[0].Get("text").String())
|
||||
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"number"`)
|
||||
})
|
||||
|
||||
t.Run("no schema field just removes output_format", func(t *testing.T) {
|
||||
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json"},"messages":[{"role":"user","content":"hi"}]}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
})
|
||||
|
||||
t.Run("no messages just removes output_format", func(t *testing.T) {
|
||||
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}}}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_RemoveOutputConfig(t *testing.T) {
|
||||
input := `{"model":"claude-sonnet-4-5","output_config":{"max_tokens":100},"messages":[]}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
|
||||
}
|
||||
|
||||
func TestRemoveCustomFieldFromTools(t *testing.T) {
|
||||
input := `{
|
||||
"tools": [
|
||||
{"name":"tool1","custom":{"defer_loading":true},"description":"desc1"},
|
||||
{"name":"tool2","description":"desc2"},
|
||||
{"name":"tool3","custom":{"defer_loading":true,"other":123},"description":"desc3"}
|
||||
]
|
||||
}`
|
||||
result := removeCustomFieldFromTools([]byte(input))
|
||||
|
||||
tools := gjson.GetBytes(result, "tools").Array()
|
||||
require.Len(t, tools, 3)
|
||||
// custom 应被移除
|
||||
assert.False(t, tools[0].Get("custom").Exists())
|
||||
// name/description 应保留
|
||||
assert.Equal(t, "tool1", tools[0].Get("name").String())
|
||||
assert.Equal(t, "desc1", tools[0].Get("description").String())
|
||||
// 没有 custom 的工具不受影响
|
||||
assert.Equal(t, "tool2", tools[1].Get("name").String())
|
||||
// 第三个工具的 custom 也应被移除
|
||||
assert.False(t, tools[2].Get("custom").Exists())
|
||||
assert.Equal(t, "tool3", tools[2].Get("name").String())
|
||||
}
|
||||
|
||||
func TestRemoveCustomFieldFromTools_NoTools(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}]}`
|
||||
result := removeCustomFieldFromTools([]byte(input))
|
||||
// 无 tools 时不改变原始数据
|
||||
assert.JSONEq(t, input, string(result))
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_RemoveScope(t *testing.T) {
|
||||
input := `{
|
||||
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","scope":"global"}}],
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"global"}}]}]
|
||||
}`
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||
|
||||
// scope 应被移除
|
||||
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.scope").Exists())
|
||||
// type 应保留
|
||||
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
|
||||
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "messages.0.content.0.cache_control.type").String())
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_TTL_OldModel(t *testing.T) {
|
||||
input := `{
|
||||
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
|
||||
}`
|
||||
// 旧模型(Claude 3.5)不支持 ttl
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
|
||||
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_TTL_Claude45_Supported(t *testing.T) {
|
||||
input := `{
|
||||
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
|
||||
}`
|
||||
// Claude 4.5+ 支持 "5m" 和 "1h"
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
|
||||
|
||||
assert.True(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
|
||||
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_TTL_Claude45_UnsupportedValue(t *testing.T) {
|
||||
input := `{
|
||||
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"10m"}}]
|
||||
}`
|
||||
// Claude 4.5 不支持 "10m"
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_TTL_Claude46(t *testing.T) {
|
||||
input := `{
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
|
||||
}`
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||
|
||||
assert.True(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").Exists())
|
||||
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_NoCacheControl(t *testing.T) {
|
||||
input := `{"system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||
// 无 cache_control 时不改变原始数据
|
||||
assert.JSONEq(t, input, string(result))
|
||||
}
|
||||
|
||||
func TestIsBedrockClaude45OrNewer(t *testing.T) {
|
||||
tests := []struct {
|
||||
modelID string
|
||||
expect bool
|
||||
}{
|
||||
{"us.anthropic.claude-opus-4-6-v1", true},
|
||||
{"us.anthropic.claude-sonnet-4-6", true},
|
||||
{"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true},
|
||||
{"us.anthropic.claude-opus-4-5-20251101-v1:0", true},
|
||||
{"us.anthropic.claude-haiku-4-5-20251001-v1:0", true},
|
||||
{"anthropic.claude-3-5-sonnet-20241022-v2:0", false},
|
||||
{"anthropic.claude-3-opus-20240229-v1:0", false},
|
||||
{"anthropic.claude-3-haiku-20240307-v1:0", false},
|
||||
// 未来版本应自动支持
|
||||
{"us.anthropic.claude-sonnet-5-0-v1", true},
|
||||
{"us.anthropic.claude-opus-4-7-v1", true},
|
||||
// 旧版本
|
||||
{"anthropic.claude-opus-4-1-v1", false},
|
||||
{"anthropic.claude-sonnet-4-0-v1", false},
|
||||
// 非 Claude 模型
|
||||
{"amazon.nova-pro-v1", false},
|
||||
{"meta.llama3-70b", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.modelID, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expect, isBedrockClaude45OrNewer(tt.modelID))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) {
|
||||
// 模拟一个完整的 Claude Code 请求
|
||||
input := `{
|
||||
"model": "claude-opus-4-6",
|
||||
"stream": true,
|
||||
"max_tokens": 16384,
|
||||
"output_format": {"type": "json", "schema": {"result": "string"}},
|
||||
"output_config": {"max_tokens": 100},
|
||||
"system": [{"type": "text", "text": "You are helpful", "cache_control": {"type": "ephemeral", "scope": "global", "ttl": "5m"}}],
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "hello", "cache_control": {"type": "ephemeral", "ttl": "1h"}}]}
|
||||
],
|
||||
"tools": [
|
||||
{"name": "bash", "description": "Run bash", "custom": {"defer_loading": true}, "input_schema": {"type": "object"}},
|
||||
{"name": "read", "description": "Read file", "input_schema": {"type": "object"}}
|
||||
]
|
||||
}`
|
||||
|
||||
betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12"
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 基本字段
|
||||
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
|
||||
assert.False(t, gjson.GetBytes(result, "model").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "stream").Exists())
|
||||
assert.Equal(t, int64(16384), gjson.GetBytes(result, "max_tokens").Int())
|
||||
|
||||
// anthropic_beta 应包含所有 beta tokens
|
||||
betaArr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, betaArr, 3)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String())
|
||||
assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String())
|
||||
assert.Equal(t, "compact-2026-01-12", betaArr[2].String())
|
||||
|
||||
// output_format 应被移除,schema 内联到最后一条 user message
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
|
||||
// content 数组:原始 text block + 内联 schema block
|
||||
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
|
||||
require.Len(t, contentArr, 2)
|
||||
assert.Equal(t, "hello", contentArr[0].Get("text").String())
|
||||
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"string"`)
|
||||
|
||||
// tools 中的 custom 应被移除
|
||||
assert.False(t, gjson.GetBytes(result, "tools.0.custom").Exists())
|
||||
assert.Equal(t, "bash", gjson.GetBytes(result, "tools.0.name").String())
|
||||
assert.Equal(t, "read", gjson.GetBytes(result, "tools.1.name").String())
|
||||
|
||||
// cache_control: scope 应被移除,ttl 在 Claude 4.6 上保留合法值
|
||||
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
|
||||
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
|
||||
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
|
||||
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
|
||||
|
||||
t.Run("empty beta header", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
|
||||
})
|
||||
|
||||
t.Run("single beta token", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 1)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
})
|
||||
|
||||
t.Run("multiple beta tokens with spaces", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 2)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
|
||||
})
|
||||
|
||||
t.Run("json array beta header", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`)
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 2)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseAnthropicBetaHeader(t *testing.T) {
|
||||
assert.Nil(t, parseAnthropicBetaHeader(""))
|
||||
assert.Equal(t, []string{"a"}, parseAnthropicBetaHeader("a"))
|
||||
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a,b"))
|
||||
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a , b "))
|
||||
assert.Equal(t, []string{"a", "b", "c"}, parseAnthropicBetaHeader("a,b,c"))
|
||||
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader(`["a","b"]`))
|
||||
}
|
||||
|
||||
func TestFilterBedrockBetaTokens(t *testing.T) {
|
||||
t.Run("supported tokens pass through", func(t *testing.T) {
|
||||
tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Equal(t, tokens, result)
|
||||
})
|
||||
|
||||
t.Run("unsupported tokens are filtered out", func(t *testing.T) {
|
||||
tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
|
||||
})
|
||||
|
||||
t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) {
|
||||
tokens := []string{"advanced-tool-use-2025-11-20"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||
// tool-examples 自动关联
|
||||
assert.Contains(t, result, "tool-examples-2025-10-29")
|
||||
})
|
||||
|
||||
t.Run("tool-search-tool auto-associates tool-examples", func(t *testing.T) {
|
||||
tokens := []string{"tool-search-tool-2025-10-19"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||
assert.Contains(t, result, "tool-examples-2025-10-29")
|
||||
})
|
||||
|
||||
t.Run("no duplication when tool-examples already present", func(t *testing.T) {
|
||||
tokens := []string{"tool-search-tool-2025-10-19", "tool-examples-2025-10-29"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
count := 0
|
||||
for _, t := range result {
|
||||
if t == "tool-examples-2025-10-29" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
|
||||
t.Run("empty input returns nil", func(t *testing.T) {
|
||||
result := filterBedrockBetaTokens(nil)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("all unsupported returns nil", func(t *testing.T) {
|
||||
result := filterBedrockBetaTokens([]string{"output-128k-2025-02-19", "effort-2025-11-24"})
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("duplicate tokens are deduplicated", func(t *testing.T) {
|
||||
tokens := []string{"context-1m-2025-08-07", "context-1m-2025-08-07"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
|
||||
|
||||
t.Run("unsupported beta tokens are filtered", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
|
||||
"interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 1)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
})
|
||||
|
||||
t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
|
||||
"advanced-tool-use-2025-11-20")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 2)
|
||||
assert.Equal(t, "tool-search-tool-2025-10-19", arr[0].String())
|
||||
assert.Equal(t, "tool-examples-2025-10-29", arr[1].String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestBedrockCrossRegionPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
region string
|
||||
expect string
|
||||
}{
|
||||
// US regions
|
||||
{"us-east-1", "us"},
|
||||
{"us-east-2", "us"},
|
||||
{"us-west-1", "us"},
|
||||
{"us-west-2", "us"},
|
||||
// GovCloud
|
||||
{"us-gov-east-1", "us-gov"},
|
||||
{"us-gov-west-1", "us-gov"},
|
||||
// EU regions
|
||||
{"eu-west-1", "eu"},
|
||||
{"eu-west-2", "eu"},
|
||||
{"eu-west-3", "eu"},
|
||||
{"eu-central-1", "eu"},
|
||||
{"eu-central-2", "eu"},
|
||||
{"eu-north-1", "eu"},
|
||||
{"eu-south-1", "eu"},
|
||||
// APAC regions
|
||||
{"ap-northeast-1", "jp"},
|
||||
{"ap-northeast-2", "apac"},
|
||||
{"ap-southeast-1", "apac"},
|
||||
{"ap-southeast-2", "au"},
|
||||
{"ap-south-1", "apac"},
|
||||
// Canada / South America fallback to us
|
||||
{"ca-central-1", "us"},
|
||||
{"sa-east-1", "us"},
|
||||
// Unknown defaults to us
|
||||
{"me-south-1", "us"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.region, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expect, BedrockCrossRegionPrefix(tt.region))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveBedrockModelID(t *testing.T) {
|
||||
t.Run("default alias resolves and adjusts region", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "eu-west-1",
|
||||
},
|
||||
}
|
||||
|
||||
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-5")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", modelID)
|
||||
})
|
||||
|
||||
t.Run("custom alias mapping reuses default bedrock mapping", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "ap-southeast-2",
|
||||
"model_mapping": map[string]any{
|
||||
"claude-*": "claude-opus-4-6",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
modelID, ok := ResolveBedrockModelID(account, "claude-opus-4-6-thinking")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID)
|
||||
})
|
||||
|
||||
t.Run("force global rewrites anthropic regional model id", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "us-east-1",
|
||||
"aws_force_global": "true",
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-6")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "global.anthropic.claude-sonnet-4-6", modelID)
|
||||
})
|
||||
|
||||
t.Run("direct bedrock model id passes through", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "us-east-1",
|
||||
},
|
||||
}
|
||||
|
||||
modelID, ok := ResolveBedrockModelID(account, "anthropic.claude-haiku-4-5-20251001-v1:0")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "anthropic.claude-haiku-4-5-20251001-v1:0", modelID)
|
||||
})
|
||||
|
||||
t.Run("unsupported alias returns false", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "us-east-1",
|
||||
},
|
||||
}
|
||||
|
||||
_, ok := ResolveBedrockModelID(account, "claude-3-5-sonnet-20241022")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAutoInjectBedrockBetaTokens(t *testing.T) {
|
||||
t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) {
|
||||
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
|
||||
})
|
||||
|
||||
t.Run("no duplicate when already present", func(t *testing.T) {
|
||||
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
count := 0
|
||||
for _, t := range result {
|
||||
if t == "interleaved-thinking-2025-05-14" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
|
||||
t.Run("inject computer-use when computer tool present", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"type":"computer_20250124","name":"computer","display_width_px":1024}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "computer-use-2025-11-24")
|
||||
})
|
||||
|
||||
t.Run("inject advanced-tool-use for programmatic tool calling", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
|
||||
})
|
||||
|
||||
t.Run("inject advanced-tool-use for input examples", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"bash","input_examples":[{"cmd":"ls"}]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
|
||||
})
|
||||
|
||||
t.Run("inject tool-search-tool directly for pure tool search (no programmatic/inputExamples)", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
|
||||
// 纯 tool search 场景直接注入 Bedrock 特定头,不走 advanced-tool-use 转换
|
||||
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
|
||||
})
|
||||
|
||||
t.Run("inject advanced-tool-use when tool search combined with programmatic calling", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"},{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
|
||||
// 混合场景使用 advanced-tool-use(后续由 filter 转换为 tool-search-tool)
|
||||
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
|
||||
})
|
||||
|
||||
t.Run("do not inject tool-search beta for unsupported models", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
|
||||
assert.NotContains(t, result, "tool-search-tool-2025-10-19")
|
||||
})
|
||||
|
||||
t.Run("no injection for regular tools", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"bash","description":"run bash","input_schema":{"type":"object"}}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("no injection when no features detected", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("preserves existing tokens", func(t *testing.T) {
|
||||
body := []byte(`{"thinking":{"type":"enabled"},"messages":[{"role":"user","content":"hi"}]}`)
|
||||
existing := []string{"context-1m-2025-08-07", "compact-2026-01-12"}
|
||||
result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "context-1m-2025-08-07")
|
||||
assert.Contains(t, result, "compact-2026-01-12")
|
||||
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveBedrockBetaTokens(t *testing.T) {
|
||||
t.Run("body-only tool features resolve to final bedrock tokens", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := ResolveBedrockBetaTokens("", body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||
assert.Contains(t, result, "tool-examples-2025-10-29")
|
||||
})
|
||||
|
||||
t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) {
|
||||
t.Run("thinking in body auto-injects beta without header", func(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
found := false
|
||||
for _, v := range arr {
|
||||
if v.String() == "interleaved-thinking-2025-05-14" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "interleaved-thinking should be auto-injected")
|
||||
})
|
||||
|
||||
t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
names := make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
names[i] = v.String()
|
||||
}
|
||||
assert.Contains(t, names, "context-1m-2025-08-07")
|
||||
assert.Contains(t, names, "interleaved-thinking-2025-05-14")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdjustBedrockModelRegionPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelID string
|
||||
region string
|
||||
expect string
|
||||
}{
|
||||
// US region — no change needed
|
||||
{"us region keeps us prefix", "us.anthropic.claude-opus-4-6-v1", "us-east-1", "us.anthropic.claude-opus-4-6-v1"},
|
||||
// EU region — replace us → eu
|
||||
{"eu region replaces prefix", "us.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
|
||||
{"eu region sonnet", "us.anthropic.claude-sonnet-4-6", "eu-central-1", "eu.anthropic.claude-sonnet-4-6"},
|
||||
// APAC region — jp and au have dedicated prefixes per AWS docs
|
||||
{"jp region (ap-northeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-northeast-1", "jp.anthropic.claude-sonnet-4-5-20250929-v1:0"},
|
||||
{"au region (ap-southeast-2)", "us.anthropic.claude-haiku-4-5-20251001-v1:0", "ap-southeast-2", "au.anthropic.claude-haiku-4-5-20251001-v1:0"},
|
||||
{"apac region (ap-southeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-southeast-1", "apac.anthropic.claude-sonnet-4-5-20250929-v1:0"},
|
||||
// eu → us (user manually set eu prefix, moved to us region)
|
||||
{"eu to us", "eu.anthropic.claude-opus-4-6-v1", "us-west-2", "us.anthropic.claude-opus-4-6-v1"},
|
||||
// global prefix — replace to match region
|
||||
{"global to eu", "global.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
|
||||
// No known prefix — leave unchanged
|
||||
{"no prefix unchanged", "anthropic.claude-3-5-sonnet-20241022-v2:0", "eu-west-1", "anthropic.claude-3-5-sonnet-20241022-v2:0"},
|
||||
// GovCloud — uses independent us-gov prefix
|
||||
{"govcloud from us", "us.anthropic.claude-opus-4-6-v1", "us-gov-east-1", "us-gov.anthropic.claude-opus-4-6-v1"},
|
||||
{"govcloud already correct", "us-gov.anthropic.claude-opus-4-6-v1", "us-gov-west-1", "us-gov.anthropic.claude-opus-4-6-v1"},
|
||||
// Force global (special region value)
|
||||
{"force global from us", "us.anthropic.claude-opus-4-6-v1", "global", "global.anthropic.claude-opus-4-6-v1"},
|
||||
{"force global from eu", "eu.anthropic.claude-sonnet-4-6", "global", "global.anthropic.claude-sonnet-4-6"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expect, AdjustBedrockModelRegionPrefix(tt.modelID, tt.region))
|
||||
})
|
||||
}
|
||||
}
|
||||
67
backend/internal/service/bedrock_signer.go
Normal file
67
backend/internal/service/bedrock_signer.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
)
|
||||
|
||||
// BedrockSigner 使用 AWS SigV4 对 Bedrock 请求签名
|
||||
type BedrockSigner struct {
|
||||
credentials aws.Credentials
|
||||
region string
|
||||
signer *v4.Signer
|
||||
}
|
||||
|
||||
// NewBedrockSigner 创建 BedrockSigner
|
||||
func NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region string) *BedrockSigner {
|
||||
return &BedrockSigner{
|
||||
credentials: aws.Credentials{
|
||||
AccessKeyID: accessKeyID,
|
||||
SecretAccessKey: secretAccessKey,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
region: region,
|
||||
signer: v4.NewSigner(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewBedrockSignerFromAccount 从 Account 凭证创建 BedrockSigner
|
||||
func NewBedrockSignerFromAccount(account *Account) (*BedrockSigner, error) {
|
||||
accessKeyID := account.GetCredential("aws_access_key_id")
|
||||
if accessKeyID == "" {
|
||||
return nil, fmt.Errorf("aws_access_key_id not found in credentials")
|
||||
}
|
||||
secretAccessKey := account.GetCredential("aws_secret_access_key")
|
||||
if secretAccessKey == "" {
|
||||
return nil, fmt.Errorf("aws_secret_access_key not found in credentials")
|
||||
}
|
||||
region := account.GetCredential("aws_region")
|
||||
if region == "" {
|
||||
region = defaultBedrockRegion
|
||||
}
|
||||
sessionToken := account.GetCredential("aws_session_token") // 可选
|
||||
|
||||
return NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region), nil
|
||||
}
|
||||
|
||||
// SignRequest 对 HTTP 请求进行 SigV4 签名
|
||||
// 重要约束:调用此方法前,req 应只包含 AWS 相关的 header(如 Content-Type、Accept)。
|
||||
// 非 AWS header(如 anthropic-beta)会参与签名计算,如果 Bedrock 服务端不识别这些 header,
|
||||
// 签名验证可能失败。litellm 通过 _filter_headers_for_aws_signature 实现头过滤,
|
||||
// 当前实现中 buildUpstreamRequestBedrock 仅设置了 Content-Type 和 Accept,因此是安全的。
|
||||
func (s *BedrockSigner) SignRequest(ctx context.Context, req *http.Request, body []byte) error {
|
||||
payloadHash := sha256Hash(body)
|
||||
return s.signer.SignHTTP(ctx, s.credentials, req, payloadHash, "bedrock", s.region, time.Now())
|
||||
}
|
||||
|
||||
func sha256Hash(data []byte) string {
|
||||
h := sha256.Sum256(data)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
35
backend/internal/service/bedrock_signer_test.go
Normal file
35
backend/internal/service/bedrock_signer_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewBedrockSignerFromAccount_DefaultRegion(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_access_key_id": "test-akid",
|
||||
"aws_secret_access_key": "test-secret",
|
||||
},
|
||||
}
|
||||
|
||||
signer, err := NewBedrockSignerFromAccount(account)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, signer)
|
||||
assert.Equal(t, defaultBedrockRegion, signer.region)
|
||||
}
|
||||
|
||||
func TestFilterBetaTokens(t *testing.T) {
|
||||
tokens := []string{"interleaved-thinking-2025-05-14", "tool-search-tool-2025-10-19"}
|
||||
filterSet := map[string]struct{}{
|
||||
"tool-search-tool-2025-10-19": {},
|
||||
}
|
||||
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, filterBetaTokens(tokens, filterSet))
|
||||
assert.Equal(t, tokens, filterBetaTokens(tokens, nil))
|
||||
assert.Nil(t, filterBetaTokens(nil, filterSet))
|
||||
}
|
||||
414
backend/internal/service/bedrock_stream.go
Normal file
414
backend/internal/service/bedrock_stream.go
Normal file
@@ -0,0 +1,414 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// handleBedrockStreamingResponse 处理 Bedrock InvokeModelWithResponseStream 的 EventStream 响应
|
||||
// Bedrock 返回 AWS EventStream 二进制格式,每个事件的 payload 中 chunk.bytes 是 base64 编码的
|
||||
// Claude SSE 事件 JSON。本方法解码后转换为标准 SSE 格式写入客户端。
|
||||
func (s *GatewayService) handleBedrockStreamingResponse(
|
||||
ctx context.Context,
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
startTime time.Time,
|
||||
model string,
|
||||
) (*streamingResult, error) {
|
||||
w := c.Writer
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
if v := resp.Header.Get("x-amzn-requestid"); v != "" {
|
||||
c.Header("x-request-id", v)
|
||||
}
|
||||
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
|
||||
// Bedrock EventStream 使用 application/vnd.amazon.eventstream 二进制格式。
|
||||
// 每个帧结构:total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4)
|
||||
// 但更实用的方式是使用行扫描找 JSON chunks,因为 Bedrock 的响应在二进制帧中。
|
||||
// 我们使用 EventStream decoder 来正确解析。
|
||||
decoder := newBedrockEventStreamDecoder(resp.Body)
|
||||
|
||||
type decodeEvent struct {
|
||||
payload []byte
|
||||
err error
|
||||
}
|
||||
events := make(chan decodeEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev decodeEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
var lastReadAt atomic.Int64
|
||||
lastReadAt.Store(time.Now().UnixNano())
|
||||
|
||||
go func() {
|
||||
defer close(events)
|
||||
for {
|
||||
payload, err := decoder.Decode()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
_ = sendEvent(decodeEvent{err: err})
|
||||
return
|
||||
}
|
||||
lastReadAt.Store(time.Now().UnixNano())
|
||||
if !sendEvent(decodeEvent{payload: payload}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
if !clientDisconnected {
|
||||
flusher.Flush()
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if clientDisconnected {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("bedrock stream read error: %w", ev.err)
|
||||
}
|
||||
|
||||
// payload 是 JSON,提取 chunk.bytes(base64 编码的 Claude SSE 事件数据)
|
||||
sseData := extractBedrockChunkData(ev.payload)
|
||||
if sseData == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
// 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式
|
||||
// 同时移除该字段避免透传给客户端
|
||||
sseData = transformBedrockInvocationMetrics(sseData)
|
||||
|
||||
// 解析 SSE 事件数据提取 usage
|
||||
s.parseSSEUsagePassthrough(string(sseData), usage)
|
||||
|
||||
// 确定 SSE event type
|
||||
eventType := gjson.GetBytes(sseData, "type").String()
|
||||
|
||||
// 写入标准 SSE 格式
|
||||
if !clientDisconnected {
|
||||
var writeErr error
|
||||
if eventType != "" {
|
||||
_, writeErr = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, sseData)
|
||||
} else {
|
||||
_, writeErr = fmt.Fprintf(w, "data: %s\n\n", sseData)
|
||||
}
|
||||
if writeErr != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.gateway", "[Bedrock] Client disconnected during streaming, continue draining for usage: account=%d", account.ID)
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, lastReadAt.Load())
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "[Bedrock] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractBedrockChunkData 从 Bedrock EventStream payload 中提取 Claude SSE 事件数据
|
||||
// Bedrock payload 格式:{"bytes":"<base64-encoded-json>"}
|
||||
func extractBedrockChunkData(payload []byte) []byte {
|
||||
b64 := gjson.GetBytes(payload, "bytes").String()
|
||||
if b64 == "" {
|
||||
return nil
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return decoded
|
||||
}
|
||||
|
||||
// transformBedrockInvocationMetrics 将 Bedrock 特有的 amazon-bedrock-invocationMetrics
|
||||
// 转换为标准 Anthropic usage 格式,并从 SSE 数据中移除该字段。
|
||||
//
|
||||
// Bedrock Invoke 返回的 message_delta 事件可能包含:
|
||||
//
|
||||
// {"type":"message_delta","delta":{...},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}
|
||||
//
|
||||
// 转换为:
|
||||
//
|
||||
// {"type":"message_delta","delta":{...},"usage":{"input_tokens":150,"output_tokens":42}}
|
||||
func transformBedrockInvocationMetrics(data []byte) []byte {
|
||||
metrics := gjson.GetBytes(data, "amazon-bedrock-invocationMetrics")
|
||||
if !metrics.Exists() || !metrics.IsObject() {
|
||||
return data
|
||||
}
|
||||
|
||||
// 移除 Bedrock 特有字段
|
||||
data, _ = sjson.DeleteBytes(data, "amazon-bedrock-invocationMetrics")
|
||||
|
||||
// 如果已有标准 usage 字段,不覆盖
|
||||
if gjson.GetBytes(data, "usage").Exists() {
|
||||
return data
|
||||
}
|
||||
|
||||
// 转换 camelCase → snake_case 写入 usage
|
||||
inputTokens := metrics.Get("inputTokenCount")
|
||||
outputTokens := metrics.Get("outputTokenCount")
|
||||
if inputTokens.Exists() {
|
||||
data, _ = sjson.SetBytes(data, "usage.input_tokens", inputTokens.Int())
|
||||
}
|
||||
if outputTokens.Exists() {
|
||||
data, _ = sjson.SetBytes(data, "usage.output_tokens", outputTokens.Int())
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// bedrockEventStreamDecoder 解码 AWS EventStream 二进制帧
|
||||
// EventStream 帧格式:
|
||||
//
|
||||
// [total_byte_length: 4 bytes]
|
||||
// [headers_byte_length: 4 bytes]
|
||||
// [prelude_crc: 4 bytes]
|
||||
// [headers: variable]
|
||||
// [payload: variable]
|
||||
// [message_crc: 4 bytes]
|
||||
type bedrockEventStreamDecoder struct {
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func newBedrockEventStreamDecoder(r io.Reader) *bedrockEventStreamDecoder {
|
||||
return &bedrockEventStreamDecoder{
|
||||
reader: bufio.NewReaderSize(r, 64*1024),
|
||||
}
|
||||
}
|
||||
|
||||
// Decode 读取下一个 EventStream 帧并返回 chunk 类型事件的 payload
|
||||
func (d *bedrockEventStreamDecoder) Decode() ([]byte, error) {
|
||||
for {
|
||||
// 读取 prelude: total_length(4) + headers_length(4) + prelude_crc(4) = 12 bytes
|
||||
prelude := make([]byte, 12)
|
||||
if _, err := io.ReadFull(d.reader, prelude); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证 prelude CRC(AWS EventStream 使用标准 CRC32 / IEEE)
|
||||
preludeCRC := bedrockReadUint32(prelude[8:12])
|
||||
if crc32.Checksum(prelude[0:8], crc32IEEETable) != preludeCRC {
|
||||
return nil, fmt.Errorf("eventstream prelude CRC mismatch")
|
||||
}
|
||||
|
||||
totalLength := bedrockReadUint32(prelude[0:4])
|
||||
headersLength := bedrockReadUint32(prelude[4:8])
|
||||
|
||||
if totalLength < 16 { // minimum: 12 prelude + 4 message_crc
|
||||
return nil, fmt.Errorf("invalid eventstream frame: total_length=%d", totalLength)
|
||||
}
|
||||
|
||||
// 读取 headers + payload + message_crc
|
||||
remaining := int(totalLength) - 12
|
||||
if remaining <= 0 {
|
||||
continue
|
||||
}
|
||||
data := make([]byte, remaining)
|
||||
if _, err := io.ReadFull(d.reader, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证 message CRC(覆盖 prelude + headers + payload)
|
||||
messageCRC := bedrockReadUint32(data[len(data)-4:])
|
||||
h := crc32.New(crc32IEEETable)
|
||||
_, _ = h.Write(prelude)
|
||||
_, _ = h.Write(data[:len(data)-4])
|
||||
if h.Sum32() != messageCRC {
|
||||
return nil, fmt.Errorf("eventstream message CRC mismatch")
|
||||
}
|
||||
|
||||
// 解析 headers
|
||||
headers := data[:headersLength]
|
||||
payload := data[headersLength : len(data)-4] // 去掉 message_crc
|
||||
|
||||
// 从 headers 中提取 :event-type
|
||||
eventType := extractEventStreamHeaderValue(headers, ":event-type")
|
||||
|
||||
// 只处理 chunk 事件
|
||||
if eventType == "chunk" {
|
||||
// payload 是完整的 JSON,包含 bytes 字段
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// 检查异常事件
|
||||
exceptionType := extractEventStreamHeaderValue(headers, ":exception-type")
|
||||
if exceptionType != "" {
|
||||
return nil, fmt.Errorf("bedrock exception: %s: %s", exceptionType, string(payload))
|
||||
}
|
||||
|
||||
messageType := extractEventStreamHeaderValue(headers, ":message-type")
|
||||
if messageType == "exception" || messageType == "error" {
|
||||
return nil, fmt.Errorf("bedrock error: %s", string(payload))
|
||||
}
|
||||
|
||||
// 跳过其他事件类型(如 initial-response)
|
||||
}
|
||||
}
|
||||
|
||||
// extractEventStreamHeaderValue 从 EventStream headers 二进制数据中提取指定 header 的字符串值
|
||||
// EventStream header 格式:
|
||||
//
|
||||
// [name_length: 1 byte][name: variable][value_type: 1 byte][value: variable]
|
||||
//
|
||||
// value_type = 7 表示 string 类型,前 2 bytes 为长度
|
||||
func extractEventStreamHeaderValue(headers []byte, targetName string) string {
|
||||
pos := 0
|
||||
for pos < len(headers) {
|
||||
if pos >= len(headers) {
|
||||
break
|
||||
}
|
||||
nameLen := int(headers[pos])
|
||||
pos++
|
||||
if pos+nameLen > len(headers) {
|
||||
break
|
||||
}
|
||||
name := string(headers[pos : pos+nameLen])
|
||||
pos += nameLen
|
||||
|
||||
if pos >= len(headers) {
|
||||
break
|
||||
}
|
||||
valueType := headers[pos]
|
||||
pos++
|
||||
|
||||
switch valueType {
|
||||
case 7: // string
|
||||
if pos+2 > len(headers) {
|
||||
return ""
|
||||
}
|
||||
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
|
||||
pos += 2
|
||||
if pos+valueLen > len(headers) {
|
||||
return ""
|
||||
}
|
||||
value := string(headers[pos : pos+valueLen])
|
||||
pos += valueLen
|
||||
if name == targetName {
|
||||
return value
|
||||
}
|
||||
case 0: // bool true
|
||||
if name == targetName {
|
||||
return "true"
|
||||
}
|
||||
case 1: // bool false
|
||||
if name == targetName {
|
||||
return "false"
|
||||
}
|
||||
case 2: // byte
|
||||
pos++
|
||||
if name == targetName {
|
||||
return ""
|
||||
}
|
||||
case 3: // short
|
||||
pos += 2
|
||||
if name == targetName {
|
||||
return ""
|
||||
}
|
||||
case 4: // int
|
||||
pos += 4
|
||||
if name == targetName {
|
||||
return ""
|
||||
}
|
||||
case 5: // long
|
||||
pos += 8
|
||||
if name == targetName {
|
||||
return ""
|
||||
}
|
||||
case 6: // bytes
|
||||
if pos+2 > len(headers) {
|
||||
return ""
|
||||
}
|
||||
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
|
||||
pos += 2 + valueLen
|
||||
case 8: // timestamp
|
||||
pos += 8
|
||||
case 9: // uuid
|
||||
pos += 16
|
||||
default:
|
||||
return "" // 未知类型,无法继续解析
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// crc32IEEETable is the CRC32 / IEEE table used by AWS EventStream.
|
||||
var crc32IEEETable = crc32.MakeTable(crc32.IEEE)
|
||||
|
||||
func bedrockReadUint32(b []byte) uint32 {
|
||||
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||
}
|
||||
|
||||
func bedrockReadUint16(b []byte) uint16 {
|
||||
return uint16(b[0])<<8 | uint16(b[1])
|
||||
}
|
||||
261
backend/internal/service/bedrock_stream_test.go
Normal file
261
backend/internal/service/bedrock_stream_test.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestExtractBedrockChunkData(t *testing.T) {
|
||||
t.Run("valid base64 payload", func(t *testing.T) {
|
||||
original := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`
|
||||
b64 := base64.StdEncoding.EncodeToString([]byte(original))
|
||||
payload := []byte(`{"bytes":"` + b64 + `"}`)
|
||||
|
||||
result := extractBedrockChunkData(payload)
|
||||
require.NotNil(t, result)
|
||||
assert.JSONEq(t, original, string(result))
|
||||
})
|
||||
|
||||
t.Run("empty bytes field", func(t *testing.T) {
|
||||
result := extractBedrockChunkData([]byte(`{"bytes":""}`))
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("no bytes field", func(t *testing.T) {
|
||||
result := extractBedrockChunkData([]byte(`{"other":"value"}`))
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("invalid base64", func(t *testing.T) {
|
||||
result := extractBedrockChunkData([]byte(`{"bytes":"not-valid-base64!!!"}`))
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransformBedrockInvocationMetrics(t *testing.T) {
|
||||
t.Run("converts metrics to usage", func(t *testing.T) {
|
||||
input := `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
|
||||
result := transformBedrockInvocationMetrics([]byte(input))
|
||||
|
||||
// amazon-bedrock-invocationMetrics should be removed
|
||||
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
|
||||
// usage should be set
|
||||
assert.Equal(t, int64(150), gjson.GetBytes(result, "usage.input_tokens").Int())
|
||||
assert.Equal(t, int64(42), gjson.GetBytes(result, "usage.output_tokens").Int())
|
||||
// original fields preserved
|
||||
assert.Equal(t, "message_delta", gjson.GetBytes(result, "type").String())
|
||||
assert.Equal(t, "end_turn", gjson.GetBytes(result, "delta.stop_reason").String())
|
||||
})
|
||||
|
||||
t.Run("no metrics present", func(t *testing.T) {
|
||||
input := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}`
|
||||
result := transformBedrockInvocationMetrics([]byte(input))
|
||||
assert.JSONEq(t, input, string(result))
|
||||
})
|
||||
|
||||
t.Run("does not overwrite existing usage", func(t *testing.T) {
|
||||
input := `{"type":"message_delta","usage":{"output_tokens":100},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
|
||||
result := transformBedrockInvocationMetrics([]byte(input))
|
||||
|
||||
// metrics removed but existing usage preserved
|
||||
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
|
||||
assert.Equal(t, int64(100), gjson.GetBytes(result, "usage.output_tokens").Int())
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractEventStreamHeaderValue(t *testing.T) {
|
||||
// Build a header with :event-type = "chunk" (string type = 7)
|
||||
buildStringHeader := func(name, value string) []byte {
|
||||
var buf bytes.Buffer
|
||||
// name length (1 byte)
|
||||
_ = buf.WriteByte(byte(len(name)))
|
||||
// name
|
||||
_, _ = buf.WriteString(name)
|
||||
// value type (7 = string)
|
||||
_ = buf.WriteByte(7)
|
||||
// value length (2 bytes, big-endian)
|
||||
_ = binary.Write(&buf, binary.BigEndian, uint16(len(value)))
|
||||
// value
|
||||
_, _ = buf.WriteString(value)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
t.Run("find string header", func(t *testing.T) {
|
||||
headers := buildStringHeader(":event-type", "chunk")
|
||||
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
|
||||
})
|
||||
|
||||
t.Run("header not found", func(t *testing.T) {
|
||||
headers := buildStringHeader(":event-type", "chunk")
|
||||
assert.Equal(t, "", extractEventStreamHeaderValue(headers, ":message-type"))
|
||||
})
|
||||
|
||||
t.Run("multiple headers", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
_, _ = buf.Write(buildStringHeader(":content-type", "application/json"))
|
||||
_, _ = buf.Write(buildStringHeader(":event-type", "chunk"))
|
||||
_, _ = buf.Write(buildStringHeader(":message-type", "event"))
|
||||
|
||||
headers := buf.Bytes()
|
||||
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
|
||||
assert.Equal(t, "application/json", extractEventStreamHeaderValue(headers, ":content-type"))
|
||||
assert.Equal(t, "event", extractEventStreamHeaderValue(headers, ":message-type"))
|
||||
})
|
||||
|
||||
t.Run("empty headers", func(t *testing.T) {
|
||||
assert.Equal(t, "", extractEventStreamHeaderValue([]byte{}, ":event-type"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestBedrockEventStreamDecoder(t *testing.T) {
|
||||
crc32IeeeTab := crc32.MakeTable(crc32.IEEE)
|
||||
|
||||
// Build a valid EventStream frame with correct CRC32/IEEE checksums.
|
||||
buildFrame := func(eventType string, payload []byte) []byte {
|
||||
// Build headers
|
||||
var headersBuf bytes.Buffer
|
||||
// :event-type header
|
||||
_ = headersBuf.WriteByte(byte(len(":event-type")))
|
||||
_, _ = headersBuf.WriteString(":event-type")
|
||||
_ = headersBuf.WriteByte(7) // string type
|
||||
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len(eventType)))
|
||||
_, _ = headersBuf.WriteString(eventType)
|
||||
// :message-type header
|
||||
_ = headersBuf.WriteByte(byte(len(":message-type")))
|
||||
_, _ = headersBuf.WriteString(":message-type")
|
||||
_ = headersBuf.WriteByte(7)
|
||||
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("event")))
|
||||
_, _ = headersBuf.WriteString("event")
|
||||
|
||||
headers := headersBuf.Bytes()
|
||||
headersLen := uint32(len(headers))
|
||||
// total = 12 (prelude) + headers + payload + 4 (message_crc)
|
||||
totalLen := uint32(12 + len(headers) + len(payload) + 4)
|
||||
|
||||
// Prelude: total_length(4) + headers_length(4)
|
||||
var preludeBuf bytes.Buffer
|
||||
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
|
||||
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
|
||||
preludeBytes := preludeBuf.Bytes()
|
||||
preludeCRC := crc32.Checksum(preludeBytes, crc32IeeeTab)
|
||||
|
||||
// Build frame: prelude + prelude_crc + headers + payload
|
||||
var frame bytes.Buffer
|
||||
_, _ = frame.Write(preludeBytes)
|
||||
_ = binary.Write(&frame, binary.BigEndian, preludeCRC)
|
||||
_, _ = frame.Write(headers)
|
||||
_, _ = frame.Write(payload)
|
||||
|
||||
// Message CRC covers everything before itself
|
||||
messageCRC := crc32.Checksum(frame.Bytes(), crc32IeeeTab)
|
||||
_ = binary.Write(&frame, binary.BigEndian, messageCRC)
|
||||
return frame.Bytes()
|
||||
}
|
||||
|
||||
t.Run("decode chunk event", func(t *testing.T) {
|
||||
payload := []byte(`{"bytes":"dGVzdA=="}`) // base64("test")
|
||||
frame := buildFrame("chunk", payload)
|
||||
|
||||
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
|
||||
result, err := decoder.Decode()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, payload, result)
|
||||
})
|
||||
|
||||
t.Run("skip non-chunk events", func(t *testing.T) {
|
||||
// Write initial-response followed by chunk
|
||||
var buf bytes.Buffer
|
||||
_, _ = buf.Write(buildFrame("initial-response", []byte(`{}`)))
|
||||
chunkPayload := []byte(`{"bytes":"aGVsbG8="}`)
|
||||
_, _ = buf.Write(buildFrame("chunk", chunkPayload))
|
||||
|
||||
decoder := newBedrockEventStreamDecoder(&buf)
|
||||
result, err := decoder.Decode()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, chunkPayload, result)
|
||||
})
|
||||
|
||||
t.Run("EOF on empty input", func(t *testing.T) {
|
||||
decoder := newBedrockEventStreamDecoder(bytes.NewReader(nil))
|
||||
_, err := decoder.Decode()
|
||||
assert.Equal(t, io.EOF, err)
|
||||
})
|
||||
|
||||
t.Run("corrupted prelude CRC", func(t *testing.T) {
|
||||
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
|
||||
// Corrupt the prelude CRC (bytes 8-11)
|
||||
frame[8] ^= 0xFF
|
||||
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
|
||||
_, err := decoder.Decode()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "prelude CRC mismatch")
|
||||
})
|
||||
|
||||
t.Run("corrupted message CRC", func(t *testing.T) {
|
||||
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
|
||||
// Corrupt the message CRC (last 4 bytes)
|
||||
frame[len(frame)-1] ^= 0xFF
|
||||
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
|
||||
_, err := decoder.Decode()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "message CRC mismatch")
|
||||
})
|
||||
|
||||
t.Run("castagnoli encoded frame is rejected", func(t *testing.T) {
|
||||
castagnoliTab := crc32.MakeTable(crc32.Castagnoli)
|
||||
payload := []byte(`{"bytes":"dGVzdA=="}`)
|
||||
|
||||
var headersBuf bytes.Buffer
|
||||
_ = headersBuf.WriteByte(byte(len(":event-type")))
|
||||
_, _ = headersBuf.WriteString(":event-type")
|
||||
_ = headersBuf.WriteByte(7)
|
||||
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("chunk")))
|
||||
_, _ = headersBuf.WriteString("chunk")
|
||||
|
||||
headers := headersBuf.Bytes()
|
||||
headersLen := uint32(len(headers))
|
||||
totalLen := uint32(12 + len(headers) + len(payload) + 4)
|
||||
|
||||
var preludeBuf bytes.Buffer
|
||||
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
|
||||
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
|
||||
preludeBytes := preludeBuf.Bytes()
|
||||
|
||||
var frame bytes.Buffer
|
||||
_, _ = frame.Write(preludeBytes)
|
||||
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(preludeBytes, castagnoliTab))
|
||||
_, _ = frame.Write(headers)
|
||||
_, _ = frame.Write(payload)
|
||||
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(frame.Bytes(), castagnoliTab))
|
||||
|
||||
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame.Bytes()))
|
||||
_, err := decoder.Decode()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "prelude CRC mismatch")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildBedrockURL(t *testing.T) {
|
||||
t.Run("stream URL with colon in model ID", func(t *testing.T) {
|
||||
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-opus-4-5-20251101-v1:0", true)
|
||||
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/invoke-with-response-stream", url)
|
||||
})
|
||||
|
||||
t.Run("non-stream URL with colon in model ID", func(t *testing.T) {
|
||||
url := BuildBedrockURL("eu-west-1", "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", false)
|
||||
assert.Equal(t, "https://bedrock-runtime.eu-west-1.amazonaws.com/model/eu.anthropic.claude-sonnet-4-5-20250929-v1%3A0/invoke", url)
|
||||
})
|
||||
|
||||
t.Run("model ID without colon", func(t *testing.T) {
|
||||
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-sonnet-4-6", true)
|
||||
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-sonnet-4-6/invoke-with-response-stream", url)
|
||||
})
|
||||
}
|
||||
@@ -35,6 +35,7 @@ type DashboardAggregationRepository interface {
|
||||
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
||||
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
||||
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
||||
CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error
|
||||
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
||||
}
|
||||
|
||||
@@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
|
||||
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
||||
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
||||
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
||||
dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays)
|
||||
|
||||
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
||||
if aggErr != nil {
|
||||
@@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
|
||||
if usageErr != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
||||
}
|
||||
if aggErr == nil && usageErr == nil {
|
||||
dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff)
|
||||
if dedupErr != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr)
|
||||
}
|
||||
if aggErr == nil && usageErr == nil && dedupErr == nil {
|
||||
s.lastRetentionCleanup.Store(now)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,12 +12,18 @@ import (
|
||||
|
||||
type dashboardAggregationRepoTestStub struct {
|
||||
aggregateCalls int
|
||||
recomputeCalls int
|
||||
cleanupUsageCalls int
|
||||
cleanupDedupCalls int
|
||||
ensurePartitionCalls int
|
||||
lastStart time.Time
|
||||
lastEnd time.Time
|
||||
watermark time.Time
|
||||
aggregateErr error
|
||||
cleanupAggregatesErr error
|
||||
cleanupUsageErr error
|
||||
cleanupDedupErr error
|
||||
ensurePartitionErr error
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
@@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
s.recomputeCalls++
|
||||
return s.AggregateRange(ctx, start, end)
|
||||
}
|
||||
|
||||
@@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
s.cleanupUsageCalls++
|
||||
return s.cleanupUsageErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
s.cleanupDedupCalls++
|
||||
return s.cleanupDedupErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
s.ensurePartitionCalls++
|
||||
return s.ensurePartitionErr
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
||||
@@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te
|
||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||
|
||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||
require.Equal(t, 1, repo.cleanupUsageCalls)
|
||||
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||
|
||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
IntervalSeconds: 60,
|
||||
LookbackSeconds: 120,
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
UsageBillingDedupDays: 2,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.runScheduledAggregation()
|
||||
|
||||
require.Equal(t, 1, repo.ensurePartitionCalls)
|
||||
require.Equal(t, 1, repo.aggregateCalls)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
||||
|
||||
@@ -327,6 +327,14 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
ranking, err := s.usageRepo.GetUserSpendingRanking(ctx, startTime, endTime, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user spending ranking: %w", err)
|
||||
}
|
||||
return ranking, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
|
||||
if err != nil {
|
||||
|
||||
@@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ const (
|
||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@@ -219,6 +220,9 @@ const (
|
||||
|
||||
// SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false:未分组 Key 返回 403)
|
||||
SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling"
|
||||
|
||||
// SettingKeyBackendModeEnabled Backend 模式:禁用用户注册和自助服务,仅管理员可登录
|
||||
SettingKeyBackendModeEnabled = "backend_mode_enabled"
|
||||
)
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
|
||||
@@ -110,7 +110,9 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_401_second_hit_upgrades_to_none",
|
||||
// Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制),
|
||||
// second hit 仍然返回 TempUnscheduled。
|
||||
name: "temp_unschedulable_401_second_hit_antigravity_stays_temp",
|
||||
account: &Account{
|
||||
ID: 15,
|
||||
Type: AccountTypeOAuth,
|
||||
@@ -129,7 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyNone,
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_body_miss_returns_none",
|
||||
|
||||
@@ -136,16 +136,18 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
deferredService: &DeferredService{},
|
||||
billingCacheService: nil,
|
||||
}
|
||||
svc := &GatewayService{
|
||||
cfg: cfg,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
deferredService: &DeferredService{},
|
||||
billingCacheService: nil,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
@@ -221,14 +223,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
cfg: cfg,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
@@ -727,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf
|
||||
require.Equal(t, 5, result.usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
},
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`,
|
||||
"",
|
||||
`data: {"type":"message_delta","usage":{"output_tokens":5}}`,
|
||||
"",
|
||||
}, "\n"))),
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing terminal event")
|
||||
require.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -1074,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi
|
||||
_ = pr.Close()
|
||||
<-done
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "stream usage incomplete after timeout")
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.Equal(t, 9, result.usage.InputTokens)
|
||||
@@ -1103,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "stream usage incomplete")
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
}
|
||||
@@ -1133,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "stream usage incomplete after disconnect")
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.Equal(t, 8, result.usage.InputTokens)
|
||||
|
||||
371
backend/internal/service/gateway_record_usage_test.go
Normal file
371
backend/internal/service/gateway_record_usage_test.go
Normal file
@@ -0,0 +1,371 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
|
||||
cfg := &config.Config{}
|
||||
cfg.Default.RateMultiplier = 1.1
|
||||
return NewGatewayService(
|
||||
nil,
|
||||
nil,
|
||||
usageRepo,
|
||||
nil,
|
||||
userRepo,
|
||||
subRepo,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
NewBillingService(cfg, nil),
|
||||
nil,
|
||||
&BillingCacheService{},
|
||||
nil,
|
||||
nil,
|
||||
&DeferredService{},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
svc.usageBillingRepo = billingRepo
|
||||
return svc
|
||||
}
|
||||
|
||||
type openAIRecordUsageBestEffortLogRepoStub struct {
|
||||
UsageLogRepository
|
||||
|
||||
bestEffortErr error
|
||||
createErr error
|
||||
bestEffortCalls int
|
||||
createCalls int
|
||||
lastLog *UsageLog
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageBestEffortLogRepoStub) CreateBestEffort(ctx context.Context, log *UsageLog) error {
|
||||
s.bestEffortCalls++
|
||||
s.lastLog = log
|
||||
s.lastCtxErr = ctx.Err()
|
||||
return s.bestEffortErr
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageBestEffortLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
|
||||
s.createCalls++
|
||||
s.lastLog = log
|
||||
s.lastCtxErr = ctx.Err()
|
||||
return false, s.createErr
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := svc.RecordUsage(reqCtx, &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_detached_ctx",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 501,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.NoError(t, userRepo.lastCtxErr)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
payloadHash := HashUsageRequestPayload([]byte(`{"messages":[{"role":"user","content":"hello"}]}`))
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_payload_hash",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
RequestPayloadHash: payloadHash,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-123")
|
||||
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_payload_fallback",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_not_persisted",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 503,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 603},
|
||||
Account: &Account{ID: 703},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := svc.RecordUsageWithLongContext(reqCtx, &RecordUsageLongContextInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_long_context_detached_ctx",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 12,
|
||||
OutputTokens: 8,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 502,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 602},
|
||||
Account: &Account{ID: 702},
|
||||
LongContextThreshold: 200000,
|
||||
LongContextMultiplier: 2,
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.NoError(t, userRepo.lastCtxErr)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "gateway-local-fallback")
|
||||
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 504},
|
||||
User: &User{ID: 604},
|
||||
Account: &Account{ID: 704},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-stable-123")
|
||||
ctx = context.WithValue(ctx, ctxkey.RequestID, "req-local-ignored")
|
||||
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "upstream-volatile-456",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 506},
|
||||
User: &User{ID: 606},
|
||||
Account: &Account{ID: 706},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, "client:client-stable-123", billingRepo.lastCmd.RequestID)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "client:client-stable-123", usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 507},
|
||||
User: &User{ID: 607},
|
||||
Account: &Account{ID: 707},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:"))
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_DroppedUsageLogDoesNotSyncFallback(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{
|
||||
bestEffortErr: MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")),
|
||||
}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_drop_usage_log",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 508},
|
||||
User: &User{ID: 608},
|
||||
Account: &Account{ID: 708},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.bestEffortCalls)
|
||||
require.Equal(t, 0, usageRepo.createCalls)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_billing_fail",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 505},
|
||||
User: &User{ID: 605},
|
||||
Account: &Account{ID: 705},
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 0, usageRepo.calls)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user