mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-06 00:10:21 +08:00
Compare commits
108 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d3a9f5bb88 | ||
|
|
7eb0415a8a | ||
|
|
bdbc8fa08f | ||
|
|
63f3af0f94 | ||
|
|
686f890fbf | ||
|
|
220fbe6544 | ||
|
|
ae44a94325 | ||
|
|
3718d6dcd4 | ||
|
|
90b3838173 | ||
|
|
19d3ecc76f | ||
|
|
6fba4ebb13 | ||
|
|
c31974c913 | ||
|
|
6177fa5dd8 | ||
|
|
cfe72159d0 | ||
|
|
8321e4a647 | ||
|
|
3084330d0c | ||
|
|
b566649e79 | ||
|
|
10a6180e4a | ||
|
|
cbe9e78977 | ||
|
|
74145b1f39 | ||
|
|
359e56751b | ||
|
|
5899784aa4 | ||
|
|
9e8959c56d | ||
|
|
1bff2292a6 | ||
|
|
cf9247754e | ||
|
|
eefab15958 | ||
|
|
0e23732631 | ||
|
|
37c044fb4b | ||
|
|
6da5fa01b9 | ||
|
|
616930f9d3 | ||
|
|
b9c31fa7c4 | ||
|
|
17b339972c | ||
|
|
39f8bd91b9 | ||
|
|
aa4e37d085 | ||
|
|
f59b66b7d4 | ||
|
|
8f0ea7a02d | ||
|
|
a1dc00890e | ||
|
|
dfbcc363d1 | ||
|
|
1047f973d5 | ||
|
|
e32977dd73 | ||
|
|
b5f78ec1e8 | ||
|
|
e0f290fdc8 | ||
|
|
fc00a4e3b2 | ||
|
|
db1f6ded88 | ||
|
|
4644af2ccc | ||
|
|
2e3e8687e1 | ||
|
|
ca42a45802 | ||
|
|
9350ecb62b | ||
|
|
a4a026e8da | ||
|
|
342fd03e72 | ||
|
|
e3f1fd9b63 | ||
|
|
e4a4dfd038 | ||
|
|
a377e99088 | ||
|
|
1d3d7a3033 | ||
|
|
e7086cb3a3 | ||
|
|
4f2a97073e | ||
|
|
7407e3b45d | ||
|
|
01ef7340aa | ||
|
|
1c960d22c1 | ||
|
|
ece0606fed | ||
|
|
2666422b99 | ||
|
|
e6d59216d4 | ||
|
|
4e8615f276 | ||
|
|
91e4d95660 | ||
|
|
45456fa24c | ||
|
|
4588258d80 | ||
|
|
c12e48f966 | ||
|
|
ec8f50a658 | ||
|
|
99c9191784 | ||
|
|
6bb02d141f | ||
|
|
07bb2a5f3f | ||
|
|
417861a48e | ||
|
|
b7e878de64 | ||
|
|
05edb5514b | ||
|
|
e90ec847b6 | ||
|
|
6344fa2a86 | ||
|
|
7e288acc90 | ||
|
|
29b0e4a8a5 | ||
|
|
27ff222cfb | ||
|
|
11f7b83522 | ||
|
|
f7177be3b6 | ||
|
|
875b417fde | ||
|
|
2573107b32 | ||
|
|
5b85005945 | ||
|
|
1ee984478f | ||
|
|
fd693dc526 | ||
|
|
e73531ce9b | ||
|
|
53ad1645cf | ||
|
|
ecea13757b | ||
|
|
af9c4a7dd0 | ||
|
|
80d8d6c3bc | ||
|
|
d648811233 | ||
|
|
34695acb85 | ||
|
|
a63de12182 | ||
|
|
f16910d616 | ||
|
|
64b3f3cec1 | ||
|
|
6a685727d0 | ||
|
|
32d25f76fc | ||
|
|
69cafe8674 | ||
|
|
18ba8d9166 | ||
|
|
e97fd7e81c | ||
|
|
cdb64b0d33 | ||
|
|
8d4d3b03bb | ||
|
|
addefe79e1 | ||
|
|
b764d3b8f6 | ||
|
|
611fd884bd | ||
|
|
6826149a8f | ||
|
|
c9debc50b1 |
20
Dockerfile
20
Dockerfile
@@ -9,6 +9,7 @@
|
|||||||
ARG NODE_IMAGE=node:24-alpine
|
ARG NODE_IMAGE=node:24-alpine
|
||||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||||
ARG ALPINE_IMAGE=alpine:3.21
|
ARG ALPINE_IMAGE=alpine:3.21
|
||||||
|
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||||
ARG GOPROXY=https://goproxy.cn,direct
|
ARG GOPROXY=https://goproxy.cn,direct
|
||||||
ARG GOSUMDB=sum.golang.google.cn
|
ARG GOSUMDB=sum.golang.google.cn
|
||||||
|
|
||||||
@@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \
|
|||||||
./cmd/server
|
./cmd/server
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Stage 3: Final Runtime Image
|
# Stage 3: PostgreSQL Client (version-matched with docker-compose)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Stage 4: Final Runtime Image
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
FROM ${ALPINE_IMAGE}
|
FROM ${ALPINE_IMAGE}
|
||||||
|
|
||||||
@@ -86,8 +92,20 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
|||||||
RUN apk add --no-cache \
|
RUN apk add --no-cache \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
tzdata \
|
tzdata \
|
||||||
|
libpq \
|
||||||
|
zstd-libs \
|
||||||
|
lz4-libs \
|
||||||
|
krb5-libs \
|
||||||
|
libldap \
|
||||||
|
libedit \
|
||||||
&& rm -rf /var/cache/apk/*
|
&& rm -rf /var/cache/apk/*
|
||||||
|
|
||||||
|
# Copy pg_dump and psql from the same postgres image used in docker-compose
|
||||||
|
# This ensures version consistency between backup tools and the database server
|
||||||
|
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||||
|
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||||
|
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||||
|
|
||||||
# Create non-root user
|
# Create non-root user
|
||||||
RUN addgroup -g 1000 sub2api && \
|
RUN addgroup -g 1000 sub2api && \
|
||||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||||
|
|||||||
@@ -5,7 +5,12 @@
|
|||||||
# It only packages the pre-built binary, no compilation needed.
|
# It only packages the pre-built binary, no compilation needed.
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
FROM alpine:3.19
|
ARG ALPINE_IMAGE=alpine:3.21
|
||||||
|
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||||
|
|
||||||
|
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||||
|
|
||||||
|
FROM ${ALPINE_IMAGE}
|
||||||
|
|
||||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||||
LABEL description="Sub2API - AI API Gateway Platform"
|
LABEL description="Sub2API - AI API Gateway Platform"
|
||||||
@@ -16,8 +21,20 @@ RUN apk add --no-cache \
|
|||||||
ca-certificates \
|
ca-certificates \
|
||||||
tzdata \
|
tzdata \
|
||||||
curl \
|
curl \
|
||||||
|
libpq \
|
||||||
|
zstd-libs \
|
||||||
|
lz4-libs \
|
||||||
|
krb5-libs \
|
||||||
|
libldap \
|
||||||
|
libedit \
|
||||||
&& rm -rf /var/cache/apk/*
|
&& rm -rf /var/cache/apk/*
|
||||||
|
|
||||||
|
# Copy pg_dump and psql from a version-matched PostgreSQL image so backup and
|
||||||
|
# restore work in the runtime container without requiring Docker socket access.
|
||||||
|
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||||
|
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||||
|
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||||
|
|
||||||
# Create non-root user
|
# Create non-root user
|
||||||
RUN addgroup -g 1000 sub2api && \
|
RUN addgroup -g 1000 sub2api && \
|
||||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||||
|
|||||||
10
README.md
10
README.md
@@ -39,6 +39,16 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
|||||||
- **Concurrency Control** - Per-user and per-account concurrency limits
|
- **Concurrency Control** - Per-user and per-account concurrency limits
|
||||||
- **Rate Limiting** - Configurable request and token rate limits
|
- **Rate Limiting** - Configurable request and token rate limits
|
||||||
- **Admin Dashboard** - Web interface for monitoring and management
|
- **Admin Dashboard** - Web interface for monitoring and management
|
||||||
|
- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
|
||||||
|
|
||||||
|
## Ecosystem
|
||||||
|
|
||||||
|
Community projects that extend or integrate with Sub2API:
|
||||||
|
|
||||||
|
| Project | Description | Features |
|
||||||
|
|---------|-------------|----------|
|
||||||
|
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | Self-service payment system | Self-service top-up and subscription purchase; supports YiPay protocol, WeChat Pay, Alipay, Stripe; embeddable via iframe |
|
||||||
|
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native |
|
||||||
|
|
||||||
## Tech Stack
|
## Tech Stack
|
||||||
|
|
||||||
|
|||||||
10
README_CN.md
10
README_CN.md
@@ -39,6 +39,16 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
|
|||||||
- **并发控制** - 用户级和账号级并发限制
|
- **并发控制** - 用户级和账号级并发限制
|
||||||
- **速率限制** - 可配置的请求和 Token 速率限制
|
- **速率限制** - 可配置的请求和 Token 速率限制
|
||||||
- **管理后台** - Web 界面进行监控和管理
|
- **管理后台** - Web 界面进行监控和管理
|
||||||
|
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
|
||||||
|
|
||||||
|
## 生态项目
|
||||||
|
|
||||||
|
围绕 Sub2API 的社区扩展与集成项目:
|
||||||
|
|
||||||
|
| 项目 | 说明 | 功能 |
|
||||||
|
|------|------|------|
|
||||||
|
| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | 自助支付系统 | 用户自助充值、自助订阅购买;兼容易支付协议、微信官方支付、支付宝官方支付、Stripe;支持 iframe 嵌入管理后台 |
|
||||||
|
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用(iOS/Android/Web),支持用户管理、账号管理、监控看板、多后端切换;基于 Expo + React Native 构建 |
|
||||||
|
|
||||||
## 技术栈
|
## 技术栈
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
// Server layer ProviderSet
|
// Server layer ProviderSet
|
||||||
server.ProviderSet,
|
server.ProviderSet,
|
||||||
|
|
||||||
|
// Privacy client factory for OpenAI training opt-out
|
||||||
|
providePrivacyClientFactory,
|
||||||
|
|
||||||
// BuildInfo provider
|
// BuildInfo provider
|
||||||
provideServiceBuildInfo,
|
provideServiceBuildInfo,
|
||||||
|
|
||||||
@@ -53,6 +56,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||||
|
return repository.CreatePrivacyReqClient
|
||||||
|
}
|
||||||
|
|
||||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||||
return service.BuildInfo{
|
return service.BuildInfo{
|
||||||
Version: buildInfo.Version,
|
Version: buildInfo.Version,
|
||||||
@@ -87,6 +94,7 @@ func provideCleanup(
|
|||||||
antigravityOAuth *service.AntigravityOAuthService,
|
antigravityOAuth *service.AntigravityOAuthService,
|
||||||
openAIGateway *service.OpenAIGatewayService,
|
openAIGateway *service.OpenAIGatewayService,
|
||||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||||
|
backupSvc *service.BackupService,
|
||||||
) func() {
|
) func() {
|
||||||
return func() {
|
return func() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
@@ -223,6 +231,12 @@ func provideCleanup(
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"BackupService", func() error {
|
||||||
|
if backupSvc != nil {
|
||||||
|
backupSvc.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
}
|
}
|
||||||
|
|
||||||
infraSteps := []cleanupStep{
|
infraSteps := []cleanupStep{
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
userHandler := handler.NewUserHandler(userService)
|
userHandler := handler.NewUserHandler(userService)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||||
|
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||||
@@ -104,7 +105,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
proxyRepository := repository.NewProxyRepository(client, db)
|
proxyRepository := repository.NewProxyRepository(client, db)
|
||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository)
|
privacyClientFactory := providePrivacyClientFactory()
|
||||||
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||||
@@ -144,6 +146,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||||
dataManagementService := service.NewDataManagementService()
|
dataManagementService := service.NewDataManagementService()
|
||||||
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
||||||
|
backupObjectStoreFactory := repository.NewS3BackupStoreFactory()
|
||||||
|
dbDumper := repository.NewPgDumper(configConfig)
|
||||||
|
backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper)
|
||||||
|
backupHandler := admin.NewBackupHandler(backupService, userService)
|
||||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||||
@@ -162,9 +168,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
|
||||||
digestSessionStore := service.NewDigestSessionStore()
|
digestSessionStore := service.NewDigestSessionStore()
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService)
|
||||||
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||||
@@ -199,7 +205,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||||
@@ -226,11 +232,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
Server: httpServer,
|
Server: httpServer,
|
||||||
Cleanup: v,
|
Cleanup: v,
|
||||||
@@ -245,6 +251,10 @@ type Application struct {
|
|||||||
Cleanup func()
|
Cleanup func()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func providePrivacyClientFactory() service.PrivacyClientFactory {
|
||||||
|
return repository.CreatePrivacyReqClient
|
||||||
|
}
|
||||||
|
|
||||||
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||||
return service.BuildInfo{
|
return service.BuildInfo{
|
||||||
Version: buildInfo.Version,
|
Version: buildInfo.Version,
|
||||||
@@ -279,6 +289,7 @@ func provideCleanup(
|
|||||||
antigravityOAuth *service.AntigravityOAuthService,
|
antigravityOAuth *service.AntigravityOAuthService,
|
||||||
openAIGateway *service.OpenAIGatewayService,
|
openAIGateway *service.OpenAIGatewayService,
|
||||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||||
|
backupSvc *service.BackupService,
|
||||||
) func() {
|
) func() {
|
||||||
return func() {
|
return func() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
@@ -414,6 +425,12 @@ func provideCleanup(
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"BackupService", func() error {
|
||||||
|
if backupSvc != nil {
|
||||||
|
backupSvc.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
}
|
}
|
||||||
|
|
||||||
infraSteps := []cleanupStep{
|
infraSteps := []cleanupStep{
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
|||||||
antigravityOAuthSvc,
|
antigravityOAuthSvc,
|
||||||
nil, // openAIGateway
|
nil, // openAIGateway
|
||||||
nil, // scheduledTestRunner
|
nil, // scheduledTestRunner
|
||||||
|
nil, // backupSvc
|
||||||
)
|
)
|
||||||
|
|
||||||
require.NotPanics(t, func() {
|
require.NotPanics(t, func() {
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ require (
|
|||||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||||
github.com/DouDOU-start/go-sora2api v1.1.0
|
github.com/DouDOU-start/go-sora2api v1.1.0
|
||||||
github.com/alitto/pond/v2 v2.6.2
|
github.com/alitto/pond/v2 v2.6.2
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.2
|
github.com/aws/aws-sdk-go-v2 v1.41.3
|
||||||
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||||
@@ -66,7 +66,7 @@ require (
|
|||||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
||||||
github.com/aws/smithy-go v1.24.1 // indirect
|
github.com/aws/smithy-go v1.24.2 // indirect
|
||||||
github.com/bdandy/go-errors v1.2.2 // indirect
|
github.com/bdandy/go-errors v1.2.2 // indirect
|
||||||
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
github.com/bdandy/go-socks4 v1.2.3 // indirect
|
||||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
|
|||||||
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
|
||||||
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
|
github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
|
||||||
@@ -60,6 +62,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8
|
|||||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
|
||||||
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
|
||||||
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
||||||
|
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||||
|
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||||
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
|
||||||
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
|
||||||
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
|
||||||
|
|||||||
@@ -935,6 +935,7 @@ type DashboardAggregationConfig struct {
|
|||||||
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
||||||
type DashboardAggregationRetentionConfig struct {
|
type DashboardAggregationRetentionConfig struct {
|
||||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||||
|
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
|
||||||
HourlyDays int `mapstructure:"hourly_days"`
|
HourlyDays int `mapstructure:"hourly_days"`
|
||||||
DailyDays int `mapstructure:"daily_days"`
|
DailyDays int `mapstructure:"daily_days"`
|
||||||
}
|
}
|
||||||
@@ -1301,6 +1302,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
||||||
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
||||||
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
||||||
|
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
|
||||||
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
||||||
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||||
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||||
@@ -1758,6 +1760,12 @@ func (c *Config) Validate() error {
|
|||||||
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
||||||
}
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||||
|
}
|
||||||
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
||||||
}
|
}
|
||||||
@@ -1780,6 +1788,14 @@ func (c *Config) Validate() error {
|
|||||||
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
||||||
}
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
|
||||||
|
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
|
||||||
|
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||||
|
}
|
||||||
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
||||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
|||||||
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
||||||
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
||||||
}
|
}
|
||||||
|
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
|
||||||
|
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
|
||||||
|
}
|
||||||
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
||||||
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
||||||
}
|
}
|
||||||
@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
|
|||||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
||||||
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "dashboard aggregation dedup retention",
|
||||||
|
mutate: func(c *Config) {
|
||||||
|
c.DashboardAgg.Enabled = true
|
||||||
|
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
|
||||||
|
},
|
||||||
|
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dashboard aggregation dedup retention smaller than usage logs",
|
||||||
|
mutate: func(c *Config) {
|
||||||
|
c.DashboardAgg.Enabled = true
|
||||||
|
c.DashboardAgg.Retention.UsageLogsDays = 30
|
||||||
|
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
|
||||||
|
},
|
||||||
|
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "dashboard aggregation disabled interval",
|
name: "dashboard aggregation disabled interval",
|
||||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ const (
|
|||||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||||
|
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Redeem type constants
|
// Redeem type constants
|
||||||
@@ -113,3 +114,27 @@ var DefaultAntigravityModelMapping = map[string]string{
|
|||||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射
|
||||||
|
// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID
|
||||||
|
// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的
|
||||||
|
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
|
||||||
|
var DefaultBedrockModelMapping = map[string]string{
|
||||||
|
// Claude Opus
|
||||||
|
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
|
||||||
|
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
|
||||||
|
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||||
|
"claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||||
|
"claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0",
|
||||||
|
"claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0",
|
||||||
|
// Claude Sonnet
|
||||||
|
"claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6",
|
||||||
|
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
|
||||||
|
"claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||||
|
"claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||||
|
"claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||||
|
"claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||||
|
// Claude Haiku
|
||||||
|
"claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||||
|
"claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||||
|
}
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
|
|||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
Notes *string `json:"notes"`
|
Notes *string `json:"notes"`
|
||||||
Platform string `json:"platform" binding:"required"`
|
Platform string `json:"platform" binding:"required"`
|
||||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"`
|
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
|
||||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||||
Extra map[string]any `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
|
|||||||
type UpdateAccountRequest struct {
|
type UpdateAccountRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Notes *string `json:"notes"`
|
Notes *string `json:"notes"`
|
||||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"`
|
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
|
||||||
Credentials map[string]any `json:"credentials"`
|
Credentials map[string]any `json:"credentials"`
|
||||||
Extra map[string]any `json:"extra"`
|
Extra map[string]any `json:"extra"`
|
||||||
ProxyID *int64 `json:"proxy_id"`
|
ProxyID *int64 `json:"proxy_id"`
|
||||||
@@ -865,6 +865,9 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenAI OAuth: 刷新成功后检查并设置 privacy_mode
|
||||||
|
h.adminService.EnsureOpenAIPrivacy(ctx, updatedAccount)
|
||||||
|
|
||||||
return updatedAccount, "", nil
|
return updatedAccount, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1715,13 +1718,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
|||||||
|
|
||||||
// Handle OpenAI accounts
|
// Handle OpenAI accounts
|
||||||
if account.IsOpenAI() {
|
if account.IsOpenAI() {
|
||||||
// For OAuth accounts: return default OpenAI models
|
// OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。
|
||||||
if account.IsOAuth() {
|
if account.IsOpenAIPassthroughEnabled() {
|
||||||
response.Success(c, openai.DefaultModels)
|
response.Success(c, openai.DefaultModels)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// For API Key accounts: check model_mapping
|
|
||||||
mapping := account.GetModelMapping()
|
mapping := account.GetModelMapping()
|
||||||
if len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
response.Success(c, openai.DefaultModels)
|
response.Success(c, openai.DefaultModels)
|
||||||
|
|||||||
@@ -0,0 +1,105 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type availableModelsAdminService struct {
|
||||||
|
*stubAdminService
|
||||||
|
account service.Account
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
|
||||||
|
if s.account.ID == id {
|
||||||
|
acc := s.account
|
||||||
|
return &acc, nil
|
||||||
|
}
|
||||||
|
return s.stubAdminService.GetAccount(context.Background(), id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||||
|
svc := &availableModelsAdminService{
|
||||||
|
stubAdminService: newStubAdminService(),
|
||||||
|
account: service.Account{
|
||||||
|
ID: 42,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Type: service.AccountTypeOAuth,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5": "gpt-5.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router := setupAvailableModelsRouter(svc)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Data []struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
|
require.Len(t, resp.Data, 1)
|
||||||
|
require.Equal(t, "gpt-5", resp.Data[0].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
|
||||||
|
svc := &availableModelsAdminService{
|
||||||
|
stubAdminService: newStubAdminService(),
|
||||||
|
account: service.Account{
|
||||||
|
ID: 43,
|
||||||
|
Name: "openai-oauth-passthrough",
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Type: service.AccountTypeOAuth,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5": "gpt-5.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_passthrough": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router := setupAvailableModelsRouter(svc)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Data []struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
|
require.NotEmpty(t, resp.Data)
|
||||||
|
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||||
|
}
|
||||||
@@ -175,6 +175,18 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
|||||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) ([]service.UserGroupRateEntry, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) ClearGroupRateMultipliers(_ context.Context, _ int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int64, _ []service.GroupRateMultiplierInput) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||||
return s.accounts, int64(len(s.accounts)), nil
|
return s.accounts, int64(len(s.accounts)), nil
|
||||||
}
|
}
|
||||||
@@ -429,5 +441,9 @@ func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) EnsureOpenAIPrivacy(ctx context.Context, account *service.Account) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure stub implements interface.
|
// Ensure stub implements interface.
|
||||||
var _ service.AdminService = (*stubAdminService)(nil)
|
var _ service.AdminService = (*stubAdminService)(nil)
|
||||||
|
|||||||
204
backend/internal/handler/admin/backup_handler.go
Normal file
204
backend/internal/handler/admin/backup_handler.go
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BackupHandler struct {
|
||||||
|
backupService *service.BackupService
|
||||||
|
userService *service.UserService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
|
||||||
|
return &BackupHandler{
|
||||||
|
backupService: backupService,
|
||||||
|
userService: userService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── S3 配置 ───
|
||||||
|
|
||||||
|
func (h *BackupHandler) GetS3Config(c *gin.Context) {
|
||||||
|
cfg, err := h.backupService.GetS3Config(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
|
||||||
|
var req service.BackupS3Config
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
|
||||||
|
var req service.BackupS3Config
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := h.backupService.TestS3Connection(c.Request.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
response.Success(c, gin.H{"ok": false, "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 定时备份 ───
|
||||||
|
|
||||||
|
func (h *BackupHandler) GetSchedule(c *gin.Context) {
|
||||||
|
cfg, err := h.backupService.GetSchedule(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
|
||||||
|
var req service.BackupScheduleConfig
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 备份操作 ───
|
||||||
|
|
||||||
|
type CreateBackupRequest struct {
|
||||||
|
ExpireDays *int `json:"expire_days"` // nil=使用默认值14,0=永不过期
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) CreateBackup(c *gin.Context) {
|
||||||
|
var req CreateBackupRequest
|
||||||
|
_ = c.ShouldBindJSON(&req) // 允许空 body
|
||||||
|
|
||||||
|
expireDays := 14 // 默认14天过期
|
||||||
|
if req.ExpireDays != nil {
|
||||||
|
expireDays = *req.ExpireDays
|
||||||
|
}
|
||||||
|
|
||||||
|
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, record)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
||||||
|
records, err := h.backupService.ListBackups(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if records == nil {
|
||||||
|
records = []service.BackupRecord{}
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"items": records})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) GetBackup(c *gin.Context) {
|
||||||
|
backupID := c.Param("id")
|
||||||
|
if backupID == "" {
|
||||||
|
response.BadRequest(c, "backup ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, record)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
|
||||||
|
backupID := c.Param("id")
|
||||||
|
if backupID == "" {
|
||||||
|
response.BadRequest(c, "backup ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"deleted": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
|
||||||
|
backupID := c.Param("id")
|
||||||
|
if backupID == "" {
|
||||||
|
response.BadRequest(c, "backup ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"url": url})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 恢复操作(需要重新输入管理员密码) ───
|
||||||
|
|
||||||
|
type RestoreBackupRequest struct {
|
||||||
|
Password string `json:"password" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
|
||||||
|
backupID := c.Param("id")
|
||||||
|
if backupID == "" {
|
||||||
|
response.BadRequest(c, "backup ID is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req RestoreBackupRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "password is required for restore operation")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从上下文获取当前管理员用户 ID
|
||||||
|
sub, ok := middleware.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
response.Unauthorized(c, "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取管理员用户并验证密码
|
||||||
|
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !user.CheckPassword(req.Password) {
|
||||||
|
response.BadRequest(c, "incorrect admin password")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"restored": true})
|
||||||
|
}
|
||||||
@@ -466,9 +466,60 @@ type BatchUsersUsageRequest struct {
|
|||||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||||
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
||||||
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
||||||
|
|
||||||
|
func parseRankingLimit(raw string) int {
|
||||||
|
limit, err := strconv.Atoi(strings.TrimSpace(raw))
|
||||||
|
if err != nil || limit <= 0 {
|
||||||
|
return 12
|
||||||
|
}
|
||||||
|
if limit > 50 {
|
||||||
|
return 50
|
||||||
|
}
|
||||||
|
return limit
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserSpendingRanking handles getting user spending ranking data.
|
||||||
|
// GET /api/v1/admin/dashboard/users-ranking
|
||||||
|
func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
|
||||||
|
startTime, endTime := parseTimeRange(c)
|
||||||
|
limit := parseRankingLimit(c.DefaultQuery("limit", "12"))
|
||||||
|
|
||||||
|
keyRaw, _ := json.Marshal(struct {
|
||||||
|
Start string `json:"start"`
|
||||||
|
End string `json:"end"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
}{
|
||||||
|
Start: startTime.UTC().Format(time.RFC3339),
|
||||||
|
End: endTime.UTC().Format(time.RFC3339),
|
||||||
|
Limit: limit,
|
||||||
|
})
|
||||||
|
cacheKey := string(keyRaw)
|
||||||
|
if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok {
|
||||||
|
c.Header("X-Snapshot-Cache", "hit")
|
||||||
|
response.Success(c, cached.Payload)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit)
|
||||||
|
if err != nil {
|
||||||
|
response.Error(c, 500, "Failed to get user spending ranking")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := gin.H{
|
||||||
|
"ranking": ranking.Ranking,
|
||||||
|
"total_actual_cost": ranking.TotalActualCost,
|
||||||
|
"start_date": startTime.Format("2006-01-02"),
|
||||||
|
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||||
|
}
|
||||||
|
dashboardUsersRankingCache.Set(cacheKey, payload)
|
||||||
|
c.Header("X-Snapshot-Cache", "miss")
|
||||||
|
response.Success(c, payload)
|
||||||
|
}
|
||||||
|
|
||||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||||
// POST /api/v1/admin/dashboard/users-usage
|
// POST /api/v1/admin/dashboard/users-usage
|
||||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ type dashboardUsageRepoCapture struct {
|
|||||||
trendStream *bool
|
trendStream *bool
|
||||||
modelRequestType *int16
|
modelRequestType *int16
|
||||||
modelStream *bool
|
modelStream *bool
|
||||||
|
rankingLimit int
|
||||||
|
ranking []usagestats.UserSpendingRankingItem
|
||||||
|
rankingTotal float64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
||||||
@@ -49,6 +52,18 @@ func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
|
|||||||
return []usagestats.ModelStat{}, nil
|
return []usagestats.ModelStat{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
|
||||||
|
ctx context.Context,
|
||||||
|
startTime, endTime time.Time,
|
||||||
|
limit int,
|
||||||
|
) (*usagestats.UserSpendingRankingResponse, error) {
|
||||||
|
s.rankingLimit = limit
|
||||||
|
return &usagestats.UserSpendingRankingResponse{
|
||||||
|
Ranking: s.ranking,
|
||||||
|
TotalActualCost: s.rankingTotal,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||||
@@ -56,6 +71,7 @@ func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Eng
|
|||||||
router := gin.New()
|
router := gin.New()
|
||||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||||
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
||||||
|
router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking)
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,3 +146,30 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) {
|
|||||||
|
|
||||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
|
||||||
|
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
|
||||||
|
repo := &dashboardUsageRepoCapture{
|
||||||
|
ranking: []usagestats.UserSpendingRankingItem{
|
||||||
|
{UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300},
|
||||||
|
},
|
||||||
|
rankingTotal: 88.8,
|
||||||
|
}
|
||||||
|
router := newDashboardRequestTypeTestRouter(repo)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.Equal(t, 50, repo.rankingLimit)
|
||||||
|
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
|
||||||
|
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
|
||||||
|
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rec2, req2)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec2.Code)
|
||||||
|
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -16,6 +19,55 @@ type GroupHandler struct {
|
|||||||
adminService service.AdminService
|
adminService service.AdminService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type optionalLimitField struct {
|
||||||
|
set bool
|
||||||
|
value *float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *optionalLimitField) UnmarshalJSON(data []byte) error {
|
||||||
|
f.set = true
|
||||||
|
|
||||||
|
trimmed := bytes.TrimSpace(data)
|
||||||
|
if bytes.Equal(trimmed, []byte("null")) {
|
||||||
|
f.value = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var number float64
|
||||||
|
if err := json.Unmarshal(trimmed, &number); err == nil {
|
||||||
|
f.value = &number
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var text string
|
||||||
|
if err := json.Unmarshal(trimmed, &text); err == nil {
|
||||||
|
text = strings.TrimSpace(text)
|
||||||
|
if text == "" {
|
||||||
|
f.value = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
number, err = strconv.ParseFloat(text, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid numeric limit value %q: %w", text, err)
|
||||||
|
}
|
||||||
|
f.value = &number
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("invalid limit value: %s", string(trimmed))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f optionalLimitField) ToServiceInput() *float64 {
|
||||||
|
if !f.set {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if f.value != nil {
|
||||||
|
return f.value
|
||||||
|
}
|
||||||
|
zero := 0.0
|
||||||
|
return &zero
|
||||||
|
}
|
||||||
|
|
||||||
// NewGroupHandler creates a new admin group handler
|
// NewGroupHandler creates a new admin group handler
|
||||||
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||||
return &GroupHandler{
|
return &GroupHandler{
|
||||||
@@ -31,9 +83,9 @@ type CreateGroupRequest struct {
|
|||||||
RateMultiplier float64 `json:"rate_multiplier"`
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
IsExclusive bool `json:"is_exclusive"`
|
IsExclusive bool `json:"is_exclusive"`
|
||||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
@@ -69,9 +121,9 @@ type UpdateGroupRequest struct {
|
|||||||
IsExclusive *bool `json:"is_exclusive"`
|
IsExclusive *bool `json:"is_exclusive"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
@@ -191,9 +243,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
RateMultiplier: req.RateMultiplier,
|
RateMultiplier: req.RateMultiplier,
|
||||||
IsExclusive: req.IsExclusive,
|
IsExclusive: req.IsExclusive,
|
||||||
SubscriptionType: req.SubscriptionType,
|
SubscriptionType: req.SubscriptionType,
|
||||||
DailyLimitUSD: req.DailyLimitUSD,
|
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
@@ -244,9 +296,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
IsExclusive: req.IsExclusive,
|
IsExclusive: req.IsExclusive,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
SubscriptionType: req.SubscriptionType,
|
SubscriptionType: req.SubscriptionType,
|
||||||
DailyLimitUSD: req.DailyLimitUSD,
|
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
@@ -335,6 +387,72 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
|||||||
response.Paginated(c, outKeys, total, page, pageSize)
|
response.Paginated(c, outKeys, total, page, pageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetGroupRateMultipliers handles getting rate multipliers for users in a group
|
||||||
|
// GET /api/v1/admin/groups/:id/rate-multipliers
|
||||||
|
func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
|
||||||
|
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid group ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := h.adminService.GetGroupRateMultipliers(c.Request.Context(), groupID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if entries == nil {
|
||||||
|
entries = []service.UserGroupRateEntry{}
|
||||||
|
}
|
||||||
|
response.Success(c, entries)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearGroupRateMultipliers handles clearing all rate multipliers for a group
|
||||||
|
// DELETE /api/v1/admin/groups/:id/rate-multipliers
|
||||||
|
func (h *GroupHandler) ClearGroupRateMultipliers(c *gin.Context) {
|
||||||
|
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid group ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.adminService.ClearGroupRateMultipliers(c.Request.Context(), groupID); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "Rate multipliers cleared successfully"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSetGroupRateMultipliersRequest represents batch set rate multipliers request
|
||||||
|
type BatchSetGroupRateMultipliersRequest struct {
|
||||||
|
Entries []service.GroupRateMultiplierInput `json:"entries" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSetGroupRateMultipliers handles batch setting rate multipliers for a group
|
||||||
|
// PUT /api/v1/admin/groups/:id/rate-multipliers
|
||||||
|
func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
|
||||||
|
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid group ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req BatchSetGroupRateMultipliersRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.adminService.BatchSetGroupRateMultipliers(c.Request.Context(), groupID, req.Entries); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||||
type UpdateSortOrderRequest struct {
|
type UpdateSortOrderRequest struct {
|
||||||
Updates []struct {
|
Updates []struct {
|
||||||
|
|||||||
@@ -289,6 +289,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
|||||||
Platform: platform,
|
Platform: platform,
|
||||||
Type: "oauth",
|
Type: "oauth",
|
||||||
Credentials: credentials,
|
Credentials: credentials,
|
||||||
|
Extra: nil,
|
||||||
ProxyID: req.ProxyID,
|
ProxyID: req.ProxyID,
|
||||||
Concurrency: req.Concurrency,
|
Concurrency: req.Concurrency,
|
||||||
Priority: req.Priority,
|
Priority: req.Priority,
|
||||||
|
|||||||
@@ -41,11 +41,14 @@ type GenerateRedeemCodesRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
||||||
|
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。
|
||||||
type CreateAndRedeemCodeRequest struct {
|
type CreateAndRedeemCodeRequest struct {
|
||||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
|
||||||
Value float64 `json:"value" binding:"required,gt=0"`
|
Value float64 `json:"value" binding:"required,gt=0"`
|
||||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||||
|
GroupID *int64 `json:"group_id"` // subscription 类型必填
|
||||||
|
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0
|
||||||
Notes string `json:"notes"`
|
Notes string `json:"notes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,6 +139,22 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Code = strings.TrimSpace(req.Code)
|
req.Code = strings.TrimSpace(req.Code)
|
||||||
|
// 向后兼容:旧版调用方(如 Sub2ApiPay)不传 type 字段,默认当作 balance 充值处理。
|
||||||
|
// 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。
|
||||||
|
if req.Type == "" {
|
||||||
|
req.Type = "balance"
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Type == "subscription" {
|
||||||
|
if req.GroupID == nil {
|
||||||
|
response.BadRequest(c, "group_id is required for subscription type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.ValidityDays <= 0 {
|
||||||
|
response.BadRequest(c, "validity_days must be greater than 0 for subscription type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||||
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
||||||
@@ -152,6 +171,8 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
|||||||
Value: req.Value,
|
Value: req.Value,
|
||||||
Status: service.StatusUnused,
|
Status: service.StatusUnused,
|
||||||
Notes: req.Notes,
|
Notes: req.Notes,
|
||||||
|
GroupID: req.GroupID,
|
||||||
|
ValidityDays: req.ValidityDays,
|
||||||
})
|
})
|
||||||
if createErr != nil {
|
if createErr != nil {
|
||||||
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
||||||
|
|||||||
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
135
backend/internal/handler/admin/redeem_handler_test.go
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal)
|
||||||
|
// RedeemService so that CreateAndRedeem's nil guard passes and we can test the
|
||||||
|
// parameter-validation layer that runs before any service call.
|
||||||
|
func newCreateAndRedeemHandler() *RedeemHandler {
|
||||||
|
return &RedeemHandler{
|
||||||
|
adminService: newStubAdminService(),
|
||||||
|
redeemService: &service.RedeemService{}, // non-nil to pass nil guard
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response
|
||||||
|
// status code. For cases that pass validation and proceed into the service layer,
|
||||||
|
// a panic may occur (because RedeemService internals are nil); this is expected
|
||||||
|
// and treated as "validation passed" (returns 0 to indicate panic).
|
||||||
|
func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) {
|
||||||
|
t.Helper()
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
jsonBytes, err := json.Marshal(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
// Panic means we passed validation and entered service layer (expected for minimal stub).
|
||||||
|
code = 0
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
handler.CreateAndRedeem(c)
|
||||||
|
return w.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) {
|
||||||
|
// 不传 type 字段时应默认 balance,不触发 subscription 校验。
|
||||||
|
// 验证通过后进入 service 层会 panic(返回 0),说明默认值生效。
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-balance-default",
|
||||||
|
"value": 10.0,
|
||||||
|
"user_id": 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||||
|
"omitting type should default to balance and pass validation")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-sub-no-group",
|
||||||
|
"type": "subscription",
|
||||||
|
"value": 29.9,
|
||||||
|
"user_id": 1,
|
||||||
|
"validity_days": 30,
|
||||||
|
// group_id 缺失
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) {
|
||||||
|
groupID := int64(5)
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
validityDays int
|
||||||
|
}{
|
||||||
|
{"zero", 0},
|
||||||
|
{"negative", -1},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-sub-bad-days-" + tc.name,
|
||||||
|
"type": "subscription",
|
||||||
|
"value": 29.9,
|
||||||
|
"user_id": 1,
|
||||||
|
"group_id": groupID,
|
||||||
|
"validity_days": tc.validityDays,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
|
||||||
|
groupID := int64(5)
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-sub-valid",
|
||||||
|
"type": "subscription",
|
||||||
|
"value": 29.9,
|
||||||
|
"user_id": 1,
|
||||||
|
"group_id": groupID,
|
||||||
|
"validity_days": 31,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||||
|
"valid subscription params should pass validation")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
|
||||||
|
h := newCreateAndRedeemHandler()
|
||||||
|
// balance 类型不传 group_id 和 validity_days,不应报 400
|
||||||
|
code := postCreateAndRedeemValidation(t, h, map[string]any{
|
||||||
|
"code": "test-balance-no-extras",
|
||||||
|
"type": "balance",
|
||||||
|
"value": 50.0,
|
||||||
|
"user_id": 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||||
|
"balance type should not require group_id or validity_days")
|
||||||
|
}
|
||||||
@@ -80,6 +80,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||||
|
FrontendURL: settings.FrontendURL,
|
||||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||||
TotpEnabled: settings.TotpEnabled,
|
TotpEnabled: settings.TotpEnabled,
|
||||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||||
@@ -125,6 +126,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||||
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
||||||
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
||||||
|
BackendModeEnabled: settings.BackendModeEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,6 +138,7 @@ type UpdateSettingsRequest struct {
|
|||||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
|
FrontendURL string `json:"frontend_url"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
|
|
||||||
@@ -199,6 +202,9 @@ type UpdateSettingsRequest struct {
|
|||||||
|
|
||||||
// 分组隔离
|
// 分组隔离
|
||||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||||
|
|
||||||
|
// Backend Mode
|
||||||
|
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSettings 更新系统设置
|
// UpdateSettings 更新系统设置
|
||||||
@@ -322,6 +328,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Frontend URL 验证
|
||||||
|
req.FrontendURL = strings.TrimSpace(req.FrontendURL)
|
||||||
|
if req.FrontendURL != "" {
|
||||||
|
if err := config.ValidateAbsoluteHTTPURL(req.FrontendURL); err != nil {
|
||||||
|
response.BadRequest(c, "Frontend URL must be an absolute http(s) URL")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 自定义菜单项验证
|
// 自定义菜单项验证
|
||||||
const (
|
const (
|
||||||
maxCustomMenuItems = 20
|
maxCustomMenuItems = 20
|
||||||
@@ -433,6 +448,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||||
|
FrontendURL: req.FrontendURL,
|
||||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||||
TotpEnabled: req.TotpEnabled,
|
TotpEnabled: req.TotpEnabled,
|
||||||
SMTPHost: req.SMTPHost,
|
SMTPHost: req.SMTPHost,
|
||||||
@@ -473,6 +489,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||||
|
BackendModeEnabled: req.BackendModeEnabled,
|
||||||
OpsMonitoringEnabled: func() bool {
|
OpsMonitoringEnabled: func() bool {
|
||||||
if req.OpsMonitoringEnabled != nil {
|
if req.OpsMonitoringEnabled != nil {
|
||||||
return *req.OpsMonitoringEnabled
|
return *req.OpsMonitoringEnabled
|
||||||
@@ -526,6 +543,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
||||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||||
|
FrontendURL: updatedSettings.FrontendURL,
|
||||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||||
TotpEnabled: updatedSettings.TotpEnabled,
|
TotpEnabled: updatedSettings.TotpEnabled,
|
||||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||||
@@ -571,6 +589,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||||
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
||||||
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
||||||
|
BackendModeEnabled: updatedSettings.BackendModeEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -608,6 +627,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||||
changed = append(changed, "password_reset_enabled")
|
changed = append(changed, "password_reset_enabled")
|
||||||
}
|
}
|
||||||
|
if before.FrontendURL != after.FrontendURL {
|
||||||
|
changed = append(changed, "frontend_url")
|
||||||
|
}
|
||||||
if before.TotpEnabled != after.TotpEnabled {
|
if before.TotpEnabled != after.TotpEnabled {
|
||||||
changed = append(changed, "totp_enabled")
|
changed = append(changed, "totp_enabled")
|
||||||
}
|
}
|
||||||
@@ -725,6 +747,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
||||||
changed = append(changed, "allow_ungrouped_key_scheduling")
|
changed = append(changed, "allow_ungrouped_key_scheduling")
|
||||||
}
|
}
|
||||||
|
if before.BackendModeEnabled != after.BackendModeEnabled {
|
||||||
|
changed = append(changed, "backend_mode_enabled")
|
||||||
|
}
|
||||||
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
||||||
changed = append(changed, "purchase_subscription_enabled")
|
changed = append(changed, "purchase_subscription_enabled")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -220,9 +220,10 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
|||||||
type ResetSubscriptionQuotaRequest struct {
|
type ResetSubscriptionQuotaRequest struct {
|
||||||
Daily bool `json:"daily"`
|
Daily bool `json:"daily"`
|
||||||
Weekly bool `json:"weekly"`
|
Weekly bool `json:"weekly"`
|
||||||
|
Monthly bool `json:"monthly"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResetQuota resets daily and/or weekly usage for a subscription.
|
// ResetQuota resets daily, weekly, and/or monthly usage for a subscription.
|
||||||
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
// POST /api/v1/admin/subscriptions/:id/reset-quota
|
||||||
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
||||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
@@ -235,11 +236,11 @@ func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
|
|||||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !req.Daily && !req.Weekly {
|
if !req.Daily && !req.Weekly && !req.Monthly {
|
||||||
response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true")
|
response.BadRequest(c, "At least one of 'daily', 'weekly', or 'monthly' must be true")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly)
|
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly, req.Monthly)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Backend mode: only admin can login
|
||||||
|
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||||
|
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
h.respondWithTokenPair(c, user)
|
h.respondWithTokenPair(c, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete the login session
|
// Get the user (before session deletion so we can check backend mode)
|
||||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
|
||||||
|
|
||||||
// Get the user
|
|
||||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Backend mode: only admin can login (check BEFORE deleting session)
|
||||||
|
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||||
|
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the login session (only after all checks pass)
|
||||||
|
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||||
|
|
||||||
h.respondWithTokenPair(c, user)
|
h.respondWithTokenPair(c, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -447,9 +459,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
|
frontendBaseURL := strings.TrimSpace(h.settingSvc.GetFrontendURL(c.Request.Context()))
|
||||||
if frontendBaseURL == "" {
|
if frontendBaseURL == "" {
|
||||||
slog.Error("server.frontend_url not configured; cannot build password reset link")
|
slog.Error("frontend_url not configured in settings or config; cannot build password reset link")
|
||||||
response.InternalError(c, "Password reset is not configured")
|
response.InternalError(c, "Password reset is not configured")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Backend mode: block non-admin token refresh
|
||||||
|
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
|
||||||
|
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
response.Success(c, RefreshTokenResponse{
|
response.Success(c, RefreshTokenResponse{
|
||||||
AccessToken: tokenPair.AccessToken,
|
AccessToken: result.AccessToken,
|
||||||
RefreshToken: tokenPair.RefreshToken,
|
RefreshToken: result.RefreshToken,
|
||||||
ExpiresIn: tokenPair.ExpiresIn,
|
ExpiresIn: result.ExpiresIn,
|
||||||
TokenType: "Bearer",
|
TokenType: "Bearer",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -264,8 +264,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
|
// 提取账号配额限制(apikey / bedrock 类型有效)
|
||||||
if a.Type == service.AccountTypeAPIKey {
|
if a.IsAPIKeyOrBedrock() {
|
||||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||||
out.QuotaLimit = &limit
|
out.QuotaLimit = &limit
|
||||||
used := a.GetQuotaUsed()
|
used := a.GetQuotaUsed()
|
||||||
@@ -281,6 +281,31 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
|||||||
used := a.GetQuotaWeeklyUsed()
|
used := a.GetQuotaWeeklyUsed()
|
||||||
out.QuotaWeeklyUsed = &used
|
out.QuotaWeeklyUsed = &used
|
||||||
}
|
}
|
||||||
|
// 固定时间重置配置
|
||||||
|
if mode := a.GetQuotaDailyResetMode(); mode == "fixed" {
|
||||||
|
out.QuotaDailyResetMode = &mode
|
||||||
|
hour := a.GetQuotaDailyResetHour()
|
||||||
|
out.QuotaDailyResetHour = &hour
|
||||||
|
}
|
||||||
|
if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" {
|
||||||
|
out.QuotaWeeklyResetMode = &mode
|
||||||
|
day := a.GetQuotaWeeklyResetDay()
|
||||||
|
out.QuotaWeeklyResetDay = &day
|
||||||
|
hour := a.GetQuotaWeeklyResetHour()
|
||||||
|
out.QuotaWeeklyResetHour = &hour
|
||||||
|
}
|
||||||
|
if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||||
|
tz := a.GetQuotaResetTimezone()
|
||||||
|
out.QuotaResetTimezone = &tz
|
||||||
|
}
|
||||||
|
if a.Extra != nil {
|
||||||
|
if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" {
|
||||||
|
out.QuotaDailyResetAt = &v
|
||||||
|
}
|
||||||
|
if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" {
|
||||||
|
out.QuotaWeeklyResetAt = &v
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
return out
|
||||||
@@ -498,6 +523,8 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
|||||||
Model: l.Model,
|
Model: l.Model,
|
||||||
ServiceTier: l.ServiceTier,
|
ServiceTier: l.ServiceTier,
|
||||||
ReasoningEffort: l.ReasoningEffort,
|
ReasoningEffort: l.ReasoningEffort,
|
||||||
|
InboundEndpoint: l.InboundEndpoint,
|
||||||
|
UpstreamEndpoint: l.UpstreamEndpoint,
|
||||||
GroupID: l.GroupID,
|
GroupID: l.GroupID,
|
||||||
SubscriptionID: l.SubscriptionID,
|
SubscriptionID: l.SubscriptionID,
|
||||||
InputTokens: l.InputTokens,
|
InputTokens: l.InputTokens,
|
||||||
|
|||||||
@@ -76,10 +76,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
serviceTier := "priority"
|
serviceTier := "priority"
|
||||||
|
inboundEndpoint := "/v1/chat/completions"
|
||||||
|
upstreamEndpoint := "/v1/responses"
|
||||||
log := &service.UsageLog{
|
log := &service.UsageLog{
|
||||||
RequestID: "req_3",
|
RequestID: "req_3",
|
||||||
Model: "gpt-5.4",
|
Model: "gpt-5.4",
|
||||||
ServiceTier: &serviceTier,
|
ServiceTier: &serviceTier,
|
||||||
|
InboundEndpoint: &inboundEndpoint,
|
||||||
|
UpstreamEndpoint: &upstreamEndpoint,
|
||||||
AccountRateMultiplier: f64Ptr(1.5),
|
AccountRateMultiplier: f64Ptr(1.5),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,8 +92,16 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
|||||||
|
|
||||||
require.NotNil(t, userDTO.ServiceTier)
|
require.NotNil(t, userDTO.ServiceTier)
|
||||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||||
|
require.NotNil(t, userDTO.InboundEndpoint)
|
||||||
|
require.Equal(t, inboundEndpoint, *userDTO.InboundEndpoint)
|
||||||
|
require.NotNil(t, userDTO.UpstreamEndpoint)
|
||||||
|
require.Equal(t, upstreamEndpoint, *userDTO.UpstreamEndpoint)
|
||||||
require.NotNil(t, adminDTO.ServiceTier)
|
require.NotNil(t, adminDTO.ServiceTier)
|
||||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||||
|
require.NotNil(t, adminDTO.InboundEndpoint)
|
||||||
|
require.Equal(t, inboundEndpoint, *adminDTO.InboundEndpoint)
|
||||||
|
require.NotNil(t, adminDTO.UpstreamEndpoint)
|
||||||
|
require.Equal(t, upstreamEndpoint, *adminDTO.UpstreamEndpoint)
|
||||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ type SystemSettings struct {
|
|||||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||||
|
FrontendURL string `json:"frontend_url"`
|
||||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||||
@@ -81,6 +82,9 @@ type SystemSettings struct {
|
|||||||
|
|
||||||
// 分组隔离
|
// 分组隔离
|
||||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||||
|
|
||||||
|
// Backend Mode
|
||||||
|
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultSubscriptionSetting struct {
|
type DefaultSubscriptionSetting struct {
|
||||||
@@ -111,6 +115,7 @@ type PublicSettings struct {
|
|||||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||||
|
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -203,6 +203,16 @@ type Account struct {
|
|||||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||||
|
|
||||||
|
// 配额固定时间重置配置
|
||||||
|
QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"`
|
||||||
|
QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"`
|
||||||
|
QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"`
|
||||||
|
QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"`
|
||||||
|
QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"`
|
||||||
|
QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"`
|
||||||
|
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
|
||||||
|
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
|
||||||
|
|
||||||
Proxy *Proxy `json:"proxy,omitempty"`
|
Proxy *Proxy `json:"proxy,omitempty"`
|
||||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||||
|
|
||||||
@@ -324,9 +334,13 @@ type UsageLog struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||||
ServiceTier *string `json:"service_tier,omitempty"`
|
ServiceTier *string `json:"service_tier,omitempty"`
|
||||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
// ReasoningEffort is the request's reasoning effort level.
|
||||||
// nil means not provided / not applicable.
|
// OpenAI: "low"/"medium"/"high"/"xhigh"; Claude: "low"/"medium"/"high"/"max".
|
||||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||||
|
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
|
||||||
|
InboundEndpoint *string `json:"inbound_endpoint,omitempty"`
|
||||||
|
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
|
||||||
|
UpstreamEndpoint *string `json:"upstream_endpoint,omitempty"`
|
||||||
|
|
||||||
GroupID *int64 `json:"group_id"`
|
GroupID *int64 `json:"group_id"`
|
||||||
SubscriptionID *int64 `json:"subscription_id"`
|
SubscriptionID *int64 `json:"subscription_id"`
|
||||||
|
|||||||
@@ -391,6 +391,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if fs.SwitchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||||
}
|
}
|
||||||
|
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||||
|
writerSizeBeforeForward := c.Writer.Size()
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity {
|
||||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||||
} else {
|
} else {
|
||||||
@@ -402,6 +404,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
|
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||||
|
if c.Writer.Size() != writerSizeBeforeForward {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||||
switch action {
|
switch action {
|
||||||
case FailoverContinue:
|
case FailoverContinue:
|
||||||
@@ -434,6 +441,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
|
if result.ReasoningEffort == nil {
|
||||||
|
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||||
|
}
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
@@ -445,6 +457,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@@ -635,6 +648,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if fs.SwitchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||||
}
|
}
|
||||||
|
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||||
|
writerSizeBeforeForward := c.Writer.Size()
|
||||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||||
} else {
|
} else {
|
||||||
@@ -704,6 +719,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
|
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||||
|
if c.Writer.Size() != writerSizeBeforeForward {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, account.Platform, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||||
switch action {
|
switch action {
|
||||||
case FailoverContinue:
|
case FailoverContinue:
|
||||||
@@ -736,6 +756,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
|
if result.ReasoningEffort == nil {
|
||||||
|
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||||
|
}
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
@@ -747,6 +772,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
Subscription: currentSubscription,
|
Subscription: currentSubscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
|
|||||||
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// partialMessageStartSSE 模拟 handleStreamingResponse 已写入的首批 SSE 事件。
|
||||||
|
const partialMessageStartSSE = "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_01\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-5\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":1}}}\n\n" +
|
||||||
|
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"
|
||||||
|
|
||||||
|
// TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten 验证:
|
||||||
|
// 当 Forward 在返回 UpstreamFailoverError 前已向客户端写入 SSE 内容时,
|
||||||
|
// 故障转移保护逻辑必须终止循环并发送 SSE 错误事件,而不是进行下一次 Forward。
|
||||||
|
// 具体验证:
|
||||||
|
// 1. c.Writer.Size() 检测条件正确触发(字节数已增加)
|
||||||
|
// 2. handleFailoverExhausted 以 streamStarted=true 调用后,响应体以 SSE 错误事件结尾
|
||||||
|
// 3. 响应体中只出现一个 message_start,不存在第二个(防止流拼接腐化)
|
||||||
|
func TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
// 步骤 1:记录 Forward 前的 writer size(模拟 writerSizeBeforeForward := c.Writer.Size())
|
||||||
|
sizeBeforeForward := c.Writer.Size()
|
||||||
|
require.Equal(t, -1, sizeBeforeForward, "gin writer 初始 Size 应为 -1(未写入任何字节)")
|
||||||
|
|
||||||
|
// 步骤 2:模拟 Forward 已向客户端写入部分 SSE 内容(message_start + content_block_start)
|
||||||
|
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 步骤 3:验证守卫条件成立(c.Writer.Size() != sizeBeforeForward)
|
||||||
|
require.NotEqual(t, sizeBeforeForward, c.Writer.Size(),
|
||||||
|
"写入 SSE 内容后 writer size 必须增加,守卫条件应为 true")
|
||||||
|
|
||||||
|
// 步骤 4:模拟 UpstreamFailoverError(上游在流中途返回 403)
|
||||||
|
failoverErr := &service.UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusForbidden,
|
||||||
|
ResponseBody: []byte(`{"error":{"type":"permission_error","message":"forbidden"}}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 步骤 5:守卫触发 → 调用 handleFailoverExhausted,streamStarted=true
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, service.PlatformAnthropic, true)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
// 断言 A:响应体中包含最初写入的 message_start SSE 事件行
|
||||||
|
require.Contains(t, body, "event: message_start", "响应体应包含已写入的 message_start SSE 事件")
|
||||||
|
|
||||||
|
// 断言 B:响应体以 SSE 错误事件结尾(data: {"type":"error",...}\n\n)
|
||||||
|
require.True(t, strings.HasSuffix(strings.TrimRight(body, "\n"), "}"),
|
||||||
|
"响应体应以 JSON 对象结尾(SSE error event 的 data 字段)")
|
||||||
|
require.Contains(t, body, `"type":"error"`, "响应体末尾必须包含 SSE 错误事件")
|
||||||
|
|
||||||
|
// 断言 C:SSE event 行 "event: message_start" 只出现一次(防止双 message_start 拼接腐化)
|
||||||
|
firstIdx := strings.Index(body, "event: message_start")
|
||||||
|
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||||
|
assert.Equal(t, firstIdx, lastIdx,
|
||||||
|
"响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten 与上述测试相同,
|
||||||
|
// 验证 Gemini 路径使用 service.PlatformGemini(而非 account.Platform)时行为一致。
|
||||||
|
func TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.0-flash:streamGenerateContent", nil)
|
||||||
|
|
||||||
|
sizeBeforeForward := c.Writer.Size()
|
||||||
|
|
||||||
|
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NotEqual(t, sizeBeforeForward, c.Writer.Size())
|
||||||
|
|
||||||
|
failoverErr := &service.UpstreamFailoverError{
|
||||||
|
StatusCode: http.StatusForbidden,
|
||||||
|
}
|
||||||
|
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
|
||||||
|
require.Contains(t, body, "event: message_start")
|
||||||
|
require.Contains(t, body, `"type":"error"`)
|
||||||
|
|
||||||
|
firstIdx := strings.Index(body, "event: message_start")
|
||||||
|
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||||
|
assert.Equal(t, firstIdx, lastIdx, "Gemini 路径不得出现双 message_start")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered 验证反向场景:
|
||||||
|
// 当 Forward 返回 UpstreamFailoverError 时若未向客户端写入任何 SSE 内容,
|
||||||
|
// 守卫条件(c.Writer.Size() != sizeBeforeForward)为 false,不应中止 failover。
|
||||||
|
func TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
// 模拟 writerSizeBeforeForward:初始为 -1
|
||||||
|
sizeBeforeForward := c.Writer.Size()
|
||||||
|
|
||||||
|
// Forward 未写入任何字节直接返回错误(例如 401 发生在连接建立前)
|
||||||
|
// c.Writer.Size() 仍为 -1
|
||||||
|
|
||||||
|
// 守卫条件:sizeBeforeForward == c.Writer.Size() → 不触发
|
||||||
|
guardTriggered := c.Writer.Size() != sizeBeforeForward
|
||||||
|
require.False(t, guardTriggered,
|
||||||
|
"未写入任何字节时,守卫条件必须为 false,应允许正常 failover 继续")
|
||||||
|
}
|
||||||
@@ -139,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
|||||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||||
&fakeGroupRepo{group: group},
|
&fakeGroupRepo{group: group},
|
||||||
nil, // usageLogRepo
|
nil, // usageLogRepo
|
||||||
|
nil, // usageBillingRepo
|
||||||
nil, // userRepo
|
nil, // userRepo
|
||||||
nil, // userSubRepo
|
nil, // userSubRepo
|
||||||
nil, // userGroupRateRepo
|
nil, // userGroupRateRepo
|
||||||
|
|||||||
@@ -503,6 +503,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
@@ -512,6 +513,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ type AdminHandlers struct {
|
|||||||
Account *admin.AccountHandler
|
Account *admin.AccountHandler
|
||||||
Announcement *admin.AnnouncementHandler
|
Announcement *admin.AnnouncementHandler
|
||||||
DataManagement *admin.DataManagementHandler
|
DataManagement *admin.DataManagementHandler
|
||||||
|
Backup *admin.BackupHandler
|
||||||
OAuth *admin.OAuthHandler
|
OAuth *admin.OAuthHandler
|
||||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||||
GeminiOAuth *admin.GeminiOAuthHandler
|
GeminiOAuth *admin.GeminiOAuthHandler
|
||||||
|
|||||||
@@ -181,13 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||||
forwardStart := time.Now()
|
forwardStart := time.Now()
|
||||||
|
|
||||||
defaultMappedModel := ""
|
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
|
||||||
if apiKey.Group != nil {
|
|
||||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
|
||||||
}
|
|
||||||
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
|
|
||||||
defaultMappedModel = fallbackModel
|
|
||||||
}
|
|
||||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||||
|
|
||||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
@@ -267,6 +261,8 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointChatCompletions),
|
||||||
|
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizedOpenAIUpstreamEndpoint(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
fallback string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "responses root maps to responses upstream",
|
||||||
|
path: "/v1/responses",
|
||||||
|
fallback: openAIUpstreamEndpointResponses,
|
||||||
|
want: "/v1/responses",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "responses compact keeps compact suffix",
|
||||||
|
path: "/openai/v1/responses/compact",
|
||||||
|
fallback: openAIUpstreamEndpointResponses,
|
||||||
|
want: "/v1/responses/compact",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "responses nested suffix preserved",
|
||||||
|
path: "/openai/v1/responses/compact/detail",
|
||||||
|
fallback: openAIUpstreamEndpointResponses,
|
||||||
|
want: "/v1/responses/compact/detail",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non responses path uses fallback",
|
||||||
|
path: "/v1/messages",
|
||||||
|
fallback: openAIUpstreamEndpointResponses,
|
||||||
|
want: "/v1/responses",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
|
||||||
|
|
||||||
|
got := normalizedOpenAIUpstreamEndpoint(c, tt.fallback)
|
||||||
|
require.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -37,6 +37,13 @@ type OpenAIGatewayHandler struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
openAIInboundEndpointResponses = "/v1/responses"
|
||||||
|
openAIInboundEndpointMessages = "/v1/messages"
|
||||||
|
openAIInboundEndpointChatCompletions = "/v1/chat/completions"
|
||||||
|
openAIUpstreamEndpointResponses = "/v1/responses"
|
||||||
|
)
|
||||||
|
|
||||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||||
func NewOpenAIGatewayHandler(
|
func NewOpenAIGatewayHandler(
|
||||||
gatewayService *service.OpenAIGatewayService,
|
gatewayService *service.OpenAIGatewayService,
|
||||||
@@ -352,6 +359,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
@@ -361,8 +369,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointResponses),
|
||||||
|
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
@@ -653,14 +664,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||||
forwardStart := time.Now()
|
forwardStart := time.Now()
|
||||||
|
|
||||||
defaultMappedModel := ""
|
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时,
|
||||||
if apiKey.Group != nil {
|
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。
|
||||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
defaultMappedModel := c.GetString("openai_messages_fallback_model")
|
||||||
}
|
|
||||||
// 如果使用了降级模型调度,强制使用降级模型
|
|
||||||
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
|
|
||||||
defaultMappedModel = fallbackModel
|
|
||||||
}
|
|
||||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||||
|
|
||||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
@@ -732,6 +738,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
@@ -740,8 +747,11 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointMessages),
|
||||||
|
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
@@ -1236,8 +1246,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
|||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointResponses),
|
||||||
|
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
reqLog.Error("openai.websocket_record_usage_failed",
|
reqLog.Error("openai.websocket_record_usage_failed",
|
||||||
@@ -1530,6 +1543,62 @@ func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64)
|
|||||||
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
|
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizedOpenAIInboundEndpoint(c *gin.Context, fallback string) string {
|
||||||
|
path := strings.TrimSpace(fallback)
|
||||||
|
if c != nil {
|
||||||
|
if fullPath := strings.TrimSpace(c.FullPath()); fullPath != "" {
|
||||||
|
path = fullPath
|
||||||
|
} else if c.Request != nil && c.Request.URL != nil {
|
||||||
|
if requestPath := strings.TrimSpace(c.Request.URL.Path); requestPath != "" {
|
||||||
|
path = requestPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.Contains(path, openAIInboundEndpointChatCompletions):
|
||||||
|
return openAIInboundEndpointChatCompletions
|
||||||
|
case strings.Contains(path, openAIInboundEndpointMessages):
|
||||||
|
return openAIInboundEndpointMessages
|
||||||
|
case strings.Contains(path, openAIInboundEndpointResponses):
|
||||||
|
return openAIInboundEndpointResponses
|
||||||
|
default:
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizedOpenAIUpstreamEndpoint(c *gin.Context, fallback string) string {
|
||||||
|
base := strings.TrimSpace(fallback)
|
||||||
|
if base == "" {
|
||||||
|
base = openAIUpstreamEndpointResponses
|
||||||
|
}
|
||||||
|
base = strings.TrimRight(base, "/")
|
||||||
|
|
||||||
|
if c == nil || c.Request == nil || c.Request.URL == nil {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
path := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
|
||||||
|
if path == "" {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
idx := strings.LastIndex(path, "/responses")
|
||||||
|
if idx < 0 {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
suffix := strings.TrimSpace(path[idx+len("/responses"):])
|
||||||
|
if suffix == "" || suffix == "/" {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(suffix, "/") {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
return base + suffix
|
||||||
|
}
|
||||||
|
|
||||||
func isOpenAIWSUpgradeRequest(r *http.Request) bool {
|
func isOpenAIWSUpgradeRequest(r *http.Request) bool {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -26,6 +26,22 @@ const (
|
|||||||
opsStreamKey = "ops_stream"
|
opsStreamKey = "ops_stream"
|
||||||
opsRequestBodyKey = "ops_request_body"
|
opsRequestBodyKey = "ops_request_body"
|
||||||
opsAccountIDKey = "ops_account_id"
|
opsAccountIDKey = "ops_account_id"
|
||||||
|
|
||||||
|
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
|
||||||
|
opsErrContextCanceled = "context canceled"
|
||||||
|
opsErrNoAvailableAccounts = "no available accounts"
|
||||||
|
opsErrInvalidAPIKey = "invalid_api_key"
|
||||||
|
opsErrAPIKeyRequired = "api_key_required"
|
||||||
|
opsErrInsufficientBalance = "insufficient balance"
|
||||||
|
opsErrInsufficientAccountBalance = "insufficient account balance"
|
||||||
|
opsErrInsufficientQuota = "insufficient_quota"
|
||||||
|
|
||||||
|
// 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited)
|
||||||
|
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
|
||||||
|
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
|
||||||
|
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
|
||||||
|
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
|
||||||
|
opsCodeUserInactive = "USER_INACTIVE"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -1024,9 +1040,9 @@ func normalizeOpsErrorType(errType string, code string) string {
|
|||||||
return errType
|
return errType
|
||||||
}
|
}
|
||||||
switch strings.TrimSpace(code) {
|
switch strings.TrimSpace(code) {
|
||||||
case "INSUFFICIENT_BALANCE":
|
case opsCodeInsufficientBalance:
|
||||||
return "billing_error"
|
return "billing_error"
|
||||||
case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
case opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||||
return "subscription_error"
|
return "subscription_error"
|
||||||
default:
|
default:
|
||||||
return "api_error"
|
return "api_error"
|
||||||
@@ -1038,7 +1054,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
|||||||
// Standardized phases: request|auth|routing|upstream|network|internal
|
// Standardized phases: request|auth|routing|upstream|network|internal
|
||||||
// Map billing/concurrency/response => request; scheduling => routing.
|
// Map billing/concurrency/response => request; scheduling => routing.
|
||||||
switch strings.TrimSpace(code) {
|
switch strings.TrimSpace(code) {
|
||||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||||
return "request"
|
return "request"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1057,7 +1073,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
|||||||
case "upstream_error", "overloaded_error":
|
case "upstream_error", "overloaded_error":
|
||||||
return "upstream"
|
return "upstream"
|
||||||
case "api_error":
|
case "api_error":
|
||||||
if strings.Contains(msg, "no available accounts") {
|
if strings.Contains(msg, opsErrNoAvailableAccounts) {
|
||||||
return "routing"
|
return "routing"
|
||||||
}
|
}
|
||||||
return "internal"
|
return "internal"
|
||||||
@@ -1103,7 +1119,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
|
|||||||
|
|
||||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
||||||
switch strings.TrimSpace(code) {
|
switch strings.TrimSpace(code) {
|
||||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
|
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if phase == "billing" || phase == "concurrency" {
|
if phase == "billing" || phase == "concurrency" {
|
||||||
@@ -1197,21 +1213,30 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
|
|||||||
|
|
||||||
// Check if context canceled errors should be ignored (client disconnects)
|
// Check if context canceled errors should be ignored (client disconnects)
|
||||||
if settings.IgnoreContextCanceled {
|
if settings.IgnoreContextCanceled {
|
||||||
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
|
if strings.Contains(msgLower, opsErrContextCanceled) || strings.Contains(bodyLower, opsErrContextCanceled) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if "no available accounts" errors should be ignored
|
// Check if "no available accounts" errors should be ignored
|
||||||
if settings.IgnoreNoAvailableAccounts {
|
if settings.IgnoreNoAvailableAccounts {
|
||||||
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
|
if strings.Contains(msgLower, opsErrNoAvailableAccounts) || strings.Contains(bodyLower, opsErrNoAvailableAccounts) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
||||||
if settings.IgnoreInvalidApiKeyErrors {
|
if settings.IgnoreInvalidApiKeyErrors {
|
||||||
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
|
if strings.Contains(bodyLower, opsErrInvalidAPIKey) || strings.Contains(bodyLower, opsErrAPIKeyRequired) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if insufficient balance errors should be ignored
|
||||||
|
if settings.IgnoreInsufficientBalanceErrors {
|
||||||
|
if strings.Contains(bodyLower, opsErrInsufficientBalance) || strings.Contains(bodyLower, opsErrInsufficientAccountBalance) ||
|
||||||
|
strings.Contains(bodyLower, opsErrInsufficientQuota) ||
|
||||||
|
strings.Contains(msgLower, opsErrInsufficientBalance) || strings.Contains(msgLower, opsErrInsufficientAccountBalance) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
|||||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||||
SoraClientEnabled: settings.SoraClientEnabled,
|
SoraClientEnabled: settings.SoraClientEnabled,
|
||||||
|
BackendModeEnabled: settings.BackendModeEnabled,
|
||||||
Version: h.version,
|
Version: h.version,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
|||||||
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
||||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||||
return service.NewGatewayService(
|
return service.NewGatewayService(
|
||||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -399,6 +399,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
|
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
|
||||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
@@ -410,6 +411,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||||
|
|||||||
@@ -334,6 +334,14 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
|
|||||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||||
|
return []usagestats.EndpointStat{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||||
|
return []usagestats.EndpointStat{}, nil
|
||||||
|
}
|
||||||
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -343,6 +351,9 @@ func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, e
|
|||||||
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -431,6 +442,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
testutil.StubGatewayCache{},
|
testutil.StubGatewayCache{},
|
||||||
cfg,
|
cfg,
|
||||||
nil,
|
nil,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
|
|||||||
accountHandler *admin.AccountHandler,
|
accountHandler *admin.AccountHandler,
|
||||||
announcementHandler *admin.AnnouncementHandler,
|
announcementHandler *admin.AnnouncementHandler,
|
||||||
dataManagementHandler *admin.DataManagementHandler,
|
dataManagementHandler *admin.DataManagementHandler,
|
||||||
|
backupHandler *admin.BackupHandler,
|
||||||
oauthHandler *admin.OAuthHandler,
|
oauthHandler *admin.OAuthHandler,
|
||||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||||
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
|
|||||||
Account: accountHandler,
|
Account: accountHandler,
|
||||||
Announcement: announcementHandler,
|
Announcement: announcementHandler,
|
||||||
DataManagement: dataManagementHandler,
|
DataManagement: dataManagementHandler,
|
||||||
|
Backup: backupHandler,
|
||||||
OAuth: oauthHandler,
|
OAuth: oauthHandler,
|
||||||
OpenAIOAuth: openaiOAuthHandler,
|
OpenAIOAuth: openaiOAuthHandler,
|
||||||
GeminiOAuth: geminiOAuthHandler,
|
GeminiOAuth: geminiOAuthHandler,
|
||||||
@@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
admin.NewAccountHandler,
|
admin.NewAccountHandler,
|
||||||
admin.NewAnnouncementHandler,
|
admin.NewAnnouncementHandler,
|
||||||
admin.NewDataManagementHandler,
|
admin.NewDataManagementHandler,
|
||||||
|
admin.NewBackupHandler,
|
||||||
admin.NewOAuthHandler,
|
admin.NewOAuthHandler,
|
||||||
admin.NewOpenAIOAuthHandler,
|
admin.NewOpenAIOAuthHandler,
|
||||||
admin.NewGeminiOAuthHandler,
|
admin.NewGeminiOAuthHandler,
|
||||||
|
|||||||
@@ -19,6 +19,16 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ForbiddenError 表示上游返回 403 Forbidden
|
||||||
|
type ForbiddenError struct {
|
||||||
|
StatusCode int
|
||||||
|
Body string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ForbiddenError) Error() string {
|
||||||
|
return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body)
|
||||||
|
}
|
||||||
|
|
||||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||||
@@ -515,6 +525,19 @@ type ModelQuotaInfo struct {
|
|||||||
// ModelInfo 模型信息
|
// ModelInfo 模型信息
|
||||||
type ModelInfo struct {
|
type ModelInfo struct {
|
||||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||||
|
DisplayName string `json:"displayName,omitempty"`
|
||||||
|
SupportsImages *bool `json:"supportsImages,omitempty"`
|
||||||
|
SupportsThinking *bool `json:"supportsThinking,omitempty"`
|
||||||
|
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||||
|
Recommended *bool `json:"recommended,omitempty"`
|
||||||
|
MaxTokens *int `json:"maxTokens,omitempty"`
|
||||||
|
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||||
|
SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeprecatedModelInfo 废弃模型转发信息
|
||||||
|
type DeprecatedModelInfo struct {
|
||||||
|
NewModelID string `json:"newModelId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
||||||
@@ -525,6 +548,7 @@ type FetchAvailableModelsRequest struct {
|
|||||||
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
||||||
type FetchAvailableModelsResponse struct {
|
type FetchAvailableModelsResponse struct {
|
||||||
Models map[string]ModelInfo `json:"models"`
|
Models map[string]ModelInfo `json:"models"`
|
||||||
|
DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||||
@@ -573,6 +597,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusForbidden {
|
||||||
|
return nil, nil, &ForbiddenError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Body: string(respBodyBytes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -189,6 +189,5 @@ var DefaultStopSequences = []string{
|
|||||||
"<|user|>",
|
"<|user|>",
|
||||||
"<|endoftext|>",
|
"<|endoftext|>",
|
||||||
"<|end_of_turn|>",
|
"<|end_of_turn|>",
|
||||||
"[DONE]",
|
|
||||||
"\n\nHuman:",
|
"\n\nHuman:",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
|
|||||||
assert.Equal(t, "assistant", items[1].Role)
|
assert.Equal(t, "assistant", items[1].Role)
|
||||||
assert.Equal(t, "function_call", items[2].Type)
|
assert.Equal(t, "function_call", items[2].Type)
|
||||||
assert.Equal(t, "fc_call_1", items[2].CallID)
|
assert.Equal(t, "fc_call_1", items[2].CallID)
|
||||||
|
assert.Empty(t, items[2].ID)
|
||||||
assert.Equal(t, "function_call_output", items[3].Type)
|
assert.Equal(t, "function_call_output", items[3].Type)
|
||||||
assert.Equal(t, "fc_call_1", items[3].CallID)
|
assert.Equal(t, "fc_call_1", items[3].CallID)
|
||||||
assert.Equal(t, "Sunny, 72°F", items[3].Output)
|
assert.Equal(t, "Sunny, 72°F", items[3].Output)
|
||||||
|
|||||||
@@ -277,7 +277,6 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
|
|||||||
CallID: fcID,
|
CallID: fcID,
|
||||||
Name: b.Name,
|
Name: b.Name,
|
||||||
Arguments: args,
|
Arguments: args,
|
||||||
ID: fcID,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
|
|||||||
// Check function_call item
|
// Check function_call item
|
||||||
assert.Equal(t, "function_call", items[1].Type)
|
assert.Equal(t, "function_call", items[1].Type)
|
||||||
assert.Equal(t, "call_1", items[1].CallID)
|
assert.Equal(t, "call_1", items[1].CallID)
|
||||||
|
assert.Empty(t, items[1].ID)
|
||||||
assert.Equal(t, "ping", items[1].Name)
|
assert.Equal(t, "ping", items[1].Name)
|
||||||
|
|
||||||
// Check function_call_output item
|
// Check function_call_output item
|
||||||
@@ -252,6 +253,55 @@ func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T)
|
|||||||
assert.Equal(t, "user", items[0].Role)
|
assert.Equal(t, "user", items[0].Role)
|
||||||
assert.Equal(t, "assistant", items[1].Role)
|
assert.Equal(t, "assistant", items[1].Role)
|
||||||
assert.Equal(t, "function_call", items[2].Type)
|
assert.Equal(t, "function_call", items[2].Type)
|
||||||
|
assert.Empty(t, items[2].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_AssistantArrayContentPreserved(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||||
|
{Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"A"},{"type":"text","text":"B"}]`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
require.Len(t, items, 2)
|
||||||
|
assert.Equal(t, "assistant", items[1].Role)
|
||||||
|
|
||||||
|
var parts []ResponsesContentPart
|
||||||
|
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||||
|
require.Len(t, parts, 1)
|
||||||
|
assert.Equal(t, "output_text", parts[0].Type)
|
||||||
|
assert.Equal(t, "AB", parts[0].Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) {
|
||||||
|
req := &ChatCompletionsRequest{
|
||||||
|
Model: "gpt-4o",
|
||||||
|
Messages: []ChatMessage{
|
||||||
|
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||||
|
{Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"internal plan"},{"type":"text","text":"final answer"}]`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ChatCompletionsToResponses(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var items []ResponsesInputItem
|
||||||
|
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||||
|
require.Len(t, items, 2)
|
||||||
|
|
||||||
|
var parts []ResponsesContentPart
|
||||||
|
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||||
|
require.Len(t, parts, 1)
|
||||||
|
assert.Equal(t, "output_text", parts[0].Type)
|
||||||
|
assert.Contains(t, parts[0].Text, "<thinking>internal plan</thinking>")
|
||||||
|
assert.Contains(t, parts[0].Text, "final answer")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -344,8 +394,8 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
|
|||||||
|
|
||||||
var content string
|
var content string
|
||||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||||
// Reasoning summary is prepended to text
|
assert.Equal(t, "The answer is 42.", content)
|
||||||
assert.Equal(t, "I thought about it.The answer is 42.", content)
|
assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
||||||
@@ -582,8 +632,35 @@ func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
|
|||||||
Delta: "Thinking...",
|
Delta: "Thinking...",
|
||||||
}, state)
|
}, state)
|
||||||
require.Len(t, chunks, 1)
|
require.Len(t, chunks, 1)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||||
|
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||||
|
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.reasoning_summary_text.done",
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesEventToChatChunks_ReasoningThenTextAutoCloseTag(t *testing.T) {
|
||||||
|
state := NewResponsesEventToChatState()
|
||||||
|
state.Model = "gpt-4o"
|
||||||
|
state.SentRole = true
|
||||||
|
|
||||||
|
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.reasoning_summary_text.delta",
|
||||||
|
Delta: "plan",
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||||
|
assert.Equal(t, "plan", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||||
|
|
||||||
|
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||||
|
Type: "response.output_text.delta",
|
||||||
|
Delta: "answer",
|
||||||
|
}, state)
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content)
|
assert.Equal(t, "answer", *chunks[0].Choices[0].Delta.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFinalizeResponsesChatStream(t *testing.T) {
|
func TestFinalizeResponsesChatStream(t *testing.T) {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package apicompat
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ChatCompletionsToResponses converts a Chat Completions request into a
|
// ChatCompletionsToResponses converts a Chat Completions request into a
|
||||||
@@ -174,8 +175,11 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
|||||||
|
|
||||||
// Emit assistant message with output_text if content is non-empty.
|
// Emit assistant message with output_text if content is non-empty.
|
||||||
if len(m.Content) > 0 {
|
if len(m.Content) > 0 {
|
||||||
var s string
|
s, err := parseAssistantContent(m.Content)
|
||||||
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s != "" {
|
||||||
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||||
partsJSON, err := json.Marshal(parts)
|
partsJSON, err := json.Marshal(parts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -196,13 +200,82 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
|||||||
CallID: tc.ID,
|
CallID: tc.ID,
|
||||||
Name: tc.Function.Name,
|
Name: tc.Function.Name,
|
||||||
Arguments: args,
|
Arguments: args,
|
||||||
ID: tc.ID,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return items, nil
|
return items, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseAssistantContent returns assistant content as plain text.
|
||||||
|
//
|
||||||
|
// Supported formats:
|
||||||
|
// - JSON string
|
||||||
|
// - JSON array of typed parts (e.g. [{"type":"text","text":"..."}])
|
||||||
|
//
|
||||||
|
// For structured thinking/reasoning parts, it preserves semantics by wrapping
|
||||||
|
// the text in explicit tags so downstream can still distinguish it from normal text.
|
||||||
|
func parseAssistantContent(raw json.RawMessage) (string, error) {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(raw, &s); err == nil {
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []map[string]any
|
||||||
|
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||||
|
// Keep compatibility with prior behavior: unsupported assistant content
|
||||||
|
// formats are ignored instead of failing the whole request conversion.
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
write := func(v string) error {
|
||||||
|
_, err := b.WriteString(v)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, p := range parts {
|
||||||
|
typ, _ := p["type"].(string)
|
||||||
|
text, _ := p["text"].(string)
|
||||||
|
thinking, _ := p["thinking"].(string)
|
||||||
|
|
||||||
|
switch typ {
|
||||||
|
case "thinking", "reasoning":
|
||||||
|
if thinking != "" {
|
||||||
|
if err := write("<thinking>"); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := write(thinking); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := write("</thinking>"); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
} else if text != "" {
|
||||||
|
if err := write("<thinking>"); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := write(text); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := write("</thinking>"); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if text != "" {
|
||||||
|
if err := write(text); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
// chatToolToResponses converts a tool result message (role=tool) into a
|
// chatToolToResponses converts a tool result message (role=tool) into a
|
||||||
// function_call_output item.
|
// function_call_output item.
|
||||||
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
|||||||
}
|
}
|
||||||
|
|
||||||
var contentText string
|
var contentText string
|
||||||
|
var reasoningText string
|
||||||
var toolCalls []ChatToolCall
|
var toolCalls []ChatToolCall
|
||||||
|
|
||||||
for _, item := range resp.Output {
|
for _, item := range resp.Output {
|
||||||
@@ -51,7 +52,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
|||||||
case "reasoning":
|
case "reasoning":
|
||||||
for _, s := range item.Summary {
|
for _, s := range item.Summary {
|
||||||
if s.Type == "summary_text" && s.Text != "" {
|
if s.Type == "summary_text" && s.Text != "" {
|
||||||
contentText += s.Text
|
reasoningText += s.Text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "web_search_call":
|
case "web_search_call":
|
||||||
@@ -67,6 +68,9 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
|||||||
raw, _ := json.Marshal(contentText)
|
raw, _ := json.Marshal(contentText)
|
||||||
msg.Content = raw
|
msg.Content = raw
|
||||||
}
|
}
|
||||||
|
if reasoningText != "" {
|
||||||
|
msg.ReasoningContent = reasoningText
|
||||||
|
}
|
||||||
|
|
||||||
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
||||||
|
|
||||||
@@ -153,6 +157,8 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
|
|||||||
return resToChatHandleFuncArgsDelta(evt, state)
|
return resToChatHandleFuncArgsDelta(evt, state)
|
||||||
case "response.reasoning_summary_text.delta":
|
case "response.reasoning_summary_text.delta":
|
||||||
return resToChatHandleReasoningDelta(evt, state)
|
return resToChatHandleReasoningDelta(evt, state)
|
||||||
|
case "response.reasoning_summary_text.done":
|
||||||
|
return nil
|
||||||
case "response.completed", "response.incomplete", "response.failed":
|
case "response.completed", "response.incomplete", "response.failed":
|
||||||
return resToChatHandleCompleted(evt, state)
|
return resToChatHandleCompleted(evt, state)
|
||||||
default:
|
default:
|
||||||
@@ -276,8 +282,8 @@ func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEv
|
|||||||
if evt.Delta == "" {
|
if evt.Delta == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
content := evt.Delta
|
reasoning := evt.Delta
|
||||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ReasoningContent: &reasoning})}
|
||||||
}
|
}
|
||||||
|
|
||||||
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||||
|
|||||||
@@ -363,6 +363,7 @@ type ChatStreamOptions struct {
|
|||||||
type ChatMessage struct {
|
type ChatMessage struct {
|
||||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||||
Content json.RawMessage `json:"content,omitempty"`
|
Content json.RawMessage `json:"content,omitempty"`
|
||||||
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
@@ -468,6 +469,7 @@ type ChatChunkChoice struct {
|
|||||||
type ChatDelta struct {
|
type ChatDelta struct {
|
||||||
Role string `json:"role,omitempty"`
|
Role string `json:"role,omitempty"`
|
||||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||||
|
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -81,6 +81,15 @@ type ModelStat struct {
|
|||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EndpointStat represents usage statistics for a single request endpoint.
|
||||||
|
type EndpointStat struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
}
|
||||||
|
|
||||||
// GroupStat represents usage statistics for a single group
|
// GroupStat represents usage statistics for a single group
|
||||||
type GroupStat struct {
|
type GroupStat struct {
|
||||||
GroupID int64 `json:"group_id"`
|
GroupID int64 `json:"group_id"`
|
||||||
@@ -96,12 +105,28 @@ type UserUsageTrendPoint struct {
|
|||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
UserID int64 `json:"user_id"`
|
UserID int64 `json:"user_id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
|
Username string `json:"username"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
Tokens int64 `json:"tokens"`
|
Tokens int64 `json:"tokens"`
|
||||||
Cost float64 `json:"cost"` // 标准计费
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserSpendingRankingItem represents a user spending ranking row.
|
||||||
|
type UserSpendingRankingItem struct {
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
Tokens int64 `json:"tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserSpendingRankingResponse represents ranking rows plus total spend for the time range.
|
||||||
|
type UserSpendingRankingResponse struct {
|
||||||
|
Ranking []UserSpendingRankingItem `json:"ranking"`
|
||||||
|
TotalActualCost float64 `json:"total_actual_cost"`
|
||||||
|
}
|
||||||
|
|
||||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||||
type APIKeyUsageTrendPoint struct {
|
type APIKeyUsageTrendPoint struct {
|
||||||
Date string `json:"date"`
|
Date string `json:"date"`
|
||||||
@@ -172,6 +197,9 @@ type UsageStats struct {
|
|||||||
TotalActualCost float64 `json:"total_actual_cost"`
|
TotalActualCost float64 `json:"total_actual_cost"`
|
||||||
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
||||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||||
|
Endpoints []EndpointStat `json:"endpoints,omitempty"`
|
||||||
|
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints,omitempty"`
|
||||||
|
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchUserUsageStats represents usage stats for a single user
|
// BatchUserUsageStats represents usage stats for a single user
|
||||||
@@ -241,4 +269,6 @@ type AccountUsageStatsResponse struct {
|
|||||||
History []AccountUsageHistory `json:"history"`
|
History []AccountUsageHistory `json:"history"`
|
||||||
Summary AccountUsageSummary `json:"summary"`
|
Summary AccountUsageSummary `json:"summary"`
|
||||||
Models []ModelStat `json:"models"`
|
Models []ModelStat `json:"models"`
|
||||||
|
Endpoints []EndpointStat `json:"endpoints"`
|
||||||
|
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -397,9 +397,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
|||||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||||
}
|
}
|
||||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
// 普通账号编辑(如 model_mapping / credentials)也需要立即刷新单账号快照,
|
||||||
|
// 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。
|
||||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1727,8 +1727,96 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
|||||||
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
||||||
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
||||||
|
|
||||||
|
// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired.
|
||||||
|
// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes.
|
||||||
|
const dailyExpiredExpr = `(
|
||||||
|
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||||
|
THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '24 hours'::interval <= NOW()
|
||||||
|
END
|
||||||
|
)`
|
||||||
|
|
||||||
|
// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired.
|
||||||
|
const weeklyExpiredExpr = `(
|
||||||
|
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||||
|
THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '168 hours'::interval <= NOW()
|
||||||
|
END
|
||||||
|
)`
|
||||||
|
|
||||||
|
// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs.
|
||||||
|
// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour.
|
||||||
|
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||||
|
const nextDailyResetAtExpr = `(
|
||||||
|
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||||
|
THEN to_char((
|
||||||
|
-- Compute today's reset point in the configured timezone, then pick next future one
|
||||||
|
CASE WHEN NOW() >= (
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
-- NOW() is at or past today's reset point → next reset is tomorrow
|
||||||
|
THEN (
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
+ '1 day'::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
-- NOW() is before today's reset point → next reset is today
|
||||||
|
ELSE (
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
END
|
||||||
|
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||||
|
ELSE NULL END
|
||||||
|
)`
|
||||||
|
|
||||||
|
// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs.
|
||||||
|
// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour.
|
||||||
|
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||||
|
const nextWeeklyResetAtExpr = `(
|
||||||
|
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||||
|
THEN to_char((
|
||||||
|
-- Compute this week's reset point in the configured timezone
|
||||||
|
-- Step 1: get today's date at reset hour in configured tz
|
||||||
|
-- Step 2: compute days forward to target weekday
|
||||||
|
-- Step 3: if same day but past reset hour, advance 7 days
|
||||||
|
CASE
|
||||||
|
WHEN (
|
||||||
|
-- days_forward = (target_day - current_day + 7) % 7
|
||||||
|
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||||
|
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||||
|
+ 7) % 7
|
||||||
|
) = 0 AND NOW() >= (
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
-- Same weekday and past reset hour → next week
|
||||||
|
THEN (
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
+ '7 days'::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
ELSE (
|
||||||
|
-- Advance to target weekday this week (or next if days_forward > 0)
|
||||||
|
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||||
|
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||||
|
+ ((
|
||||||
|
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||||
|
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||||
|
+ 7) % 7
|
||||||
|
) || ' days')::interval
|
||||||
|
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||||
|
END
|
||||||
|
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||||
|
ELSE NULL END
|
||||||
|
)`
|
||||||
|
|
||||||
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
||||||
// 日/周额度在周期过期时自动重置为 0 再递增。
|
// 日/周额度在周期过期时自动重置为 0 再递增。
|
||||||
|
// 支持滚动窗口(rolling)和固定时间(fixed)两种重置模式。
|
||||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||||
rows, err := r.sql.QueryContext(ctx,
|
rows, err := r.sql.QueryContext(ctx,
|
||||||
`UPDATE accounts SET extra = (
|
`UPDATE accounts SET extra = (
|
||||||
@@ -1739,31 +1827,35 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
|||||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||||
jsonb_build_object(
|
jsonb_build_object(
|
||||||
'quota_daily_used',
|
'quota_daily_used',
|
||||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+dailyExpiredExpr+`
|
||||||
+ '24 hours'::interval <= NOW()
|
|
||||||
THEN $1
|
THEN $1
|
||||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||||
'quota_daily_start',
|
'quota_daily_start',
|
||||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+dailyExpiredExpr+`
|
||||||
+ '24 hours'::interval <= NOW()
|
|
||||||
THEN `+nowUTC+`
|
THEN `+nowUTC+`
|
||||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||||
)
|
)
|
||||||
|
-- 固定模式重置时更新下次重置时间
|
||||||
|
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
|
||||||
|
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
|
||||||
|
ELSE '{}'::jsonb END
|
||||||
ELSE '{}'::jsonb END
|
ELSE '{}'::jsonb END
|
||||||
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
||||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||||
jsonb_build_object(
|
jsonb_build_object(
|
||||||
'quota_weekly_used',
|
'quota_weekly_used',
|
||||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+weeklyExpiredExpr+`
|
||||||
+ '168 hours'::interval <= NOW()
|
|
||||||
THEN $1
|
THEN $1
|
||||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||||
'quota_weekly_start',
|
'quota_weekly_start',
|
||||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+weeklyExpiredExpr+`
|
||||||
+ '168 hours'::interval <= NOW()
|
|
||||||
THEN `+nowUTC+`
|
THEN `+nowUTC+`
|
||||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||||
)
|
)
|
||||||
|
-- 固定模式重置时更新下次重置时间
|
||||||
|
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
|
||||||
|
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
|
||||||
|
ELSE '{}'::jsonb END
|
||||||
ELSE '{}'::jsonb END
|
ELSE '{}'::jsonb END
|
||||||
), updated_at = NOW()
|
), updated_at = NOW()
|
||||||
WHERE id = $2 AND deleted_at IS NULL
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
@@ -1796,12 +1888,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
||||||
|
// 保留固定重置模式的配置字段(quota_daily_reset_mode 等),仅清零用量和窗口起始时间
|
||||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||||
_, err := r.sql.ExecContext(ctx,
|
_, err := r.sql.ExecContext(ctx,
|
||||||
`UPDATE accounts SET extra = (
|
`UPDATE accounts SET extra = (
|
||||||
COALESCE(extra, '{}'::jsonb)
|
COALESCE(extra, '{}'::jsonb)
|
||||||
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
||||||
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
|
) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW()
|
||||||
WHERE id = $1 AND deleted_at IS NULL`,
|
WHERE id = $1 AND deleted_at IS NULL`,
|
||||||
id)
|
id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -142,6 +142,35 @@ func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
|
|||||||
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnCredentialsChange() {
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||||
|
Name: "sync-credentials-update",
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5": "gpt-5.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
cacheRecorder := &schedulerCacheRecorder{}
|
||||||
|
s.repo.schedulerCache = cacheRecorder
|
||||||
|
|
||||||
|
account.Credentials = map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5": "gpt-5.2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := s.repo.Update(s.ctx, account)
|
||||||
|
s.Require().NoError(err, "Update")
|
||||||
|
|
||||||
|
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||||
|
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||||
|
mapping, ok := cacheRecorder.setAccounts[0].Credentials["model_mapping"].(map[string]any)
|
||||||
|
s.Require().True(ok)
|
||||||
|
s.Require().Equal("gpt-5.2", mapping["gpt-5"])
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AccountRepoSuite) TestDelete() {
|
func (s *AccountRepoSuite) TestDelete() {
|
||||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||||
|
|
||||||
|
|||||||
98
backend/internal/repository/backup_pg_dumper.go
Normal file
98
backend/internal/repository/backup_pg_dumper.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os/exec"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PgDumper implements service.DBDumper using pg_dump/psql
|
||||||
|
type PgDumper struct {
|
||||||
|
cfg *config.DatabaseConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPgDumper creates a new PgDumper
|
||||||
|
func NewPgDumper(cfg *config.Config) service.DBDumper {
|
||||||
|
return &PgDumper{cfg: &cfg.Database}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dump executes pg_dump and returns a streaming reader of the output
|
||||||
|
func (d *PgDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
|
||||||
|
args := []string{
|
||||||
|
"-h", d.cfg.Host,
|
||||||
|
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||||
|
"-U", d.cfg.User,
|
||||||
|
"-d", d.cfg.DBName,
|
||||||
|
"--no-owner",
|
||||||
|
"--no-acl",
|
||||||
|
"--clean",
|
||||||
|
"--if-exists",
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "pg_dump", args...)
|
||||||
|
if d.cfg.Password != "" {
|
||||||
|
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||||
|
}
|
||||||
|
if d.cfg.SSLMode != "" {
|
||||||
|
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
stdout, err := cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create stdout pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return nil, fmt.Errorf("start pg_dump: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 返回一个 ReadCloser:读 stdout,关闭时等待进程退出
|
||||||
|
return &cmdReadCloser{ReadCloser: stdout, cmd: cmd}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore executes psql to restore from a streaming reader
|
||||||
|
func (d *PgDumper) Restore(ctx context.Context, data io.Reader) error {
|
||||||
|
args := []string{
|
||||||
|
"-h", d.cfg.Host,
|
||||||
|
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||||
|
"-U", d.cfg.User,
|
||||||
|
"-d", d.cfg.DBName,
|
||||||
|
"--single-transaction",
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||||
|
if d.cfg.Password != "" {
|
||||||
|
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||||
|
}
|
||||||
|
if d.cfg.SSLMode != "" {
|
||||||
|
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Stdin = data
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%v: %s", err, string(output))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cmdReadCloser wraps a command stdout pipe and waits for the process on Close
|
||||||
|
type cmdReadCloser struct {
|
||||||
|
io.ReadCloser
|
||||||
|
cmd *exec.Cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cmdReadCloser) Close() error {
|
||||||
|
// Close the pipe first
|
||||||
|
_ = c.ReadCloser.Close()
|
||||||
|
// Wait for the process to exit
|
||||||
|
if err := c.cmd.Wait(); err != nil {
|
||||||
|
return fmt.Errorf("pg_dump exited with error: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
116
backend/internal/repository/backup_s3_store.go
Normal file
116
backend/internal/repository/backup_s3_store.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
|
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||||
|
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage
|
||||||
|
type S3BackupStore struct {
|
||||||
|
client *s3.Client
|
||||||
|
bucket string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores
|
||||||
|
func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
|
||||||
|
return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) {
|
||||||
|
region := cfg.Region
|
||||||
|
if region == "" {
|
||||||
|
region = "auto" // Cloudflare R2 默认 region
|
||||||
|
}
|
||||||
|
|
||||||
|
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||||||
|
awsconfig.WithRegion(region),
|
||||||
|
awsconfig.WithCredentialsProvider(
|
||||||
|
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load aws config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||||
|
if cfg.Endpoint != "" {
|
||||||
|
o.BaseEndpoint = &cfg.Endpoint
|
||||||
|
}
|
||||||
|
if cfg.ForcePathStyle {
|
||||||
|
o.UsePathStyle = true
|
||||||
|
}
|
||||||
|
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
|
||||||
|
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||||
|
})
|
||||||
|
|
||||||
|
return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
|
||||||
|
// 读取全部内容以获取大小(S3 PutObject 需要知道内容长度)
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("read body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
||||||
|
Bucket: &s.bucket,
|
||||||
|
Key: &key,
|
||||||
|
Body: bytes.NewReader(data),
|
||||||
|
ContentType: &contentType,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("S3 PutObject: %w", err)
|
||||||
|
}
|
||||||
|
return int64(len(data)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||||
|
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||||
|
Bucket: &s.bucket,
|
||||||
|
Key: &key,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("S3 GetObject: %w", err)
|
||||||
|
}
|
||||||
|
return result.Body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *S3BackupStore) Delete(ctx context.Context, key string) error {
|
||||||
|
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||||
|
Bucket: &s.bucket,
|
||||||
|
Key: &key,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) {
|
||||||
|
presignClient := s3.NewPresignClient(s.client)
|
||||||
|
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||||
|
Bucket: &s.bucket,
|
||||||
|
Key: &key,
|
||||||
|
}, s3.WithPresignExpires(expiry))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("presign url: %w", err)
|
||||||
|
}
|
||||||
|
return result.URL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *S3BackupStore) HeadBucket(ctx context.Context) error {
|
||||||
|
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||||
|
Bucket: &s.bucket,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("S3 HeadBucket failed: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -17,6 +17,9 @@ type dashboardAggregationRepository struct {
|
|||||||
sql sqlExecutor
|
sql sqlExecutor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const usageLogsCleanupBatchSize = 10000
|
||||||
|
const usageBillingDedupCleanupBatchSize = 10000
|
||||||
|
|
||||||
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
||||||
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
||||||
if sqlDB == nil {
|
if sqlDB == nil {
|
||||||
@@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
if r == nil || r.sql == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
loc := timezone.Location()
|
loc := timezone.Location()
|
||||||
startLocal := start.In(loc)
|
startLocal := start.In(loc)
|
||||||
endLocal := end.In(loc)
|
endLocal := end.In(loc)
|
||||||
@@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
|
|||||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if db, ok := r.sql.(*sql.DB); ok {
|
||||||
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
|
||||||
|
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
|
||||||
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
||||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c
|
|||||||
if isPartitioned {
|
if isPartitioned {
|
||||||
return r.dropUsageLogsPartitions(ctx, cutoff)
|
return r.dropUsageLogsPartitions(ctx, cutoff)
|
||||||
}
|
}
|
||||||
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
|
for {
|
||||||
|
res, err := r.sql.ExecContext(ctx, `
|
||||||
|
WITH victims AS (
|
||||||
|
SELECT ctid
|
||||||
|
FROM usage_logs
|
||||||
|
WHERE created_at < $1
|
||||||
|
LIMIT $2
|
||||||
|
)
|
||||||
|
DELETE FROM usage_logs
|
||||||
|
WHERE ctid IN (SELECT ctid FROM victims)
|
||||||
|
`, cutoff.UTC(), usageLogsCleanupBatchSize)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected < usageLogsCleanupBatchSize {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||||
|
for {
|
||||||
|
res, err := r.sql.ExecContext(ctx, `
|
||||||
|
WITH victims AS (
|
||||||
|
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
|
||||||
|
FROM usage_billing_dedup
|
||||||
|
WHERE created_at < $1
|
||||||
|
LIMIT $2
|
||||||
|
), archived AS (
|
||||||
|
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
|
||||||
|
SELECT request_id, api_key_id, request_fingerprint, created_at
|
||||||
|
FROM victims
|
||||||
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
|
)
|
||||||
|
DELETE FROM usage_billing_dedup
|
||||||
|
WHERE ctid IN (SELECT ctid FROM victims)
|
||||||
|
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected < usageBillingDedupCleanupBatchSize {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||||
|
|||||||
@@ -262,6 +262,42 @@ func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se
|
|||||||
SetKey(k.Key).
|
SetKey(k.Key).
|
||||||
SetName(k.Name).
|
SetName(k.Name).
|
||||||
SetStatus(k.Status)
|
SetStatus(k.Status)
|
||||||
|
if k.Quota != 0 {
|
||||||
|
create.SetQuota(k.Quota)
|
||||||
|
}
|
||||||
|
if k.QuotaUsed != 0 {
|
||||||
|
create.SetQuotaUsed(k.QuotaUsed)
|
||||||
|
}
|
||||||
|
if k.RateLimit5h != 0 {
|
||||||
|
create.SetRateLimit5h(k.RateLimit5h)
|
||||||
|
}
|
||||||
|
if k.RateLimit1d != 0 {
|
||||||
|
create.SetRateLimit1d(k.RateLimit1d)
|
||||||
|
}
|
||||||
|
if k.RateLimit7d != 0 {
|
||||||
|
create.SetRateLimit7d(k.RateLimit7d)
|
||||||
|
}
|
||||||
|
if k.Usage5h != 0 {
|
||||||
|
create.SetUsage5h(k.Usage5h)
|
||||||
|
}
|
||||||
|
if k.Usage1d != 0 {
|
||||||
|
create.SetUsage1d(k.Usage1d)
|
||||||
|
}
|
||||||
|
if k.Usage7d != 0 {
|
||||||
|
create.SetUsage7d(k.Usage7d)
|
||||||
|
}
|
||||||
|
if k.Window5hStart != nil {
|
||||||
|
create.SetWindow5hStart(*k.Window5hStart)
|
||||||
|
}
|
||||||
|
if k.Window1dStart != nil {
|
||||||
|
create.SetWindow1dStart(*k.Window1dStart)
|
||||||
|
}
|
||||||
|
if k.Window7dStart != nil {
|
||||||
|
create.SetWindow7dStart(*k.Window7dStart)
|
||||||
|
}
|
||||||
|
if k.ExpiresAt != nil {
|
||||||
|
create.SetExpiresAt(*k.ExpiresAt)
|
||||||
|
}
|
||||||
if k.GroupID != nil {
|
if k.GroupID != nil {
|
||||||
create.SetGroupID(*k.GroupID)
|
create.SetGroupID(*k.GroupID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
|||||||
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
|
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
|
||||||
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
|
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
|
||||||
|
|
||||||
|
// usage_billing_dedup: billing idempotency narrow table
|
||||||
|
var usageBillingDedupRegclass sql.NullString
|
||||||
|
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
|
||||||
|
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
|
||||||
|
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
|
||||||
|
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
|
||||||
|
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
|
||||||
|
|
||||||
|
var usageBillingDedupArchiveRegclass sql.NullString
|
||||||
|
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
|
||||||
|
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
|
||||||
|
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
|
||||||
|
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
|
||||||
|
|
||||||
// settings table should exist
|
// settings table should exist
|
||||||
var settingsRegclass sql.NullString
|
var settingsRegclass sql.NullString
|
||||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
||||||
@@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
|||||||
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
|
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var exists bool
|
||||||
|
err := tx.QueryRowContext(context.Background(), `
|
||||||
|
SELECT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM pg_indexes
|
||||||
|
WHERE schemaname = 'public'
|
||||||
|
AND tablename = $1
|
||||||
|
AND indexname = $2
|
||||||
|
)
|
||||||
|
`, table, index).Scan(&exists)
|
||||||
|
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
|
||||||
|
require.True(t, exists, "expected index %s on %s", index, table)
|
||||||
|
}
|
||||||
|
|
||||||
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@@ -73,3 +73,14 @@ func buildReqClientKey(opts reqClientOptions) string {
|
|||||||
opts.ForceHTTP2,
|
opts.ForceHTTP2,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreatePrivacyReqClient creates an HTTP client for OpenAI privacy settings API
|
||||||
|
// This is exported for use by OpenAIPrivacyService
|
||||||
|
// Uses Chrome TLS fingerprint impersonation to bypass Cloudflare checks
|
||||||
|
func CreatePrivacyReqClient(proxyURL string) (*req.Client, error) {
|
||||||
|
return getSharedReqClient(reqClientOptions{
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
Impersonate: true, // Enable Chrome TLS fingerprint impersonation
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
308
backend/internal/repository/usage_billing_repo.go
Normal file
308
backend/internal/repository/usage_billing_repo.go
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type usageBillingRepository struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
|
||||||
|
return &usageBillingRepository{db: sqlDB}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
|
||||||
|
if cmd == nil {
|
||||||
|
return &service.UsageBillingApplyResult{}, nil
|
||||||
|
}
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return nil, errors.New("usage billing repository db is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Normalize()
|
||||||
|
if cmd.RequestID == "" {
|
||||||
|
return nil, service.ErrUsageBillingRequestIDRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := r.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if tx != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !applied {
|
||||||
|
return &service.UsageBillingApplyResult{Applied: false}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &service.UsageBillingApplyResult{Applied: true}
|
||||||
|
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tx = nil
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
|
||||||
|
var id int64
|
||||||
|
err := tx.QueryRowContext(ctx, `
|
||||||
|
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
|
RETURNING id
|
||||||
|
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
var existingFingerprint string
|
||||||
|
if err := tx.QueryRowContext(ctx, `
|
||||||
|
SELECT request_fingerprint
|
||||||
|
FROM usage_billing_dedup
|
||||||
|
WHERE request_id = $1 AND api_key_id = $2
|
||||||
|
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||||
|
return false, service.ErrUsageBillingRequestConflict
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
var archivedFingerprint string
|
||||||
|
err = tx.QueryRowContext(ctx, `
|
||||||
|
SELECT request_fingerprint
|
||||||
|
FROM usage_billing_dedup_archive
|
||||||
|
WHERE request_id = $1 AND api_key_id = $2
|
||||||
|
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
|
||||||
|
if err == nil {
|
||||||
|
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||||
|
return false, service.ErrUsageBillingRequestConflict
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
|
||||||
|
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
|
||||||
|
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.BalanceCost > 0 {
|
||||||
|
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.APIKeyQuotaCost > 0 {
|
||||||
|
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result.APIKeyQuotaExhausted = exhausted
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.APIKeyRateLimitCost > 0 {
|
||||||
|
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) {
|
||||||
|
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
|
||||||
|
const updateSQL = `
|
||||||
|
UPDATE user_subscriptions us
|
||||||
|
SET
|
||||||
|
daily_usage_usd = us.daily_usage_usd + $1,
|
||||||
|
weekly_usage_usd = us.weekly_usage_usd + $1,
|
||||||
|
monthly_usage_usd = us.monthly_usage_usd + $1,
|
||||||
|
updated_at = NOW()
|
||||||
|
FROM groups g
|
||||||
|
WHERE us.id = $2
|
||||||
|
AND us.deleted_at IS NULL
|
||||||
|
AND us.group_id = g.id
|
||||||
|
AND g.deleted_at IS NULL
|
||||||
|
`
|
||||||
|
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return service.ErrSubscriptionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
|
||||||
|
res, err := tx.ExecContext(ctx, `
|
||||||
|
UPDATE users
|
||||||
|
SET balance = balance - $1,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
|
`, amount, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return service.ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
|
||||||
|
var exhausted bool
|
||||||
|
err := tx.QueryRowContext(ctx, `
|
||||||
|
UPDATE api_keys
|
||||||
|
SET quota_used = quota_used + $1,
|
||||||
|
status = CASE
|
||||||
|
WHEN quota > 0
|
||||||
|
AND status = $3
|
||||||
|
AND quota_used < quota
|
||||||
|
AND quota_used + $1 >= quota
|
||||||
|
THEN $4
|
||||||
|
ELSE status
|
||||||
|
END,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
|
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
|
||||||
|
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return false, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return exhausted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
|
||||||
|
res, err := tx.ExecContext(ctx, `
|
||||||
|
UPDATE api_keys SET
|
||||||
|
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
|
||||||
|
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||||
|
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||||
|
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||||
|
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||||
|
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
|
`, cost, apiKeyID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected == 0 {
|
||||||
|
return service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
|
||||||
|
rows, err := tx.QueryContext(ctx,
|
||||||
|
`UPDATE accounts SET extra = (
|
||||||
|
COALESCE(extra, '{}'::jsonb)
|
||||||
|
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||||
|
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||||
|
jsonb_build_object(
|
||||||
|
'quota_daily_used',
|
||||||
|
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '24 hours'::interval <= NOW()
|
||||||
|
THEN $1
|
||||||
|
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||||
|
'quota_daily_start',
|
||||||
|
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '24 hours'::interval <= NOW()
|
||||||
|
THEN `+nowUTC+`
|
||||||
|
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||||
|
)
|
||||||
|
ELSE '{}'::jsonb END
|
||||||
|
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||||
|
jsonb_build_object(
|
||||||
|
'quota_weekly_used',
|
||||||
|
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '168 hours'::interval <= NOW()
|
||||||
|
THEN $1
|
||||||
|
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||||
|
'quota_weekly_start',
|
||||||
|
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||||
|
+ '168 hours'::interval <= NOW()
|
||||||
|
THEN `+nowUTC+`
|
||||||
|
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||||
|
)
|
||||||
|
ELSE '{}'::jsonb END
|
||||||
|
), updated_at = NOW()
|
||||||
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
|
RETURNING
|
||||||
|
COALESCE((extra->>'quota_used')::numeric, 0),
|
||||||
|
COALESCE((extra->>'quota_limit')::numeric, 0)`,
|
||||||
|
amount, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var newUsed, limit float64
|
||||||
|
if rows.Next() {
|
||||||
|
if err := rows.Scan(&newUsed, &limit); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return service.ErrAccountNotFound
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||||
|
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
||||||
|
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,279 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Balance: 100,
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-usage-billing-" + uuid.NewString(),
|
||||||
|
Name: "billing",
|
||||||
|
Quota: 1,
|
||||||
|
})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{
|
||||||
|
Name: "usage-billing-account-" + uuid.NewString(),
|
||||||
|
Type: service.AccountTypeAPIKey,
|
||||||
|
})
|
||||||
|
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
cmd := &service.UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountType: service.AccountTypeAPIKey,
|
||||||
|
BalanceCost: 1.25,
|
||||||
|
APIKeyQuotaCost: 1.25,
|
||||||
|
APIKeyRateLimitCost: 1.25,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result1)
|
||||||
|
require.True(t, result1.Applied)
|
||||||
|
require.True(t, result1.APIKeyQuotaExhausted)
|
||||||
|
|
||||||
|
result2, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result2)
|
||||||
|
require.False(t, result2.Applied)
|
||||||
|
|
||||||
|
var balance float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||||
|
require.InDelta(t, 98.75, balance, 0.000001)
|
||||||
|
|
||||||
|
var quotaUsed float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan("aUsed))
|
||||||
|
require.InDelta(t, 1.25, quotaUsed, 0.000001)
|
||||||
|
|
||||||
|
var usage5h float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h))
|
||||||
|
require.InDelta(t, 1.25, usage5h, 0.000001)
|
||||||
|
|
||||||
|
var status string
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status))
|
||||||
|
require.Equal(t, service.StatusAPIKeyQuotaExhausted, status)
|
||||||
|
|
||||||
|
var dedupCount int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount))
|
||||||
|
require.Equal(t, 1, dedupCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
})
|
||||||
|
group := mustCreateGroup(t, client, &service.Group{
|
||||||
|
Name: "usage-billing-group-" + uuid.NewString(),
|
||||||
|
Platform: service.PlatformAnthropic,
|
||||||
|
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: &group.ID,
|
||||||
|
Key: "sk-usage-billing-sub-" + uuid.NewString(),
|
||||||
|
Name: "billing-sub",
|
||||||
|
})
|
||||||
|
subscription := mustCreateSubscription(t, client, &service.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
})
|
||||||
|
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
cmd := &service.UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
AccountID: 0,
|
||||||
|
SubscriptionID: &subscription.ID,
|
||||||
|
SubscriptionCost: 2.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result1.Applied)
|
||||||
|
|
||||||
|
result2, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, result2.Applied)
|
||||||
|
|
||||||
|
var dailyUsage float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage))
|
||||||
|
require.InDelta(t, 2.5, dailyUsage, 0.000001)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Balance: 100,
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-usage-billing-conflict-" + uuid.NewString(),
|
||||||
|
Name: "billing-conflict",
|
||||||
|
})
|
||||||
|
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
BalanceCost: 1.25,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
BalanceCost: 2.50,
|
||||||
|
})
|
||||||
|
require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-usage-billing-account-" + uuid.NewString(),
|
||||||
|
Name: "billing-account",
|
||||||
|
})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{
|
||||||
|
Name: "usage-billing-account-quota-" + uuid.NewString(),
|
||||||
|
Type: service.AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"quota_limit": 100.0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountType: service.AccountTypeAPIKey,
|
||||||
|
AccountQuotaCost: 3.5,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var quotaUsed float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan("aUsed))
|
||||||
|
require.InDelta(t, 3.5, quotaUsed, 0.000001)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||||
|
|
||||||
|
oldRequestID := "dedup-old-" + uuid.NewString()
|
||||||
|
newRequestID := "dedup-new-" + uuid.NewString()
|
||||||
|
oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400)
|
||||||
|
newCreatedAt := time.Now().UTC().Add(-time.Hour)
|
||||||
|
|
||||||
|
_, err := integrationDB.ExecContext(ctx, `
|
||||||
|
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at)
|
||||||
|
VALUES ($1, 1, $2, $3), ($4, 1, $5, $6)
|
||||||
|
`,
|
||||||
|
oldRequestID, strings.Repeat("a", 64), oldCreatedAt,
|
||||||
|
newRequestID, strings.Repeat("b", 64), newCreatedAt,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||||
|
|
||||||
|
var oldCount int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount))
|
||||||
|
require.Equal(t, 0, oldCount)
|
||||||
|
|
||||||
|
var newCount int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount))
|
||||||
|
require.Equal(t, 1, newCount)
|
||||||
|
|
||||||
|
var archivedCount int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount))
|
||||||
|
require.Equal(t, 1, archivedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewUsageBillingRepository(client, integrationDB)
|
||||||
|
aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Balance: 100,
|
||||||
|
})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-usage-billing-archive-" + uuid.NewString(),
|
||||||
|
Name: "billing-archive",
|
||||||
|
})
|
||||||
|
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
cmd := &service.UsageBillingCommand{
|
||||||
|
RequestID: requestID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: user.ID,
|
||||||
|
BalanceCost: 1.25,
|
||||||
|
}
|
||||||
|
|
||||||
|
result1, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result1.Applied)
|
||||||
|
|
||||||
|
_, err = integrationDB.ExecContext(ctx, `
|
||||||
|
UPDATE usage_billing_dedup
|
||||||
|
SET created_at = $1
|
||||||
|
WHERE request_id = $2 AND api_key_id = $3
|
||||||
|
`, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||||
|
|
||||||
|
result2, err := repo.Apply(ctx, cmd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, result2.Applied)
|
||||||
|
|
||||||
|
var balance float64
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||||
|
require.InDelta(t, 98.75, balance, 0.000001)
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,8 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -14,6 +16,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -84,6 +87,367 @@ func (s *UsageLogRepoSuite) TestCreate() {
|
|||||||
s.Require().NotZero(log.ID)
|
s.Require().NotZero(log.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreate_BatchPathConcurrent(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-" + uuid.NewString()})
|
||||||
|
|
||||||
|
const total = 16
|
||||||
|
results := make([]bool, total)
|
||||||
|
errs := make([]error, total)
|
||||||
|
logs := make([]*service.UsageLog, total)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(total)
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
i := i
|
||||||
|
logs[i] = &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10 + i,
|
||||||
|
OutputTokens: 20 + i,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
results[i], errs[i] = repo.Create(ctx, logs[i])
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
require.NoError(t, errs[i])
|
||||||
|
require.True(t, results[i])
|
||||||
|
require.NotZero(t, logs[i].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE api_key_id = $1", apiKey.ID).Scan(&count))
|
||||||
|
require.Equal(t, total, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-dup-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-dup-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-dup-" + uuid.NewString()})
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
|
||||||
|
log1 := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: requestID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
log2 := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: requestID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
|
||||||
|
inserted1, err1 := repo.Create(ctx, log1)
|
||||||
|
inserted2, err2 := repo.Create(ctx, log2)
|
||||||
|
require.NoError(t, err1)
|
||||||
|
require.NoError(t, err2)
|
||||||
|
require.True(t, inserted1)
|
||||||
|
require.False(t, inserted2)
|
||||||
|
require.Equal(t, log1.ID, log2.ID)
|
||||||
|
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()})
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
|
||||||
|
const total = 8
|
||||||
|
batch := make([]usageLogCreateRequest, 0, total)
|
||||||
|
logs := make([]*service.UsageLog, 0, total)
|
||||||
|
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
log := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: requestID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10 + i,
|
||||||
|
OutputTokens: 20 + i,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
logs = append(logs, log)
|
||||||
|
batch = append(batch, usageLogCreateRequest{
|
||||||
|
log: log,
|
||||||
|
prepared: prepareUsageLogInsert(log),
|
||||||
|
resultCh: make(chan usageLogCreateResult, 1),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
repo.flushCreateBatch(integrationDB, batch)
|
||||||
|
|
||||||
|
insertedCount := 0
|
||||||
|
var firstID int64
|
||||||
|
for idx, req := range batch {
|
||||||
|
res := <-req.resultCh
|
||||||
|
require.NoError(t, res.err)
|
||||||
|
if res.inserted {
|
||||||
|
insertedCount++
|
||||||
|
}
|
||||||
|
require.NotZero(t, logs[idx].ID)
|
||||||
|
if idx == 0 {
|
||||||
|
firstID = logs[idx].ID
|
||||||
|
} else {
|
||||||
|
require.Equal(t, firstID, logs[idx].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 1, insertedCount)
|
||||||
|
|
||||||
|
var count int
|
||||||
|
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()})
|
||||||
|
requestID := uuid.NewString()
|
||||||
|
|
||||||
|
log1 := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: requestID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
log2 := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: requestID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, repo.CreateBestEffort(ctx, log1))
|
||||||
|
require.NoError(t, repo.CreateBestEffort(ctx, log2))
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
var count int
|
||||||
|
err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)
|
||||||
|
return err == nil && count == 1
|
||||||
|
}, 3*time.Second, 20*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1)
|
||||||
|
repo.bestEffortBatchCh <- usageLogBestEffortRequest{}
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()})
|
||||||
|
|
||||||
|
err := repo.CreateBestEffort(ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, service.IsUsageLogCreateDropped(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
})
|
||||||
|
|
||||||
|
require.False(t, inserted)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||||
|
repo.createBatchCh <- usageLogCreateRequest{}
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()})
|
||||||
|
|
||||||
|
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
})
|
||||||
|
|
||||||
|
require.False(t, inserted)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := repo.createBatched(ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
})
|
||||||
|
errCh <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
req := <-repo.createBatchCh
|
||||||
|
require.NotNil(t, req.shared)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := <-errCh
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||||
|
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) {
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()})
|
||||||
|
|
||||||
|
log := &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
req := usageLogCreateRequest{
|
||||||
|
log: log,
|
||||||
|
prepared: prepareUsageLogInsert(log),
|
||||||
|
shared: &usageLogCreateShared{},
|
||||||
|
resultCh: make(chan usageLogCreateResult, 1),
|
||||||
|
}
|
||||||
|
req.shared.state.Store(usageLogCreateStateCanceled)
|
||||||
|
|
||||||
|
repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req})
|
||||||
|
|
||||||
|
res := <-req.resultCh
|
||||||
|
require.False(t, res.inserted)
|
||||||
|
require.Error(t, res.err)
|
||||||
|
require.True(t, service.IsUsageLogCreateNotPersisted(res.err))
|
||||||
|
}
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
||||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||||
|
|||||||
@@ -73,6 +73,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
|||||||
sqlmock.AnyArg(), // media_type
|
sqlmock.AnyArg(), // media_type
|
||||||
sqlmock.AnyArg(), // service_tier
|
sqlmock.AnyArg(), // service_tier
|
||||||
sqlmock.AnyArg(), // reasoning_effort
|
sqlmock.AnyArg(), // reasoning_effort
|
||||||
|
sqlmock.AnyArg(), // inbound_endpoint
|
||||||
|
sqlmock.AnyArg(), // upstream_endpoint
|
||||||
log.CacheTTLOverridden,
|
log.CacheTTLOverridden,
|
||||||
createdAt,
|
createdAt,
|
||||||
).
|
).
|
||||||
@@ -141,6 +143,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
|||||||
sqlmock.AnyArg(),
|
sqlmock.AnyArg(),
|
||||||
serviceTier,
|
serviceTier,
|
||||||
sqlmock.AnyArg(),
|
sqlmock.AnyArg(),
|
||||||
|
sqlmock.AnyArg(),
|
||||||
|
sqlmock.AnyArg(),
|
||||||
log.CacheTTLOverridden,
|
log.CacheTTLOverridden,
|
||||||
createdAt,
|
createdAt,
|
||||||
).
|
).
|
||||||
@@ -248,6 +252,35 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
|
|||||||
require.NoError(t, mock.ExpectationsWereMet())
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryGetUserSpendingRanking(t *testing.T) {
|
||||||
|
db, mock := newSQLMock(t)
|
||||||
|
repo := &usageLogRepository{sql: db}
|
||||||
|
|
||||||
|
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
end := start.Add(24 * time.Hour)
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows([]string{"user_id", "email", "actual_cost", "requests", "tokens", "total_actual_cost"}).
|
||||||
|
AddRow(int64(2), "beta@example.com", 12.5, int64(9), int64(900), 40.0).
|
||||||
|
AddRow(int64(1), "alpha@example.com", 12.5, int64(8), int64(800), 40.0).
|
||||||
|
AddRow(int64(3), "gamma@example.com", 4.25, int64(5), int64(300), 40.0)
|
||||||
|
|
||||||
|
mock.ExpectQuery("WITH user_spend AS \\(").
|
||||||
|
WithArgs(start, end, 12).
|
||||||
|
WillReturnRows(rows)
|
||||||
|
|
||||||
|
got, err := repo.GetUserSpendingRanking(context.Background(), start, end, 12)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, &usagestats.UserSpendingRankingResponse{
|
||||||
|
Ranking: []usagestats.UserSpendingRankingItem{
|
||||||
|
{UserID: 2, Email: "beta@example.com", ActualCost: 12.5, Requests: 9, Tokens: 900},
|
||||||
|
{UserID: 1, Email: "alpha@example.com", ActualCost: 12.5, Requests: 8, Tokens: 800},
|
||||||
|
{UserID: 3, Email: "gamma@example.com", ActualCost: 4.25, Requests: 5, Tokens: 300},
|
||||||
|
},
|
||||||
|
TotalActualCost: 40.0,
|
||||||
|
}, got)
|
||||||
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
|
func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -347,6 +380,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{Valid: true, String: "priority"},
|
sql.NullString{Valid: true, String: "priority"},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
|
sql.NullString{},
|
||||||
|
sql.NullString{},
|
||||||
false,
|
false,
|
||||||
now,
|
now,
|
||||||
}})
|
}})
|
||||||
@@ -386,6 +421,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{Valid: true, String: "flex"},
|
sql.NullString{Valid: true, String: "flex"},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
|
sql.NullString{},
|
||||||
|
sql.NullString{},
|
||||||
false,
|
false,
|
||||||
now,
|
now,
|
||||||
}})
|
}})
|
||||||
@@ -425,6 +462,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{Valid: true, String: "priority"},
|
sql.NullString{Valid: true, String: "priority"},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
|
sql.NullString{},
|
||||||
|
sql.NullString{},
|
||||||
false,
|
false,
|
||||||
now,
|
now,
|
||||||
}})
|
}})
|
||||||
|
|||||||
@@ -3,8 +3,11 @@
|
|||||||
package repository
|
package repository
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) {
|
||||||
|
log := &service.UsageLog{
|
||||||
|
UserID: 1,
|
||||||
|
APIKeyID: 2,
|
||||||
|
AccountID: 3,
|
||||||
|
RequestID: "req-batch-no-update",
|
||||||
|
Model: "gpt-5",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalCost: 1.2,
|
||||||
|
ActualCost: 1.2,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
prepared := prepareUsageLogInsert(log)
|
||||||
|
|
||||||
|
query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{
|
||||||
|
usageLogBatchKey(log.RequestID, log.APIKeyID): prepared,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING")
|
||||||
|
require.NotContains(t, strings.ToUpper(query), "DO UPDATE")
|
||||||
|
}
|
||||||
|
|||||||
@@ -95,6 +95,35 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
||||||
|
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
|
||||||
|
query := `
|
||||||
|
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
|
||||||
|
FROM user_group_rate_multipliers ugr
|
||||||
|
JOIN users u ON u.id = ugr.user_id
|
||||||
|
WHERE ugr.group_id = $1
|
||||||
|
ORDER BY ugr.user_id
|
||||||
|
`
|
||||||
|
rows, err := r.sql.QueryContext(ctx, query, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var result []service.UserGroupRateEntry
|
||||||
|
for rows.Next() {
|
||||||
|
var entry service.UserGroupRateEntry
|
||||||
|
if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, entry)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||||
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||||
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||||
@@ -164,6 +193,31 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
|
||||||
|
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
|
||||||
|
if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
userIDs := make([]int64, len(entries))
|
||||||
|
rates := make([]float64, len(entries))
|
||||||
|
for i, e := range entries {
|
||||||
|
userIDs[i] = e.UserID
|
||||||
|
rates[i] = e.RateMultiplier
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
_, err := r.sql.ExecContext(ctx, `
|
||||||
|
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
|
||||||
|
SELECT data.user_id, $1::bigint, data.rate_multiplier, $2::timestamptz, $2::timestamptz
|
||||||
|
FROM unnest($3::bigint[], $4::double precision[]) AS data(user_id, rate_multiplier)
|
||||||
|
ON CONFLICT (user_id, group_id)
|
||||||
|
DO UPDATE SET rate_multiplier = EXCLUDED.rate_multiplier, updated_at = EXCLUDED.updated_at
|
||||||
|
`, groupID, now, pq.Array(userIDs), pq.Array(rates))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率
|
// DeleteByGroupID 删除指定分组的所有用户专属倍率
|
||||||
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
||||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewAnnouncementRepository,
|
NewAnnouncementRepository,
|
||||||
NewAnnouncementReadRepository,
|
NewAnnouncementReadRepository,
|
||||||
NewUsageLogRepository,
|
NewUsageLogRepository,
|
||||||
|
NewUsageBillingRepository,
|
||||||
NewIdempotencyRepository,
|
NewIdempotencyRepository,
|
||||||
NewUsageCleanupRepository,
|
NewUsageCleanupRepository,
|
||||||
NewDashboardAggregationRepository,
|
NewDashboardAggregationRepository,
|
||||||
@@ -99,6 +100,10 @@ var ProviderSet = wire.NewSet(
|
|||||||
// Encryptors
|
// Encryptors
|
||||||
NewAESEncryptor,
|
NewAESEncryptor,
|
||||||
|
|
||||||
|
// Backup infrastructure
|
||||||
|
NewPgDumper,
|
||||||
|
NewS3BackupStoreFactory,
|
||||||
|
|
||||||
// HTTP service ports (DI Strategy A: return interface directly)
|
// HTTP service ports (DI Strategy A: return interface directly)
|
||||||
NewTurnstileVerifier,
|
NewTurnstileVerifier,
|
||||||
ProvidePricingRemoteClient,
|
ProvidePricingRemoteClient,
|
||||||
|
|||||||
@@ -493,6 +493,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"registration_email_suffix_whitelist": [],
|
"registration_email_suffix_whitelist": [],
|
||||||
"promo_code_enabled": true,
|
"promo_code_enabled": true,
|
||||||
"password_reset_enabled": false,
|
"password_reset_enabled": false,
|
||||||
|
"frontend_url": "",
|
||||||
"totp_enabled": false,
|
"totp_enabled": false,
|
||||||
"totp_encryption_key_configured": false,
|
"totp_encryption_key_configured": false,
|
||||||
"smtp_host": "smtp.example.com",
|
"smtp_host": "smtp.example.com",
|
||||||
@@ -537,6 +538,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"purchase_subscription_url": "",
|
"purchase_subscription_url": "",
|
||||||
"min_claude_code_version": "",
|
"min_claude_code_version": "",
|
||||||
"allow_ungrouped_key_scheduling": false,
|
"allow_ungrouped_key_scheduling": false,
|
||||||
|
"backend_mode_enabled": false,
|
||||||
"custom_menu_items": []
|
"custom_menu_items": []
|
||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
@@ -645,7 +647,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
settingRepo := newStubSettingRepo()
|
settingRepo := newStubSettingRepo()
|
||||||
settingService := service.NewSettingService(settingRepo, cfg)
|
settingService := service.NewSettingService(settingRepo, cfg)
|
||||||
|
|
||||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
@@ -1623,6 +1625,14 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
@@ -1635,6 +1645,10 @@ func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, end
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||||
logs := r.userLogs[userID]
|
logs := r.userLogs[userID]
|
||||||
if len(logs) == 0 {
|
if len(logs) == 0 {
|
||||||
|
|||||||
51
backend/internal/server/middleware/backend_mode_guard.go
Normal file
51
backend/internal/server/middleware/backend_mode_guard.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BackendModeUserGuard blocks non-admin users from accessing user routes when backend mode is enabled.
|
||||||
|
// Must be placed AFTER JWT auth middleware so that the user role is available in context.
|
||||||
|
func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
role, _ := GetUserRoleFromContext(c)
|
||||||
|
if role == "admin" {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Forbidden(c, "Backend mode is active. User self-service is disabled.")
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
|
||||||
|
// Allows: login, login/2fa, logout, refresh (admin needs these).
|
||||||
|
// Blocks: register, forgot-password, reset-password, OAuth, etc.
|
||||||
|
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
path := c.Request.URL.Path
|
||||||
|
// Allow login, 2FA, logout, refresh, public settings
|
||||||
|
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
|
||||||
|
for _, suffix := range allowedSuffixes {
|
||||||
|
if strings.HasSuffix(path, suffix) {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
}
|
||||||
239
backend/internal/server/middleware/backend_mode_guard_test.go
Normal file
239
backend/internal/server/middleware/backend_mode_guard_test.go
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type bmSettingRepo struct {
|
||||||
|
values map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bmSettingRepo) Get(_ context.Context, _ string) (*service.Setting, error) {
|
||||||
|
panic("unexpected Get call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bmSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||||
|
v, ok := r.values[key]
|
||||||
|
if !ok {
|
||||||
|
return "", service.ErrSettingNotFound
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bmSettingRepo) Set(_ context.Context, _, _ string) error {
|
||||||
|
panic("unexpected Set call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bmSettingRepo) GetMultiple(_ context.Context, _ []string) (map[string]string, error) {
|
||||||
|
panic("unexpected GetMultiple call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bmSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||||
|
if r.values == nil {
|
||||||
|
r.values = make(map[string]string, len(settings))
|
||||||
|
}
|
||||||
|
for key, value := range settings {
|
||||||
|
r.values[key] = value
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bmSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||||
|
panic("unexpected GetAll call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bmSettingRepo) Delete(_ context.Context, _ string) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBackendModeSettingService(t *testing.T, enabled string) *service.SettingService {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
repo := &bmSettingRepo{
|
||||||
|
values: map[string]string{
|
||||||
|
service.SettingKeyBackendModeEnabled: enabled,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := service.NewSettingService(repo, &config.Config{})
|
||||||
|
require.NoError(t, svc.UpdateSettings(context.Background(), &service.SystemSettings{
|
||||||
|
BackendModeEnabled: enabled == "true",
|
||||||
|
}))
|
||||||
|
|
||||||
|
return svc
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringPtr(v string) *string {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackendModeUserGuard(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
nilService bool
|
||||||
|
enabled string
|
||||||
|
role *string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "disabled_allows_all",
|
||||||
|
enabled: "false",
|
||||||
|
role: stringPtr("user"),
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil_service_allows_all",
|
||||||
|
nilService: true,
|
||||||
|
role: stringPtr("user"),
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_admin_allowed",
|
||||||
|
enabled: "true",
|
||||||
|
role: stringPtr("admin"),
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_user_blocked",
|
||||||
|
enabled: "true",
|
||||||
|
role: stringPtr("user"),
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_no_role_blocked",
|
||||||
|
enabled: "true",
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_empty_role_blocked",
|
||||||
|
enabled: "true",
|
||||||
|
role: stringPtr(""),
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
if tc.role != nil {
|
||||||
|
role := *tc.role
|
||||||
|
r.Use(func(c *gin.Context) {
|
||||||
|
c.Set(string(ContextKeyUserRole), role)
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
var svc *service.SettingService
|
||||||
|
if !tc.nilService {
|
||||||
|
svc = newBackendModeSettingService(t, tc.enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Use(BackendModeUserGuard(svc))
|
||||||
|
r.GET("/test", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, tc.wantStatus, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackendModeAuthGuard(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
nilService bool
|
||||||
|
enabled string
|
||||||
|
path string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "disabled_allows_all",
|
||||||
|
enabled: "false",
|
||||||
|
path: "/api/v1/auth/register",
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil_service_allows_all",
|
||||||
|
nilService: true,
|
||||||
|
path: "/api/v1/auth/register",
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_allows_login",
|
||||||
|
enabled: "true",
|
||||||
|
path: "/api/v1/auth/login",
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_allows_login_2fa",
|
||||||
|
enabled: "true",
|
||||||
|
path: "/api/v1/auth/login/2fa",
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_allows_logout",
|
||||||
|
enabled: "true",
|
||||||
|
path: "/api/v1/auth/logout",
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_allows_refresh",
|
||||||
|
enabled: "true",
|
||||||
|
path: "/api/v1/auth/refresh",
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_blocks_register",
|
||||||
|
enabled: "true",
|
||||||
|
path: "/api/v1/auth/register",
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_blocks_forgot_password",
|
||||||
|
enabled: "true",
|
||||||
|
path: "/api/v1/auth/forgot-password",
|
||||||
|
wantStatus: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
|
||||||
|
var svc *service.SettingService
|
||||||
|
if !tc.nilService {
|
||||||
|
svc = newBackendModeSettingService(t, tc.enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Use(BackendModeAuthGuard(svc))
|
||||||
|
r.Any("/*path", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, tc.wantStatus, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -107,9 +107,9 @@ func registerRoutes(
|
|||||||
v1 := r.Group("/api/v1")
|
v1 := r.Group("/api/v1")
|
||||||
|
|
||||||
// 注册各模块路由
|
// 注册各模块路由
|
||||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
|
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
|
||||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
|
||||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
|
routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
|
||||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,6 +58,9 @@ func RegisterAdminRoutes(
|
|||||||
// 数据管理
|
// 数据管理
|
||||||
registerDataManagementRoutes(admin, h)
|
registerDataManagementRoutes(admin, h)
|
||||||
|
|
||||||
|
// 数据库备份恢复
|
||||||
|
registerBackupRoutes(admin, h)
|
||||||
|
|
||||||
// 运维监控(Ops)
|
// 运维监控(Ops)
|
||||||
registerOpsRoutes(admin, h)
|
registerOpsRoutes(admin, h)
|
||||||
|
|
||||||
@@ -192,6 +195,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats)
|
dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats)
|
||||||
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
|
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
|
||||||
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
||||||
|
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
|
||||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
||||||
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
|
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
|
||||||
@@ -228,6 +232,9 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
groups.PUT("/:id", h.Admin.Group.Update)
|
groups.PUT("/:id", h.Admin.Group.Update)
|
||||||
groups.DELETE("/:id", h.Admin.Group.Delete)
|
groups.DELETE("/:id", h.Admin.Group.Delete)
|
||||||
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
||||||
|
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
|
||||||
|
groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers)
|
||||||
|
groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers)
|
||||||
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -436,6 +443,30 @@ func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func registerBackupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
|
backup := admin.Group("/backups")
|
||||||
|
{
|
||||||
|
// S3 存储配置
|
||||||
|
backup.GET("/s3-config", h.Admin.Backup.GetS3Config)
|
||||||
|
backup.PUT("/s3-config", h.Admin.Backup.UpdateS3Config)
|
||||||
|
backup.POST("/s3-config/test", h.Admin.Backup.TestS3Connection)
|
||||||
|
|
||||||
|
// 定时备份配置
|
||||||
|
backup.GET("/schedule", h.Admin.Backup.GetSchedule)
|
||||||
|
backup.PUT("/schedule", h.Admin.Backup.UpdateSchedule)
|
||||||
|
|
||||||
|
// 备份操作
|
||||||
|
backup.POST("", h.Admin.Backup.CreateBackup)
|
||||||
|
backup.GET("", h.Admin.Backup.ListBackups)
|
||||||
|
backup.GET("/:id", h.Admin.Backup.GetBackup)
|
||||||
|
backup.DELETE("/:id", h.Admin.Backup.DeleteBackup)
|
||||||
|
backup.GET("/:id/download-url", h.Admin.Backup.GetDownloadURL)
|
||||||
|
|
||||||
|
// 恢复操作
|
||||||
|
backup.POST("/:id/restore", h.Admin.Backup.RestoreBackup)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||||
system := admin.Group("/system")
|
system := admin.Group("/system")
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
@@ -17,12 +18,14 @@ func RegisterAuthRoutes(
|
|||||||
h *handler.Handlers,
|
h *handler.Handlers,
|
||||||
jwtAuth servermiddleware.JWTAuthMiddleware,
|
jwtAuth servermiddleware.JWTAuthMiddleware,
|
||||||
redisClient *redis.Client,
|
redisClient *redis.Client,
|
||||||
|
settingService *service.SettingService,
|
||||||
) {
|
) {
|
||||||
// 创建速率限制器
|
// 创建速率限制器
|
||||||
rateLimiter := middleware.NewRateLimiter(redisClient)
|
rateLimiter := middleware.NewRateLimiter(redisClient)
|
||||||
|
|
||||||
// 公开接口
|
// 公开接口
|
||||||
auth := v1.Group("/auth")
|
auth := v1.Group("/auth")
|
||||||
|
auth.Use(servermiddleware.BackendModeAuthGuard(settingService))
|
||||||
{
|
{
|
||||||
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
|
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
|
||||||
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
|
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
|
||||||
@@ -78,6 +81,7 @@ func RegisterAuthRoutes(
|
|||||||
// 需要认证的当前用户信息
|
// 需要认证的当前用户信息
|
||||||
authenticated := v1.Group("")
|
authenticated := v1.Group("")
|
||||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||||
|
authenticated.Use(servermiddleware.BackendModeUserGuard(settingService))
|
||||||
{
|
{
|
||||||
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
||||||
// 撤销所有会话(需要认证)
|
// 撤销所有会话(需要认证)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
|
|||||||
c.Next()
|
c.Next()
|
||||||
}),
|
}),
|
||||||
redisClient,
|
redisClient,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package routes
|
|||||||
import (
|
import (
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -12,6 +13,7 @@ func RegisterSoraClientRoutes(
|
|||||||
v1 *gin.RouterGroup,
|
v1 *gin.RouterGroup,
|
||||||
h *handler.Handlers,
|
h *handler.Handlers,
|
||||||
jwtAuth middleware.JWTAuthMiddleware,
|
jwtAuth middleware.JWTAuthMiddleware,
|
||||||
|
settingService *service.SettingService,
|
||||||
) {
|
) {
|
||||||
if h.SoraClient == nil {
|
if h.SoraClient == nil {
|
||||||
return
|
return
|
||||||
@@ -19,6 +21,7 @@ func RegisterSoraClientRoutes(
|
|||||||
|
|
||||||
authenticated := v1.Group("/sora")
|
authenticated := v1.Group("/sora")
|
||||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||||
|
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||||
{
|
{
|
||||||
authenticated.POST("/generate", h.SoraClient.Generate)
|
authenticated.POST("/generate", h.SoraClient.Generate)
|
||||||
authenticated.GET("/generations", h.SoraClient.ListGenerations)
|
authenticated.GET("/generations", h.SoraClient.ListGenerations)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package routes
|
|||||||
import (
|
import (
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -12,9 +13,11 @@ func RegisterUserRoutes(
|
|||||||
v1 *gin.RouterGroup,
|
v1 *gin.RouterGroup,
|
||||||
h *handler.Handlers,
|
h *handler.Handlers,
|
||||||
jwtAuth middleware.JWTAuthMiddleware,
|
jwtAuth middleware.JWTAuthMiddleware,
|
||||||
|
settingService *service.SettingService,
|
||||||
) {
|
) {
|
||||||
authenticated := v1.Group("")
|
authenticated := v1.Group("")
|
||||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||||
|
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||||
{
|
{
|
||||||
// 用户接口
|
// 用户接口
|
||||||
user := authenticated.Group("/user")
|
user := authenticated.Group("/user")
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
@@ -412,6 +413,7 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri
|
|||||||
if a.Platform == domain.PlatformAntigravity {
|
if a.Platform == domain.PlatformAntigravity {
|
||||||
return domain.DefaultAntigravityModelMapping
|
return domain.DefaultAntigravityModelMapping
|
||||||
}
|
}
|
||||||
|
// Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if len(rawMapping) == 0 {
|
if len(rawMapping) == 0 {
|
||||||
@@ -521,16 +523,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
|||||||
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
|
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
|
||||||
// 如果未配置 mapping,返回原始模型名
|
// 如果未配置 mapping,返回原始模型名
|
||||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||||
|
mappedModel, _ := a.ResolveMappedModel(requestedModel)
|
||||||
|
return mappedModel
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveMappedModel 获取映射后的模型名,并返回是否命中了账号级映射。
|
||||||
|
// matched=true 表示命中了精确映射或通配符映射,即使映射结果与原模型名相同。
|
||||||
|
func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, matched bool) {
|
||||||
mapping := a.GetModelMapping()
|
mapping := a.GetModelMapping()
|
||||||
if len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
return requestedModel
|
return requestedModel, false
|
||||||
}
|
}
|
||||||
// 精确匹配优先
|
// 精确匹配优先
|
||||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||||
return mappedModel
|
return mappedModel, true
|
||||||
}
|
}
|
||||||
// 通配符匹配(最长优先)
|
// 通配符匹配(最长优先)
|
||||||
return matchWildcardMapping(mapping, requestedModel)
|
return matchWildcardMappingResult(mapping, requestedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) GetBaseURL() string {
|
func (a *Account) GetBaseURL() string {
|
||||||
@@ -604,9 +613,7 @@ func matchWildcard(pattern, str string) bool {
|
|||||||
return matchAntigravityWildcard(pattern, str)
|
return matchAntigravityWildcard(pattern, str)
|
||||||
}
|
}
|
||||||
|
|
||||||
// matchWildcardMapping 通配符映射匹配(最长优先)
|
func matchWildcardMappingResult(mapping map[string]string, requestedModel string) (string, bool) {
|
||||||
// 如果没有匹配,返回原始字符串
|
|
||||||
func matchWildcardMapping(mapping map[string]string, requestedModel string) string {
|
|
||||||
// 收集所有匹配的 pattern,按长度降序排序(最长优先)
|
// 收集所有匹配的 pattern,按长度降序排序(最长优先)
|
||||||
type patternMatch struct {
|
type patternMatch struct {
|
||||||
pattern string
|
pattern string
|
||||||
@@ -621,7 +628,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(matches) == 0 {
|
if len(matches) == 0 {
|
||||||
return requestedModel // 无匹配,返回原始模型名
|
return requestedModel, false // 无匹配,返回原始模型名
|
||||||
}
|
}
|
||||||
|
|
||||||
// 按 pattern 长度降序排序
|
// 按 pattern 长度降序排序
|
||||||
@@ -632,7 +639,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
|
|||||||
return matches[i].pattern < matches[j].pattern
|
return matches[i].pattern < matches[j].pattern
|
||||||
})
|
})
|
||||||
|
|
||||||
return matches[0].target
|
return matches[0].target, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||||
@@ -650,7 +657,7 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
|
|||||||
// IsPoolMode 检查 API Key 账号是否启用池模式。
|
// IsPoolMode 检查 API Key 账号是否启用池模式。
|
||||||
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
|
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
|
||||||
func (a *Account) IsPoolMode() bool {
|
func (a *Account) IsPoolMode() bool {
|
||||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
if !a.IsAPIKeyOrBedrock() || a.Credentials == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if v, ok := a.Credentials["pool_mode"]; ok {
|
if v, ok := a.Credentials["pool_mode"]; ok {
|
||||||
@@ -764,6 +771,19 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Account) IsBedrock() bool {
|
||||||
|
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) IsBedrockAPIKey() bool {
|
||||||
|
return a.IsBedrock() && a.GetCredential("auth_mode") == "apikey"
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAPIKeyOrBedrock 返回账号类型是否支持配额和池模式等特性
|
||||||
|
func (a *Account) IsAPIKeyOrBedrock() bool {
|
||||||
|
return a.Type == AccountTypeAPIKey || a.Type == AccountTypeBedrock
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Account) IsOpenAI() bool {
|
func (a *Account) IsOpenAI() bool {
|
||||||
return a.Platform == PlatformOpenAI
|
return a.Platform == PlatformOpenAI
|
||||||
}
|
}
|
||||||
@@ -1260,6 +1280,240 @@ func (a *Account) getExtraTime(key string) time.Time {
|
|||||||
return time.Time{}
|
return time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getExtraString 从 Extra 中读取指定 key 的字符串值
|
||||||
|
func (a *Account) getExtraString(key string) string {
|
||||||
|
if a.Extra == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if v, ok := a.Extra[key]; ok {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// getExtraInt 从 Extra 中读取指定 key 的 int 值
|
||||||
|
func (a *Account) getExtraInt(key string) int {
|
||||||
|
if a.Extra == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if v, ok := a.Extra[key]; ok {
|
||||||
|
return int(parseExtraFloat64(v))
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQuotaDailyResetMode 获取日额度重置模式:"rolling"(默认)或 "fixed"
|
||||||
|
func (a *Account) GetQuotaDailyResetMode() string {
|
||||||
|
if m := a.getExtraString("quota_daily_reset_mode"); m == "fixed" {
|
||||||
|
return "fixed"
|
||||||
|
}
|
||||||
|
return "rolling"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQuotaDailyResetHour 获取固定重置的小时(0-23),默认 0
|
||||||
|
func (a *Account) GetQuotaDailyResetHour() int {
|
||||||
|
return a.getExtraInt("quota_daily_reset_hour")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQuotaWeeklyResetMode 获取周额度重置模式:"rolling"(默认)或 "fixed"
|
||||||
|
func (a *Account) GetQuotaWeeklyResetMode() string {
|
||||||
|
if m := a.getExtraString("quota_weekly_reset_mode"); m == "fixed" {
|
||||||
|
return "fixed"
|
||||||
|
}
|
||||||
|
return "rolling"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQuotaWeeklyResetDay 获取固定重置的星期几(0=周日, 1=周一, ..., 6=周六),默认 1(周一)
|
||||||
|
func (a *Account) GetQuotaWeeklyResetDay() int {
|
||||||
|
if a.Extra == nil {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if _, ok := a.Extra["quota_weekly_reset_day"]; !ok {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return a.getExtraInt("quota_weekly_reset_day")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQuotaWeeklyResetHour 获取周配额固定重置的小时(0-23),默认 0
|
||||||
|
func (a *Account) GetQuotaWeeklyResetHour() int {
|
||||||
|
return a.getExtraInt("quota_weekly_reset_hour")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQuotaResetTimezone 获取固定重置的时区名(IANA),默认 "UTC"
|
||||||
|
func (a *Account) GetQuotaResetTimezone() string {
|
||||||
|
if tz := a.getExtraString("quota_reset_timezone"); tz != "" {
|
||||||
|
return tz
|
||||||
|
}
|
||||||
|
return "UTC"
|
||||||
|
}
|
||||||
|
|
||||||
|
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
|
||||||
|
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
|
||||||
|
t := after.In(tz)
|
||||||
|
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||||
|
if !after.Before(today) {
|
||||||
|
return today.AddDate(0, 0, 1)
|
||||||
|
}
|
||||||
|
return today
|
||||||
|
}
|
||||||
|
|
||||||
|
// lastFixedDailyReset 计算 now 之前最近一次的每日固定重置时间点
|
||||||
|
func lastFixedDailyReset(hour int, tz *time.Location, now time.Time) time.Time {
|
||||||
|
t := now.In(tz)
|
||||||
|
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||||
|
if now.Before(today) {
|
||||||
|
return today.AddDate(0, 0, -1)
|
||||||
|
}
|
||||||
|
return today
|
||||||
|
}
|
||||||
|
|
||||||
|
// nextFixedWeeklyReset 计算在 after 之后的下一个每周固定重置时间点
|
||||||
|
// day: 0=Sunday, 1=Monday, ..., 6=Saturday
|
||||||
|
func nextFixedWeeklyReset(day, hour int, tz *time.Location, after time.Time) time.Time {
|
||||||
|
t := after.In(tz)
|
||||||
|
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||||
|
currentDay := int(todayReset.Weekday())
|
||||||
|
|
||||||
|
daysForward := (day - currentDay + 7) % 7
|
||||||
|
if daysForward == 0 && !after.Before(todayReset) {
|
||||||
|
daysForward = 7
|
||||||
|
}
|
||||||
|
return todayReset.AddDate(0, 0, daysForward)
|
||||||
|
}
|
||||||
|
|
||||||
|
// lastFixedWeeklyReset 计算 now 之前最近一次的每周固定重置时间点
|
||||||
|
func lastFixedWeeklyReset(day, hour int, tz *time.Location, now time.Time) time.Time {
|
||||||
|
t := now.In(tz)
|
||||||
|
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||||
|
currentDay := int(todayReset.Weekday())
|
||||||
|
|
||||||
|
daysBack := (currentDay - day + 7) % 7
|
||||||
|
if daysBack == 0 && now.Before(todayReset) {
|
||||||
|
daysBack = 7
|
||||||
|
}
|
||||||
|
return todayReset.AddDate(0, 0, -daysBack)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isFixedDailyPeriodExpired 检查日配额是否在固定时间模式下已过期
|
||||||
|
func (a *Account) isFixedDailyPeriodExpired(periodStart time.Time) bool {
|
||||||
|
if periodStart.IsZero() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
|
||||||
|
if err != nil {
|
||||||
|
tz = time.UTC
|
||||||
|
}
|
||||||
|
lastReset := lastFixedDailyReset(a.GetQuotaDailyResetHour(), tz, time.Now())
|
||||||
|
return periodStart.Before(lastReset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isFixedWeeklyPeriodExpired 检查周配额是否在固定时间模式下已过期
|
||||||
|
func (a *Account) isFixedWeeklyPeriodExpired(periodStart time.Time) bool {
|
||||||
|
if periodStart.IsZero() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
|
||||||
|
if err != nil {
|
||||||
|
tz = time.UTC
|
||||||
|
}
|
||||||
|
lastReset := lastFixedWeeklyReset(a.GetQuotaWeeklyResetDay(), a.GetQuotaWeeklyResetHour(), tz, time.Now())
|
||||||
|
return periodStart.Before(lastReset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ComputeQuotaResetAt 根据当前配置计算并填充 extra 中的 quota_daily_reset_at / quota_weekly_reset_at
|
||||||
|
// 在保存账号配置时调用
|
||||||
|
func ComputeQuotaResetAt(extra map[string]any) {
|
||||||
|
now := time.Now()
|
||||||
|
tzName, _ := extra["quota_reset_timezone"].(string)
|
||||||
|
if tzName == "" {
|
||||||
|
tzName = "UTC"
|
||||||
|
}
|
||||||
|
tz, err := time.LoadLocation(tzName)
|
||||||
|
if err != nil {
|
||||||
|
tz = time.UTC
|
||||||
|
}
|
||||||
|
|
||||||
|
// 日配额固定重置时间
|
||||||
|
if mode, _ := extra["quota_daily_reset_mode"].(string); mode == "fixed" {
|
||||||
|
hour := int(parseExtraFloat64(extra["quota_daily_reset_hour"]))
|
||||||
|
if hour < 0 || hour > 23 {
|
||||||
|
hour = 0
|
||||||
|
}
|
||||||
|
resetAt := nextFixedDailyReset(hour, tz, now)
|
||||||
|
extra["quota_daily_reset_at"] = resetAt.UTC().Format(time.RFC3339)
|
||||||
|
} else {
|
||||||
|
delete(extra, "quota_daily_reset_at")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 周配额固定重置时间
|
||||||
|
if mode, _ := extra["quota_weekly_reset_mode"].(string); mode == "fixed" {
|
||||||
|
day := 1 // 默认周一
|
||||||
|
if d, ok := extra["quota_weekly_reset_day"]; ok {
|
||||||
|
day = int(parseExtraFloat64(d))
|
||||||
|
}
|
||||||
|
if day < 0 || day > 6 {
|
||||||
|
day = 1
|
||||||
|
}
|
||||||
|
hour := int(parseExtraFloat64(extra["quota_weekly_reset_hour"]))
|
||||||
|
if hour < 0 || hour > 23 {
|
||||||
|
hour = 0
|
||||||
|
}
|
||||||
|
resetAt := nextFixedWeeklyReset(day, hour, tz, now)
|
||||||
|
extra["quota_weekly_reset_at"] = resetAt.UTC().Format(time.RFC3339)
|
||||||
|
} else {
|
||||||
|
delete(extra, "quota_weekly_reset_at")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateQuotaResetConfig 校验配额固定重置时间配置的合法性
|
||||||
|
func ValidateQuotaResetConfig(extra map[string]any) error {
|
||||||
|
if extra == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// 校验时区
|
||||||
|
if tz, ok := extra["quota_reset_timezone"].(string); ok && tz != "" {
|
||||||
|
if _, err := time.LoadLocation(tz); err != nil {
|
||||||
|
return errors.New("invalid quota_reset_timezone: must be a valid IANA timezone name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 日配额重置模式
|
||||||
|
if mode, ok := extra["quota_daily_reset_mode"].(string); ok {
|
||||||
|
if mode != "rolling" && mode != "fixed" {
|
||||||
|
return errors.New("quota_daily_reset_mode must be 'rolling' or 'fixed'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 日配额重置小时
|
||||||
|
if v, ok := extra["quota_daily_reset_hour"]; ok {
|
||||||
|
hour := int(parseExtraFloat64(v))
|
||||||
|
if hour < 0 || hour > 23 {
|
||||||
|
return errors.New("quota_daily_reset_hour must be between 0 and 23")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 周配额重置模式
|
||||||
|
if mode, ok := extra["quota_weekly_reset_mode"].(string); ok {
|
||||||
|
if mode != "rolling" && mode != "fixed" {
|
||||||
|
return errors.New("quota_weekly_reset_mode must be 'rolling' or 'fixed'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 周配额重置星期几
|
||||||
|
if v, ok := extra["quota_weekly_reset_day"]; ok {
|
||||||
|
day := int(parseExtraFloat64(v))
|
||||||
|
if day < 0 || day > 6 {
|
||||||
|
return errors.New("quota_weekly_reset_day must be between 0 (Sunday) and 6 (Saturday)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 周配额重置小时
|
||||||
|
if v, ok := extra["quota_weekly_reset_hour"]; ok {
|
||||||
|
hour := int(parseExtraFloat64(v))
|
||||||
|
if hour < 0 || hour > 23 {
|
||||||
|
return errors.New("quota_weekly_reset_hour must be between 0 and 23")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
|
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
|
||||||
func (a *Account) HasAnyQuotaLimit() bool {
|
func (a *Account) HasAnyQuotaLimit() bool {
|
||||||
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
|
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
|
||||||
@@ -1282,14 +1536,26 @@ func (a *Account) IsQuotaExceeded() bool {
|
|||||||
// 日额度(周期过期视为未超限,下次 increment 会重置)
|
// 日额度(周期过期视为未超限,下次 increment 会重置)
|
||||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||||
start := a.getExtraTime("quota_daily_start")
|
start := a.getExtraTime("quota_daily_start")
|
||||||
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit {
|
var expired bool
|
||||||
|
if a.GetQuotaDailyResetMode() == "fixed" {
|
||||||
|
expired = a.isFixedDailyPeriodExpired(start)
|
||||||
|
} else {
|
||||||
|
expired = isPeriodExpired(start, 24*time.Hour)
|
||||||
|
}
|
||||||
|
if !expired && a.GetQuotaDailyUsed() >= limit {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 周额度
|
// 周额度
|
||||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||||
start := a.getExtraTime("quota_weekly_start")
|
start := a.getExtraTime("quota_weekly_start")
|
||||||
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit {
|
var expired bool
|
||||||
|
if a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||||
|
expired = a.isFixedWeeklyPeriodExpired(start)
|
||||||
|
} else {
|
||||||
|
expired = isPeriodExpired(start, 7*24*time.Hour)
|
||||||
|
}
|
||||||
|
if !expired && a.GetQuotaWeeklyUsed() >= limit {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
516
backend/internal/service/account_quota_reset_test.go
Normal file
516
backend/internal/service/account_quota_reset_test.go
Normal file
@@ -0,0 +1,516 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// nextFixedDailyReset
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestNextFixedDailyReset_BeforeResetHour(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
// 2026-03-14 06:00 UTC, reset hour = 9
|
||||||
|
after := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
|
||||||
|
got := nextFixedDailyReset(9, tz, after)
|
||||||
|
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextFixedDailyReset_AtResetHour(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
// Exactly at reset hour → should return tomorrow
|
||||||
|
after := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||||
|
got := nextFixedDailyReset(9, tz, after)
|
||||||
|
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextFixedDailyReset_AfterResetHour(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
// After reset hour → should return tomorrow
|
||||||
|
after := time.Date(2026, 3, 14, 15, 30, 0, 0, tz)
|
||||||
|
got := nextFixedDailyReset(9, tz, after)
|
||||||
|
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextFixedDailyReset_MidnightReset(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
// Reset at hour 0 (midnight), currently 23:59
|
||||||
|
after := time.Date(2026, 3, 14, 23, 59, 0, 0, tz)
|
||||||
|
got := nextFixedDailyReset(0, tz, after)
|
||||||
|
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextFixedDailyReset_NonUTCTimezone(t *testing.T) {
|
||||||
|
tz, err := time.LoadLocation("Asia/Shanghai")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 2026-03-14 07:00 UTC = 2026-03-14 15:00 CST, reset hour = 9 (CST)
|
||||||
|
after := time.Date(2026, 3, 14, 7, 0, 0, 0, time.UTC)
|
||||||
|
got := nextFixedDailyReset(9, tz, after)
|
||||||
|
// Already past 9:00 CST today → tomorrow 9:00 CST = 2026-03-15 01:00 UTC
|
||||||
|
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// lastFixedDailyReset
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestLastFixedDailyReset_BeforeResetHour(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
now := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
|
||||||
|
got := lastFixedDailyReset(9, tz, now)
|
||||||
|
// Before today's 9:00 → yesterday 9:00
|
||||||
|
want := time.Date(2026, 3, 13, 9, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLastFixedDailyReset_AtResetHour(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
now := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||||
|
got := lastFixedDailyReset(9, tz, now)
|
||||||
|
// At exactly 9:00 → today 9:00
|
||||||
|
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLastFixedDailyReset_AfterResetHour(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
now := time.Date(2026, 3, 14, 15, 0, 0, 0, tz)
|
||||||
|
got := lastFixedDailyReset(9, tz, now)
|
||||||
|
// After 9:00 → today 9:00
|
||||||
|
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// nextFixedWeeklyReset
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestNextFixedWeeklyReset_TargetDayAhead(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
// 2026-03-14 is Saturday (day=6), target = Monday (day=1), hour = 9
|
||||||
|
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
|
||||||
|
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||||
|
// Next Monday = 2026-03-16
|
||||||
|
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextFixedWeeklyReset_TargetDayToday_BeforeHour(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, before 9:00
|
||||||
|
after := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
|
||||||
|
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||||
|
// Today at 9:00
|
||||||
|
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextFixedWeeklyReset_TargetDayToday_AtHour(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
// 2026-03-16 is Monday, target = Monday, hour = 9, exactly at 9:00
|
||||||
|
after := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||||
|
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||||
|
// Next Monday at 9:00
|
||||||
|
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextFixedWeeklyReset_TargetDayToday_AfterHour(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
// 2026-03-16 is Monday, target = Monday, hour = 9, after 9:00
|
||||||
|
after := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
|
||||||
|
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||||
|
// Next Monday at 9:00
|
||||||
|
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextFixedWeeklyReset_TargetDayPast(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
|
||||||
|
after := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
|
||||||
|
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||||
|
// Next Monday = 2026-03-23
|
||||||
|
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextFixedWeeklyReset_Sunday(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
// 2026-03-14 is Saturday (day=6), target = Sunday (day=0)
|
||||||
|
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
|
||||||
|
got := nextFixedWeeklyReset(0, 0, tz, after)
|
||||||
|
// Next Sunday = 2026-03-15
|
||||||
|
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// lastFixedWeeklyReset
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestLastFixedWeeklyReset_SameDay_AfterHour(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, now = 15:00
|
||||||
|
now := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
|
||||||
|
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||||
|
// Today at 9:00
|
||||||
|
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLastFixedWeeklyReset_SameDay_BeforeHour(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
// 2026-03-16 is Monday, target = Monday, hour = 9, now = 06:00
|
||||||
|
now := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
|
||||||
|
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||||
|
// Last Monday at 9:00 = 2026-03-09
|
||||||
|
want := time.Date(2026, 3, 9, 9, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLastFixedWeeklyReset_DifferentDay(t *testing.T) {
|
||||||
|
tz := time.UTC
|
||||||
|
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
|
||||||
|
now := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
|
||||||
|
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||||
|
// Last Monday = 2026-03-16
|
||||||
|
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// isFixedDailyPeriodExpired
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestIsFixedDailyPeriodExpired_ZeroPeriodStart(t *testing.T) {
|
||||||
|
a := &Account{Extra: map[string]any{
|
||||||
|
"quota_daily_reset_mode": "fixed",
|
||||||
|
"quota_daily_reset_hour": float64(9),
|
||||||
|
"quota_reset_timezone": "UTC",
|
||||||
|
}}
|
||||||
|
assert.True(t, a.isFixedDailyPeriodExpired(time.Time{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsFixedDailyPeriodExpired_NotExpired(t *testing.T) {
|
||||||
|
a := &Account{Extra: map[string]any{
|
||||||
|
"quota_daily_reset_mode": "fixed",
|
||||||
|
"quota_daily_reset_hour": float64(9),
|
||||||
|
"quota_reset_timezone": "UTC",
|
||||||
|
}}
|
||||||
|
// Period started after the most recent reset → not expired
|
||||||
|
// (This test uses a time very close to "now", which is after the last reset)
|
||||||
|
periodStart := time.Now().Add(-1 * time.Minute)
|
||||||
|
assert.False(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsFixedDailyPeriodExpired_Expired(t *testing.T) {
|
||||||
|
a := &Account{Extra: map[string]any{
|
||||||
|
"quota_daily_reset_mode": "fixed",
|
||||||
|
"quota_daily_reset_hour": float64(9),
|
||||||
|
"quota_reset_timezone": "UTC",
|
||||||
|
}}
|
||||||
|
// Period started 3 days ago → definitely expired
|
||||||
|
periodStart := time.Now().Add(-72 * time.Hour)
|
||||||
|
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsFixedDailyPeriodExpired_InvalidTimezone(t *testing.T) {
|
||||||
|
a := &Account{Extra: map[string]any{
|
||||||
|
"quota_daily_reset_mode": "fixed",
|
||||||
|
"quota_daily_reset_hour": float64(9),
|
||||||
|
"quota_reset_timezone": "Invalid/Timezone",
|
||||||
|
}}
|
||||||
|
// Invalid timezone falls back to UTC
|
||||||
|
periodStart := time.Now().Add(-72 * time.Hour)
|
||||||
|
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// isFixedWeeklyPeriodExpired
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestIsFixedWeeklyPeriodExpired_ZeroPeriodStart(t *testing.T) {
|
||||||
|
a := &Account{Extra: map[string]any{
|
||||||
|
"quota_weekly_reset_mode": "fixed",
|
||||||
|
"quota_weekly_reset_day": float64(1),
|
||||||
|
"quota_weekly_reset_hour": float64(9),
|
||||||
|
"quota_reset_timezone": "UTC",
|
||||||
|
}}
|
||||||
|
assert.True(t, a.isFixedWeeklyPeriodExpired(time.Time{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsFixedWeeklyPeriodExpired_NotExpired(t *testing.T) {
|
||||||
|
a := &Account{Extra: map[string]any{
|
||||||
|
"quota_weekly_reset_mode": "fixed",
|
||||||
|
"quota_weekly_reset_day": float64(1),
|
||||||
|
"quota_weekly_reset_hour": float64(9),
|
||||||
|
"quota_reset_timezone": "UTC",
|
||||||
|
}}
|
||||||
|
// Period started 1 minute ago → not expired
|
||||||
|
periodStart := time.Now().Add(-1 * time.Minute)
|
||||||
|
assert.False(t, a.isFixedWeeklyPeriodExpired(periodStart))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsFixedWeeklyPeriodExpired_Expired(t *testing.T) {
|
||||||
|
a := &Account{Extra: map[string]any{
|
||||||
|
"quota_weekly_reset_mode": "fixed",
|
||||||
|
"quota_weekly_reset_day": float64(1),
|
||||||
|
"quota_weekly_reset_hour": float64(9),
|
||||||
|
"quota_reset_timezone": "UTC",
|
||||||
|
}}
|
||||||
|
// Period started 10 days ago → definitely expired
|
||||||
|
periodStart := time.Now().Add(-240 * time.Hour)
|
||||||
|
assert.True(t, a.isFixedWeeklyPeriodExpired(periodStart))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ValidateQuotaResetConfig
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestValidateQuotaResetConfig_NilExtra(t *testing.T) {
|
||||||
|
assert.NoError(t, ValidateQuotaResetConfig(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateQuotaResetConfig_EmptyExtra(t *testing.T) {
|
||||||
|
assert.NoError(t, ValidateQuotaResetConfig(map[string]any{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateQuotaResetConfig_ValidFixed(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_daily_reset_mode": "fixed",
|
||||||
|
"quota_daily_reset_hour": float64(9),
|
||||||
|
"quota_weekly_reset_mode": "fixed",
|
||||||
|
"quota_weekly_reset_day": float64(1),
|
||||||
|
"quota_weekly_reset_hour": float64(0),
|
||||||
|
"quota_reset_timezone": "Asia/Shanghai",
|
||||||
|
}
|
||||||
|
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateQuotaResetConfig_ValidRolling(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_daily_reset_mode": "rolling",
|
||||||
|
"quota_weekly_reset_mode": "rolling",
|
||||||
|
}
|
||||||
|
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateQuotaResetConfig_InvalidTimezone(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_reset_timezone": "Not/A/Timezone",
|
||||||
|
}
|
||||||
|
err := ValidateQuotaResetConfig(extra)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "quota_reset_timezone")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateQuotaResetConfig_InvalidDailyMode(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_daily_reset_mode": "invalid",
|
||||||
|
}
|
||||||
|
err := ValidateQuotaResetConfig(extra)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "quota_daily_reset_mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateQuotaResetConfig_InvalidDailyHour_TooHigh(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_daily_reset_hour": float64(24),
|
||||||
|
}
|
||||||
|
err := ValidateQuotaResetConfig(extra)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateQuotaResetConfig_InvalidDailyHour_Negative(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_daily_reset_hour": float64(-1),
|
||||||
|
}
|
||||||
|
err := ValidateQuotaResetConfig(extra)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateQuotaResetConfig_InvalidWeeklyMode(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_weekly_reset_mode": "unknown",
|
||||||
|
}
|
||||||
|
err := ValidateQuotaResetConfig(extra)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "quota_weekly_reset_mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateQuotaResetConfig_InvalidWeeklyDay_TooHigh(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_weekly_reset_day": float64(7),
|
||||||
|
}
|
||||||
|
err := ValidateQuotaResetConfig(extra)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateQuotaResetConfig_InvalidWeeklyDay_Negative(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_weekly_reset_day": float64(-1),
|
||||||
|
}
|
||||||
|
err := ValidateQuotaResetConfig(extra)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateQuotaResetConfig_InvalidWeeklyHour(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_weekly_reset_hour": float64(25),
|
||||||
|
}
|
||||||
|
err := ValidateQuotaResetConfig(extra)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "quota_weekly_reset_hour")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateQuotaResetConfig_BoundaryValues(t *testing.T) {
|
||||||
|
// All boundary values should be valid
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_daily_reset_hour": float64(23),
|
||||||
|
"quota_weekly_reset_day": float64(0), // Sunday
|
||||||
|
"quota_weekly_reset_hour": float64(0),
|
||||||
|
"quota_reset_timezone": "UTC",
|
||||||
|
}
|
||||||
|
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||||
|
|
||||||
|
extra2 := map[string]any{
|
||||||
|
"quota_daily_reset_hour": float64(0),
|
||||||
|
"quota_weekly_reset_day": float64(6), // Saturday
|
||||||
|
"quota_weekly_reset_hour": float64(23),
|
||||||
|
}
|
||||||
|
assert.NoError(t, ValidateQuotaResetConfig(extra2))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ComputeQuotaResetAt
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestComputeQuotaResetAt_RollingMode_NoResetAt(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_daily_reset_mode": "rolling",
|
||||||
|
"quota_weekly_reset_mode": "rolling",
|
||||||
|
}
|
||||||
|
ComputeQuotaResetAt(extra)
|
||||||
|
_, hasDailyResetAt := extra["quota_daily_reset_at"]
|
||||||
|
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
|
||||||
|
assert.False(t, hasDailyResetAt, "rolling mode should not set quota_daily_reset_at")
|
||||||
|
assert.False(t, hasWeeklyResetAt, "rolling mode should not set quota_weekly_reset_at")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeQuotaResetAt_RollingMode_ClearsExistingResetAt(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_daily_reset_mode": "rolling",
|
||||||
|
"quota_weekly_reset_mode": "rolling",
|
||||||
|
"quota_daily_reset_at": "2026-03-14T09:00:00Z",
|
||||||
|
"quota_weekly_reset_at": "2026-03-16T09:00:00Z",
|
||||||
|
}
|
||||||
|
ComputeQuotaResetAt(extra)
|
||||||
|
_, hasDailyResetAt := extra["quota_daily_reset_at"]
|
||||||
|
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
|
||||||
|
assert.False(t, hasDailyResetAt, "rolling mode should remove quota_daily_reset_at")
|
||||||
|
assert.False(t, hasWeeklyResetAt, "rolling mode should remove quota_weekly_reset_at")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeQuotaResetAt_FixedDaily_SetsResetAt(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_daily_reset_mode": "fixed",
|
||||||
|
"quota_daily_reset_hour": float64(9),
|
||||||
|
"quota_reset_timezone": "UTC",
|
||||||
|
}
|
||||||
|
ComputeQuotaResetAt(extra)
|
||||||
|
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||||
|
require.True(t, ok, "quota_daily_reset_at should be set")
|
||||||
|
|
||||||
|
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Reset time should be in the future
|
||||||
|
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
|
||||||
|
// Reset hour should be 9 UTC
|
||||||
|
assert.Equal(t, 9, resetAt.UTC().Hour())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeQuotaResetAt_FixedWeekly_SetsResetAt(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_weekly_reset_mode": "fixed",
|
||||||
|
"quota_weekly_reset_day": float64(1), // Monday
|
||||||
|
"quota_weekly_reset_hour": float64(0),
|
||||||
|
"quota_reset_timezone": "UTC",
|
||||||
|
}
|
||||||
|
ComputeQuotaResetAt(extra)
|
||||||
|
resetAtStr, ok := extra["quota_weekly_reset_at"].(string)
|
||||||
|
require.True(t, ok, "quota_weekly_reset_at should be set")
|
||||||
|
|
||||||
|
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Reset time should be in the future
|
||||||
|
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
|
||||||
|
// Reset day should be Monday
|
||||||
|
assert.Equal(t, time.Monday, resetAt.UTC().Weekday())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeQuotaResetAt_FixedDaily_WithTimezone(t *testing.T) {
|
||||||
|
tz, err := time.LoadLocation("Asia/Shanghai")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_daily_reset_mode": "fixed",
|
||||||
|
"quota_daily_reset_hour": float64(9),
|
||||||
|
"quota_reset_timezone": "Asia/Shanghai",
|
||||||
|
}
|
||||||
|
ComputeQuotaResetAt(extra)
|
||||||
|
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// In Shanghai timezone, the hour should be 9
|
||||||
|
assert.Equal(t, 9, resetAt.In(tz).Hour())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeQuotaResetAt_DefaultTimezone(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_daily_reset_mode": "fixed",
|
||||||
|
"quota_daily_reset_hour": float64(12),
|
||||||
|
}
|
||||||
|
ComputeQuotaResetAt(extra)
|
||||||
|
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Default timezone is UTC
|
||||||
|
assert.Equal(t, 12, resetAt.UTC().Hour())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeQuotaResetAt_InvalidHour_ClampedToZero(t *testing.T) {
|
||||||
|
extra := map[string]any{
|
||||||
|
"quota_daily_reset_mode": "fixed",
|
||||||
|
"quota_daily_reset_hour": float64(99),
|
||||||
|
"quota_reset_timezone": "UTC",
|
||||||
|
}
|
||||||
|
ComputeQuotaResetAt(extra)
|
||||||
|
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Invalid hour → clamped to 0
|
||||||
|
assert.Equal(t, 0, resetAt.UTC().Hour())
|
||||||
|
}
|
||||||
@@ -207,14 +207,14 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
|||||||
testModelID = claude.DefaultTestModel
|
testModelID = claude.DefaultTestModel
|
||||||
}
|
}
|
||||||
|
|
||||||
// For API Key accounts with model mapping, map the model
|
// API Key 账号测试连接时也需要应用通配符模型映射。
|
||||||
if account.Type == "apikey" {
|
if account.Type == "apikey" {
|
||||||
mapping := account.GetModelMapping()
|
testModelID = account.GetMappedModel(testModelID)
|
||||||
if len(mapping) > 0 {
|
|
||||||
if mappedModel, exists := mapping[testModelID]; exists {
|
|
||||||
testModelID = mappedModel
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Bedrock accounts use a separate test path
|
||||||
|
if account.IsBedrock() {
|
||||||
|
return s.testBedrockAccountConnection(c, ctx, account, testModelID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine authentication method and API URL
|
// Determine authentication method and API URL
|
||||||
@@ -312,6 +312,109 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
|||||||
return s.processClaudeStream(c, resp.Body)
|
return s.processClaudeStream(c, resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
|
||||||
|
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
|
||||||
|
region := bedrockRuntimeRegion(account)
|
||||||
|
resolvedModelID, ok := ResolveBedrockModelID(account, testModelID)
|
||||||
|
if !ok {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Bedrock model: %s", testModelID))
|
||||||
|
}
|
||||||
|
testModelID = resolvedModelID
|
||||||
|
|
||||||
|
// Set SSE headers (test UI expects SSE)
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Writer.Flush()
|
||||||
|
|
||||||
|
// Create a minimal Bedrock-compatible payload (no stream, no cache_control)
|
||||||
|
bedrockPayload := map[string]any{
|
||||||
|
"anthropic_version": "bedrock-2023-05-31",
|
||||||
|
"messages": []map[string]any{
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": []map[string]any{
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "hi",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"max_tokens": 256,
|
||||||
|
"temperature": 1,
|
||||||
|
}
|
||||||
|
bedrockBody, _ := json.Marshal(bedrockPayload)
|
||||||
|
|
||||||
|
// Use non-streaming endpoint (response is standard Claude JSON)
|
||||||
|
apiURL := BuildBedrockURL(region, testModelID, false)
|
||||||
|
|
||||||
|
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bedrockBody))
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Sign or set auth based on account type
|
||||||
|
if account.IsBedrockAPIKey() {
|
||||||
|
apiKey := account.GetCredential("api_key")
|
||||||
|
if apiKey == "" {
|
||||||
|
return s.sendErrorAndEnd(c, "No API key available")
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
} else {
|
||||||
|
signer, err := NewBedrockSignerFromAccount(account)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Bedrock signer: %s", err.Error()))
|
||||||
|
}
|
||||||
|
if err := signer.SignRequest(ctx, req, bedrockBody); err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to sign request: %s", err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL := ""
|
||||||
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, false)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bedrock non-streaming response is standard Claude JSON, extract the text
|
||||||
|
var result struct {
|
||||||
|
Content []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"content"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
text := ""
|
||||||
|
if len(result.Content) > 0 {
|
||||||
|
text = result.Content[0].Text
|
||||||
|
}
|
||||||
|
if text == "" {
|
||||||
|
text = "(empty response)"
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||||
|
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// testOpenAIAccountConnection tests an OpenAI account's connection
|
// testOpenAIAccountConnection tests an OpenAI account's connection
|
||||||
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"log/slog"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -44,9 +45,12 @@ type UsageLogRepository interface {
|
|||||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
|
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
|
||||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
||||||
|
GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
|
||||||
|
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
|
||||||
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
|
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
|
||||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||||
|
GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error)
|
||||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
|
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
||||||
|
|
||||||
@@ -99,6 +103,7 @@ type antigravityUsageCache struct {
|
|||||||
const (
|
const (
|
||||||
apiCacheTTL = 3 * time.Minute
|
apiCacheTTL = 3 * time.Minute
|
||||||
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
|
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
|
||||||
|
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误)
|
||||||
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
|
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
|
||||||
windowStatsCacheTTL = 1 * time.Minute
|
windowStatsCacheTTL = 1 * time.Minute
|
||||||
openAIProbeCacheTTL = 10 * time.Minute
|
openAIProbeCacheTTL = 10 * time.Minute
|
||||||
@@ -110,7 +115,8 @@ type UsageCache struct {
|
|||||||
apiCache sync.Map // accountID -> *apiUsageCache
|
apiCache sync.Map // accountID -> *apiUsageCache
|
||||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存
|
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic)
|
||||||
|
antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
|
||||||
openAIProbeCache sync.Map // accountID -> time.Time
|
openAIProbeCache sync.Map // accountID -> time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,6 +154,18 @@ type AntigravityModelQuota struct {
|
|||||||
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
|
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AntigravityModelDetail Antigravity 单个模型的详细能力信息
|
||||||
|
type AntigravityModelDetail struct {
|
||||||
|
DisplayName string `json:"display_name,omitempty"`
|
||||||
|
SupportsImages *bool `json:"supports_images,omitempty"`
|
||||||
|
SupportsThinking *bool `json:"supports_thinking,omitempty"`
|
||||||
|
ThinkingBudget *int `json:"thinking_budget,omitempty"`
|
||||||
|
Recommended *bool `json:"recommended,omitempty"`
|
||||||
|
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||||
|
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||||
|
SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// UsageInfo 账号使用量信息
|
// UsageInfo 账号使用量信息
|
||||||
type UsageInfo struct {
|
type UsageInfo struct {
|
||||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||||
@@ -163,6 +181,33 @@ type UsageInfo struct {
|
|||||||
|
|
||||||
// Antigravity 多模型配额
|
// Antigravity 多模型配额
|
||||||
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
|
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
|
||||||
|
|
||||||
|
// Antigravity 账号级信息
|
||||||
|
SubscriptionTier string `json:"subscription_tier,omitempty"` // 归一化订阅等级: FREE/PRO/ULTRA/UNKNOWN
|
||||||
|
SubscriptionTierRaw string `json:"subscription_tier_raw,omitempty"` // 上游原始订阅等级名称
|
||||||
|
|
||||||
|
// Antigravity 模型详细能力信息(与 antigravity_quota 同 key)
|
||||||
|
AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"`
|
||||||
|
|
||||||
|
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
|
||||||
|
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
|
||||||
|
|
||||||
|
// Antigravity 账号是否被上游禁止 (HTTP 403)
|
||||||
|
IsForbidden bool `json:"is_forbidden,omitempty"`
|
||||||
|
ForbiddenReason string `json:"forbidden_reason,omitempty"`
|
||||||
|
ForbiddenType string `json:"forbidden_type,omitempty"` // "validation" / "violation" / "forbidden"
|
||||||
|
ValidationURL string `json:"validation_url,omitempty"` // 验证/申诉链接
|
||||||
|
|
||||||
|
// 状态标记(从 ForbiddenType / HTTP 错误码推导)
|
||||||
|
NeedsVerify bool `json:"needs_verify,omitempty"` // 需要人工验证(forbidden_type=validation)
|
||||||
|
IsBanned bool `json:"is_banned,omitempty"` // 账号被封(forbidden_type=violation)
|
||||||
|
NeedsReauth bool `json:"needs_reauth,omitempty"` // token 失效需重新授权(401)
|
||||||
|
|
||||||
|
// 错误码(机器可读):forbidden / unauthenticated / rate_limited / network_error
|
||||||
|
ErrorCode string `json:"error_code,omitempty"`
|
||||||
|
|
||||||
|
// 获取 usage 时的错误信息(降级返回,而非 500)
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClaudeUsageResponse Anthropic API返回的usage结构
|
// ClaudeUsageResponse Anthropic API返回的usage结构
|
||||||
@@ -647,10 +692,11 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
|
|||||||
return &UsageInfo{UpdatedAt: &now}, nil
|
return &UsageInfo{UpdatedAt: &now}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. 检查缓存(10 分钟)
|
// 1. 检查缓存
|
||||||
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||||
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
if cache, ok := cached.(*antigravityUsageCache); ok {
|
||||||
// 重新计算 RemainingSeconds
|
ttl := antigravityCacheTTL(cache.usageInfo)
|
||||||
|
if time.Since(cache.timestamp) < ttl {
|
||||||
usage := cache.usageInfo
|
usage := cache.usageInfo
|
||||||
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
||||||
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
|
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
|
||||||
@@ -658,23 +704,145 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
|
|||||||
return usage, nil
|
return usage, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 获取代理 URL
|
|
||||||
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account)
|
|
||||||
|
|
||||||
// 3. 调用 API 获取额度
|
|
||||||
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 缓存结果
|
// 2. singleflight 防止并发击穿
|
||||||
|
flightKey := fmt.Sprintf("ag-usage:%d", account.ID)
|
||||||
|
result, flightErr, _ := s.cache.antigravityFlight.Do(flightKey, func() (any, error) {
|
||||||
|
// 再次检查缓存(等待期间可能已被填充)
|
||||||
|
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||||
|
if cache, ok := cached.(*antigravityUsageCache); ok {
|
||||||
|
ttl := antigravityCacheTTL(cache.usageInfo)
|
||||||
|
if time.Since(cache.timestamp) < ttl {
|
||||||
|
usage := cache.usageInfo
|
||||||
|
// 重新计算 RemainingSeconds,避免返回过时的剩余秒数
|
||||||
|
recalcAntigravityRemainingSeconds(usage)
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用独立 context,避免调用方 cancel 导致所有共享 flight 的请求失败
|
||||||
|
fetchCtx, fetchCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer fetchCancel()
|
||||||
|
|
||||||
|
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(fetchCtx, account)
|
||||||
|
fetchResult, err := s.antigravityQuotaFetcher.FetchQuota(fetchCtx, account, proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
degraded := buildAntigravityDegradedUsage(err)
|
||||||
|
enrichUsageWithAccountError(degraded, account)
|
||||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||||
usageInfo: result.UsageInfo,
|
usageInfo: degraded,
|
||||||
timestamp: time.Now(),
|
timestamp: time.Now(),
|
||||||
})
|
})
|
||||||
|
return degraded, nil
|
||||||
|
}
|
||||||
|
|
||||||
return result.UsageInfo, nil
|
enrichUsageWithAccountError(fetchResult.UsageInfo, account)
|
||||||
|
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||||
|
usageInfo: fetchResult.UsageInfo,
|
||||||
|
timestamp: time.Now(),
|
||||||
|
})
|
||||||
|
return fetchResult.UsageInfo, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if flightErr != nil {
|
||||||
|
return nil, flightErr
|
||||||
|
}
|
||||||
|
usage, ok := result.(*UsageInfo)
|
||||||
|
if !ok || usage == nil {
|
||||||
|
now := time.Now()
|
||||||
|
return &UsageInfo{UpdatedAt: &now}, nil
|
||||||
|
}
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
|
||||||
|
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
|
||||||
|
func recalcAntigravityRemainingSeconds(info *UsageInfo) {
|
||||||
|
if info == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if info.FiveHour != nil && info.FiveHour.ResetsAt != nil {
|
||||||
|
remaining := int(time.Until(*info.FiveHour.ResetsAt).Seconds())
|
||||||
|
if remaining < 0 {
|
||||||
|
remaining = 0
|
||||||
|
}
|
||||||
|
info.FiveHour.RemainingSeconds = remaining
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// antigravityCacheTTL 根据 UsageInfo 内容决定缓存 TTL
|
||||||
|
// 403 forbidden 状态稳定,缓存与成功相同(3 分钟);
|
||||||
|
// 其他错误(401/网络)可能快速恢复,缓存 1 分钟。
|
||||||
|
func antigravityCacheTTL(info *UsageInfo) time.Duration {
|
||||||
|
if info == nil {
|
||||||
|
return antigravityErrorTTL
|
||||||
|
}
|
||||||
|
if info.IsForbidden {
|
||||||
|
return apiCacheTTL // 封号/验证状态不会很快变
|
||||||
|
}
|
||||||
|
if info.ErrorCode != "" || info.Error != "" {
|
||||||
|
return antigravityErrorTTL
|
||||||
|
}
|
||||||
|
return apiCacheTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildAntigravityDegradedUsage 从 FetchQuota 错误构建降级 UsageInfo
|
||||||
|
func buildAntigravityDegradedUsage(err error) *UsageInfo {
|
||||||
|
now := time.Now()
|
||||||
|
errMsg := fmt.Sprintf("usage API error: %v", err)
|
||||||
|
slog.Warn("antigravity usage fetch failed, returning degraded response", "error", err)
|
||||||
|
|
||||||
|
info := &UsageInfo{
|
||||||
|
UpdatedAt: &now,
|
||||||
|
Error: errMsg,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从错误信息推断 error_code 和状态标记
|
||||||
|
// 错误格式来自 antigravity/client.go: "fetchAvailableModels 失败 (HTTP %d): ..."
|
||||||
|
errStr := err.Error()
|
||||||
|
switch {
|
||||||
|
case strings.Contains(errStr, "HTTP 401") ||
|
||||||
|
strings.Contains(errStr, "UNAUTHENTICATED") ||
|
||||||
|
strings.Contains(errStr, "invalid_grant"):
|
||||||
|
info.ErrorCode = errorCodeUnauthenticated
|
||||||
|
info.NeedsReauth = true
|
||||||
|
case strings.Contains(errStr, "HTTP 429"):
|
||||||
|
info.ErrorCode = errorCodeRateLimited
|
||||||
|
default:
|
||||||
|
info.ErrorCode = errorCodeNetworkError
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
// enrichUsageWithAccountError 结合账号错误状态修正 UsageInfo
|
||||||
|
// 场景 1(成功路径):FetchAvailableModels 正常返回,但账号已因 403 被标记为 error,
|
||||||
|
//
|
||||||
|
// 需要在正常 usage 数据上附加 forbidden/validation 信息。
|
||||||
|
//
|
||||||
|
// 场景 2(降级路径):被封号的账号 OAuth token 失效,FetchAvailableModels 返回 401,
|
||||||
|
//
|
||||||
|
// 降级逻辑设置了 needs_reauth,但账号实际是 403 封号/需验证,需覆盖为正确状态。
|
||||||
|
func enrichUsageWithAccountError(info *UsageInfo, account *Account) {
|
||||||
|
if info == nil || account == nil || account.Status != StatusError {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msg := strings.ToLower(account.ErrorMessage)
|
||||||
|
if !strings.Contains(msg, "403") && !strings.Contains(msg, "forbidden") &&
|
||||||
|
!strings.Contains(msg, "violation") && !strings.Contains(msg, "validation") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fbType := classifyForbiddenType(account.ErrorMessage)
|
||||||
|
info.IsForbidden = true
|
||||||
|
info.ForbiddenType = fbType
|
||||||
|
info.ForbiddenReason = account.ErrorMessage
|
||||||
|
info.NeedsVerify = fbType == forbiddenTypeValidation
|
||||||
|
info.IsBanned = fbType == forbiddenTypeViolation
|
||||||
|
info.ValidationURL = extractValidationURL(account.ErrorMessage)
|
||||||
|
info.ErrorCode = errorCodeForbidden
|
||||||
|
info.NeedsReauth = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// addWindowStats 为 usage 数据添加窗口期统计
|
// addWindowStats 为 usage 数据添加窗口期统计
|
||||||
|
|||||||
@@ -43,12 +43,13 @@ func TestMatchWildcard(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMatchWildcardMapping(t *testing.T) {
|
func TestMatchWildcardMappingResult(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
mapping map[string]string
|
mapping map[string]string
|
||||||
requestedModel string
|
requestedModel string
|
||||||
expected string
|
expected string
|
||||||
|
matched bool
|
||||||
}{
|
}{
|
||||||
// 精确匹配优先于通配符
|
// 精确匹配优先于通配符
|
||||||
{
|
{
|
||||||
@@ -59,6 +60,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
|||||||
},
|
},
|
||||||
requestedModel: "claude-sonnet-4-5",
|
requestedModel: "claude-sonnet-4-5",
|
||||||
expected: "claude-sonnet-4-5-exact",
|
expected: "claude-sonnet-4-5-exact",
|
||||||
|
matched: true,
|
||||||
},
|
},
|
||||||
|
|
||||||
// 最长通配符优先
|
// 最长通配符优先
|
||||||
@@ -71,6 +73,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
|||||||
},
|
},
|
||||||
requestedModel: "claude-sonnet-4-5",
|
requestedModel: "claude-sonnet-4-5",
|
||||||
expected: "claude-sonnet-4-series",
|
expected: "claude-sonnet-4-series",
|
||||||
|
matched: true,
|
||||||
},
|
},
|
||||||
|
|
||||||
// 单个通配符
|
// 单个通配符
|
||||||
@@ -81,6 +84,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
|||||||
},
|
},
|
||||||
requestedModel: "claude-opus-4-5",
|
requestedModel: "claude-opus-4-5",
|
||||||
expected: "claude-mapped",
|
expected: "claude-mapped",
|
||||||
|
matched: true,
|
||||||
},
|
},
|
||||||
|
|
||||||
// 无匹配返回原始模型
|
// 无匹配返回原始模型
|
||||||
@@ -91,6 +95,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
|||||||
},
|
},
|
||||||
requestedModel: "gemini-3-flash",
|
requestedModel: "gemini-3-flash",
|
||||||
expected: "gemini-3-flash",
|
expected: "gemini-3-flash",
|
||||||
|
matched: false,
|
||||||
},
|
},
|
||||||
|
|
||||||
// 空映射返回原始模型
|
// 空映射返回原始模型
|
||||||
@@ -99,6 +104,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
|||||||
mapping: map[string]string{},
|
mapping: map[string]string{},
|
||||||
requestedModel: "claude-sonnet-4-5",
|
requestedModel: "claude-sonnet-4-5",
|
||||||
expected: "claude-sonnet-4-5",
|
expected: "claude-sonnet-4-5",
|
||||||
|
matched: false,
|
||||||
},
|
},
|
||||||
|
|
||||||
// Gemini 模型映射
|
// Gemini 模型映射
|
||||||
@@ -110,14 +116,15 @@ func TestMatchWildcardMapping(t *testing.T) {
|
|||||||
},
|
},
|
||||||
requestedModel: "gemini-3-flash-preview",
|
requestedModel: "gemini-3-flash-preview",
|
||||||
expected: "gemini-3-pro-high",
|
expected: "gemini-3-pro-high",
|
||||||
|
matched: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := matchWildcardMapping(tt.mapping, tt.requestedModel)
|
result, matched := matchWildcardMappingResult(tt.mapping, tt.requestedModel)
|
||||||
if result != tt.expected {
|
if result != tt.expected || matched != tt.matched {
|
||||||
t.Errorf("matchWildcardMapping(%v, %q) = %q, want %q", tt.mapping, tt.requestedModel, result, tt.expected)
|
t.Errorf("matchWildcardMappingResult(%v, %q) = (%q, %v), want (%q, %v)", tt.mapping, tt.requestedModel, result, matched, tt.expected, tt.matched)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -268,6 +275,69 @@ func TestAccountGetMappedModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountResolveMappedModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
credentials map[string]any
|
||||||
|
requestedModel string
|
||||||
|
expectedModel string
|
||||||
|
expectedMatch bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no mapping reports unmatched",
|
||||||
|
credentials: nil,
|
||||||
|
requestedModel: "gpt-5.4",
|
||||||
|
expectedModel: "gpt-5.4",
|
||||||
|
expectedMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exact passthrough mapping still counts as matched",
|
||||||
|
credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5.4": "gpt-5.4",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModel: "gpt-5.4",
|
||||||
|
expectedModel: "gpt-5.4",
|
||||||
|
expectedMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard passthrough mapping still counts as matched",
|
||||||
|
credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-*": "gpt-5.4",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModel: "gpt-5.4",
|
||||||
|
expectedModel: "gpt-5.4",
|
||||||
|
expectedMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing mapping reports unmatched",
|
||||||
|
credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5.2": "gpt-5.2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModel: "gpt-5.4",
|
||||||
|
expectedModel: "gpt-5.4",
|
||||||
|
expectedMatch: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Credentials: tt.credentials,
|
||||||
|
}
|
||||||
|
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
|
||||||
|
if mappedModel != tt.expectedModel || matched != tt.expectedMatch {
|
||||||
|
t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
|
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Platform: PlatformAntigravity,
|
Platform: PlatformAntigravity,
|
||||||
|
|||||||
@@ -42,6 +42,9 @@ type AdminService interface {
|
|||||||
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
|
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
|
||||||
DeleteGroup(ctx context.Context, id int64) error
|
DeleteGroup(ctx context.Context, id int64) error
|
||||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
|
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||||
|
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
||||||
|
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
|
||||||
|
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
|
||||||
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
||||||
|
|
||||||
// API Key management (admin)
|
// API Key management (admin)
|
||||||
@@ -57,6 +60,8 @@ type AdminService interface {
|
|||||||
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
|
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
|
||||||
ClearAccountError(ctx context.Context, id int64) (*Account, error)
|
ClearAccountError(ctx context.Context, id int64) (*Account, error)
|
||||||
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
||||||
|
// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号 privacy_mode,未设置则尝试关闭训练数据共享并持久化。
|
||||||
|
EnsureOpenAIPrivacy(ctx context.Context, account *Account) string
|
||||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
||||||
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||||
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
|
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
|
||||||
@@ -433,6 +438,7 @@ type adminServiceImpl struct {
|
|||||||
settingService *SettingService
|
settingService *SettingService
|
||||||
defaultSubAssigner DefaultSubscriptionAssigner
|
defaultSubAssigner DefaultSubscriptionAssigner
|
||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
|
privacyClientFactory PrivacyClientFactory
|
||||||
}
|
}
|
||||||
|
|
||||||
type userGroupRateBatchReader interface {
|
type userGroupRateBatchReader interface {
|
||||||
@@ -461,6 +467,7 @@ func NewAdminService(
|
|||||||
settingService *SettingService,
|
settingService *SettingService,
|
||||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||||
userSubRepo UserSubscriptionRepository,
|
userSubRepo UserSubscriptionRepository,
|
||||||
|
privacyClientFactory PrivacyClientFactory,
|
||||||
) AdminService {
|
) AdminService {
|
||||||
return &adminServiceImpl{
|
return &adminServiceImpl{
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
@@ -479,6 +486,7 @@ func NewAdminService(
|
|||||||
settingService: settingService,
|
settingService: settingService,
|
||||||
defaultSubAssigner: defaultSubAssigner,
|
defaultSubAssigner: defaultSubAssigner,
|
||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
|
privacyClientFactory: privacyClientFactory,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -824,7 +832,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
subscriptionType = SubscriptionTypeStandard
|
subscriptionType = SubscriptionTypeStandard
|
||||||
}
|
}
|
||||||
|
|
||||||
// 限额字段:0 和 nil 都表示"无限制"
|
// 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额
|
||||||
dailyLimit := normalizeLimit(input.DailyLimitUSD)
|
dailyLimit := normalizeLimit(input.DailyLimitUSD)
|
||||||
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
|
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
|
||||||
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
|
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
|
||||||
@@ -936,9 +944,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalizeLimit 将 0 或负数转换为 nil(表示无限制)
|
// normalizeLimit 将负数转换为 nil(表示无限制),0 保留(表示限额为零)
|
||||||
func normalizeLimit(limit *float64) *float64 {
|
func normalizeLimit(limit *float64) *float64 {
|
||||||
if limit == nil || *limit <= 0 {
|
if limit == nil || *limit < 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return limit
|
return limit
|
||||||
@@ -1050,16 +1058,11 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
if input.SubscriptionType != "" {
|
if input.SubscriptionType != "" {
|
||||||
group.SubscriptionType = input.SubscriptionType
|
group.SubscriptionType = input.SubscriptionType
|
||||||
}
|
}
|
||||||
// 限额字段:0 和 nil 都表示"无限制",正数表示具体限额
|
// 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额
|
||||||
if input.DailyLimitUSD != nil {
|
// 前端始终发送这三个字段,无需 nil 守卫
|
||||||
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
|
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
|
||||||
}
|
|
||||||
if input.WeeklyLimitUSD != nil {
|
|
||||||
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
|
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
|
||||||
}
|
|
||||||
if input.MonthlyLimitUSD != nil {
|
|
||||||
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
||||||
}
|
|
||||||
// 图片生成计费配置:负数表示清除(使用默认价格)
|
// 图片生成计费配置:负数表示清除(使用默认价格)
|
||||||
if input.ImagePrice1K != nil {
|
if input.ImagePrice1K != nil {
|
||||||
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
|
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
|
||||||
@@ -1244,6 +1247,27 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
|||||||
return keys, result.Total, nil
|
return keys, result.Total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
||||||
|
if s.userGroupRateRepo == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return s.userGroupRateRepo.GetByGroupID(ctx, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) ClearGroupRateMultipliers(ctx context.Context, groupID int64) error {
|
||||||
|
if s.userGroupRateRepo == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.userGroupRateRepo.DeleteByGroupID(ctx, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
|
||||||
|
if s.userGroupRateRepo == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
||||||
return s.groupRepo.UpdateSortOrders(ctx, updates)
|
return s.groupRepo.UpdateSortOrders(ctx, updates)
|
||||||
}
|
}
|
||||||
@@ -1433,6 +1457,13 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
|||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
Schedulable: true,
|
Schedulable: true,
|
||||||
}
|
}
|
||||||
|
// 预计算固定时间重置的下次重置时间
|
||||||
|
if account.Extra != nil {
|
||||||
|
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ComputeQuotaResetAt(account.Extra)
|
||||||
|
}
|
||||||
if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
|
if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
|
||||||
expiresAt := time.Unix(*input.ExpiresAt, 0)
|
expiresAt := time.Unix(*input.ExpiresAt, 0)
|
||||||
account.ExpiresAt = &expiresAt
|
account.ExpiresAt = &expiresAt
|
||||||
@@ -1506,6 +1537,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
account.Extra = input.Extra
|
account.Extra = input.Extra
|
||||||
|
// 校验并预计算固定时间重置的下次重置时间
|
||||||
|
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ComputeQuotaResetAt(account.Extra)
|
||||||
}
|
}
|
||||||
if input.ProxyID != nil {
|
if input.ProxyID != nil {
|
||||||
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
|
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
|
||||||
@@ -2502,3 +2538,39 @@ func (e *MixedChannelError) Error() string {
|
|||||||
func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error {
|
func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error {
|
||||||
return s.accountRepo.ResetQuotaUsed(ctx, id)
|
return s.accountRepo.ResetQuotaUsed(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号是否已设置 privacy_mode,
|
||||||
|
// 未设置则调用 disableOpenAITraining 并持久化到 Extra,返回设置的 mode 值。
|
||||||
|
func (s *adminServiceImpl) EnsureOpenAIPrivacy(ctx context.Context, account *Account) string {
|
||||||
|
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if s.privacyClientFactory == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if account.Extra != nil {
|
||||||
|
if _, ok := account.Extra["privacy_mode"]; ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
token, _ := account.Credentials["access_token"].(string)
|
||||||
|
if token == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var proxyURL string
|
||||||
|
if account.ProxyID != nil {
|
||||||
|
if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil {
|
||||||
|
proxyURL = p.URL()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mode := disableOpenAITraining(ctx, s.privacyClientFactory, token, proxyURL)
|
||||||
|
if mode == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode})
|
||||||
|
return mode
|
||||||
|
}
|
||||||
|
|||||||
176
backend/internal/service/admin_service_group_rate_test.go
Normal file
176
backend/internal/service/admin_service_group_rate_test.go
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// userGroupRateRepoStubForGroupRate implements UserGroupRateRepository for group rate tests.
|
||||||
|
type userGroupRateRepoStubForGroupRate struct {
|
||||||
|
getByGroupIDData map[int64][]UserGroupRateEntry
|
||||||
|
getByGroupIDErr error
|
||||||
|
|
||||||
|
deletedGroupIDs []int64
|
||||||
|
deleteByGroupErr error
|
||||||
|
|
||||||
|
syncedGroupID int64
|
||||||
|
syncedEntries []GroupRateMultiplierInput
|
||||||
|
syncGroupErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
|
||||||
|
panic("unexpected GetByUserID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context, _, _ int64) (*float64, error) {
|
||||||
|
panic("unexpected GetByUserAndGroup call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
||||||
|
if s.getByGroupIDErr != nil {
|
||||||
|
return nil, s.getByGroupIDErr
|
||||||
|
}
|
||||||
|
return s.getByGroupIDData[groupID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) SyncUserGroupRates(_ context.Context, _ int64, _ map[int64]*float64) error {
|
||||||
|
panic("unexpected SyncUserGroupRates call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
|
||||||
|
s.syncedGroupID = groupID
|
||||||
|
s.syncedEntries = entries
|
||||||
|
return s.syncGroupErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
|
||||||
|
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
|
||||||
|
return s.deleteByGroupErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForGroupRate) DeleteByUserID(_ context.Context, _ int64) error {
|
||||||
|
panic("unexpected DeleteByUserID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
|
||||||
|
t.Run("returns entries for group", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{
|
||||||
|
getByGroupIDData: map[int64][]UserGroupRateEntry{
|
||||||
|
10: {
|
||||||
|
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
|
||||||
|
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
|
||||||
|
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, entries, 2)
|
||||||
|
require.Equal(t, int64(1), entries[0].UserID)
|
||||||
|
require.Equal(t, "alice", entries[0].UserName)
|
||||||
|
require.Equal(t, 1.5, entries[0].RateMultiplier)
|
||||||
|
require.Equal(t, int64(2), entries[1].UserID)
|
||||||
|
require.Equal(t, 0.8, entries[1].RateMultiplier)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: nil}
|
||||||
|
|
||||||
|
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, entries)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns empty slice for group with no entries", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{
|
||||||
|
getByGroupIDData: map[int64][]UserGroupRateEntry{},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
|
||||||
|
entries, err := svc.GetGroupRateMultipliers(context.Background(), 99)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, entries)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("propagates repo error", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{
|
||||||
|
getByGroupIDErr: errors.New("db error"),
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
|
||||||
|
_, err := svc.GetGroupRateMultipliers(context.Background(), 10)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "db error")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_ClearGroupRateMultipliers(t *testing.T) {
|
||||||
|
t.Run("deletes by group ID", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
|
||||||
|
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []int64{42}, repo.deletedGroupIDs)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: nil}
|
||||||
|
|
||||||
|
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("propagates repo error", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{
|
||||||
|
deleteByGroupErr: errors.New("delete failed"),
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
|
||||||
|
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "delete failed")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
|
||||||
|
t.Run("syncs entries to repo", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
|
||||||
|
entries := []GroupRateMultiplierInput{
|
||||||
|
{UserID: 1, RateMultiplier: 1.5},
|
||||||
|
{UserID: 2, RateMultiplier: 0.8},
|
||||||
|
}
|
||||||
|
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, entries)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(10), repo.syncedGroupID)
|
||||||
|
require.Equal(t, entries, repo.syncedEntries)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: nil}
|
||||||
|
|
||||||
|
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("propagates repo error", func(t *testing.T) {
|
||||||
|
repo := &userGroupRateRepoStubForGroupRate{
|
||||||
|
syncGroupErr: errors.New("sync failed"),
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||||
|
|
||||||
|
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, []GroupRateMultiplierInput{
|
||||||
|
{UserID: 1, RateMultiplier: 1.0},
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "sync failed")
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -68,7 +68,15 @@ func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context
|
|||||||
panic("unexpected SyncUserGroupRates call")
|
panic("unexpected SyncUserGroupRates call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
|
func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, _ int64) ([]UserGroupRateEntry, error) {
|
||||||
|
panic("unexpected GetByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.Context, _ int64, _ []GroupRateMultiplierInput) error {
|
||||||
|
panic("unexpected SyncGroupRateMultipliers call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
|
||||||
panic("unexpected DeleteByGroupID call")
|
panic("unexpected DeleteByGroupID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,12 +2,29 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
forbiddenTypeValidation = "validation"
|
||||||
|
forbiddenTypeViolation = "violation"
|
||||||
|
forbiddenTypeForbidden = "forbidden"
|
||||||
|
|
||||||
|
// 机器可读的错误码
|
||||||
|
errorCodeForbidden = "forbidden"
|
||||||
|
errorCodeUnauthenticated = "unauthenticated"
|
||||||
|
errorCodeRateLimited = "rate_limited"
|
||||||
|
errorCodeNetworkError = "network_error"
|
||||||
|
)
|
||||||
|
|
||||||
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
|
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
|
||||||
type AntigravityQuotaFetcher struct {
|
type AntigravityQuotaFetcher struct {
|
||||||
proxyRepo ProxyRepository
|
proxyRepo ProxyRepository
|
||||||
@@ -40,11 +57,32 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
|||||||
// 调用 API 获取配额
|
// 调用 API 获取配额
|
||||||
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// 403 Forbidden: 不报错,返回 is_forbidden 标记
|
||||||
|
var forbiddenErr *antigravity.ForbiddenError
|
||||||
|
if errors.As(err, &forbiddenErr) {
|
||||||
|
now := time.Now()
|
||||||
|
fbType := classifyForbiddenType(forbiddenErr.Body)
|
||||||
|
return &QuotaResult{
|
||||||
|
UsageInfo: &UsageInfo{
|
||||||
|
UpdatedAt: &now,
|
||||||
|
IsForbidden: true,
|
||||||
|
ForbiddenReason: forbiddenErr.Body,
|
||||||
|
ForbiddenType: fbType,
|
||||||
|
ValidationURL: extractValidationURL(forbiddenErr.Body),
|
||||||
|
NeedsVerify: fbType == forbiddenTypeValidation,
|
||||||
|
IsBanned: fbType == forbiddenTypeViolation,
|
||||||
|
ErrorCode: errorCodeForbidden,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 调用 LoadCodeAssist 获取订阅等级(非关键路径,失败不影响主流程)
|
||||||
|
tierRaw, tierNormalized := f.fetchSubscriptionTier(ctx, client, accessToken)
|
||||||
|
|
||||||
// 转换为 UsageInfo
|
// 转换为 UsageInfo
|
||||||
usageInfo := f.buildUsageInfo(modelsResp)
|
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized)
|
||||||
|
|
||||||
return &QuotaResult{
|
return &QuotaResult{
|
||||||
UsageInfo: usageInfo,
|
UsageInfo: usageInfo,
|
||||||
@@ -52,15 +90,52 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串
|
||||||
|
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string) {
|
||||||
|
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to fetch subscription tier", "error", err)
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
if loadResp == nil {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
raw = loadResp.GetTier() // 已有方法:paidTier > currentTier
|
||||||
|
normalized = normalizeTier(raw)
|
||||||
|
return raw, normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
|
||||||
|
func normalizeTier(raw string) string {
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(raw)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(lower, "ultra"):
|
||||||
|
return "ULTRA"
|
||||||
|
case strings.Contains(lower, "pro"):
|
||||||
|
return "PRO"
|
||||||
|
case strings.Contains(lower, "free"):
|
||||||
|
return "FREE"
|
||||||
|
default:
|
||||||
|
return "UNKNOWN"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
||||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo {
|
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string) *UsageInfo {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
info := &UsageInfo{
|
info := &UsageInfo{
|
||||||
UpdatedAt: &now,
|
UpdatedAt: &now,
|
||||||
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
||||||
|
AntigravityQuotaDetails: make(map[string]*AntigravityModelDetail),
|
||||||
|
SubscriptionTier: tierNormalized,
|
||||||
|
SubscriptionTierRaw: tierRaw,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 遍历所有模型,填充 AntigravityQuota
|
// 遍历所有模型,填充 AntigravityQuota 和 AntigravityQuotaDetails
|
||||||
for modelName, modelInfo := range modelsResp.Models {
|
for modelName, modelInfo := range modelsResp.Models {
|
||||||
if modelInfo.QuotaInfo == nil {
|
if modelInfo.QuotaInfo == nil {
|
||||||
continue
|
continue
|
||||||
@@ -73,6 +148,27 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv
|
|||||||
Utilization: utilization,
|
Utilization: utilization,
|
||||||
ResetTime: modelInfo.QuotaInfo.ResetTime,
|
ResetTime: modelInfo.QuotaInfo.ResetTime,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 填充模型详细能力信息
|
||||||
|
detail := &AntigravityModelDetail{
|
||||||
|
DisplayName: modelInfo.DisplayName,
|
||||||
|
SupportsImages: modelInfo.SupportsImages,
|
||||||
|
SupportsThinking: modelInfo.SupportsThinking,
|
||||||
|
ThinkingBudget: modelInfo.ThinkingBudget,
|
||||||
|
Recommended: modelInfo.Recommended,
|
||||||
|
MaxTokens: modelInfo.MaxTokens,
|
||||||
|
MaxOutputTokens: modelInfo.MaxOutputTokens,
|
||||||
|
SupportedMimeTypes: modelInfo.SupportedMimeTypes,
|
||||||
|
}
|
||||||
|
info.AntigravityQuotaDetails[modelName] = detail
|
||||||
|
}
|
||||||
|
|
||||||
|
// 废弃模型转发规则
|
||||||
|
if len(modelsResp.DeprecatedModelIDs) > 0 {
|
||||||
|
info.ModelForwardingRules = make(map[string]string, len(modelsResp.DeprecatedModelIDs))
|
||||||
|
for oldID, deprecated := range modelsResp.DeprecatedModelIDs {
|
||||||
|
info.ModelForwardingRules[oldID] = deprecated.NewModelID
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 同时设置 FiveHour 用于兼容展示(取主要模型)
|
// 同时设置 FiveHour 用于兼容展示(取主要模型)
|
||||||
@@ -108,3 +204,58 @@ func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Acco
|
|||||||
}
|
}
|
||||||
return proxy.URL()
|
return proxy.URL()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// classifyForbiddenType 根据 403 响应体判断禁止类型
|
||||||
|
func classifyForbiddenType(body string) string {
|
||||||
|
lower := strings.ToLower(body)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(lower, "validation_required") ||
|
||||||
|
strings.Contains(lower, "verify your account") ||
|
||||||
|
strings.Contains(lower, "validation_url"):
|
||||||
|
return forbiddenTypeValidation
|
||||||
|
case strings.Contains(lower, "terms of service") ||
|
||||||
|
strings.Contains(lower, "violation"):
|
||||||
|
return forbiddenTypeViolation
|
||||||
|
default:
|
||||||
|
return forbiddenTypeForbidden
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// urlPattern 用于从 403 响应体中提取 URL(降级方案)
|
||||||
|
var urlPattern = regexp.MustCompile(`https://[^\s"'\\]+`)
|
||||||
|
|
||||||
|
// extractValidationURL 从 403 响应 JSON 中提取验证/申诉链接
|
||||||
|
func extractValidationURL(body string) string {
|
||||||
|
// 1. 尝试结构化 JSON 提取: /error/details[*]/metadata/validation_url 或 appeal_url
|
||||||
|
var parsed struct {
|
||||||
|
Error struct {
|
||||||
|
Details []struct {
|
||||||
|
Metadata map[string]string `json:"metadata"`
|
||||||
|
} `json:"details"`
|
||||||
|
} `json:"error"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal([]byte(body), &parsed) == nil {
|
||||||
|
for _, detail := range parsed.Error.Details {
|
||||||
|
if u := detail.Metadata["validation_url"]; u != "" {
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
if u := detail.Metadata["appeal_url"]; u != "" {
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 降级:正则匹配 URL
|
||||||
|
lower := strings.ToLower(body)
|
||||||
|
if !strings.Contains(lower, "validation") &&
|
||||||
|
!strings.Contains(lower, "verify") &&
|
||||||
|
!strings.Contains(lower, "appeal") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// 先解码常见转义再匹配
|
||||||
|
normalized := strings.ReplaceAll(body, `\u0026`, "&")
|
||||||
|
if m := urlPattern.FindString(normalized); m != "" {
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|||||||
497
backend/internal/service/antigravity_quota_fetcher_test.go
Normal file
497
backend/internal/service/antigravity_quota_fetcher_test.go
Normal file
@@ -0,0 +1,497 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// normalizeTier
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestNormalizeTier(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
raw string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{name: "empty string", raw: "", expected: ""},
|
||||||
|
{name: "free-tier", raw: "free-tier", expected: "FREE"},
|
||||||
|
{name: "g1-pro-tier", raw: "g1-pro-tier", expected: "PRO"},
|
||||||
|
{name: "g1-ultra-tier", raw: "g1-ultra-tier", expected: "ULTRA"},
|
||||||
|
{name: "unknown-something", raw: "unknown-something", expected: "UNKNOWN"},
|
||||||
|
{name: "Google AI Pro contains pro keyword", raw: "Google AI Pro", expected: "PRO"},
|
||||||
|
{name: "case insensitive FREE", raw: "FREE-TIER", expected: "FREE"},
|
||||||
|
{name: "case insensitive Ultra", raw: "Ultra Plan", expected: "ULTRA"},
|
||||||
|
{name: "arbitrary unrecognized string", raw: "enterprise-custom", expected: "UNKNOWN"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := normalizeTier(tt.raw)
|
||||||
|
require.Equal(t, tt.expected, got, "normalizeTier(%q)", tt.raw)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// buildUsageInfo
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func aqfBoolPtr(v bool) *bool { return &v }
|
||||||
|
func aqfIntPtr(v int) *int { return &v }
|
||||||
|
|
||||||
|
func TestBuildUsageInfo_BasicModels(t *testing.T) {
|
||||||
|
fetcher := &AntigravityQuotaFetcher{}
|
||||||
|
|
||||||
|
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||||
|
Models: map[string]antigravity.ModelInfo{
|
||||||
|
"claude-sonnet-4-20250514": {
|
||||||
|
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||||
|
RemainingFraction: 0.75,
|
||||||
|
ResetTime: "2026-03-08T12:00:00Z",
|
||||||
|
},
|
||||||
|
DisplayName: "Claude Sonnet 4",
|
||||||
|
SupportsImages: aqfBoolPtr(true),
|
||||||
|
SupportsThinking: aqfBoolPtr(false),
|
||||||
|
ThinkingBudget: aqfIntPtr(0),
|
||||||
|
Recommended: aqfBoolPtr(true),
|
||||||
|
MaxTokens: aqfIntPtr(200000),
|
||||||
|
MaxOutputTokens: aqfIntPtr(16384),
|
||||||
|
SupportedMimeTypes: map[string]bool{
|
||||||
|
"image/png": true,
|
||||||
|
"image/jpeg": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"gemini-2.5-pro": {
|
||||||
|
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||||
|
RemainingFraction: 0.50,
|
||||||
|
ResetTime: "2026-03-08T15:00:00Z",
|
||||||
|
},
|
||||||
|
DisplayName: "Gemini 2.5 Pro",
|
||||||
|
MaxTokens: aqfIntPtr(1000000),
|
||||||
|
MaxOutputTokens: aqfIntPtr(65536),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO")
|
||||||
|
|
||||||
|
// 基本字段
|
||||||
|
require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set")
|
||||||
|
require.Equal(t, "PRO", info.SubscriptionTier)
|
||||||
|
require.Equal(t, "g1-pro-tier", info.SubscriptionTierRaw)
|
||||||
|
|
||||||
|
// AntigravityQuota
|
||||||
|
require.Len(t, info.AntigravityQuota, 2)
|
||||||
|
|
||||||
|
sonnetQuota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||||
|
require.NotNil(t, sonnetQuota)
|
||||||
|
require.Equal(t, 25, sonnetQuota.Utilization) // (1 - 0.75) * 100 = 25
|
||||||
|
require.Equal(t, "2026-03-08T12:00:00Z", sonnetQuota.ResetTime)
|
||||||
|
|
||||||
|
geminiQuota := info.AntigravityQuota["gemini-2.5-pro"]
|
||||||
|
require.NotNil(t, geminiQuota)
|
||||||
|
require.Equal(t, 50, geminiQuota.Utilization) // (1 - 0.50) * 100 = 50
|
||||||
|
require.Equal(t, "2026-03-08T15:00:00Z", geminiQuota.ResetTime)
|
||||||
|
|
||||||
|
// AntigravityQuotaDetails
|
||||||
|
require.Len(t, info.AntigravityQuotaDetails, 2)
|
||||||
|
|
||||||
|
sonnetDetail := info.AntigravityQuotaDetails["claude-sonnet-4-20250514"]
|
||||||
|
require.NotNil(t, sonnetDetail)
|
||||||
|
require.Equal(t, "Claude Sonnet 4", sonnetDetail.DisplayName)
|
||||||
|
require.Equal(t, aqfBoolPtr(true), sonnetDetail.SupportsImages)
|
||||||
|
require.Equal(t, aqfBoolPtr(false), sonnetDetail.SupportsThinking)
|
||||||
|
require.Equal(t, aqfIntPtr(0), sonnetDetail.ThinkingBudget)
|
||||||
|
require.Equal(t, aqfBoolPtr(true), sonnetDetail.Recommended)
|
||||||
|
require.Equal(t, aqfIntPtr(200000), sonnetDetail.MaxTokens)
|
||||||
|
require.Equal(t, aqfIntPtr(16384), sonnetDetail.MaxOutputTokens)
|
||||||
|
require.Equal(t, map[string]bool{"image/png": true, "image/jpeg": true}, sonnetDetail.SupportedMimeTypes)
|
||||||
|
|
||||||
|
geminiDetail := info.AntigravityQuotaDetails["gemini-2.5-pro"]
|
||||||
|
require.NotNil(t, geminiDetail)
|
||||||
|
require.Equal(t, "Gemini 2.5 Pro", geminiDetail.DisplayName)
|
||||||
|
require.Nil(t, geminiDetail.SupportsImages)
|
||||||
|
require.Nil(t, geminiDetail.SupportsThinking)
|
||||||
|
require.Equal(t, aqfIntPtr(1000000), geminiDetail.MaxTokens)
|
||||||
|
require.Equal(t, aqfIntPtr(65536), geminiDetail.MaxOutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUsageInfo_DeprecatedModels(t *testing.T) {
|
||||||
|
fetcher := &AntigravityQuotaFetcher{}
|
||||||
|
|
||||||
|
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||||
|
Models: map[string]antigravity.ModelInfo{
|
||||||
|
"claude-sonnet-4-20250514": {
|
||||||
|
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||||
|
RemainingFraction: 1.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
DeprecatedModelIDs: map[string]antigravity.DeprecatedModelInfo{
|
||||||
|
"claude-3-sonnet-20240229": {NewModelID: "claude-sonnet-4-20250514"},
|
||||||
|
"claude-3-haiku-20240307": {NewModelID: "claude-haiku-3.5-latest"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||||
|
|
||||||
|
require.Len(t, info.ModelForwardingRules, 2)
|
||||||
|
require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"])
|
||||||
|
require.Equal(t, "claude-haiku-3.5-latest", info.ModelForwardingRules["claude-3-haiku-20240307"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) {
|
||||||
|
fetcher := &AntigravityQuotaFetcher{}
|
||||||
|
|
||||||
|
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||||
|
Models: map[string]antigravity.ModelInfo{
|
||||||
|
"some-model": {
|
||||||
|
QuotaInfo: &antigravity.ModelQuotaInfo{RemainingFraction: 0.9},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||||
|
|
||||||
|
require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUsageInfo_EmptyModels(t *testing.T) {
|
||||||
|
fetcher := &AntigravityQuotaFetcher{}
|
||||||
|
|
||||||
|
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||||
|
Models: map[string]antigravity.ModelInfo{},
|
||||||
|
}
|
||||||
|
|
||||||
|
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||||
|
|
||||||
|
require.NotNil(t, info)
|
||||||
|
require.NotNil(t, info.AntigravityQuota)
|
||||||
|
require.Empty(t, info.AntigravityQuota)
|
||||||
|
require.NotNil(t, info.AntigravityQuotaDetails)
|
||||||
|
require.Empty(t, info.AntigravityQuotaDetails)
|
||||||
|
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) {
|
||||||
|
fetcher := &AntigravityQuotaFetcher{}
|
||||||
|
|
||||||
|
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||||
|
Models: map[string]antigravity.ModelInfo{
|
||||||
|
"model-without-quota": {
|
||||||
|
DisplayName: "No Quota Model",
|
||||||
|
// QuotaInfo is nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||||
|
|
||||||
|
require.NotNil(t, info)
|
||||||
|
require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped")
|
||||||
|
require.Empty(t, info.AntigravityQuotaDetails, "models with nil QuotaInfo should be skipped from details too")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) {
|
||||||
|
fetcher := &AntigravityQuotaFetcher{}
|
||||||
|
|
||||||
|
// priorityModels = ["claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"]
|
||||||
|
// When the first priority model exists, it should be used for FiveHour
|
||||||
|
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||||
|
Models: map[string]antigravity.ModelInfo{
|
||||||
|
"gemini-2.5-pro": {
|
||||||
|
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||||
|
RemainingFraction: 0.40,
|
||||||
|
ResetTime: "2026-03-08T18:00:00Z",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"claude-sonnet-4-20250514": {
|
||||||
|
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||||
|
RemainingFraction: 0.80,
|
||||||
|
ResetTime: "2026-03-08T12:00:00Z",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||||
|
|
||||||
|
require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists")
|
||||||
|
// claude-sonnet-4-20250514 is first in priority list, so it should be used
|
||||||
|
expectedUtilization := (1.0 - 0.80) * 100 // 20
|
||||||
|
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||||
|
require.NotNil(t, info.FiveHour.ResetsAt, "ResetsAt should be parsed from ResetTime")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) {
|
||||||
|
fetcher := &AntigravityQuotaFetcher{}
|
||||||
|
|
||||||
|
// Only claude-sonnet-4 exists (second in priority list), not claude-sonnet-4-20250514
|
||||||
|
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||||
|
Models: map[string]antigravity.ModelInfo{
|
||||||
|
"claude-sonnet-4": {
|
||||||
|
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||||
|
RemainingFraction: 0.60,
|
||||||
|
ResetTime: "2026-03-08T14:00:00Z",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"gemini-2.5-pro": {
|
||||||
|
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||||
|
RemainingFraction: 0.30,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||||
|
|
||||||
|
require.NotNil(t, info.FiveHour)
|
||||||
|
expectedUtilization := (1.0 - 0.60) * 100 // 40
|
||||||
|
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) {
|
||||||
|
fetcher := &AntigravityQuotaFetcher{}
|
||||||
|
|
||||||
|
// Only gemini-2.5-pro exists (third in priority list)
|
||||||
|
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||||
|
Models: map[string]antigravity.ModelInfo{
|
||||||
|
"gemini-2.5-pro": {
|
||||||
|
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||||
|
RemainingFraction: 0.30,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"other-model": {
|
||||||
|
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||||
|
RemainingFraction: 0.90,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||||
|
|
||||||
|
require.NotNil(t, info.FiveHour)
|
||||||
|
expectedUtilization := (1.0 - 0.30) * 100 // 70
|
||||||
|
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) {
|
||||||
|
fetcher := &AntigravityQuotaFetcher{}
|
||||||
|
|
||||||
|
// None of the priority models exist
|
||||||
|
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||||
|
Models: map[string]antigravity.ModelInfo{
|
||||||
|
"some-other-model": {
|
||||||
|
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||||
|
RemainingFraction: 0.50,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||||
|
|
||||||
|
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) {
|
||||||
|
fetcher := &AntigravityQuotaFetcher{}
|
||||||
|
|
||||||
|
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||||
|
Models: map[string]antigravity.ModelInfo{
|
||||||
|
"claude-sonnet-4-20250514": {
|
||||||
|
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||||
|
RemainingFraction: 0.50,
|
||||||
|
ResetTime: "", // empty reset time
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||||
|
|
||||||
|
require.NotNil(t, info.FiveHour)
|
||||||
|
require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty")
|
||||||
|
require.Equal(t, 0, info.FiveHour.RemainingSeconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUsageInfo_FullUtilization(t *testing.T) {
|
||||||
|
fetcher := &AntigravityQuotaFetcher{}
|
||||||
|
|
||||||
|
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||||
|
Models: map[string]antigravity.ModelInfo{
|
||||||
|
"claude-sonnet-4-20250514": {
|
||||||
|
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||||
|
RemainingFraction: 0.0, // fully used
|
||||||
|
ResetTime: "2026-03-08T12:00:00Z",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||||
|
|
||||||
|
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||||
|
require.NotNil(t, quota)
|
||||||
|
require.Equal(t, 100, quota.Utilization)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUsageInfo_ZeroUtilization(t *testing.T) {
|
||||||
|
fetcher := &AntigravityQuotaFetcher{}
|
||||||
|
|
||||||
|
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||||
|
Models: map[string]antigravity.ModelInfo{
|
||||||
|
"claude-sonnet-4-20250514": {
|
||||||
|
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||||
|
RemainingFraction: 1.0, // fully available
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||||
|
|
||||||
|
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||||
|
require.NotNil(t, quota)
|
||||||
|
require.Equal(t, 0, quota.Utilization)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) {
|
||||||
|
// 模拟 FetchQuota 遇到 403 时的行为:
|
||||||
|
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true
|
||||||
|
forbiddenErr := &antigravity.ForbiddenError{
|
||||||
|
StatusCode: 403,
|
||||||
|
Body: "Access denied",
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 ForbiddenError 满足 errors.As
|
||||||
|
var target *antigravity.ForbiddenError
|
||||||
|
require.True(t, errors.As(forbiddenErr, &target))
|
||||||
|
require.Equal(t, 403, target.StatusCode)
|
||||||
|
require.Equal(t, "Access denied", target.Body)
|
||||||
|
require.Contains(t, forbiddenErr.Error(), "403")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// classifyForbiddenType
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestClassifyForbiddenType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "VALIDATION_REQUIRED keyword",
|
||||||
|
body: `{"error":{"message":"VALIDATION_REQUIRED"}}`,
|
||||||
|
expected: "validation",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "verify your account",
|
||||||
|
body: `Please verify your account to continue`,
|
||||||
|
expected: "validation",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "contains validation_url field",
|
||||||
|
body: `{"error":{"details":[{"metadata":{"validation_url":"https://..."}}]}}`,
|
||||||
|
expected: "validation",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "terms of service violation",
|
||||||
|
body: `Your account has been suspended for Terms of Service violation`,
|
||||||
|
expected: "violation",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "violation keyword",
|
||||||
|
body: `Account suspended due to policy violation`,
|
||||||
|
expected: "violation",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "generic 403",
|
||||||
|
body: `Access denied`,
|
||||||
|
expected: "forbidden",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty body",
|
||||||
|
body: "",
|
||||||
|
expected: "forbidden",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := classifyForbiddenType(tt.body)
|
||||||
|
require.Equal(t, tt.expected, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// extractValidationURL
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestExtractValidationURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "structured validation_url",
|
||||||
|
body: `{"error":{"details":[{"metadata":{"validation_url":"https://accounts.google.com/verify?token=abc"}}]}}`,
|
||||||
|
expected: "https://accounts.google.com/verify?token=abc",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "structured appeal_url",
|
||||||
|
body: `{"error":{"details":[{"metadata":{"appeal_url":"https://support.google.com/appeal/123"}}]}}`,
|
||||||
|
expected: "https://support.google.com/appeal/123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "validation_url takes priority over appeal_url",
|
||||||
|
body: `{"error":{"details":[{"metadata":{"validation_url":"https://v.com","appeal_url":"https://a.com"}}]}}`,
|
||||||
|
expected: "https://v.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fallback regex with verify keyword",
|
||||||
|
body: `Please verify your account at https://accounts.google.com/verify`,
|
||||||
|
expected: "https://accounts.google.com/verify",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no URL in generic forbidden",
|
||||||
|
body: `Access denied`,
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty body",
|
||||||
|
body: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL present but no validation keywords",
|
||||||
|
body: `Error at https://example.com/something`,
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unicode escaped ampersand",
|
||||||
|
body: `validation required: https://accounts.google.com/verify?a=1\u0026b=2`,
|
||||||
|
expected: "https://accounts.google.com/verify?a=1&b=2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := extractValidationURL(tt.body)
|
||||||
|
require.Equal(t, tt.expected, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -22,8 +22,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
|
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
|
||||||
|
// A nil windowStart is treated as expired — no initialized window means any accumulated usage is stale.
|
||||||
func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool {
|
func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool {
|
||||||
return windowStart != nil && time.Since(*windowStart) >= duration
|
return windowStart == nil || time.Since(*windowStart) >= duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type APIKey struct {
|
type APIKey struct {
|
||||||
|
|||||||
@@ -15,10 +15,10 @@ func TestIsWindowExpired(t *testing.T) {
|
|||||||
want bool
|
want bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "nil window start",
|
name: "nil window start (treated as expired)",
|
||||||
start: nil,
|
start: nil,
|
||||||
duration: RateLimitWindow5h,
|
duration: RateLimitWindow5h,
|
||||||
want: false,
|
want: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "active window (started 1h ago, 5h window)",
|
name: "active window (started 1h ago, 5h window)",
|
||||||
@@ -113,7 +113,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
|
|||||||
want7d: 0,
|
want7d: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "nil window starts return raw usage",
|
name: "nil window starts return 0 (stale usage reset)",
|
||||||
key: APIKey{
|
key: APIKey{
|
||||||
Usage5h: 5.0,
|
Usage5h: 5.0,
|
||||||
Usage1d: 10.0,
|
Usage1d: 10.0,
|
||||||
@@ -122,9 +122,9 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
|
|||||||
Window1dStart: nil,
|
Window1dStart: nil,
|
||||||
Window7dStart: nil,
|
Window7dStart: nil,
|
||||||
},
|
},
|
||||||
want5h: 5.0,
|
want5h: 0,
|
||||||
want1d: 10.0,
|
want1d: 0,
|
||||||
want7d: 50.0,
|
want7d: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mixed: 5h expired, 1d active, 7d nil",
|
name: "mixed: 5h expired, 1d active, 7d nil",
|
||||||
@@ -138,7 +138,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
|
|||||||
},
|
},
|
||||||
want5h: 0,
|
want5h: 0,
|
||||||
want1d: 10.0,
|
want1d: 10.0,
|
||||||
want7d: 50.0,
|
want7d: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "zero usage with active windows",
|
name: "zero usage with active windows",
|
||||||
@@ -210,7 +210,7 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
|
|||||||
want7d: 0,
|
want7d: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "nil window starts return raw usage",
|
name: "nil window starts return 0 (stale usage reset)",
|
||||||
data: APIKeyRateLimitData{
|
data: APIKeyRateLimitData{
|
||||||
Usage5h: 3.0,
|
Usage5h: 3.0,
|
||||||
Usage1d: 8.0,
|
Usage1d: 8.0,
|
||||||
@@ -219,9 +219,9 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
|
|||||||
Window1dStart: nil,
|
Window1dStart: nil,
|
||||||
Window7dStart: nil,
|
Window7dStart: nil,
|
||||||
},
|
},
|
||||||
want5h: 3.0,
|
want5h: 0,
|
||||||
want1d: 8.0,
|
want1d: 0,
|
||||||
want7d: 40.0,
|
want7d: 0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1087,6 +1087,12 @@ type TokenPair struct {
|
|||||||
ExpiresIn int `json:"expires_in"` // Access Token有效期(秒)
|
ExpiresIn int `json:"expires_in"` // Access Token有效期(秒)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TokenPairWithUser extends TokenPair with user role for backend mode checks
|
||||||
|
type TokenPairWithUser struct {
|
||||||
|
TokenPair
|
||||||
|
UserRole string
|
||||||
|
}
|
||||||
|
|
||||||
// GenerateTokenPair 生成Access Token和Refresh Token对
|
// GenerateTokenPair 生成Access Token和Refresh Token对
|
||||||
// familyID: 可选的Token家族ID,用于Token轮转时保持家族关系
|
// familyID: 可选的Token家族ID,用于Token轮转时保持家族关系
|
||||||
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
|
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
|
||||||
@@ -1168,7 +1174,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
|
|||||||
|
|
||||||
// RefreshTokenPair 使用Refresh Token刷新Token对
|
// RefreshTokenPair 使用Refresh Token刷新Token对
|
||||||
// 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效
|
// 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效
|
||||||
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPairWithUser, error) {
|
||||||
// 检查 refreshTokenCache 是否可用
|
// 检查 refreshTokenCache 是否可用
|
||||||
if s.refreshTokenCache == nil {
|
if s.refreshTokenCache == nil {
|
||||||
return nil, ErrRefreshTokenInvalid
|
return nil, ErrRefreshTokenInvalid
|
||||||
@@ -1233,7 +1239,14 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 生成新的Token对,保持同一个家族ID
|
// 生成新的Token对,保持同一个家族ID
|
||||||
return s.GenerateTokenPair(ctx, user, data.FamilyID)
|
pair, err := s.GenerateTokenPair(ctx, user, data.FamilyID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &TokenPairWithUser{
|
||||||
|
TokenPair: *pair,
|
||||||
|
UserRole: user.Role,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RevokeRefreshToken 撤销单个Refresh Token
|
// RevokeRefreshToken 撤销单个Refresh Token
|
||||||
|
|||||||
770
backend/internal/service/backup_service.go
Normal file
770
backend/internal/service/backup_service.go
Normal file
@@ -0,0 +1,770 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"compress/gzip"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/robfig/cron/v3"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
settingKeyBackupS3Config = "backup_s3_config"
|
||||||
|
settingKeyBackupSchedule = "backup_schedule"
|
||||||
|
settingKeyBackupRecords = "backup_records"
|
||||||
|
|
||||||
|
maxBackupRecords = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrBackupS3NotConfigured = infraerrors.BadRequest("BACKUP_S3_NOT_CONFIGURED", "backup S3 storage is not configured")
|
||||||
|
ErrBackupNotFound = infraerrors.NotFound("BACKUP_NOT_FOUND", "backup record not found")
|
||||||
|
ErrBackupInProgress = infraerrors.Conflict("BACKUP_IN_PROGRESS", "a backup is already in progress")
|
||||||
|
ErrRestoreInProgress = infraerrors.Conflict("RESTORE_IN_PROGRESS", "a restore is already in progress")
|
||||||
|
ErrBackupRecordsCorrupt = infraerrors.InternalServer("BACKUP_RECORDS_CORRUPT", "backup records data is corrupted")
|
||||||
|
ErrBackupS3ConfigCorrupt = infraerrors.InternalServer("BACKUP_S3_CONFIG_CORRUPT", "backup S3 config data is corrupted")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ─── 接口定义 ───
|
||||||
|
|
||||||
|
// DBDumper abstracts database dump/restore operations
|
||||||
|
type DBDumper interface {
|
||||||
|
Dump(ctx context.Context) (io.ReadCloser, error)
|
||||||
|
Restore(ctx context.Context, data io.Reader) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackupObjectStore abstracts object storage for backup files
|
||||||
|
type BackupObjectStore interface {
|
||||||
|
Upload(ctx context.Context, key string, body io.Reader, contentType string) (sizeBytes int64, err error)
|
||||||
|
Download(ctx context.Context, key string) (io.ReadCloser, error)
|
||||||
|
Delete(ctx context.Context, key string) error
|
||||||
|
PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error)
|
||||||
|
HeadBucket(ctx context.Context) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackupObjectStoreFactory creates an object store from S3 config
|
||||||
|
type BackupObjectStoreFactory func(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error)
|
||||||
|
|
||||||
|
// ─── 数据模型 ───
|
||||||
|
|
||||||
|
// BackupS3Config S3 兼容存储配置(支持 Cloudflare R2)
|
||||||
|
type BackupS3Config struct {
|
||||||
|
Endpoint string `json:"endpoint"` // e.g. https://<account_id>.r2.cloudflarestorage.com
|
||||||
|
Region string `json:"region"` // R2 用 "auto"
|
||||||
|
Bucket string `json:"bucket"`
|
||||||
|
AccessKeyID string `json:"access_key_id"`
|
||||||
|
SecretAccessKey string `json:"secret_access_key,omitempty"` //nolint:revive // field name follows AWS convention
|
||||||
|
Prefix string `json:"prefix"` // S3 key 前缀,如 "backups/"
|
||||||
|
ForcePathStyle bool `json:"force_path_style"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConfigured 检查必要字段是否已配置
|
||||||
|
func (c *BackupS3Config) IsConfigured() bool {
|
||||||
|
return c.Bucket != "" && c.AccessKeyID != "" && c.SecretAccessKey != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackupScheduleConfig 定时备份配置
|
||||||
|
type BackupScheduleConfig struct {
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
CronExpr string `json:"cron_expr"` // cron 表达式,如 "0 2 * * *" 每天凌晨2点
|
||||||
|
RetainDays int `json:"retain_days"` // 备份文件过期天数,默认14,0=不自动清理
|
||||||
|
RetainCount int `json:"retain_count"` // 最多保留份数,0=不限制
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackupRecord 备份记录
|
||||||
|
type BackupRecord struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Status string `json:"status"` // pending, running, completed, failed
|
||||||
|
BackupType string `json:"backup_type"` // postgres
|
||||||
|
FileName string `json:"file_name"`
|
||||||
|
S3Key string `json:"s3_key"`
|
||||||
|
SizeBytes int64 `json:"size_bytes"`
|
||||||
|
TriggeredBy string `json:"triggered_by"` // manual, scheduled
|
||||||
|
ErrorMsg string `json:"error_message,omitempty"`
|
||||||
|
StartedAt string `json:"started_at"`
|
||||||
|
FinishedAt string `json:"finished_at,omitempty"`
|
||||||
|
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackupService 数据库备份恢复服务
|
||||||
|
type BackupService struct {
|
||||||
|
settingRepo SettingRepository
|
||||||
|
dbCfg *config.DatabaseConfig
|
||||||
|
encryptor SecretEncryptor
|
||||||
|
storeFactory BackupObjectStoreFactory
|
||||||
|
dumper DBDumper
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
store BackupObjectStore
|
||||||
|
s3Cfg *BackupS3Config
|
||||||
|
backingUp bool
|
||||||
|
restoring bool
|
||||||
|
|
||||||
|
recordsMu sync.Mutex // 保护 records 的 load/save 操作
|
||||||
|
|
||||||
|
cronMu sync.Mutex
|
||||||
|
cronSched *cron.Cron
|
||||||
|
cronEntryID cron.EntryID
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBackupService(
|
||||||
|
settingRepo SettingRepository,
|
||||||
|
cfg *config.Config,
|
||||||
|
encryptor SecretEncryptor,
|
||||||
|
storeFactory BackupObjectStoreFactory,
|
||||||
|
dumper DBDumper,
|
||||||
|
) *BackupService {
|
||||||
|
return &BackupService{
|
||||||
|
settingRepo: settingRepo,
|
||||||
|
dbCfg: &cfg.Database,
|
||||||
|
encryptor: encryptor,
|
||||||
|
storeFactory: storeFactory,
|
||||||
|
dumper: dumper,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start 启动定时备份调度器
|
||||||
|
func (s *BackupService) Start() {
|
||||||
|
s.cronSched = cron.New()
|
||||||
|
s.cronSched.Start()
|
||||||
|
|
||||||
|
// 加载已有的定时配置
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
schedule, err := s.GetSchedule(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 加载定时备份配置失败: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if schedule.Enabled && schedule.CronExpr != "" {
|
||||||
|
if err := s.applyCronSchedule(schedule); err != nil {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 应用定时备份配置失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop 停止定时备份
|
||||||
|
func (s *BackupService) Stop() {
|
||||||
|
s.cronMu.Lock()
|
||||||
|
defer s.cronMu.Unlock()
|
||||||
|
if s.cronSched != nil {
|
||||||
|
s.cronSched.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── S3 配置管理 ───
|
||||||
|
|
||||||
|
func (s *BackupService) GetS3Config(ctx context.Context) (*BackupS3Config, error) {
|
||||||
|
cfg, err := s.loadS3Config(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
return &BackupS3Config{}, nil
|
||||||
|
}
|
||||||
|
// 脱敏返回
|
||||||
|
cfg.SecretAccessKey = ""
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) (*BackupS3Config, error) {
|
||||||
|
// 如果没提供 secret,保留原有值
|
||||||
|
if cfg.SecretAccessKey == "" {
|
||||||
|
old, _ := s.loadS3Config(ctx)
|
||||||
|
if old != nil {
|
||||||
|
cfg.SecretAccessKey = old.SecretAccessKey
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 加密 SecretAccessKey
|
||||||
|
encrypted, err := s.encryptor.Encrypt(cfg.SecretAccessKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("encrypt secret: %w", err)
|
||||||
|
}
|
||||||
|
cfg.SecretAccessKey = encrypted
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal s3 config: %w", err)
|
||||||
|
}
|
||||||
|
if err := s.settingRepo.Set(ctx, settingKeyBackupS3Config, string(data)); err != nil {
|
||||||
|
return nil, fmt.Errorf("save s3 config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清除缓存的 S3 客户端
|
||||||
|
s.mu.Lock()
|
||||||
|
s.store = nil
|
||||||
|
s.s3Cfg = nil
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
cfg.SecretAccessKey = ""
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BackupService) TestS3Connection(ctx context.Context, cfg BackupS3Config) error {
|
||||||
|
// 如果没提供 secret,用已保存的
|
||||||
|
if cfg.SecretAccessKey == "" {
|
||||||
|
old, _ := s.loadS3Config(ctx)
|
||||||
|
if old != nil {
|
||||||
|
cfg.SecretAccessKey = old.SecretAccessKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
|
||||||
|
return fmt.Errorf("incomplete S3 config: bucket, access_key_id, secret_access_key are required")
|
||||||
|
}
|
||||||
|
|
||||||
|
store, err := s.storeFactory(ctx, &cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return store.HeadBucket(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 定时备份管理 ───
|
||||||
|
|
||||||
|
func (s *BackupService) GetSchedule(ctx context.Context) (*BackupScheduleConfig, error) {
|
||||||
|
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupSchedule)
|
||||||
|
if err != nil || raw == "" {
|
||||||
|
return &BackupScheduleConfig{}, nil
|
||||||
|
}
|
||||||
|
var cfg BackupScheduleConfig
|
||||||
|
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||||
|
return &BackupScheduleConfig{}, nil
|
||||||
|
}
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BackupService) UpdateSchedule(ctx context.Context, cfg BackupScheduleConfig) (*BackupScheduleConfig, error) {
|
||||||
|
if cfg.Enabled && cfg.CronExpr == "" {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_CRON", "cron expression is required when schedule is enabled")
|
||||||
|
}
|
||||||
|
// 验证 cron 表达式
|
||||||
|
if cfg.CronExpr != "" {
|
||||||
|
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||||
|
if _, err := parser.Parse(cfg.CronExpr); err != nil {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("invalid cron expression: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal schedule config: %w", err)
|
||||||
|
}
|
||||||
|
if err := s.settingRepo.Set(ctx, settingKeyBackupSchedule, string(data)); err != nil {
|
||||||
|
return nil, fmt.Errorf("save schedule config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应用或停止定时任务
|
||||||
|
if cfg.Enabled {
|
||||||
|
if err := s.applyCronSchedule(&cfg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
s.removeCronSchedule()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BackupService) applyCronSchedule(cfg *BackupScheduleConfig) error {
|
||||||
|
s.cronMu.Lock()
|
||||||
|
defer s.cronMu.Unlock()
|
||||||
|
|
||||||
|
if s.cronSched == nil {
|
||||||
|
return fmt.Errorf("cron scheduler not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除旧任务
|
||||||
|
if s.cronEntryID != 0 {
|
||||||
|
s.cronSched.Remove(s.cronEntryID)
|
||||||
|
s.cronEntryID = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
entryID, err := s.cronSched.AddFunc(cfg.CronExpr, func() {
|
||||||
|
s.runScheduledBackup()
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("failed to schedule: %v", err))
|
||||||
|
}
|
||||||
|
s.cronEntryID = entryID
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已启用: %s", cfg.CronExpr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BackupService) removeCronSchedule() {
|
||||||
|
s.cronMu.Lock()
|
||||||
|
defer s.cronMu.Unlock()
|
||||||
|
if s.cronSched != nil && s.cronEntryID != 0 {
|
||||||
|
s.cronSched.Remove(s.cronEntryID)
|
||||||
|
s.cronEntryID = 0
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已停用")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BackupService) runScheduledBackup() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// 读取定时备份配置中的过期天数
|
||||||
|
schedule, _ := s.GetSchedule(ctx)
|
||||||
|
expireDays := 14 // 默认14天过期
|
||||||
|
if schedule != nil && schedule.RetainDays > 0 {
|
||||||
|
expireDays = schedule.RetainDays
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
|
||||||
|
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
|
||||||
|
|
||||||
|
// 清理过期备份(复用已加载的 schedule)
|
||||||
|
if schedule == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := s.cleanupOldBackups(ctx, schedule); err != nil {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 清理过期备份失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 备份/恢复核心 ───
|
||||||
|
|
||||||
|
// CreateBackup 创建全量数据库备份并上传到 S3(流式处理)
|
||||||
|
// expireDays: 备份过期天数,0=永不过期,默认14天
|
||||||
|
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
if s.backingUp {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil, ErrBackupInProgress
|
||||||
|
}
|
||||||
|
s.backingUp = true
|
||||||
|
s.mu.Unlock()
|
||||||
|
defer func() {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.backingUp = false
|
||||||
|
s.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
s3Cfg, err := s.loadS3Config(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s3Cfg == nil || !s3Cfg.IsConfigured() {
|
||||||
|
return nil, ErrBackupS3NotConfigured
|
||||||
|
}
|
||||||
|
|
||||||
|
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("init object store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
backupID := uuid.New().String()[:8]
|
||||||
|
fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405"))
|
||||||
|
s3Key := s.buildS3Key(s3Cfg, fileName)
|
||||||
|
|
||||||
|
var expiresAt string
|
||||||
|
if expireDays > 0 {
|
||||||
|
expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
record := &BackupRecord{
|
||||||
|
ID: backupID,
|
||||||
|
Status: "running",
|
||||||
|
BackupType: "postgres",
|
||||||
|
FileName: fileName,
|
||||||
|
S3Key: s3Key,
|
||||||
|
TriggeredBy: triggeredBy,
|
||||||
|
StartedAt: now.Format(time.RFC3339),
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 流式执行: pg_dump -> gzip -> S3 upload
|
||||||
|
dumpReader, err := s.dumper.Dump(ctx)
|
||||||
|
if err != nil {
|
||||||
|
record.Status = "failed"
|
||||||
|
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
|
||||||
|
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||||
|
_ = s.saveRecord(ctx, record)
|
||||||
|
return record, fmt.Errorf("pg_dump: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
var gzipErr error
|
||||||
|
go func() {
|
||||||
|
gzWriter := gzip.NewWriter(pw)
|
||||||
|
_, gzipErr = io.Copy(gzWriter, dumpReader)
|
||||||
|
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
|
||||||
|
gzipErr = closeErr
|
||||||
|
}
|
||||||
|
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
|
||||||
|
gzipErr = closeErr
|
||||||
|
}
|
||||||
|
if gzipErr != nil {
|
||||||
|
_ = pw.CloseWithError(gzipErr)
|
||||||
|
} else {
|
||||||
|
_ = pw.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
contentType := "application/gzip"
|
||||||
|
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
|
||||||
|
if err != nil {
|
||||||
|
record.Status = "failed"
|
||||||
|
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
|
||||||
|
if gzipErr != nil {
|
||||||
|
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
|
||||||
|
}
|
||||||
|
record.ErrorMsg = errMsg
|
||||||
|
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||||
|
_ = s.saveRecord(ctx, record)
|
||||||
|
return record, fmt.Errorf("backup upload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
record.SizeBytes = sizeBytes
|
||||||
|
record.Status = "completed"
|
||||||
|
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||||
|
if err := s.saveRecord(ctx, record); err != nil {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
|
||||||
|
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
if s.restoring {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return ErrRestoreInProgress
|
||||||
|
}
|
||||||
|
s.restoring = true
|
||||||
|
s.mu.Unlock()
|
||||||
|
defer func() {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.restoring = false
|
||||||
|
s.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
record, err := s.GetBackupRecord(ctx, backupID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if record.Status != "completed" {
|
||||||
|
return infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup")
|
||||||
|
}
|
||||||
|
|
||||||
|
s3Cfg, err := s.loadS3Config(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("init object store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从 S3 流式下载
|
||||||
|
body, err := objectStore.Download(ctx, record.S3Key)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("S3 download failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = body.Close() }()
|
||||||
|
|
||||||
|
// 流式解压 gzip -> psql(不将全部数据加载到内存)
|
||||||
|
gzReader, err := gzip.NewReader(body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("gzip reader: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = gzReader.Close() }()
|
||||||
|
|
||||||
|
// 流式恢复
|
||||||
|
if err := s.dumper.Restore(ctx, gzReader); err != nil {
|
||||||
|
return fmt.Errorf("pg restore: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 备份记录管理 ───
|
||||||
|
|
||||||
|
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
|
||||||
|
records, err := s.loadRecords(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// 倒序返回(最新在前)
|
||||||
|
sort.Slice(records, func(i, j int) bool {
|
||||||
|
return records[i].StartedAt > records[j].StartedAt
|
||||||
|
})
|
||||||
|
return records, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BackupService) GetBackupRecord(ctx context.Context, backupID string) (*BackupRecord, error) {
|
||||||
|
records, err := s.loadRecords(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for i := range records {
|
||||||
|
if records[i].ID == backupID {
|
||||||
|
return &records[i], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, ErrBackupNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error {
|
||||||
|
s.recordsMu.Lock()
|
||||||
|
defer s.recordsMu.Unlock()
|
||||||
|
|
||||||
|
records, err := s.loadRecordsLocked(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var found *BackupRecord
|
||||||
|
var remaining []BackupRecord
|
||||||
|
for i := range records {
|
||||||
|
if records[i].ID == backupID {
|
||||||
|
found = &records[i]
|
||||||
|
} else {
|
||||||
|
remaining = append(remaining, records[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if found == nil {
|
||||||
|
return ErrBackupNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从 S3 删除
|
||||||
|
if found.S3Key != "" && found.Status == "completed" {
|
||||||
|
s3Cfg, err := s.loadS3Config(ctx)
|
||||||
|
if err == nil && s3Cfg != nil && s3Cfg.IsConfigured() {
|
||||||
|
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||||
|
if err == nil {
|
||||||
|
_ = objectStore.Delete(ctx, found.S3Key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.saveRecordsLocked(ctx, remaining)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBackupDownloadURL 获取备份文件预签名下载 URL
|
||||||
|
func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID string) (string, error) {
|
||||||
|
record, err := s.GetBackupRecord(ctx, backupID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if record.Status != "completed" {
|
||||||
|
return "", infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "backup is not completed")
|
||||||
|
}
|
||||||
|
|
||||||
|
s3Cfg, err := s.loadS3Config(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
url, err := objectStore.PresignURL(ctx, record.S3Key, 1*time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("presign url: %w", err)
|
||||||
|
}
|
||||||
|
return url, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 内部方法 ───
|
||||||
|
|
||||||
|
func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, error) {
|
||||||
|
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config)
|
||||||
|
if err != nil || raw == "" {
|
||||||
|
return nil, nil //nolint:nilnil // no config is a valid state
|
||||||
|
}
|
||||||
|
var cfg BackupS3Config
|
||||||
|
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||||
|
return nil, ErrBackupS3ConfigCorrupt
|
||||||
|
}
|
||||||
|
// 解密 SecretAccessKey
|
||||||
|
if cfg.SecretAccessKey != "" {
|
||||||
|
decrypted, err := s.encryptor.Decrypt(cfg.SecretAccessKey)
|
||||||
|
if err != nil {
|
||||||
|
// 兼容未加密的旧数据:如果解密失败,保持原值
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] S3 SecretAccessKey 解密失败(可能是旧的未加密数据): %v", err)
|
||||||
|
} else {
|
||||||
|
cfg.SecretAccessKey = decrypted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.store != nil && s.s3Cfg != nil {
|
||||||
|
return s.store, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, ErrBackupS3NotConfigured
|
||||||
|
}
|
||||||
|
|
||||||
|
store, err := s.storeFactory(ctx, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.store = store
|
||||||
|
s.s3Cfg = cfg
|
||||||
|
return store, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string {
|
||||||
|
prefix := strings.TrimRight(cfg.Prefix, "/")
|
||||||
|
if prefix == "" {
|
||||||
|
prefix = "backups"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/%s/%s", prefix, time.Now().Format("2006/01/02"), fileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadRecords 加载备份记录,区分"无数据"和"数据损坏"
|
||||||
|
func (s *BackupService) loadRecords(ctx context.Context) ([]BackupRecord, error) {
|
||||||
|
s.recordsMu.Lock()
|
||||||
|
defer s.recordsMu.Unlock()
|
||||||
|
return s.loadRecordsLocked(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadRecordsLocked 在已持有 recordsMu 锁的情况下加载记录
|
||||||
|
func (s *BackupService) loadRecordsLocked(ctx context.Context) ([]BackupRecord, error) {
|
||||||
|
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupRecords)
|
||||||
|
if err != nil || raw == "" {
|
||||||
|
return nil, nil //nolint:nilnil // no records is a valid state
|
||||||
|
}
|
||||||
|
var records []BackupRecord
|
||||||
|
if err := json.Unmarshal([]byte(raw), &records); err != nil {
|
||||||
|
return nil, ErrBackupRecordsCorrupt
|
||||||
|
}
|
||||||
|
return records, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveRecordsLocked 在已持有 recordsMu 锁的情况下保存记录
|
||||||
|
func (s *BackupService) saveRecordsLocked(ctx context.Context, records []BackupRecord) error {
|
||||||
|
data, err := json.Marshal(records)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.settingRepo.Set(ctx, settingKeyBackupRecords, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveRecord 保存单条记录(带互斥锁保护)
|
||||||
|
func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) error {
|
||||||
|
s.recordsMu.Lock()
|
||||||
|
defer s.recordsMu.Unlock()
|
||||||
|
|
||||||
|
records, _ := s.loadRecordsLocked(ctx)
|
||||||
|
|
||||||
|
// 更新已有记录或追加
|
||||||
|
found := false
|
||||||
|
for i := range records {
|
||||||
|
if records[i].ID == record.ID {
|
||||||
|
records[i] = *record
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
records = append(records, *record)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 限制记录数量
|
||||||
|
if len(records) > maxBackupRecords {
|
||||||
|
records = records[len(records)-maxBackupRecords:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.saveRecordsLocked(ctx, records)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupScheduleConfig) error {
|
||||||
|
if schedule == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.recordsMu.Lock()
|
||||||
|
defer s.recordsMu.Unlock()
|
||||||
|
|
||||||
|
records, err := s.loadRecordsLocked(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 按时间倒序
|
||||||
|
sort.Slice(records, func(i, j int) bool {
|
||||||
|
return records[i].StartedAt > records[j].StartedAt
|
||||||
|
})
|
||||||
|
|
||||||
|
var toDelete []BackupRecord
|
||||||
|
var toKeep []BackupRecord
|
||||||
|
|
||||||
|
for i, r := range records {
|
||||||
|
shouldDelete := false
|
||||||
|
|
||||||
|
// 按保留份数清理
|
||||||
|
if schedule.RetainCount > 0 && i >= schedule.RetainCount {
|
||||||
|
shouldDelete = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 按保留天数清理
|
||||||
|
if schedule.RetainDays > 0 && r.StartedAt != "" {
|
||||||
|
startedAt, err := time.Parse(time.RFC3339, r.StartedAt)
|
||||||
|
if err == nil && time.Since(startedAt) > time.Duration(schedule.RetainDays)*24*time.Hour {
|
||||||
|
shouldDelete = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldDelete && r.Status == "completed" {
|
||||||
|
toDelete = append(toDelete, r)
|
||||||
|
} else {
|
||||||
|
toKeep = append(toKeep, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除 S3 上的文件
|
||||||
|
for _, r := range toDelete {
|
||||||
|
if r.S3Key != "" {
|
||||||
|
_ = s.deleteS3Object(ctx, r.S3Key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toDelete) > 0 {
|
||||||
|
logger.LegacyPrintf("service.backup", "[Backup] 自动清理了 %d 个过期备份", len(toDelete))
|
||||||
|
return s.saveRecordsLocked(ctx, toKeep)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BackupService) deleteS3Object(ctx context.Context, key string) error {
|
||||||
|
s3Cfg, err := s.loadS3Config(ctx)
|
||||||
|
if err != nil || s3Cfg == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return objectStore.Delete(ctx, key)
|
||||||
|
}
|
||||||
528
backend/internal/service/backup_service_test.go
Normal file
528
backend/internal/service/backup_service_test.go
Normal file
@@ -0,0 +1,528 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ─── Mocks ───
|
||||||
|
|
||||||
|
type mockSettingRepo struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
data map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockSettingRepo() *mockSettingRepo {
|
||||||
|
return &mockSettingRepo{data: make(map[string]string)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSettingRepo) Get(_ context.Context, key string) (*Setting, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
v, ok := m.data[key]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrSettingNotFound
|
||||||
|
}
|
||||||
|
return &Setting{Key: key, Value: v}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
v, ok := m.data[key]
|
||||||
|
if !ok {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSettingRepo) Set(_ context.Context, key, value string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.data[key] = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
result := make(map[string]string)
|
||||||
|
for _, k := range keys {
|
||||||
|
if v, ok := m.data[k]; ok {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
for k, v := range settings {
|
||||||
|
m.data[k] = v
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
result := make(map[string]string, len(m.data))
|
||||||
|
for k, v := range m.data {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSettingRepo) Delete(_ context.Context, key string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
delete(m.data, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// plainEncryptor 仅做 base64-like 包装,用于测试
|
||||||
|
type plainEncryptor struct{}
|
||||||
|
|
||||||
|
func (e *plainEncryptor) Encrypt(plaintext string) (string, error) {
|
||||||
|
return "ENC:" + plaintext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *plainEncryptor) Decrypt(ciphertext string) (string, error) {
|
||||||
|
if strings.HasPrefix(ciphertext, "ENC:") {
|
||||||
|
return strings.TrimPrefix(ciphertext, "ENC:"), nil
|
||||||
|
}
|
||||||
|
return ciphertext, fmt.Errorf("not encrypted")
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockDumper struct {
|
||||||
|
dumpData []byte
|
||||||
|
dumpErr error
|
||||||
|
restored []byte
|
||||||
|
restErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDumper) Dump(_ context.Context) (io.ReadCloser, error) {
|
||||||
|
if m.dumpErr != nil {
|
||||||
|
return nil, m.dumpErr
|
||||||
|
}
|
||||||
|
return io.NopCloser(bytes.NewReader(m.dumpData)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDumper) Restore(_ context.Context, data io.Reader) error {
|
||||||
|
if m.restErr != nil {
|
||||||
|
return m.restErr
|
||||||
|
}
|
||||||
|
d, err := io.ReadAll(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.restored = d
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockObjectStore struct {
|
||||||
|
objects map[string][]byte
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockObjectStore() *mockObjectStore {
|
||||||
|
return &mockObjectStore{objects: make(map[string][]byte)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockObjectStore) Upload(_ context.Context, key string, body io.Reader, _ string) (int64, error) {
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
m.objects[key] = data
|
||||||
|
m.mu.Unlock()
|
||||||
|
return int64(len(data)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockObjectStore) Download(_ context.Context, key string) (io.ReadCloser, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
data, ok := m.objects[key]
|
||||||
|
m.mu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("not found: %s", key)
|
||||||
|
}
|
||||||
|
return io.NopCloser(bytes.NewReader(data)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockObjectStore) Delete(_ context.Context, key string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
delete(m.objects, key)
|
||||||
|
m.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockObjectStore) PresignURL(_ context.Context, key string, _ time.Duration) (string, error) {
|
||||||
|
return "https://presigned.example.com/" + key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockObjectStore) HeadBucket(_ context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Database: config.DatabaseConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
User: "test",
|
||||||
|
DBName: "testdb",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
factory := func(_ context.Context, _ *BackupS3Config) (BackupObjectStore, error) {
|
||||||
|
return store, nil
|
||||||
|
}
|
||||||
|
return NewBackupService(repo, cfg, &plainEncryptor{}, factory, dumper)
|
||||||
|
}
|
||||||
|
|
||||||
|
func seedS3Config(t *testing.T, repo *mockSettingRepo) {
|
||||||
|
t.Helper()
|
||||||
|
cfg := BackupS3Config{
|
||||||
|
Bucket: "test-bucket",
|
||||||
|
AccessKeyID: "AKID",
|
||||||
|
SecretAccessKey: "ENC:secret123",
|
||||||
|
Prefix: "backups",
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(cfg)
|
||||||
|
require.NoError(t, repo.Set(context.Background(), settingKeyBackupS3Config, string(data)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Tests ───
|
||||||
|
|
||||||
|
func TestBackupService_S3ConfigEncryption(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||||
|
|
||||||
|
// 保存配置 -> SecretAccessKey 应被加密
|
||||||
|
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||||
|
Bucket: "my-bucket",
|
||||||
|
AccessKeyID: "AKID",
|
||||||
|
SecretAccessKey: "my-secret",
|
||||||
|
Prefix: "backups",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 直接读取数据库中存储的值,应该是加密后的
|
||||||
|
raw, _ := repo.GetValue(context.Background(), settingKeyBackupS3Config)
|
||||||
|
var stored BackupS3Config
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(raw), &stored))
|
||||||
|
require.Equal(t, "ENC:my-secret", stored.SecretAccessKey)
|
||||||
|
|
||||||
|
// 通过 GetS3Config 获取应该脱敏
|
||||||
|
cfg, err := svc.GetS3Config(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, cfg.SecretAccessKey)
|
||||||
|
require.Equal(t, "my-bucket", cfg.Bucket)
|
||||||
|
|
||||||
|
// loadS3Config 内部应解密
|
||||||
|
internal, err := svc.loadS3Config(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "my-secret", internal.SecretAccessKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_S3ConfigKeepExistingSecret(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||||
|
|
||||||
|
// 先保存一个有 secret 的配置
|
||||||
|
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||||
|
Bucket: "my-bucket",
|
||||||
|
AccessKeyID: "AKID",
|
||||||
|
SecretAccessKey: "original-secret",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 再更新时不提供 secret,应保留原值
|
||||||
|
_, err = svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||||
|
Bucket: "my-bucket",
|
||||||
|
AccessKeyID: "AKID-NEW",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
internal, err := svc.loadS3Config(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "original-secret", internal.SecretAccessKey)
|
||||||
|
require.Equal(t, "AKID-NEW", internal.AccessKeyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_SaveRecordConcurrency(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
n := 20
|
||||||
|
wg.Add(n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
record := &BackupRecord{
|
||||||
|
ID: fmt.Sprintf("rec-%d", idx),
|
||||||
|
Status: "completed",
|
||||||
|
StartedAt: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
_ = svc.saveRecord(context.Background(), record)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
records, err := svc.loadRecords(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, records, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_LoadRecords_Empty(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||||
|
|
||||||
|
records, err := svc.loadRecords(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, records) // 无数据时返回 nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_LoadRecords_Corrupted(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
_ = repo.Set(context.Background(), settingKeyBackupRecords, "not valid json{{{")
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||||
|
|
||||||
|
records, err := svc.loadRecords(context.Background())
|
||||||
|
require.Error(t, err) // 损坏数据应返回错误
|
||||||
|
require.Nil(t, records)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_CreateBackup_Streaming(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
|
||||||
|
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
|
||||||
|
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||||
|
store := newMockObjectStore()
|
||||||
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
|
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "completed", record.Status)
|
||||||
|
require.Greater(t, record.SizeBytes, int64(0))
|
||||||
|
require.NotEmpty(t, record.S3Key)
|
||||||
|
|
||||||
|
// 验证 S3 上确实有文件
|
||||||
|
store.mu.Lock()
|
||||||
|
require.Len(t, store.objects, 1)
|
||||||
|
store.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_CreateBackup_DumpFailure(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
|
||||||
|
dumper := &mockDumper{dumpErr: fmt.Errorf("pg_dump failed")}
|
||||||
|
store := newMockObjectStore()
|
||||||
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
|
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, "failed", record.Status)
|
||||||
|
require.Contains(t, record.ErrorMsg, "pg_dump")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_CreateBackup_NoS3Config(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||||
|
|
||||||
|
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||||
|
require.ErrorIs(t, err, ErrBackupS3NotConfigured)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
|
||||||
|
// 使用一个慢速 dumper 来模拟正在进行的备份
|
||||||
|
dumper := &mockDumper{dumpData: []byte("data")}
|
||||||
|
store := newMockObjectStore()
|
||||||
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
|
// 手动设置 backingUp 标志
|
||||||
|
svc.mu.Lock()
|
||||||
|
svc.backingUp = true
|
||||||
|
svc.mu.Unlock()
|
||||||
|
|
||||||
|
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||||
|
require.ErrorIs(t, err, ErrBackupInProgress)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_RestoreBackup_Streaming(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
|
||||||
|
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
|
||||||
|
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||||
|
store := newMockObjectStore()
|
||||||
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
|
// 先创建一个备份
|
||||||
|
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 恢复
|
||||||
|
err = svc.RestoreBackup(context.Background(), record.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 验证 psql 收到的数据是否与原始 dump 内容一致
|
||||||
|
require.Equal(t, dumpContent, string(dumper.restored))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_RestoreBackup_NotCompleted(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||||
|
|
||||||
|
// 手动插入一条 failed 记录
|
||||||
|
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||||
|
ID: "fail-1",
|
||||||
|
Status: "failed",
|
||||||
|
})
|
||||||
|
|
||||||
|
err := svc.RestoreBackup(context.Background(), "fail-1")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_DeleteBackup(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
|
||||||
|
dumpContent := "data"
|
||||||
|
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||||
|
store := newMockObjectStore()
|
||||||
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
|
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// S3 中应有文件
|
||||||
|
store.mu.Lock()
|
||||||
|
require.Len(t, store.objects, 1)
|
||||||
|
store.mu.Unlock()
|
||||||
|
|
||||||
|
// 删除
|
||||||
|
err = svc.DeleteBackup(context.Background(), record.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// S3 中文件应被删除
|
||||||
|
store.mu.Lock()
|
||||||
|
require.Len(t, store.objects, 0)
|
||||||
|
store.mu.Unlock()
|
||||||
|
|
||||||
|
// 记录应不存在
|
||||||
|
_, err = svc.GetBackupRecord(context.Background(), record.ID)
|
||||||
|
require.ErrorIs(t, err, ErrBackupNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_GetDownloadURL(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
seedS3Config(t, repo)
|
||||||
|
|
||||||
|
dumper := &mockDumper{dumpData: []byte("data")}
|
||||||
|
store := newMockObjectStore()
|
||||||
|
svc := newTestBackupService(repo, dumper, store)
|
||||||
|
|
||||||
|
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
url, err := svc.GetBackupDownloadURL(context.Background(), record.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, url, "https://presigned.example.com/")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_ListBackups_Sorted(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||||
|
ID: fmt.Sprintf("rec-%d", i),
|
||||||
|
Status: "completed",
|
||||||
|
StartedAt: now.Add(time.Duration(i) * time.Hour).Format(time.RFC3339),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
records, err := svc.ListBackups(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, records, 3)
|
||||||
|
// 最新在前
|
||||||
|
require.Equal(t, "rec-2", records[0].ID)
|
||||||
|
require.Equal(t, "rec-0", records[2].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_TestS3Connection(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
store := newMockObjectStore()
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{}, store)
|
||||||
|
|
||||||
|
err := svc.TestS3Connection(context.Background(), BackupS3Config{
|
||||||
|
Bucket: "test",
|
||||||
|
AccessKeyID: "ak",
|
||||||
|
SecretAccessKey: "sk",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_TestS3Connection_Incomplete(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||||
|
|
||||||
|
err := svc.TestS3Connection(context.Background(), BackupS3Config{
|
||||||
|
Bucket: "test",
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "incomplete")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_Schedule_CronValidation(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||||
|
svc.cronSched = nil // 未初始化 cron
|
||||||
|
|
||||||
|
// 启用但 cron 为空
|
||||||
|
_, err := svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
|
||||||
|
Enabled: true,
|
||||||
|
CronExpr: "",
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// 无效的 cron 表达式
|
||||||
|
_, err = svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
|
||||||
|
Enabled: true,
|
||||||
|
CronExpr: "invalid",
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupService_LoadS3Config_Corrupted(t *testing.T) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
_ = repo.Set(context.Background(), settingKeyBackupS3Config, "not json!!!!")
|
||||||
|
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||||
|
|
||||||
|
cfg, err := svc.loadS3Config(context.Background())
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, cfg)
|
||||||
|
}
|
||||||
607
backend/internal/service/bedrock_request.go
Normal file
607
backend/internal/service/bedrock_request.go
Normal file
@@ -0,0 +1,607 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultBedrockRegion = "us-east-1"
|
||||||
|
|
||||||
|
var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."}
|
||||||
|
|
||||||
|
// BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀
|
||||||
|
// 参考: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
|
||||||
|
func BedrockCrossRegionPrefix(region string) string {
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(region, "us-gov"):
|
||||||
|
return "us-gov" // GovCloud 使用独立的 us-gov 前缀
|
||||||
|
case strings.HasPrefix(region, "us-"):
|
||||||
|
return "us"
|
||||||
|
case strings.HasPrefix(region, "eu-"):
|
||||||
|
return "eu"
|
||||||
|
case region == "ap-northeast-1":
|
||||||
|
return "jp" // 日本区域使用独立的 jp 前缀(AWS 官方定义)
|
||||||
|
case region == "ap-southeast-2":
|
||||||
|
return "au" // 澳大利亚区域使用独立的 au 前缀(AWS 官方定义)
|
||||||
|
case strings.HasPrefix(region, "ap-"):
|
||||||
|
return "apac" // 其余亚太区域使用通用 apac 前缀
|
||||||
|
case strings.HasPrefix(region, "ca-"):
|
||||||
|
return "us" // 加拿大区域使用 us 前缀的跨区域推理
|
||||||
|
case strings.HasPrefix(region, "sa-"):
|
||||||
|
return "us" // 南美区域使用 us 前缀的跨区域推理
|
||||||
|
default:
|
||||||
|
return "us"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdjustBedrockModelRegionPrefix 将模型 ID 的区域前缀替换为与当前 AWS Region 匹配的前缀
|
||||||
|
// 例如 region=eu-west-1 时,"us.anthropic.claude-opus-4-6-v1" → "eu.anthropic.claude-opus-4-6-v1"
|
||||||
|
// 特殊值 region="global" 强制使用 global. 前缀
|
||||||
|
func AdjustBedrockModelRegionPrefix(modelID, region string) string {
|
||||||
|
var targetPrefix string
|
||||||
|
if region == "global" {
|
||||||
|
targetPrefix = "global"
|
||||||
|
} else {
|
||||||
|
targetPrefix = BedrockCrossRegionPrefix(region)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range bedrockCrossRegionPrefixes {
|
||||||
|
if strings.HasPrefix(modelID, p) {
|
||||||
|
if p == targetPrefix+"." {
|
||||||
|
return modelID // 前缀已匹配,无需替换
|
||||||
|
}
|
||||||
|
return targetPrefix + "." + modelID[len(p):]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 模型 ID 没有已知区域前缀(如 "anthropic.claude-..."),不做修改
|
||||||
|
return modelID
|
||||||
|
}
|
||||||
|
|
||||||
|
func bedrockRuntimeRegion(account *Account) string {
|
||||||
|
if account == nil {
|
||||||
|
return defaultBedrockRegion
|
||||||
|
}
|
||||||
|
if region := account.GetCredential("aws_region"); region != "" {
|
||||||
|
return region
|
||||||
|
}
|
||||||
|
return defaultBedrockRegion
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldForceBedrockGlobal(account *Account) bool {
|
||||||
|
return account != nil && account.GetCredential("aws_force_global") == "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRegionalBedrockModelID(modelID string) bool {
|
||||||
|
for _, prefix := range bedrockCrossRegionPrefixes {
|
||||||
|
if strings.HasPrefix(modelID, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLikelyBedrockModelID(modelID string) bool {
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(modelID))
|
||||||
|
if lower == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(lower, "arn:") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, prefix := range []string{
|
||||||
|
"anthropic.",
|
||||||
|
"amazon.",
|
||||||
|
"meta.",
|
||||||
|
"mistral.",
|
||||||
|
"cohere.",
|
||||||
|
"ai21.",
|
||||||
|
"deepseek.",
|
||||||
|
"stability.",
|
||||||
|
"writer.",
|
||||||
|
"nova.",
|
||||||
|
} {
|
||||||
|
if strings.HasPrefix(lower, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return isRegionalBedrockModelID(lower)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeBedrockModelID(modelID string) (normalized string, shouldAdjustRegion bool, ok bool) {
|
||||||
|
modelID = strings.TrimSpace(modelID)
|
||||||
|
if modelID == "" {
|
||||||
|
return "", false, false
|
||||||
|
}
|
||||||
|
if mapped, exists := domain.DefaultBedrockModelMapping[modelID]; exists {
|
||||||
|
return mapped, true, true
|
||||||
|
}
|
||||||
|
if isRegionalBedrockModelID(modelID) {
|
||||||
|
return modelID, true, true
|
||||||
|
}
|
||||||
|
if isLikelyBedrockModelID(modelID) {
|
||||||
|
return modelID, false, true
|
||||||
|
}
|
||||||
|
return "", false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveBedrockModelID resolves a requested Claude model into a Bedrock model ID.
|
||||||
|
// It applies account model_mapping first, then default Bedrock aliases, and finally
|
||||||
|
// adjusts Anthropic cross-region prefixes to match the account region.
|
||||||
|
func ResolveBedrockModelID(account *Account, requestedModel string) (string, bool) {
|
||||||
|
if account == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
mappedModel := account.GetMappedModel(requestedModel)
|
||||||
|
modelID, shouldAdjustRegion, ok := normalizeBedrockModelID(mappedModel)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
if shouldAdjustRegion {
|
||||||
|
targetRegion := bedrockRuntimeRegion(account)
|
||||||
|
if shouldForceBedrockGlobal(account) {
|
||||||
|
targetRegion = "global"
|
||||||
|
}
|
||||||
|
modelID = AdjustBedrockModelRegionPrefix(modelID, targetRegion)
|
||||||
|
}
|
||||||
|
return modelID, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildBedrockURL 构建 Bedrock InvokeModel 的 URL
|
||||||
|
// stream=true 时使用 invoke-with-response-stream 端点
|
||||||
|
// modelID 中的特殊字符会被 URL 编码(与 litellm 的 urllib.parse.quote(safe="") 对齐)
|
||||||
|
func BuildBedrockURL(region, modelID string, stream bool) string {
|
||||||
|
if region == "" {
|
||||||
|
region = defaultBedrockRegion
|
||||||
|
}
|
||||||
|
encodedModelID := url.PathEscape(modelID)
|
||||||
|
// url.PathEscape 不编码冒号(RFC 允许 path 中出现 ":"),
|
||||||
|
// 但 AWS Bedrock 期望模型 ID 中的冒号被编码为 %3A
|
||||||
|
encodedModelID = strings.ReplaceAll(encodedModelID, ":", "%3A")
|
||||||
|
if stream {
|
||||||
|
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", region, encodedModelID)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke", region, encodedModelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrepareBedrockRequestBody 处理请求体以适配 Bedrock API
|
||||||
|
// 1. 注入 anthropic_version
|
||||||
|
// 2. 注入 anthropic_beta(从客户端 anthropic-beta 头解析)
|
||||||
|
// 3. 移除 Bedrock 不支持的字段(model, stream, output_format, output_config)
|
||||||
|
// 4. 移除工具定义中的 custom 字段(Claude Code 会发送 custom: {defer_loading: true})
|
||||||
|
// 5. 清理 cache_control 中 Bedrock 不支持的字段(scope, ttl)
|
||||||
|
func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) {
|
||||||
|
betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID)
|
||||||
|
return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens.
|
||||||
|
func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// 注入 anthropic_version(Bedrock 要求)
|
||||||
|
body, err = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("inject anthropic_version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 注入 anthropic_beta(Bedrock Invoke 通过请求体传递 beta 头,而非 HTTP 头)
|
||||||
|
// 1. 从客户端 anthropic-beta header 解析
|
||||||
|
// 2. 根据请求体内容自动补齐必要的 beta token
|
||||||
|
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() + _get_tool_search_beta_header_for_bedrock()
|
||||||
|
if len(betaTokens) > 0 {
|
||||||
|
body, err = sjson.SetBytes(body, "anthropic_beta", betaTokens)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("inject anthropic_beta: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除 model 字段(Bedrock 通过 URL 指定模型)
|
||||||
|
body, err = sjson.DeleteBytes(body, "model")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("remove model field: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除 stream 字段(Bedrock 通过不同端点控制流式,不接受请求体中的 stream 字段)
|
||||||
|
body, err = sjson.DeleteBytes(body, "stream")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("remove stream field: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换 output_format(Bedrock Invoke 不支持此字段,但可将 schema 内联到最后一条 user message)
|
||||||
|
// 参考 litellm: _convert_output_format_to_inline_schema()
|
||||||
|
body = convertOutputFormatToInlineSchema(body)
|
||||||
|
|
||||||
|
// 移除 output_config 字段(Bedrock Invoke 不支持)
|
||||||
|
body, err = sjson.DeleteBytes(body, "output_config")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("remove output_config field: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除工具定义中的 custom 字段
|
||||||
|
// Claude Code (v2.1.69+) 在 tool 定义中发送 custom: {defer_loading: true},
|
||||||
|
// Anthropic API 接受但 Bedrock 会拒绝并报 "Extra inputs are not permitted"
|
||||||
|
body = removeCustomFieldFromTools(body)
|
||||||
|
|
||||||
|
// 清理 cache_control 中 Bedrock 不支持的字段
|
||||||
|
body = sanitizeBedrockCacheControl(body, modelID)
|
||||||
|
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveBedrockBetaTokens computes the final Bedrock beta token list before policy filtering.
|
||||||
|
func ResolveBedrockBetaTokens(betaHeader string, body []byte, modelID string) []string {
|
||||||
|
betaTokens := parseAnthropicBetaHeader(betaHeader)
|
||||||
|
betaTokens = autoInjectBedrockBetaTokens(betaTokens, body, modelID)
|
||||||
|
return filterBedrockBetaTokens(betaTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertOutputFormatToInlineSchema 将 output_format 中的 JSON schema 内联到最后一条 user message
|
||||||
|
// Bedrock Invoke 不支持 output_format 参数,litellm 的做法是将 schema 追加到用户消息中
|
||||||
|
// 参考: litellm AmazonAnthropicClaudeMessagesConfig._convert_output_format_to_inline_schema()
|
||||||
|
func convertOutputFormatToInlineSchema(body []byte) []byte {
|
||||||
|
outputFormat := gjson.GetBytes(body, "output_format")
|
||||||
|
if !outputFormat.Exists() || !outputFormat.IsObject() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// 先从请求体中移除 output_format
|
||||||
|
body, _ = sjson.DeleteBytes(body, "output_format")
|
||||||
|
|
||||||
|
schema := outputFormat.Get("schema")
|
||||||
|
if !schema.Exists() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// 找到最后一条 user message
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
msgArr := messages.Array()
|
||||||
|
lastUserIdx := -1
|
||||||
|
for i := len(msgArr) - 1; i >= 0; i-- {
|
||||||
|
if msgArr[i].Get("role").String() == "user" {
|
||||||
|
lastUserIdx = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lastUserIdx < 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将 schema 序列化为 JSON 文本追加到该 message 的 content 数组
|
||||||
|
schemaJSON, err := json.Marshal(json.RawMessage(schema.Raw))
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
content := msgArr[lastUserIdx].Get("content")
|
||||||
|
basePath := fmt.Sprintf("messages.%d.content", lastUserIdx)
|
||||||
|
|
||||||
|
if content.IsArray() {
|
||||||
|
// 追加一个 text block 到 content 数组末尾
|
||||||
|
idx := len(content.Array())
|
||||||
|
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.type", basePath, idx), "text")
|
||||||
|
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.text", basePath, idx), string(schemaJSON))
|
||||||
|
} else if content.Type == gjson.String {
|
||||||
|
// content 是纯字符串,转换为数组格式
|
||||||
|
originalText := content.String()
|
||||||
|
body, _ = sjson.SetBytes(body, basePath, []map[string]string{
|
||||||
|
{"type": "text", "text": originalText},
|
||||||
|
{"type": "text", "text": string(schemaJSON)},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeCustomFieldFromTools 移除 tools 数组中每个工具定义的 custom 字段
|
||||||
|
func removeCustomFieldFromTools(body []byte) []byte {
|
||||||
|
tools := gjson.GetBytes(body, "tools")
|
||||||
|
if !tools.Exists() || !tools.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
for i := range tools.Array() {
|
||||||
|
body, err = sjson.DeleteBytes(body, fmt.Sprintf("tools.%d.custom", i))
|
||||||
|
if err != nil {
|
||||||
|
// 删除失败不影响整体流程,跳过
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// claudeVersionRe 匹配 Claude 模型 ID 中的版本号部分
|
||||||
|
// 支持 claude-{tier}-{major}-{minor} 和 claude-{tier}-{major}.{minor} 格式
|
||||||
|
var claudeVersionRe = regexp.MustCompile(`claude-(?:haiku|sonnet|opus)-(\d+)[-.](\d+)`)
|
||||||
|
|
||||||
|
// isBedrockClaude45OrNewer 判断 Bedrock 模型 ID 是否为 Claude 4.5 或更新版本
|
||||||
|
// Claude 4.5+ 支持 cache_control 中的 ttl 字段("5m" 和 "1h")
|
||||||
|
func isBedrockClaude45OrNewer(modelID string) bool {
|
||||||
|
lower := strings.ToLower(modelID)
|
||||||
|
matches := claudeVersionRe.FindStringSubmatch(lower)
|
||||||
|
if matches == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
major, _ := strconv.Atoi(matches[1])
|
||||||
|
minor, _ := strconv.Atoi(matches[2])
|
||||||
|
return major > 4 || (major == 4 && minor >= 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeBedrockCacheControl 清理 system 和 messages 中 cache_control 里
|
||||||
|
// Bedrock 不支持的字段:
|
||||||
|
// - scope:Bedrock 不支持(如 "global" 跨请求缓存)
|
||||||
|
// - ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",旧模型需要移除
|
||||||
|
func sanitizeBedrockCacheControl(body []byte, modelID string) []byte {
|
||||||
|
isClaude45 := isBedrockClaude45OrNewer(modelID)
|
||||||
|
|
||||||
|
// 清理 system 数组中的 cache_control
|
||||||
|
systemArr := gjson.GetBytes(body, "system")
|
||||||
|
if systemArr.Exists() && systemArr.IsArray() {
|
||||||
|
for i, item := range systemArr.Array() {
|
||||||
|
if !item.IsObject() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cc := item.Get("cache_control")
|
||||||
|
if !cc.Exists() || !cc.IsObject() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("system.%d.cache_control", i), cc, isClaude45)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理 messages 中的 cache_control
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
for mi, msg := range messages.Array() {
|
||||||
|
if !msg.IsObject() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content := msg.Get("content")
|
||||||
|
if !content.Exists() || !content.IsArray() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for ci, block := range content.Array() {
|
||||||
|
if !block.IsObject() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cc := block.Get("cache_control")
|
||||||
|
if !cc.Exists() || !cc.IsObject() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("messages.%d.content.%d.cache_control", mi, ci), cc, isClaude45)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteCacheControlUnsupportedFields 删除给定 cache_control 路径下 Bedrock 不支持的字段
|
||||||
|
func deleteCacheControlUnsupportedFields(body []byte, basePath string, cc gjson.Result, isClaude45 bool) []byte {
|
||||||
|
// Bedrock 不支持 scope(如 "global")
|
||||||
|
if cc.Get("scope").Exists() {
|
||||||
|
body, _ = sjson.DeleteBytes(body, basePath+".scope")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",其余情况移除
|
||||||
|
ttl := cc.Get("ttl")
|
||||||
|
if ttl.Exists() {
|
||||||
|
shouldRemove := true
|
||||||
|
if isClaude45 {
|
||||||
|
v := ttl.String()
|
||||||
|
if v == "5m" || v == "1h" {
|
||||||
|
shouldRemove = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if shouldRemove {
|
||||||
|
body, _ = sjson.DeleteBytes(body, basePath+".ttl")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAnthropicBetaHeader 解析 anthropic-beta 头的逗号分隔字符串为 token 列表
|
||||||
|
func parseAnthropicBetaHeader(header string) []string {
|
||||||
|
header = strings.TrimSpace(header)
|
||||||
|
if header == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(header, "[") && strings.HasSuffix(header, "]") {
|
||||||
|
var parsed []any
|
||||||
|
if err := json.Unmarshal([]byte(header), &parsed); err == nil {
|
||||||
|
tokens := make([]string, 0, len(parsed))
|
||||||
|
for _, item := range parsed {
|
||||||
|
token := strings.TrimSpace(fmt.Sprint(item))
|
||||||
|
if token != "" {
|
||||||
|
tokens = append(tokens, token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var tokens []string
|
||||||
|
for _, part := range strings.Split(header, ",") {
|
||||||
|
t := strings.TrimSpace(part)
|
||||||
|
if t != "" {
|
||||||
|
tokens = append(tokens, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单
|
||||||
|
// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json)
|
||||||
|
// 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单
|
||||||
|
var bedrockSupportedBetaTokens = map[string]bool{
|
||||||
|
"computer-use-2025-01-24": true,
|
||||||
|
"computer-use-2025-11-24": true,
|
||||||
|
"context-1m-2025-08-07": true,
|
||||||
|
"context-management-2025-06-27": true,
|
||||||
|
"compact-2026-01-12": true,
|
||||||
|
"interleaved-thinking-2025-05-14": true,
|
||||||
|
"tool-search-tool-2025-10-19": true,
|
||||||
|
"tool-examples-2025-10-29": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则
|
||||||
|
// Anthropic 直接 API 使用通用头,Bedrock Invoke 需要特定的替代头
|
||||||
|
var bedrockBetaTokenTransforms = map[string]string{
|
||||||
|
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
|
||||||
|
}
|
||||||
|
|
||||||
|
// autoInjectBedrockBetaTokens 根据请求体内容自动补齐必要的 beta token
|
||||||
|
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() 和
|
||||||
|
// AmazonAnthropicClaudeMessagesConfig._get_tool_search_beta_header_for_bedrock()
|
||||||
|
//
|
||||||
|
// 客户端(特别是非 Claude Code 客户端)可能只在 body 中启用了功能而不在 header 中带对应 beta token,
|
||||||
|
// 这里通过检测请求体特征自动补齐,确保 Bedrock Invoke 不会因缺少必要 beta 头而 400。
|
||||||
|
func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) []string {
|
||||||
|
seen := make(map[string]bool, len(tokens))
|
||||||
|
for _, t := range tokens {
|
||||||
|
seen[t] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
inject := func(token string) {
|
||||||
|
if !seen[token] {
|
||||||
|
tokens = append(tokens, token)
|
||||||
|
seen[token] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检测 thinking / interleaved thinking
|
||||||
|
// 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta
|
||||||
|
if gjson.GetBytes(body, "thinking").Exists() {
|
||||||
|
inject("interleaved-thinking-2025-05-14")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检测 computer_use 工具
|
||||||
|
// tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta
|
||||||
|
tools := gjson.GetBytes(body, "tools")
|
||||||
|
if tools.Exists() && tools.IsArray() {
|
||||||
|
toolSearchUsed := false
|
||||||
|
programmaticToolCallingUsed := false
|
||||||
|
inputExamplesUsed := false
|
||||||
|
for _, tool := range tools.Array() {
|
||||||
|
toolType := tool.Get("type").String()
|
||||||
|
if strings.HasPrefix(toolType, "computer_20") {
|
||||||
|
inject("computer-use-2025-11-24")
|
||||||
|
}
|
||||||
|
if isBedrockToolSearchType(toolType) {
|
||||||
|
toolSearchUsed = true
|
||||||
|
}
|
||||||
|
if hasCodeExecutionAllowedCallers(tool) {
|
||||||
|
programmaticToolCallingUsed = true
|
||||||
|
}
|
||||||
|
if hasInputExamples(tool) {
|
||||||
|
inputExamplesUsed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if programmaticToolCallingUsed || inputExamplesUsed {
|
||||||
|
// programmatic tool calling 和 input examples 需要 advanced-tool-use,
|
||||||
|
// 后续 filterBedrockBetaTokens 会将其转换为 Bedrock 特定的 tool-search-tool
|
||||||
|
inject("advanced-tool-use-2025-11-20")
|
||||||
|
}
|
||||||
|
if toolSearchUsed && bedrockModelSupportsToolSearch(modelID) {
|
||||||
|
// 纯 tool search(无 programmatic/inputExamples)时直接注入 Bedrock 特定头,
|
||||||
|
// 跳过 advanced-tool-use → tool-search-tool 的转换步骤(与 litellm 对齐)
|
||||||
|
if !programmaticToolCallingUsed && !inputExamplesUsed {
|
||||||
|
inject("tool-search-tool-2025-10-19")
|
||||||
|
} else {
|
||||||
|
inject("advanced-tool-use-2025-11-20")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBedrockToolSearchType(toolType string) bool {
|
||||||
|
return toolType == "tool_search_tool_regex_20251119" || toolType == "tool_search_tool_bm25_20251119"
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasCodeExecutionAllowedCallers(tool gjson.Result) bool {
|
||||||
|
allowedCallers := tool.Get("allowed_callers")
|
||||||
|
if containsStringInJSONArray(allowedCallers, "code_execution_20250825") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return containsStringInJSONArray(tool.Get("function.allowed_callers"), "code_execution_20250825")
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasInputExamples(tool gjson.Result) bool {
|
||||||
|
if arr := tool.Get("input_examples"); arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
arr := tool.Get("function.input_examples")
|
||||||
|
return arr.Exists() && arr.IsArray() && len(arr.Array()) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsStringInJSONArray(result gjson.Result, target string) bool {
|
||||||
|
if !result.Exists() || !result.IsArray() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, item := range result.Array() {
|
||||||
|
if item.String() == target {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// bedrockModelSupportsToolSearch 判断 Bedrock 模型是否支持 tool search
|
||||||
|
// 目前仅 Claude Opus/Sonnet 4.5+ 支持,Haiku 不支持
|
||||||
|
func bedrockModelSupportsToolSearch(modelID string) bool {
|
||||||
|
lower := strings.ToLower(modelID)
|
||||||
|
matches := claudeVersionRe.FindStringSubmatch(lower)
|
||||||
|
if matches == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Haiku 不支持 tool search
|
||||||
|
if strings.Contains(lower, "haiku") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
major, _ := strconv.Atoi(matches[1])
|
||||||
|
minor, _ := strconv.Atoi(matches[2])
|
||||||
|
return major > 4 || (major == 4 && minor >= 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterBedrockBetaTokens 过滤并转换 beta token 列表,仅保留 Bedrock Invoke 支持的 token
|
||||||
|
// 1. 应用转换规则(如 advanced-tool-use → tool-search-tool)
|
||||||
|
// 2. 过滤掉 Bedrock 不支持的 token(如 output-128k, files-api, structured-outputs 等)
|
||||||
|
// 3. 自动关联 tool-examples(当 tool-search-tool 存在时)
|
||||||
|
func filterBedrockBetaTokens(tokens []string) []string {
|
||||||
|
seen := make(map[string]bool, len(tokens))
|
||||||
|
var result []string
|
||||||
|
|
||||||
|
for _, t := range tokens {
|
||||||
|
// 应用转换规则
|
||||||
|
if replacement, ok := bedrockBetaTokenTransforms[t]; ok {
|
||||||
|
t = replacement
|
||||||
|
}
|
||||||
|
// 只保留白名单中的 token,且去重
|
||||||
|
if bedrockSupportedBetaTokens[t] && !seen[t] {
|
||||||
|
result = append(result, t)
|
||||||
|
seen[t] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 自动关联: tool-search-tool 存在时,确保 tool-examples 也存在
|
||||||
|
if seen["tool-search-tool-2025-10-19"] && !seen["tool-examples-2025-10-29"] {
|
||||||
|
result = append(result, "tool-examples-2025-10-29")
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
659
backend/internal/service/bedrock_request_test.go
Normal file
659
backend/internal/service/bedrock_request_test.go
Normal file
@@ -0,0 +1,659 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPrepareBedrockRequestBody_BasicFields(t *testing.T) {
|
||||||
|
input := `{"model":"claude-opus-4-6","stream":true,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// anthropic_version 应被注入
|
||||||
|
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
|
||||||
|
// model 和 stream 应被移除
|
||||||
|
assert.False(t, gjson.GetBytes(result, "model").Exists())
|
||||||
|
assert.False(t, gjson.GetBytes(result, "stream").Exists())
|
||||||
|
// max_tokens 应保留
|
||||||
|
assert.Equal(t, int64(1024), gjson.GetBytes(result, "max_tokens").Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareBedrockRequestBody_OutputFormatInlineSchema(t *testing.T) {
|
||||||
|
t.Run("schema inlined into last user message array content", func(t *testing.T) {
|
||||||
|
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}},"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||||
|
// schema 应内联到最后一条 user message 的 content 数组末尾
|
||||||
|
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
|
||||||
|
require.Len(t, contentArr, 2)
|
||||||
|
assert.Equal(t, "text", contentArr[1].Get("type").String())
|
||||||
|
assert.Contains(t, contentArr[1].Get("text").String(), `"name":"string"`)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("schema inlined into string content", func(t *testing.T) {
|
||||||
|
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"result":"number"}},"messages":[{"role":"user","content":"compute this"}]}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||||
|
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
|
||||||
|
require.Len(t, contentArr, 2)
|
||||||
|
assert.Equal(t, "compute this", contentArr[0].Get("text").String())
|
||||||
|
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"number"`)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no schema field just removes output_format", func(t *testing.T) {
|
||||||
|
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json"},"messages":[{"role":"user","content":"hi"}]}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no messages just removes output_format", func(t *testing.T) {
|
||||||
|
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}}}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareBedrockRequestBody_RemoveOutputConfig(t *testing.T) {
|
||||||
|
input := `{"model":"claude-sonnet-4-5","output_config":{"max_tokens":100},"messages":[]}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoveCustomFieldFromTools(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"tools": [
|
||||||
|
{"name":"tool1","custom":{"defer_loading":true},"description":"desc1"},
|
||||||
|
{"name":"tool2","description":"desc2"},
|
||||||
|
{"name":"tool3","custom":{"defer_loading":true,"other":123},"description":"desc3"}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
result := removeCustomFieldFromTools([]byte(input))
|
||||||
|
|
||||||
|
tools := gjson.GetBytes(result, "tools").Array()
|
||||||
|
require.Len(t, tools, 3)
|
||||||
|
// custom 应被移除
|
||||||
|
assert.False(t, tools[0].Get("custom").Exists())
|
||||||
|
// name/description 应保留
|
||||||
|
assert.Equal(t, "tool1", tools[0].Get("name").String())
|
||||||
|
assert.Equal(t, "desc1", tools[0].Get("description").String())
|
||||||
|
// 没有 custom 的工具不受影响
|
||||||
|
assert.Equal(t, "tool2", tools[1].Get("name").String())
|
||||||
|
// 第三个工具的 custom 也应被移除
|
||||||
|
assert.False(t, tools[2].Get("custom").Exists())
|
||||||
|
assert.Equal(t, "tool3", tools[2].Get("name").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoveCustomFieldFromTools_NoTools(t *testing.T) {
|
||||||
|
input := `{"messages":[{"role":"user","content":"hi"}]}`
|
||||||
|
result := removeCustomFieldFromTools([]byte(input))
|
||||||
|
// 无 tools 时不改变原始数据
|
||||||
|
assert.JSONEq(t, input, string(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeBedrockCacheControl_RemoveScope(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","scope":"global"}}],
|
||||||
|
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"global"}}]}]
|
||||||
|
}`
|
||||||
|
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
|
||||||
|
// scope 应被移除
|
||||||
|
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
|
||||||
|
assert.False(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.scope").Exists())
|
||||||
|
// type 应保留
|
||||||
|
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
|
||||||
|
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "messages.0.content.0.cache_control.type").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeBedrockCacheControl_TTL_OldModel(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
|
||||||
|
}`
|
||||||
|
// 旧模型(Claude 3.5)不支持 ttl
|
||||||
|
result := sanitizeBedrockCacheControl([]byte(input), "anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||||
|
|
||||||
|
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
|
||||||
|
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeBedrockCacheControl_TTL_Claude45_Supported(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
|
||||||
|
}`
|
||||||
|
// Claude 4.5+ 支持 "5m" 和 "1h"
|
||||||
|
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
|
||||||
|
|
||||||
|
assert.True(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
|
||||||
|
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeBedrockCacheControl_TTL_Claude45_UnsupportedValue(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"10m"}}]
|
||||||
|
}`
|
||||||
|
// Claude 4.5 不支持 "10m"
|
||||||
|
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
|
||||||
|
|
||||||
|
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeBedrockCacheControl_TTL_Claude46(t *testing.T) {
|
||||||
|
input := `{
|
||||||
|
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
|
||||||
|
}`
|
||||||
|
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
|
||||||
|
assert.True(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").Exists())
|
||||||
|
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeBedrockCacheControl_NoCacheControl(t *testing.T) {
|
||||||
|
input := `{"system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`
|
||||||
|
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
// 无 cache_control 时不改变原始数据
|
||||||
|
assert.JSONEq(t, input, string(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsBedrockClaude45OrNewer(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
modelID string
|
||||||
|
expect bool
|
||||||
|
}{
|
||||||
|
{"us.anthropic.claude-opus-4-6-v1", true},
|
||||||
|
{"us.anthropic.claude-sonnet-4-6", true},
|
||||||
|
{"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true},
|
||||||
|
{"us.anthropic.claude-opus-4-5-20251101-v1:0", true},
|
||||||
|
{"us.anthropic.claude-haiku-4-5-20251001-v1:0", true},
|
||||||
|
{"anthropic.claude-3-5-sonnet-20241022-v2:0", false},
|
||||||
|
{"anthropic.claude-3-opus-20240229-v1:0", false},
|
||||||
|
{"anthropic.claude-3-haiku-20240307-v1:0", false},
|
||||||
|
// 未来版本应自动支持
|
||||||
|
{"us.anthropic.claude-sonnet-5-0-v1", true},
|
||||||
|
{"us.anthropic.claude-opus-4-7-v1", true},
|
||||||
|
// 旧版本
|
||||||
|
{"anthropic.claude-opus-4-1-v1", false},
|
||||||
|
{"anthropic.claude-sonnet-4-0-v1", false},
|
||||||
|
// 非 Claude 模型
|
||||||
|
{"amazon.nova-pro-v1", false},
|
||||||
|
{"meta.llama3-70b", false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.modelID, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expect, isBedrockClaude45OrNewer(tt.modelID))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) {
|
||||||
|
// 模拟一个完整的 Claude Code 请求
|
||||||
|
input := `{
|
||||||
|
"model": "claude-opus-4-6",
|
||||||
|
"stream": true,
|
||||||
|
"max_tokens": 16384,
|
||||||
|
"output_format": {"type": "json", "schema": {"result": "string"}},
|
||||||
|
"output_config": {"max_tokens": 100},
|
||||||
|
"system": [{"type": "text", "text": "You are helpful", "cache_control": {"type": "ephemeral", "scope": "global", "ttl": "5m"}}],
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "hello", "cache_control": {"type": "ephemeral", "ttl": "1h"}}]}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{"name": "bash", "description": "Run bash", "custom": {"defer_loading": true}, "input_schema": {"type": "object"}},
|
||||||
|
{"name": "read", "description": "Read file", "input_schema": {"type": "object"}}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12"
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 基本字段
|
||||||
|
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
|
||||||
|
assert.False(t, gjson.GetBytes(result, "model").Exists())
|
||||||
|
assert.False(t, gjson.GetBytes(result, "stream").Exists())
|
||||||
|
assert.Equal(t, int64(16384), gjson.GetBytes(result, "max_tokens").Int())
|
||||||
|
|
||||||
|
// anthropic_beta 应包含所有 beta tokens
|
||||||
|
betaArr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
require.Len(t, betaArr, 3)
|
||||||
|
assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String())
|
||||||
|
assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String())
|
||||||
|
assert.Equal(t, "compact-2026-01-12", betaArr[2].String())
|
||||||
|
|
||||||
|
// output_format 应被移除,schema 内联到最后一条 user message
|
||||||
|
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||||
|
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
|
||||||
|
// content 数组:原始 text block + 内联 schema block
|
||||||
|
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
|
||||||
|
require.Len(t, contentArr, 2)
|
||||||
|
assert.Equal(t, "hello", contentArr[0].Get("text").String())
|
||||||
|
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"string"`)
|
||||||
|
|
||||||
|
// tools 中的 custom 应被移除
|
||||||
|
assert.False(t, gjson.GetBytes(result, "tools.0.custom").Exists())
|
||||||
|
assert.Equal(t, "bash", gjson.GetBytes(result, "tools.0.name").String())
|
||||||
|
assert.Equal(t, "read", gjson.GetBytes(result, "tools.1.name").String())
|
||||||
|
|
||||||
|
// cache_control: scope 应被移除,ttl 在 Claude 4.6 上保留合法值
|
||||||
|
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
|
||||||
|
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
|
||||||
|
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
|
||||||
|
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) {
|
||||||
|
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
|
||||||
|
|
||||||
|
t.Run("empty beta header", func(t *testing.T) {
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("single beta token", func(t *testing.T) {
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14")
|
||||||
|
require.NoError(t, err)
|
||||||
|
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
require.Len(t, arr, 1)
|
||||||
|
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple beta tokens with spaces", func(t *testing.T) {
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ")
|
||||||
|
require.NoError(t, err)
|
||||||
|
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
require.Len(t, arr, 2)
|
||||||
|
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||||
|
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("json array beta header", func(t *testing.T) {
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
require.Len(t, arr, 2)
|
||||||
|
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||||
|
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseAnthropicBetaHeader(t *testing.T) {
|
||||||
|
assert.Nil(t, parseAnthropicBetaHeader(""))
|
||||||
|
assert.Equal(t, []string{"a"}, parseAnthropicBetaHeader("a"))
|
||||||
|
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a,b"))
|
||||||
|
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a , b "))
|
||||||
|
assert.Equal(t, []string{"a", "b", "c"}, parseAnthropicBetaHeader("a,b,c"))
|
||||||
|
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader(`["a","b"]`))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterBedrockBetaTokens(t *testing.T) {
|
||||||
|
t.Run("supported tokens pass through", func(t *testing.T) {
|
||||||
|
tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"}
|
||||||
|
result := filterBedrockBetaTokens(tokens)
|
||||||
|
assert.Equal(t, tokens, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported tokens are filtered out", func(t *testing.T) {
|
||||||
|
tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"}
|
||||||
|
result := filterBedrockBetaTokens(tokens)
|
||||||
|
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) {
|
||||||
|
tokens := []string{"advanced-tool-use-2025-11-20"}
|
||||||
|
result := filterBedrockBetaTokens(tokens)
|
||||||
|
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||||
|
// tool-examples 自动关联
|
||||||
|
assert.Contains(t, result, "tool-examples-2025-10-29")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tool-search-tool auto-associates tool-examples", func(t *testing.T) {
|
||||||
|
tokens := []string{"tool-search-tool-2025-10-19"}
|
||||||
|
result := filterBedrockBetaTokens(tokens)
|
||||||
|
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||||
|
assert.Contains(t, result, "tool-examples-2025-10-29")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no duplication when tool-examples already present", func(t *testing.T) {
|
||||||
|
tokens := []string{"tool-search-tool-2025-10-19", "tool-examples-2025-10-29"}
|
||||||
|
result := filterBedrockBetaTokens(tokens)
|
||||||
|
count := 0
|
||||||
|
for _, t := range result {
|
||||||
|
if t == "tool-examples-2025-10-29" {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, 1, count)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty input returns nil", func(t *testing.T) {
|
||||||
|
result := filterBedrockBetaTokens(nil)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("all unsupported returns nil", func(t *testing.T) {
|
||||||
|
result := filterBedrockBetaTokens([]string{"output-128k-2025-02-19", "effort-2025-11-24"})
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("duplicate tokens are deduplicated", func(t *testing.T) {
|
||||||
|
tokens := []string{"context-1m-2025-08-07", "context-1m-2025-08-07"}
|
||||||
|
result := filterBedrockBetaTokens(tokens)
|
||||||
|
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) {
|
||||||
|
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
|
||||||
|
|
||||||
|
t.Run("unsupported beta tokens are filtered", func(t *testing.T) {
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
|
||||||
|
"interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14")
|
||||||
|
require.NoError(t, err)
|
||||||
|
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
require.Len(t, arr, 1)
|
||||||
|
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) {
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
|
||||||
|
"advanced-tool-use-2025-11-20")
|
||||||
|
require.NoError(t, err)
|
||||||
|
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
require.Len(t, arr, 2)
|
||||||
|
assert.Equal(t, "tool-search-tool-2025-10-19", arr[0].String())
|
||||||
|
assert.Equal(t, "tool-examples-2025-10-29", arr[1].String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBedrockCrossRegionPrefix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
region string
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
// US regions
|
||||||
|
{"us-east-1", "us"},
|
||||||
|
{"us-east-2", "us"},
|
||||||
|
{"us-west-1", "us"},
|
||||||
|
{"us-west-2", "us"},
|
||||||
|
// GovCloud
|
||||||
|
{"us-gov-east-1", "us-gov"},
|
||||||
|
{"us-gov-west-1", "us-gov"},
|
||||||
|
// EU regions
|
||||||
|
{"eu-west-1", "eu"},
|
||||||
|
{"eu-west-2", "eu"},
|
||||||
|
{"eu-west-3", "eu"},
|
||||||
|
{"eu-central-1", "eu"},
|
||||||
|
{"eu-central-2", "eu"},
|
||||||
|
{"eu-north-1", "eu"},
|
||||||
|
{"eu-south-1", "eu"},
|
||||||
|
// APAC regions
|
||||||
|
{"ap-northeast-1", "jp"},
|
||||||
|
{"ap-northeast-2", "apac"},
|
||||||
|
{"ap-southeast-1", "apac"},
|
||||||
|
{"ap-southeast-2", "au"},
|
||||||
|
{"ap-south-1", "apac"},
|
||||||
|
// Canada / South America fallback to us
|
||||||
|
{"ca-central-1", "us"},
|
||||||
|
{"sa-east-1", "us"},
|
||||||
|
// Unknown defaults to us
|
||||||
|
{"me-south-1", "us"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.region, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expect, BedrockCrossRegionPrefix(tt.region))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveBedrockModelID(t *testing.T) {
|
||||||
|
t.Run("default alias resolves and adjusts region", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeBedrock,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"aws_region": "eu-west-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-5")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", modelID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("custom alias mapping reuses default bedrock mapping", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeBedrock,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"aws_region": "ap-southeast-2",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"claude-*": "claude-opus-4-6",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
modelID, ok := ResolveBedrockModelID(account, "claude-opus-4-6-thinking")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("force global rewrites anthropic regional model id", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeBedrock,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"aws_region": "us-east-1",
|
||||||
|
"aws_force_global": "true",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-6")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "global.anthropic.claude-sonnet-4-6", modelID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("direct bedrock model id passes through", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeBedrock,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"aws_region": "us-east-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
modelID, ok := ResolveBedrockModelID(account, "anthropic.claude-haiku-4-5-20251001-v1:0")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "anthropic.claude-haiku-4-5-20251001-v1:0", modelID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported alias returns false", func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeBedrock,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"aws_region": "us-east-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := ResolveBedrockModelID(account, "claude-3-5-sonnet-20241022")
|
||||||
|
assert.False(t, ok)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAutoInjectBedrockBetaTokens(t *testing.T) {
|
||||||
|
t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) {
|
||||||
|
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no duplicate when already present", func(t *testing.T) {
|
||||||
|
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
count := 0
|
||||||
|
for _, t := range result {
|
||||||
|
if t == "interleaved-thinking-2025-05-14" {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, 1, count)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("inject computer-use when computer tool present", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"type":"computer_20250124","name":"computer","display_width_px":1024}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Contains(t, result, "computer-use-2025-11-24")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("inject advanced-tool-use for programmatic tool calling", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("inject advanced-tool-use for input examples", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"name":"bash","input_examples":[{"cmd":"ls"}]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("inject tool-search-tool directly for pure tool search (no programmatic/inputExamples)", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
|
||||||
|
// 纯 tool search 场景直接注入 Bedrock 特定头,不走 advanced-tool-use 转换
|
||||||
|
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||||
|
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("inject advanced-tool-use when tool search combined with programmatic calling", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"},{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
|
||||||
|
// 混合场景使用 advanced-tool-use(后续由 filter 转换为 tool-search-tool)
|
||||||
|
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("do not inject tool-search beta for unsupported models", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||||
|
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
|
||||||
|
assert.NotContains(t, result, "tool-search-tool-2025-10-19")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no injection for regular tools", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"name":"bash","description":"run bash","input_schema":{"type":"object"}}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Empty(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no injection when no features detected", func(t *testing.T) {
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`)
|
||||||
|
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Empty(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves existing tokens", func(t *testing.T) {
|
||||||
|
body := []byte(`{"thinking":{"type":"enabled"},"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
existing := []string{"context-1m-2025-08-07", "compact-2026-01-12"}
|
||||||
|
result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Contains(t, result, "context-1m-2025-08-07")
|
||||||
|
assert.Contains(t, result, "compact-2026-01-12")
|
||||||
|
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveBedrockBetaTokens(t *testing.T) {
|
||||||
|
t.Run("body-only tool features resolve to final bedrock tokens", func(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := ResolveBedrockBetaTokens("", body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||||
|
assert.Contains(t, result, "tool-examples-2025-10-29")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) {
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1")
|
||||||
|
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) {
|
||||||
|
t.Run("thinking in body auto-injects beta without header", func(t *testing.T) {
|
||||||
|
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
found := false
|
||||||
|
for _, v := range arr {
|
||||||
|
if v.String() == "interleaved-thinking-2025-05-14" {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, found, "interleaved-thinking should be auto-injected")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) {
|
||||||
|
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
|
||||||
|
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07")
|
||||||
|
require.NoError(t, err)
|
||||||
|
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||||
|
names := make([]string, len(arr))
|
||||||
|
for i, v := range arr {
|
||||||
|
names[i] = v.String()
|
||||||
|
}
|
||||||
|
assert.Contains(t, names, "context-1m-2025-08-07")
|
||||||
|
assert.Contains(t, names, "interleaved-thinking-2025-05-14")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdjustBedrockModelRegionPrefix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
modelID string
|
||||||
|
region string
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
// US region — no change needed
|
||||||
|
{"us region keeps us prefix", "us.anthropic.claude-opus-4-6-v1", "us-east-1", "us.anthropic.claude-opus-4-6-v1"},
|
||||||
|
// EU region — replace us → eu
|
||||||
|
{"eu region replaces prefix", "us.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
|
||||||
|
{"eu region sonnet", "us.anthropic.claude-sonnet-4-6", "eu-central-1", "eu.anthropic.claude-sonnet-4-6"},
|
||||||
|
// APAC region — jp and au have dedicated prefixes per AWS docs
|
||||||
|
{"jp region (ap-northeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-northeast-1", "jp.anthropic.claude-sonnet-4-5-20250929-v1:0"},
|
||||||
|
{"au region (ap-southeast-2)", "us.anthropic.claude-haiku-4-5-20251001-v1:0", "ap-southeast-2", "au.anthropic.claude-haiku-4-5-20251001-v1:0"},
|
||||||
|
{"apac region (ap-southeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-southeast-1", "apac.anthropic.claude-sonnet-4-5-20250929-v1:0"},
|
||||||
|
// eu → us (user manually set eu prefix, moved to us region)
|
||||||
|
{"eu to us", "eu.anthropic.claude-opus-4-6-v1", "us-west-2", "us.anthropic.claude-opus-4-6-v1"},
|
||||||
|
// global prefix — replace to match region
|
||||||
|
{"global to eu", "global.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
|
||||||
|
// No known prefix — leave unchanged
|
||||||
|
{"no prefix unchanged", "anthropic.claude-3-5-sonnet-20241022-v2:0", "eu-west-1", "anthropic.claude-3-5-sonnet-20241022-v2:0"},
|
||||||
|
// GovCloud — uses independent us-gov prefix
|
||||||
|
{"govcloud from us", "us.anthropic.claude-opus-4-6-v1", "us-gov-east-1", "us-gov.anthropic.claude-opus-4-6-v1"},
|
||||||
|
{"govcloud already correct", "us-gov.anthropic.claude-opus-4-6-v1", "us-gov-west-1", "us-gov.anthropic.claude-opus-4-6-v1"},
|
||||||
|
// Force global (special region value)
|
||||||
|
{"force global from us", "us.anthropic.claude-opus-4-6-v1", "global", "global.anthropic.claude-opus-4-6-v1"},
|
||||||
|
{"force global from eu", "eu.anthropic.claude-sonnet-4-6", "global", "global.anthropic.claude-sonnet-4-6"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expect, AdjustBedrockModelRegionPrefix(tt.modelID, tt.region))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
67
backend/internal/service/bedrock_signer.go
Normal file
67
backend/internal/service/bedrock_signer.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
|
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BedrockSigner 使用 AWS SigV4 对 Bedrock 请求签名
|
||||||
|
type BedrockSigner struct {
|
||||||
|
credentials aws.Credentials
|
||||||
|
region string
|
||||||
|
signer *v4.Signer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBedrockSigner 创建 BedrockSigner
|
||||||
|
func NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region string) *BedrockSigner {
|
||||||
|
return &BedrockSigner{
|
||||||
|
credentials: aws.Credentials{
|
||||||
|
AccessKeyID: accessKeyID,
|
||||||
|
SecretAccessKey: secretAccessKey,
|
||||||
|
SessionToken: sessionToken,
|
||||||
|
},
|
||||||
|
region: region,
|
||||||
|
signer: v4.NewSigner(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBedrockSignerFromAccount 从 Account 凭证创建 BedrockSigner
|
||||||
|
func NewBedrockSignerFromAccount(account *Account) (*BedrockSigner, error) {
|
||||||
|
accessKeyID := account.GetCredential("aws_access_key_id")
|
||||||
|
if accessKeyID == "" {
|
||||||
|
return nil, fmt.Errorf("aws_access_key_id not found in credentials")
|
||||||
|
}
|
||||||
|
secretAccessKey := account.GetCredential("aws_secret_access_key")
|
||||||
|
if secretAccessKey == "" {
|
||||||
|
return nil, fmt.Errorf("aws_secret_access_key not found in credentials")
|
||||||
|
}
|
||||||
|
region := account.GetCredential("aws_region")
|
||||||
|
if region == "" {
|
||||||
|
region = defaultBedrockRegion
|
||||||
|
}
|
||||||
|
sessionToken := account.GetCredential("aws_session_token") // 可选
|
||||||
|
|
||||||
|
return NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignRequest 对 HTTP 请求进行 SigV4 签名
|
||||||
|
// 重要约束:调用此方法前,req 应只包含 AWS 相关的 header(如 Content-Type、Accept)。
|
||||||
|
// 非 AWS header(如 anthropic-beta)会参与签名计算,如果 Bedrock 服务端不识别这些 header,
|
||||||
|
// 签名验证可能失败。litellm 通过 _filter_headers_for_aws_signature 实现头过滤,
|
||||||
|
// 当前实现中 buildUpstreamRequestBedrock 仅设置了 Content-Type 和 Accept,因此是安全的。
|
||||||
|
func (s *BedrockSigner) SignRequest(ctx context.Context, req *http.Request, body []byte) error {
|
||||||
|
payloadHash := sha256Hash(body)
|
||||||
|
return s.signer.SignHTTP(ctx, s.credentials, req, payloadHash, "bedrock", s.region, time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
func sha256Hash(data []byte) string {
|
||||||
|
h := sha256.Sum256(data)
|
||||||
|
return hex.EncodeToString(h[:])
|
||||||
|
}
|
||||||
35
backend/internal/service/bedrock_signer_test.go
Normal file
35
backend/internal/service/bedrock_signer_test.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewBedrockSignerFromAccount_DefaultRegion(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeBedrock,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"aws_access_key_id": "test-akid",
|
||||||
|
"aws_secret_access_key": "test-secret",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
signer, err := NewBedrockSignerFromAccount(account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, signer)
|
||||||
|
assert.Equal(t, defaultBedrockRegion, signer.region)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterBetaTokens(t *testing.T) {
|
||||||
|
tokens := []string{"interleaved-thinking-2025-05-14", "tool-search-tool-2025-10-19"}
|
||||||
|
filterSet := map[string]struct{}{
|
||||||
|
"tool-search-tool-2025-10-19": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, filterBetaTokens(tokens, filterSet))
|
||||||
|
assert.Equal(t, tokens, filterBetaTokens(tokens, nil))
|
||||||
|
assert.Nil(t, filterBetaTokens(nil, filterSet))
|
||||||
|
}
|
||||||
414
backend/internal/service/bedrock_stream.go
Normal file
414
backend/internal/service/bedrock_stream.go
Normal file
@@ -0,0 +1,414 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"hash/crc32"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handleBedrockStreamingResponse 处理 Bedrock InvokeModelWithResponseStream 的 EventStream 响应
|
||||||
|
// Bedrock 返回 AWS EventStream 二进制格式,每个事件的 payload 中 chunk.bytes 是 base64 编码的
|
||||||
|
// Claude SSE 事件 JSON。本方法解码后转换为标准 SSE 格式写入客户端。
|
||||||
|
func (s *GatewayService) handleBedrockStreamingResponse(
|
||||||
|
ctx context.Context,
|
||||||
|
resp *http.Response,
|
||||||
|
c *gin.Context,
|
||||||
|
account *Account,
|
||||||
|
startTime time.Time,
|
||||||
|
model string,
|
||||||
|
) (*streamingResult, error) {
|
||||||
|
w := c.Writer
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("streaming not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
if v := resp.Header.Get("x-amzn-requestid"); v != "" {
|
||||||
|
c.Header("x-request-id", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &ClaudeUsage{}
|
||||||
|
var firstTokenMs *int
|
||||||
|
clientDisconnected := false
|
||||||
|
|
||||||
|
// Bedrock EventStream 使用 application/vnd.amazon.eventstream 二进制格式。
|
||||||
|
// 每个帧结构:total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4)
|
||||||
|
// 但更实用的方式是使用行扫描找 JSON chunks,因为 Bedrock 的响应在二进制帧中。
|
||||||
|
// 我们使用 EventStream decoder 来正确解析。
|
||||||
|
decoder := newBedrockEventStreamDecoder(resp.Body)
|
||||||
|
|
||||||
|
type decodeEvent struct {
|
||||||
|
payload []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
events := make(chan decodeEvent, 16)
|
||||||
|
done := make(chan struct{})
|
||||||
|
sendEvent := func(ev decodeEvent) bool {
|
||||||
|
select {
|
||||||
|
case events <- ev:
|
||||||
|
return true
|
||||||
|
case <-done:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var lastReadAt atomic.Int64
|
||||||
|
lastReadAt.Store(time.Now().UnixNano())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(events)
|
||||||
|
for {
|
||||||
|
payload, err := decoder.Decode()
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = sendEvent(decodeEvent{err: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lastReadAt.Store(time.Now().UnixNano())
|
||||||
|
if !sendEvent(decodeEvent{payload: payload}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
streamInterval := time.Duration(0)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||||
|
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||||
|
}
|
||||||
|
var intervalTicker *time.Ticker
|
||||||
|
if streamInterval > 0 {
|
||||||
|
intervalTicker = time.NewTicker(streamInterval)
|
||||||
|
defer intervalTicker.Stop()
|
||||||
|
}
|
||||||
|
var intervalCh <-chan time.Time
|
||||||
|
if intervalTicker != nil {
|
||||||
|
intervalCh = intervalTicker.C
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev, ok := <-events:
|
||||||
|
if !ok {
|
||||||
|
if !clientDisconnected {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||||
|
}
|
||||||
|
if ev.err != nil {
|
||||||
|
if clientDisconnected {
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
|
}
|
||||||
|
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
|
}
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("bedrock stream read error: %w", ev.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// payload 是 JSON,提取 chunk.bytes(base64 编码的 Claude SSE 事件数据)
|
||||||
|
sseData := extractBedrockChunkData(ev.payload)
|
||||||
|
if sseData == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if firstTokenMs == nil {
|
||||||
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
|
firstTokenMs = &ms
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式
|
||||||
|
// 同时移除该字段避免透传给客户端
|
||||||
|
sseData = transformBedrockInvocationMetrics(sseData)
|
||||||
|
|
||||||
|
// 解析 SSE 事件数据提取 usage
|
||||||
|
s.parseSSEUsagePassthrough(string(sseData), usage)
|
||||||
|
|
||||||
|
// 确定 SSE event type
|
||||||
|
eventType := gjson.GetBytes(sseData, "type").String()
|
||||||
|
|
||||||
|
// 写入标准 SSE 格式
|
||||||
|
if !clientDisconnected {
|
||||||
|
var writeErr error
|
||||||
|
if eventType != "" {
|
||||||
|
_, writeErr = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, sseData)
|
||||||
|
} else {
|
||||||
|
_, writeErr = fmt.Fprintf(w, "data: %s\n\n", sseData)
|
||||||
|
}
|
||||||
|
if writeErr != nil {
|
||||||
|
clientDisconnected = true
|
||||||
|
logger.LegacyPrintf("service.gateway", "[Bedrock] Client disconnected during streaming, continue draining for usage: account=%d", account.ID)
|
||||||
|
} else {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-intervalCh:
|
||||||
|
lastRead := time.Unix(0, lastReadAt.Load())
|
||||||
|
if time.Since(lastRead) < streamInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if clientDisconnected {
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
|
}
|
||||||
|
logger.LegacyPrintf("service.gateway", "[Bedrock] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
|
||||||
|
}
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractBedrockChunkData 从 Bedrock EventStream payload 中提取 Claude SSE 事件数据
|
||||||
|
// Bedrock payload 格式:{"bytes":"<base64-encoded-json>"}
|
||||||
|
func extractBedrockChunkData(payload []byte) []byte {
|
||||||
|
b64 := gjson.GetBytes(payload, "bytes").String()
|
||||||
|
if b64 == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(b64)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return decoded
|
||||||
|
}
|
||||||
|
|
||||||
|
// transformBedrockInvocationMetrics 将 Bedrock 特有的 amazon-bedrock-invocationMetrics
|
||||||
|
// 转换为标准 Anthropic usage 格式,并从 SSE 数据中移除该字段。
|
||||||
|
//
|
||||||
|
// Bedrock Invoke 返回的 message_delta 事件可能包含:
|
||||||
|
//
|
||||||
|
// {"type":"message_delta","delta":{...},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}
|
||||||
|
//
|
||||||
|
// 转换为:
|
||||||
|
//
|
||||||
|
// {"type":"message_delta","delta":{...},"usage":{"input_tokens":150,"output_tokens":42}}
|
||||||
|
func transformBedrockInvocationMetrics(data []byte) []byte {
|
||||||
|
metrics := gjson.GetBytes(data, "amazon-bedrock-invocationMetrics")
|
||||||
|
if !metrics.Exists() || !metrics.IsObject() {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除 Bedrock 特有字段
|
||||||
|
data, _ = sjson.DeleteBytes(data, "amazon-bedrock-invocationMetrics")
|
||||||
|
|
||||||
|
// 如果已有标准 usage 字段,不覆盖
|
||||||
|
if gjson.GetBytes(data, "usage").Exists() {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换 camelCase → snake_case 写入 usage
|
||||||
|
inputTokens := metrics.Get("inputTokenCount")
|
||||||
|
outputTokens := metrics.Get("outputTokenCount")
|
||||||
|
if inputTokens.Exists() {
|
||||||
|
data, _ = sjson.SetBytes(data, "usage.input_tokens", inputTokens.Int())
|
||||||
|
}
|
||||||
|
if outputTokens.Exists() {
|
||||||
|
data, _ = sjson.SetBytes(data, "usage.output_tokens", outputTokens.Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// bedrockEventStreamDecoder 解码 AWS EventStream 二进制帧
|
||||||
|
// EventStream 帧格式:
|
||||||
|
//
|
||||||
|
// [total_byte_length: 4 bytes]
|
||||||
|
// [headers_byte_length: 4 bytes]
|
||||||
|
// [prelude_crc: 4 bytes]
|
||||||
|
// [headers: variable]
|
||||||
|
// [payload: variable]
|
||||||
|
// [message_crc: 4 bytes]
|
||||||
|
type bedrockEventStreamDecoder struct {
|
||||||
|
reader *bufio.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBedrockEventStreamDecoder(r io.Reader) *bedrockEventStreamDecoder {
|
||||||
|
return &bedrockEventStreamDecoder{
|
||||||
|
reader: bufio.NewReaderSize(r, 64*1024),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode 读取下一个 EventStream 帧并返回 chunk 类型事件的 payload
|
||||||
|
func (d *bedrockEventStreamDecoder) Decode() ([]byte, error) {
|
||||||
|
for {
|
||||||
|
// 读取 prelude: total_length(4) + headers_length(4) + prelude_crc(4) = 12 bytes
|
||||||
|
prelude := make([]byte, 12)
|
||||||
|
if _, err := io.ReadFull(d.reader, prelude); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 prelude CRC(AWS EventStream 使用标准 CRC32 / IEEE)
|
||||||
|
preludeCRC := bedrockReadUint32(prelude[8:12])
|
||||||
|
if crc32.Checksum(prelude[0:8], crc32IEEETable) != preludeCRC {
|
||||||
|
return nil, fmt.Errorf("eventstream prelude CRC mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
totalLength := bedrockReadUint32(prelude[0:4])
|
||||||
|
headersLength := bedrockReadUint32(prelude[4:8])
|
||||||
|
|
||||||
|
if totalLength < 16 { // minimum: 12 prelude + 4 message_crc
|
||||||
|
return nil, fmt.Errorf("invalid eventstream frame: total_length=%d", totalLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 读取 headers + payload + message_crc
|
||||||
|
remaining := int(totalLength) - 12
|
||||||
|
if remaining <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := make([]byte, remaining)
|
||||||
|
if _, err := io.ReadFull(d.reader, data); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 message CRC(覆盖 prelude + headers + payload)
|
||||||
|
messageCRC := bedrockReadUint32(data[len(data)-4:])
|
||||||
|
h := crc32.New(crc32IEEETable)
|
||||||
|
_, _ = h.Write(prelude)
|
||||||
|
_, _ = h.Write(data[:len(data)-4])
|
||||||
|
if h.Sum32() != messageCRC {
|
||||||
|
return nil, fmt.Errorf("eventstream message CRC mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 headers
|
||||||
|
headers := data[:headersLength]
|
||||||
|
payload := data[headersLength : len(data)-4] // 去掉 message_crc
|
||||||
|
|
||||||
|
// 从 headers 中提取 :event-type
|
||||||
|
eventType := extractEventStreamHeaderValue(headers, ":event-type")
|
||||||
|
|
||||||
|
// 只处理 chunk 事件
|
||||||
|
if eventType == "chunk" {
|
||||||
|
// payload 是完整的 JSON,包含 bytes 字段
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查异常事件
|
||||||
|
exceptionType := extractEventStreamHeaderValue(headers, ":exception-type")
|
||||||
|
if exceptionType != "" {
|
||||||
|
return nil, fmt.Errorf("bedrock exception: %s: %s", exceptionType, string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
messageType := extractEventStreamHeaderValue(headers, ":message-type")
|
||||||
|
if messageType == "exception" || messageType == "error" {
|
||||||
|
return nil, fmt.Errorf("bedrock error: %s", string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 跳过其他事件类型(如 initial-response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractEventStreamHeaderValue 从 EventStream headers 二进制数据中提取指定 header 的字符串值
|
||||||
|
// EventStream header 格式:
|
||||||
|
//
|
||||||
|
// [name_length: 1 byte][name: variable][value_type: 1 byte][value: variable]
|
||||||
|
//
|
||||||
|
// value_type = 7 表示 string 类型,前 2 bytes 为长度
|
||||||
|
func extractEventStreamHeaderValue(headers []byte, targetName string) string {
|
||||||
|
pos := 0
|
||||||
|
for pos < len(headers) {
|
||||||
|
if pos >= len(headers) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
nameLen := int(headers[pos])
|
||||||
|
pos++
|
||||||
|
if pos+nameLen > len(headers) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
name := string(headers[pos : pos+nameLen])
|
||||||
|
pos += nameLen
|
||||||
|
|
||||||
|
if pos >= len(headers) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
valueType := headers[pos]
|
||||||
|
pos++
|
||||||
|
|
||||||
|
switch valueType {
|
||||||
|
case 7: // string
|
||||||
|
if pos+2 > len(headers) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
|
||||||
|
pos += 2
|
||||||
|
if pos+valueLen > len(headers) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
value := string(headers[pos : pos+valueLen])
|
||||||
|
pos += valueLen
|
||||||
|
if name == targetName {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
case 0: // bool true
|
||||||
|
if name == targetName {
|
||||||
|
return "true"
|
||||||
|
}
|
||||||
|
case 1: // bool false
|
||||||
|
if name == targetName {
|
||||||
|
return "false"
|
||||||
|
}
|
||||||
|
case 2: // byte
|
||||||
|
pos++
|
||||||
|
if name == targetName {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
case 3: // short
|
||||||
|
pos += 2
|
||||||
|
if name == targetName {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
case 4: // int
|
||||||
|
pos += 4
|
||||||
|
if name == targetName {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
case 5: // long
|
||||||
|
pos += 8
|
||||||
|
if name == targetName {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
case 6: // bytes
|
||||||
|
if pos+2 > len(headers) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
|
||||||
|
pos += 2 + valueLen
|
||||||
|
case 8: // timestamp
|
||||||
|
pos += 8
|
||||||
|
case 9: // uuid
|
||||||
|
pos += 16
|
||||||
|
default:
|
||||||
|
return "" // 未知类型,无法继续解析
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// crc32IEEETable is the CRC32 / IEEE table used by AWS EventStream.
|
||||||
|
var crc32IEEETable = crc32.MakeTable(crc32.IEEE)
|
||||||
|
|
||||||
|
func bedrockReadUint32(b []byte) uint32 {
|
||||||
|
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||||
|
}
|
||||||
|
|
||||||
|
func bedrockReadUint16(b []byte) uint16 {
|
||||||
|
return uint16(b[0])<<8 | uint16(b[1])
|
||||||
|
}
|
||||||
261
backend/internal/service/bedrock_stream_test.go
Normal file
261
backend/internal/service/bedrock_stream_test.go
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/binary"
|
||||||
|
"hash/crc32"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractBedrockChunkData(t *testing.T) {
|
||||||
|
t.Run("valid base64 payload", func(t *testing.T) {
|
||||||
|
original := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`
|
||||||
|
b64 := base64.StdEncoding.EncodeToString([]byte(original))
|
||||||
|
payload := []byte(`{"bytes":"` + b64 + `"}`)
|
||||||
|
|
||||||
|
result := extractBedrockChunkData(payload)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.JSONEq(t, original, string(result))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty bytes field", func(t *testing.T) {
|
||||||
|
result := extractBedrockChunkData([]byte(`{"bytes":""}`))
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no bytes field", func(t *testing.T) {
|
||||||
|
result := extractBedrockChunkData([]byte(`{"other":"value"}`))
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid base64", func(t *testing.T) {
|
||||||
|
result := extractBedrockChunkData([]byte(`{"bytes":"not-valid-base64!!!"}`))
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransformBedrockInvocationMetrics(t *testing.T) {
|
||||||
|
t.Run("converts metrics to usage", func(t *testing.T) {
|
||||||
|
input := `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
|
||||||
|
result := transformBedrockInvocationMetrics([]byte(input))
|
||||||
|
|
||||||
|
// amazon-bedrock-invocationMetrics should be removed
|
||||||
|
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
|
||||||
|
// usage should be set
|
||||||
|
assert.Equal(t, int64(150), gjson.GetBytes(result, "usage.input_tokens").Int())
|
||||||
|
assert.Equal(t, int64(42), gjson.GetBytes(result, "usage.output_tokens").Int())
|
||||||
|
// original fields preserved
|
||||||
|
assert.Equal(t, "message_delta", gjson.GetBytes(result, "type").String())
|
||||||
|
assert.Equal(t, "end_turn", gjson.GetBytes(result, "delta.stop_reason").String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no metrics present", func(t *testing.T) {
|
||||||
|
input := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}`
|
||||||
|
result := transformBedrockInvocationMetrics([]byte(input))
|
||||||
|
assert.JSONEq(t, input, string(result))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not overwrite existing usage", func(t *testing.T) {
|
||||||
|
input := `{"type":"message_delta","usage":{"output_tokens":100},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
|
||||||
|
result := transformBedrockInvocationMetrics([]byte(input))
|
||||||
|
|
||||||
|
// metrics removed but existing usage preserved
|
||||||
|
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
|
||||||
|
assert.Equal(t, int64(100), gjson.GetBytes(result, "usage.output_tokens").Int())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractEventStreamHeaderValue(t *testing.T) {
|
||||||
|
// Build a header with :event-type = "chunk" (string type = 7)
|
||||||
|
buildStringHeader := func(name, value string) []byte {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
// name length (1 byte)
|
||||||
|
_ = buf.WriteByte(byte(len(name)))
|
||||||
|
// name
|
||||||
|
_, _ = buf.WriteString(name)
|
||||||
|
// value type (7 = string)
|
||||||
|
_ = buf.WriteByte(7)
|
||||||
|
// value length (2 bytes, big-endian)
|
||||||
|
_ = binary.Write(&buf, binary.BigEndian, uint16(len(value)))
|
||||||
|
// value
|
||||||
|
_, _ = buf.WriteString(value)
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("find string header", func(t *testing.T) {
|
||||||
|
headers := buildStringHeader(":event-type", "chunk")
|
||||||
|
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("header not found", func(t *testing.T) {
|
||||||
|
headers := buildStringHeader(":event-type", "chunk")
|
||||||
|
assert.Equal(t, "", extractEventStreamHeaderValue(headers, ":message-type"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple headers", func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
_, _ = buf.Write(buildStringHeader(":content-type", "application/json"))
|
||||||
|
_, _ = buf.Write(buildStringHeader(":event-type", "chunk"))
|
||||||
|
_, _ = buf.Write(buildStringHeader(":message-type", "event"))
|
||||||
|
|
||||||
|
headers := buf.Bytes()
|
||||||
|
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
|
||||||
|
assert.Equal(t, "application/json", extractEventStreamHeaderValue(headers, ":content-type"))
|
||||||
|
assert.Equal(t, "event", extractEventStreamHeaderValue(headers, ":message-type"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty headers", func(t *testing.T) {
|
||||||
|
assert.Equal(t, "", extractEventStreamHeaderValue([]byte{}, ":event-type"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBedrockEventStreamDecoder(t *testing.T) {
|
||||||
|
crc32IeeeTab := crc32.MakeTable(crc32.IEEE)
|
||||||
|
|
||||||
|
// Build a valid EventStream frame with correct CRC32/IEEE checksums.
|
||||||
|
buildFrame := func(eventType string, payload []byte) []byte {
|
||||||
|
// Build headers
|
||||||
|
var headersBuf bytes.Buffer
|
||||||
|
// :event-type header
|
||||||
|
_ = headersBuf.WriteByte(byte(len(":event-type")))
|
||||||
|
_, _ = headersBuf.WriteString(":event-type")
|
||||||
|
_ = headersBuf.WriteByte(7) // string type
|
||||||
|
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len(eventType)))
|
||||||
|
_, _ = headersBuf.WriteString(eventType)
|
||||||
|
// :message-type header
|
||||||
|
_ = headersBuf.WriteByte(byte(len(":message-type")))
|
||||||
|
_, _ = headersBuf.WriteString(":message-type")
|
||||||
|
_ = headersBuf.WriteByte(7)
|
||||||
|
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("event")))
|
||||||
|
_, _ = headersBuf.WriteString("event")
|
||||||
|
|
||||||
|
headers := headersBuf.Bytes()
|
||||||
|
headersLen := uint32(len(headers))
|
||||||
|
// total = 12 (prelude) + headers + payload + 4 (message_crc)
|
||||||
|
totalLen := uint32(12 + len(headers) + len(payload) + 4)
|
||||||
|
|
||||||
|
// Prelude: total_length(4) + headers_length(4)
|
||||||
|
var preludeBuf bytes.Buffer
|
||||||
|
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
|
||||||
|
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
|
||||||
|
preludeBytes := preludeBuf.Bytes()
|
||||||
|
preludeCRC := crc32.Checksum(preludeBytes, crc32IeeeTab)
|
||||||
|
|
||||||
|
// Build frame: prelude + prelude_crc + headers + payload
|
||||||
|
var frame bytes.Buffer
|
||||||
|
_, _ = frame.Write(preludeBytes)
|
||||||
|
_ = binary.Write(&frame, binary.BigEndian, preludeCRC)
|
||||||
|
_, _ = frame.Write(headers)
|
||||||
|
_, _ = frame.Write(payload)
|
||||||
|
|
||||||
|
// Message CRC covers everything before itself
|
||||||
|
messageCRC := crc32.Checksum(frame.Bytes(), crc32IeeeTab)
|
||||||
|
_ = binary.Write(&frame, binary.BigEndian, messageCRC)
|
||||||
|
return frame.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("decode chunk event", func(t *testing.T) {
|
||||||
|
payload := []byte(`{"bytes":"dGVzdA=="}`) // base64("test")
|
||||||
|
frame := buildFrame("chunk", payload)
|
||||||
|
|
||||||
|
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
|
||||||
|
result, err := decoder.Decode()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, payload, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("skip non-chunk events", func(t *testing.T) {
|
||||||
|
// Write initial-response followed by chunk
|
||||||
|
var buf bytes.Buffer
|
||||||
|
_, _ = buf.Write(buildFrame("initial-response", []byte(`{}`)))
|
||||||
|
chunkPayload := []byte(`{"bytes":"aGVsbG8="}`)
|
||||||
|
_, _ = buf.Write(buildFrame("chunk", chunkPayload))
|
||||||
|
|
||||||
|
decoder := newBedrockEventStreamDecoder(&buf)
|
||||||
|
result, err := decoder.Decode()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, chunkPayload, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EOF on empty input", func(t *testing.T) {
|
||||||
|
decoder := newBedrockEventStreamDecoder(bytes.NewReader(nil))
|
||||||
|
_, err := decoder.Decode()
|
||||||
|
assert.Equal(t, io.EOF, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("corrupted prelude CRC", func(t *testing.T) {
|
||||||
|
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
|
||||||
|
// Corrupt the prelude CRC (bytes 8-11)
|
||||||
|
frame[8] ^= 0xFF
|
||||||
|
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
|
||||||
|
_, err := decoder.Decode()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "prelude CRC mismatch")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("corrupted message CRC", func(t *testing.T) {
|
||||||
|
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
|
||||||
|
// Corrupt the message CRC (last 4 bytes)
|
||||||
|
frame[len(frame)-1] ^= 0xFF
|
||||||
|
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
|
||||||
|
_, err := decoder.Decode()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "message CRC mismatch")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("castagnoli encoded frame is rejected", func(t *testing.T) {
|
||||||
|
castagnoliTab := crc32.MakeTable(crc32.Castagnoli)
|
||||||
|
payload := []byte(`{"bytes":"dGVzdA=="}`)
|
||||||
|
|
||||||
|
var headersBuf bytes.Buffer
|
||||||
|
_ = headersBuf.WriteByte(byte(len(":event-type")))
|
||||||
|
_, _ = headersBuf.WriteString(":event-type")
|
||||||
|
_ = headersBuf.WriteByte(7)
|
||||||
|
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("chunk")))
|
||||||
|
_, _ = headersBuf.WriteString("chunk")
|
||||||
|
|
||||||
|
headers := headersBuf.Bytes()
|
||||||
|
headersLen := uint32(len(headers))
|
||||||
|
totalLen := uint32(12 + len(headers) + len(payload) + 4)
|
||||||
|
|
||||||
|
var preludeBuf bytes.Buffer
|
||||||
|
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
|
||||||
|
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
|
||||||
|
preludeBytes := preludeBuf.Bytes()
|
||||||
|
|
||||||
|
var frame bytes.Buffer
|
||||||
|
_, _ = frame.Write(preludeBytes)
|
||||||
|
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(preludeBytes, castagnoliTab))
|
||||||
|
_, _ = frame.Write(headers)
|
||||||
|
_, _ = frame.Write(payload)
|
||||||
|
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(frame.Bytes(), castagnoliTab))
|
||||||
|
|
||||||
|
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame.Bytes()))
|
||||||
|
_, err := decoder.Decode()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "prelude CRC mismatch")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildBedrockURL(t *testing.T) {
|
||||||
|
t.Run("stream URL with colon in model ID", func(t *testing.T) {
|
||||||
|
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-opus-4-5-20251101-v1:0", true)
|
||||||
|
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/invoke-with-response-stream", url)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-stream URL with colon in model ID", func(t *testing.T) {
|
||||||
|
url := BuildBedrockURL("eu-west-1", "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", false)
|
||||||
|
assert.Equal(t, "https://bedrock-runtime.eu-west-1.amazonaws.com/model/eu.anthropic.claude-sonnet-4-5-20250929-v1%3A0/invoke", url)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("model ID without colon", func(t *testing.T) {
|
||||||
|
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-sonnet-4-6", true)
|
||||||
|
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-sonnet-4-6/invoke-with-response-stream", url)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -35,6 +35,7 @@ type DashboardAggregationRepository interface {
|
|||||||
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
||||||
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
||||||
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
||||||
|
CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error
|
||||||
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
|
|||||||
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
||||||
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
||||||
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
||||||
|
dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays)
|
||||||
|
|
||||||
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
||||||
if aggErr != nil {
|
if aggErr != nil {
|
||||||
@@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
|
|||||||
if usageErr != nil {
|
if usageErr != nil {
|
||||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
||||||
}
|
}
|
||||||
if aggErr == nil && usageErr == nil {
|
dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff)
|
||||||
|
if dedupErr != nil {
|
||||||
|
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr)
|
||||||
|
}
|
||||||
|
if aggErr == nil && usageErr == nil && dedupErr == nil {
|
||||||
s.lastRetentionCleanup.Store(now)
|
s.lastRetentionCleanup.Store(now)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,12 +12,18 @@ import (
|
|||||||
|
|
||||||
type dashboardAggregationRepoTestStub struct {
|
type dashboardAggregationRepoTestStub struct {
|
||||||
aggregateCalls int
|
aggregateCalls int
|
||||||
|
recomputeCalls int
|
||||||
|
cleanupUsageCalls int
|
||||||
|
cleanupDedupCalls int
|
||||||
|
ensurePartitionCalls int
|
||||||
lastStart time.Time
|
lastStart time.Time
|
||||||
lastEnd time.Time
|
lastEnd time.Time
|
||||||
watermark time.Time
|
watermark time.Time
|
||||||
aggregateErr error
|
aggregateErr error
|
||||||
cleanupAggregatesErr error
|
cleanupAggregatesErr error
|
||||||
cleanupUsageErr error
|
cleanupUsageErr error
|
||||||
|
cleanupDedupErr error
|
||||||
|
ensurePartitionErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
@@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
s.recomputeCalls++
|
||||||
return s.AggregateRange(ctx, start, end)
|
return s.AggregateRange(ctx, start, end)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||||
|
s.cleanupUsageCalls++
|
||||||
return s.cleanupUsageErr
|
return s.cleanupUsageErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||||
|
s.cleanupDedupCalls++
|
||||||
|
return s.cleanupDedupErr
|
||||||
|
}
|
||||||
|
|
||||||
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||||
return nil
|
s.ensurePartitionCalls++
|
||||||
|
return s.ensurePartitionErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
||||||
@@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te
|
|||||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||||
|
|
||||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||||
|
require.Equal(t, 1, repo.cleanupUsageCalls)
|
||||||
|
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) {
|
||||||
|
repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")}
|
||||||
|
svc := &DashboardAggregationService{
|
||||||
|
repo: repo,
|
||||||
|
cfg: config.DashboardAggregationConfig{
|
||||||
|
Retention: config.DashboardAggregationRetentionConfig{
|
||||||
|
UsageLogsDays: 1,
|
||||||
|
HourlyDays: 1,
|
||||||
|
DailyDays: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||||
|
|
||||||
|
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||||
|
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) {
|
||||||
|
repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")}
|
||||||
|
svc := &DashboardAggregationService{
|
||||||
|
repo: repo,
|
||||||
|
cfg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
IntervalSeconds: 60,
|
||||||
|
LookbackSeconds: 120,
|
||||||
|
Retention: config.DashboardAggregationRetentionConfig{
|
||||||
|
UsageLogsDays: 1,
|
||||||
|
UsageBillingDedupDays: 2,
|
||||||
|
HourlyDays: 1,
|
||||||
|
DailyDays: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc.runScheduledAggregation()
|
||||||
|
|
||||||
|
require.Equal(t, 1, repo.ensurePartitionCalls)
|
||||||
|
require.Equal(t, 1, repo.aggregateCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user