Compare commits

...

42 Commits

Author SHA1 Message Date
shaw
6da5fa01b9 fix(frontend): 修复运维设置对话框保存按钮始终禁用的问题
后端默认 alert.enabled=true 但 recipients 为空,前端验证将其视为
错误并阻断保存按钮。移除该阻断性验证,改为保存时自动禁用无收件人
的邮件通知配置。
2026-03-14 20:39:29 +08:00
shaw
616930f9d3 refactor(frontend): 将备份和数据管理页面合并为设置页的标签页
将独立的 /admin/backup 和 /admin/data-management 页面整合到设置页,
作为「备份」和「Sora 存储」标签页,减少侧边栏条目,集中管理配置。

- 移除 BackupView 和 DataManagementView 的 AppLayout 包装
- 在 SettingsView 中以子组件形式嵌入,使用 v-show 切换标签
- 删除独立路由和侧边栏菜单入口
- 备份/数据标签页下隐藏主保存按钮(各自有独立保存)
- 优化标签栏样式适配7个标签,PC端支持细滚动条
- 清理未使用的图标组件和 i18n 键
2026-03-14 20:22:39 +08:00
Wesley Liddick
b9c31fa7c4 Merge pull request #999 from InCerryGit/fix/enc_coot
fix: handle invalid encrypted content error and retry logic.
2026-03-14 19:29:07 +08:00
Wesley Liddick
17b339972c Merge pull request #1000 from touwaeriol/fix/ops-agg-tuning
fix(ops): tune aggregation constants to prevent PG overload
2026-03-14 19:03:03 +08:00
shaw
39f8bd91b9 fix: remove unused saveRecords method to pass lint 2026-03-14 19:01:27 +08:00
Wesley Liddick
aa4e37d085 Merge pull request #966 from GuangYiDing/feat/db-backup-restore
feat: 数据库定时备份与恢复(S3 兼容存储,支持 Cloudflare R2)
2026-03-14 18:58:56 +08:00
erio
f59b66b7d4 fix(ops): tune aggregation constants to prevent PG overload
Increase MAX(bucket_start) query timeout from 3s to 5s to reduce
timeout-induced fallbacks. Shrink backfill window from 30 days to
1 hour so that fallback recomputation stays lightweight instead of
scanning the entire retention range.
2026-03-14 18:47:37 +08:00
InCerry
8f0ea7a02d Merge branch 'main' into fix/enc_coot 2026-03-14 18:46:33 +08:00
Wesley Liddick
a1dc00890e Merge pull request #944 from miraserver/feat/backend-mode
feat: add Backend Mode toggle to disable user self-service
2026-03-14 17:53:54 +08:00
Wesley Liddick
dfbcc363d1 Merge pull request #969 from wucm667/feat/quota-fixed-reset-mode
feat: 账号配额支持固定时间重置模式
2026-03-14 17:52:56 +08:00
Rose Ding
1047f973d5 fix: 按 review 意见重构数据库备份服务(安全性 + 架构 + 健壮性)
1. S3 凭证加密存储:使用 SecretEncryptor (AES-256-GCM) 加密 SecretAccessKey,
   防止备份文件中泄露 S3 凭证,兼容旧的未加密数据
2. 修复 saveRecord 竞态条件:添加 recordsMu 互斥锁保护 records 的 load/save
3. 恢复操作增加服务端验证:handler 层要求重新输入管理员密码,通过 bcrypt
   校验,前端弹出密码输入框
4. pg_dump/psql/S3 操作抽象为接口:定义 DBDumper 和 BackupObjectStore 接口,
   实现放入 repository 层,遵循项目依赖注入架构规范
5. 改为流式处理避免大数据库 OOM:备份时 pg_dump stdout -> gzip -> io.Pipe ->
   S3 upload;恢复时 S3 download -> gzip reader -> psql stdin,不再全量加载
6. loadRecords 区分"无数据"和"数据损坏"场景:JSON 解析失败返回明确错误
7. 添加 18 个核心逻辑单元测试:覆盖加密、并发、流式备份/恢复、错误处理等

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-14 17:48:21 +08:00
Wesley Liddick
e32977dd73 Merge pull request #997 from SsageParuders/refactor/bedrock-channel-merge
refactor: merge bedrock-apikey into unified bedrock channel with auth_mode
2026-03-14 17:48:01 +08:00
wucm667
b5f78ec1e8 feat: 实现固定时间重置模式的 SQL 表达式,并添加相关单元测试 2026-03-14 17:37:34 +08:00
SsageParuders
e0f290fdc8 style: fix gofmt formatting for account type constants 2026-03-14 17:32:53 +08:00
Wesley Liddick
fc00a4e3b2 Merge pull request #959 from touwaeriol/feat/antigravity-403-detection
feat(antigravity): add 403 forbidden status detection and display
2026-03-14 17:23:22 +08:00
Wesley Liddick
db1f6ded88 Merge pull request #961 from 0xObjc/codex/ops-openai-token-visibility
feat(ops): make OpenAI token stats optional
2026-03-14 17:23:01 +08:00
SsageParuders
4644af2ccc refactor: merge bedrock-apikey into bedrock with auth_mode credential
Consolidate two separate channel types (bedrock + bedrock-apikey) into
a single "AWS Bedrock" channel. Authentication mode is now distinguished
by credentials.auth_mode ("sigv4" | "apikey") instead of separate types.

Backend:
- Remove AccountTypeBedrockAPIKey constant
- IsBedrock() simplified; IsBedrockAPIKey() checks auth_mode
- Add IsAPIKeyOrBedrock() helper to eliminate repeated type checks
- Extend pool mode, quota scheduling, and billing to bedrock
- Add RetryableOnSameAccount to handleBedrockUpstreamErrors
- Add "bedrock" scope to Beta Policy for independent control

Frontend:
- Merge two buttons into one "AWS Bedrock" with auth mode radio
- Badge displays "Anthropic | AWS"
- Pool mode and quota limit UI available for bedrock
- Quota display in account list (usage bars, capacity badges, reset)
- Remove all bedrock-apikey type references
2026-03-14 17:13:30 +08:00
Wesley Liddick
2e3e8687e1 Merge pull request #993 from xvhuan/fix/codex-responses-id-prefix-hotfix-20260314
fix: 止血 Codex/Responses 原生 input id 被误改成 fc_*
2026-03-14 13:54:26 +08:00
ius
ca42a45802 fix: stop rewriting native responses input ids 2026-03-14 13:47:01 +08:00
Wesley Liddick
9350ecb62b Merge pull request #987 from Ethan0x0000/feat-chatcompletions2repsonses-fix
fix: chat compatibility model fallback and reasoning_content output
2026-03-14 13:41:47 +08:00
Wesley Liddick
a4a026e8da Merge pull request #990 from LvyuanW/admin-openai-available-models-fix
fix: respect OpenAI OAuth model mapping in admin available models
2026-03-14 13:33:18 +08:00
Wesley Liddick
342fd03e72 Merge pull request #986 from LvyuanW/openai-model-mapping-fix
fix: honor account model mapping before group fallback
2026-03-14 13:32:26 +08:00
Ethan0x0000
e3f1fd9b63 fix: handle strings.Builder write errors in assistant parsing 2026-03-14 13:12:17 +08:00
InCerry
e4a4dfd038 Merge remote-tracking branch 'origin/main' into fix/enc_coot
# Conflicts:
#	backend/internal/service/openai_gateway_service.go
2026-03-14 13:04:24 +08:00
Wang Lvyuan
a377e99088 fix: remove unused wildcard mapping helper 2026-03-14 12:56:34 +08:00
Wang Lvyuan
1d3d7a3033 fix: respect OpenAI model mapping in admin available models 2026-03-14 12:45:10 +08:00
Wesley Liddick
e7086cb3a3 Merge pull request #988 from LvyuanW/scheduler-snapshot-sync
fix: sync scheduler snapshot on account updates
2026-03-14 12:38:48 +08:00
Wang Lvyuan
01ef7340aa Merge remote-tracking branch 'origin/main' into openai-model-mapping-fix 2026-03-14 12:27:08 +08:00
Wang Lvyuan
1c960d22c1 fix: sync scheduler snapshot on account updates 2026-03-14 12:21:28 +08:00
Ethan0x0000
ece0606fed fix: consolidate chat-completions compatibility fixes
- apply default mapped model only when scheduling fallback is actually used

- preserve reasoning in OpenAI-compatible output via reasoning_content and avoid invalid input function_call ids
2026-03-14 12:12:08 +08:00
InCerry
2666422b99 fix: handle invalid encrypted content error and retry logic. 2026-03-14 11:42:42 +08:00
Wang Lvyuan
4e8615f276 fix: honor account model mapping before group fallback 2026-03-14 10:47:31 +08:00
erio
45456fa24c fix: restore OAuth 401 temp-unschedulable for Gemini, update Antigravity tests
The 403 detection PR changed the 401 handler condition from
`account.Type == AccountTypeOAuth` to
`account.Type == AccountTypeOAuth && account.Platform == PlatformOpenAI`,
which accidentally excluded Gemini OAuth from the temp-unschedulable path.

Fix: use `!= PlatformAntigravity` instead, preserving Gemini behavior
while correctly excluding Antigravity (whose 401 is handled by
applyErrorPolicy's temp_unschedulable_rules).

Update tests to reflect Antigravity's new 401 semantics:
- HandleUpstreamError: Antigravity OAuth 401 now uses SetError
- CheckErrorPolicy: Antigravity 401 second hit stays TempUnscheduled
- DB fallback: split into Gemini (escalates) and Antigravity (stays temp)
2026-03-14 02:21:22 +08:00
erio
6344fa2a86 feat(antigravity): add 403 forbidden status detection, classification and display
Backend:
- Detect and classify 403 responses into three types:
  validation (account needs Google verification),
  violation (terms of service / banned),
  forbidden (generic 403)
- Extract verification/appeal URLs from 403 response body
  (structured JSON parsing with regex fallback)
- Add needs_verify, is_banned, needs_reauth, error_code fields
  to UsageInfo (omitempty for zero impact on other platforms)
- Handle 403 in request path: classify and permanently set account error
- Save validation_url in error_message for degraded path recovery
- Enrich usage with account error on both success and degraded paths
- Add singleflight dedup for usage requests with independent context
- Differentiate cache TTL: success/403 → 3min, errors → 1min
- Return degraded UsageInfo instead of HTTP 500 on quota fetch errors

Frontend:
- Display forbidden status badges with color coding (red for banned,
  amber for needs verification, gray for generic)
- Show clickable verification/appeal URL links
- Display needs_reauth and degraded error states in usage cell
- Add Antigravity tier label badge next to platform type

Tests:
- Comprehensive unit tests for classifyForbiddenType (7 cases)
- Unit tests for extractValidationURL (8 cases including unicode escapes)
- Integration test for FetchQuota forbidden path
2026-03-13 18:22:45 +08:00
Peter
29b0e4a8a5 feat(ops): allow hiding alert events 2026-03-13 17:18:04 +08:00
Rose Ding
f7177be3b6 fix: golangci-lint 修复(gofmt 格式化 + errcheck 返回值检查)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-13 12:01:05 +08:00
Rose Ding
875b417fde fix: 补充 wire_gen_test.go 中 provideCleanup 缺少的 backupSvc 参数
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-13 11:53:46 +08:00
wucm667
2573107b32 refactor: 将 ComputeQuotaResetAt 和 ValidateQuotaResetConfig 函数中的 map 类型从 map[string]interface{} 修改为 map[string]any 2026-03-13 11:44:49 +08:00
wucm667
5b85005945 feat: 账号配额支持固定时间重置模式
- 后端新增 rolling/fixed 两种配额重置模式,支持日配额和周配额
- fixed 模式下可配置重置时刻(小时)、重置星期几(周配额)及时区(IANA)
- 在 account_repo.go 中使用 SQL 表达式适配两种模式的过期判断与重置时间推进
- 新增 ComputeQuotaResetAt / ValidateQuotaResetConfig 等辅助函数
- DTO 层新增相关字段并在 mappers 中完整映射
- 前端 QuotaLimitCard 新增 rolling/fixed 切换 UI、时区选择器
- CreateAccountModal / EditAccountModal 透传新配置字段
- i18n(zh/en)同步新增相关翻译词条
2026-03-13 11:12:37 +08:00
Rose Ding
53ad1645cf feat: 数据库定时备份与恢复(S3 兼容存储,支持 Cloudflare R2)
新增管理员专属的数据库备份与恢复功能:
- 全量 PostgreSQL 备份(pg_dump),gzip 压缩后上传到 S3 兼容存储
- 支持手动备份和 cron 定时备份
- 支持从备份恢复(psql --single-transaction)
- 备份文件自动过期清理(默认 14 天)
- 前端完整管理页面(S3 配置、定时配置、备份列表、恢复/下载/删除)
- 内置 Cloudflare R2 配置教程弹窗
- Dockerfile 从 postgres 镜像多阶段复制 pg_dump/psql,确保版本一致

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-13 10:38:19 +08:00
Peter
af9c4a7dd0 feat(ops): make openai token stats optional 2026-03-13 04:11:58 +08:00
John Doe
6826149a8f feat: add Backend Mode toggle to disable user self-service
Add a system-wide "Backend Mode" that disables user self-registration
and self-service while keeping admin panel and API gateway fully
functional. When enabled, only admin can log in; all user-facing
routes return 403.

Backend:
- New setting key `backend_mode_enabled` with atomic cached reads (60s TTL)
- BackendModeUserGuard middleware blocks non-admin authenticated routes
- BackendModeAuthGuard middleware blocks registration/password-reset auth routes
- Login/Login2FA/RefreshToken handlers reject non-admin when enabled
- TokenPairWithUser struct for role-aware token refresh
- 20 unit tests (middleware + service layer)

Frontend:
- Router guards redirect unauthenticated users to /login
- Admin toggle in Settings page
- Login page hides register link and footer in backend mode
- 9 unit tests for router guard logic
- i18n support (en/zh)

27 files changed, 833 insertions(+), 17 deletions(-)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-12 02:42:57 +03:00
102 changed files with 8170 additions and 910 deletions

View File

@@ -9,6 +9,7 @@
ARG NODE_IMAGE=node:24-alpine ARG NODE_IMAGE=node:24-alpine
ARG GOLANG_IMAGE=golang:1.26.1-alpine ARG GOLANG_IMAGE=golang:1.26.1-alpine
ARG ALPINE_IMAGE=alpine:3.21 ARG ALPINE_IMAGE=alpine:3.21
ARG POSTGRES_IMAGE=postgres:18-alpine
ARG GOPROXY=https://goproxy.cn,direct ARG GOPROXY=https://goproxy.cn,direct
ARG GOSUMDB=sum.golang.google.cn ARG GOSUMDB=sum.golang.google.cn
@@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \
./cmd/server ./cmd/server
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Stage 3: Final Runtime Image # Stage 3: PostgreSQL Client (version-matched with docker-compose)
# -----------------------------------------------------------------------------
FROM ${POSTGRES_IMAGE} AS pg-client
# -----------------------------------------------------------------------------
# Stage 4: Final Runtime Image
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
FROM ${ALPINE_IMAGE} FROM ${ALPINE_IMAGE}
@@ -86,8 +92,20 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
RUN apk add --no-cache \ RUN apk add --no-cache \
ca-certificates \ ca-certificates \
tzdata \ tzdata \
libpq \
zstd-libs \
lz4-libs \
krb5-libs \
libldap \
libedit \
&& rm -rf /var/cache/apk/* && rm -rf /var/cache/apk/*
# Copy pg_dump and psql from the same postgres image used in docker-compose
# This ensures version consistency between backup tools and the database server
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
# Create non-root user # Create non-root user
RUN addgroup -g 1000 sub2api && \ RUN addgroup -g 1000 sub2api && \
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api adduser -u 1000 -G sub2api -s /bin/sh -D sub2api

View File

@@ -94,6 +94,7 @@ func provideCleanup(
antigravityOAuth *service.AntigravityOAuthService, antigravityOAuth *service.AntigravityOAuthService,
openAIGateway *service.OpenAIGatewayService, openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService, scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService,
) func() { ) func() {
return func() { return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -230,6 +231,12 @@ func provideCleanup(
} }
return nil return nil
}}, }},
{"BackupService", func() error {
if backupSvc != nil {
backupSvc.Stop()
}
return nil
}},
} }
infraSteps := []cleanupStep{ infraSteps := []cleanupStep{

View File

@@ -146,6 +146,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
dataManagementService := service.NewDataManagementService() dataManagementService := service.NewDataManagementService()
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService) dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
backupObjectStoreFactory := repository.NewS3BackupStoreFactory()
dbDumper := repository.NewPgDumper(configConfig)
backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper)
backupHandler := admin.NewBackupHandler(backupService, userService)
oAuthHandler := admin.NewOAuthHandler(oAuthService) oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
@@ -201,7 +205,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db) scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@@ -232,7 +236,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
accountExpiryService := service.ProvideAccountExpiryService(accountRepository) accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
application := &Application{ application := &Application{
Server: httpServer, Server: httpServer,
Cleanup: v, Cleanup: v,
@@ -285,6 +289,7 @@ func provideCleanup(
antigravityOAuth *service.AntigravityOAuthService, antigravityOAuth *service.AntigravityOAuthService,
openAIGateway *service.OpenAIGatewayService, openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService, scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService,
) func() { ) func() {
return func() { return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -420,6 +425,12 @@ func provideCleanup(
} }
return nil return nil
}}, }},
{"BackupService", func() error {
if backupSvc != nil {
backupSvc.Stop()
}
return nil
}},
} }
infraSteps := []cleanupStep{ infraSteps := []cleanupStep{

View File

@@ -75,6 +75,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
antigravityOAuthSvc, antigravityOAuthSvc,
nil, // openAIGateway nil, // openAIGateway
nil, // scheduledTestRunner nil, // scheduledTestRunner
nil, // backupSvc
) )
require.NotPanics(t, func() { require.NotPanics(t, func() {

View File

@@ -27,12 +27,11 @@ const (
// Account type constants // Account type constants
const ( const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号 AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock由 credentials.auth_mode 区分
AccountTypeBedrockAPIKey = "bedrock-apikey" // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock
) )
// Redeem type constants // Redeem type constants

View File

@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
Name string `json:"name" binding:"required"` Name string `json:"name" binding:"required"`
Notes *string `json:"notes"` Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"` Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"` Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
Credentials map[string]any `json:"credentials" binding:"required"` Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"` Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
type UpdateAccountRequest struct { type UpdateAccountRequest struct {
Name string `json:"name"` Name string `json:"name"`
Notes *string `json:"notes"` Notes *string `json:"notes"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"` Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
Credentials map[string]any `json:"credentials"` Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"` Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
@@ -1718,13 +1718,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
// Handle OpenAI accounts // Handle OpenAI accounts
if account.IsOpenAI() { if account.IsOpenAI() {
// For OAuth accounts: return default OpenAI models // OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。
if account.IsOAuth() { if account.IsOpenAIPassthroughEnabled() {
response.Success(c, openai.DefaultModels) response.Success(c, openai.DefaultModels)
return return
} }
// For API Key accounts: check model_mapping
mapping := account.GetModelMapping() mapping := account.GetModelMapping()
if len(mapping) == 0 { if len(mapping) == 0 {
response.Success(c, openai.DefaultModels) response.Success(c, openai.DefaultModels)

View File

@@ -0,0 +1,105 @@
package admin
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type availableModelsAdminService struct {
*stubAdminService
account service.Account
}
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
if s.account.ID == id {
acc := s.account
return &acc, nil
}
return s.stubAdminService.GetAccount(context.Background(), id)
}
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
return router
}
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 42,
Name: "openai-oauth",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.1",
},
},
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Len(t, resp.Data, 1)
require.Equal(t, "gpt-5", resp.Data[0].ID)
}
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 43,
Name: "openai-oauth-passthrough",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.1",
},
},
Extra: map[string]any{
"openai_passthrough": true,
},
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp.Data)
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
}

View File

@@ -0,0 +1,204 @@
package admin
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type BackupHandler struct {
backupService *service.BackupService
userService *service.UserService
}
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
return &BackupHandler{
backupService: backupService,
userService: userService,
}
}
// ─── S3 配置 ───
func (h *BackupHandler) GetS3Config(c *gin.Context) {
cfg, err := h.backupService.GetS3Config(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
var req service.BackupS3Config
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
var req service.BackupS3Config
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.backupService.TestS3Connection(c.Request.Context(), req)
if err != nil {
response.Success(c, gin.H{"ok": false, "message": err.Error()})
return
}
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
}
// ─── 定时备份 ───
func (h *BackupHandler) GetSchedule(c *gin.Context) {
cfg, err := h.backupService.GetSchedule(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
var req service.BackupScheduleConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
// ─── 备份操作 ───
type CreateBackupRequest struct {
ExpireDays *int `json:"expire_days"` // nil=使用默认值140=永不过期
}
func (h *BackupHandler) CreateBackup(c *gin.Context) {
var req CreateBackupRequest
_ = c.ShouldBindJSON(&req) // 允许空 body
expireDays := 14 // 默认14天过期
if req.ExpireDays != nil {
expireDays = *req.ExpireDays
}
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, record)
}
func (h *BackupHandler) ListBackups(c *gin.Context) {
records, err := h.backupService.ListBackups(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
if records == nil {
records = []service.BackupRecord{}
}
response.Success(c, gin.H{"items": records})
}
func (h *BackupHandler) GetBackup(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, record)
}
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"url": url})
}
// ─── 恢复操作(需要重新输入管理员密码) ───
type RestoreBackupRequest struct {
Password string `json:"password" binding:"required"`
}
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
var req RestoreBackupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "password is required for restore operation")
return
}
// 从上下文获取当前管理员用户 ID
sub, ok := middleware.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "unauthorized")
return
}
// 获取管理员用户并验证密码
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if !user.CheckPassword(req.Password) {
response.BadRequest(c, "incorrect admin password")
return
}
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"restored": true})
}

View File

@@ -125,6 +125,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds, OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: settings.MinClaudeCodeVersion, MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling, AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
BackendModeEnabled: settings.BackendModeEnabled,
}) })
} }
@@ -199,6 +200,9 @@ type UpdateSettingsRequest struct {
// 分组隔离 // 分组隔离
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"` AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
// Backend Mode
BackendModeEnabled bool `json:"backend_mode_enabled"`
} }
// UpdateSettings 更新系统设置 // UpdateSettings 更新系统设置
@@ -473,6 +477,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
IdentityPatchPrompt: req.IdentityPatchPrompt, IdentityPatchPrompt: req.IdentityPatchPrompt,
MinClaudeCodeVersion: req.MinClaudeCodeVersion, MinClaudeCodeVersion: req.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling, AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
BackendModeEnabled: req.BackendModeEnabled,
OpsMonitoringEnabled: func() bool { OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil { if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled return *req.OpsMonitoringEnabled
@@ -571,6 +576,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds, OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion, MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling, AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
BackendModeEnabled: updatedSettings.BackendModeEnabled,
}) })
} }
@@ -725,6 +731,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling { if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
changed = append(changed, "allow_ungrouped_key_scheduling") changed = append(changed, "allow_ungrouped_key_scheduling")
} }
if before.BackendModeEnabled != after.BackendModeEnabled {
changed = append(changed, "backend_mode_enabled")
}
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled { if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
changed = append(changed, "purchase_subscription_enabled") changed = append(changed, "purchase_subscription_enabled")
} }

View File

@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
return return
} }
// Backend mode: only admin can login
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return
}
h.respondWithTokenPair(c, user) h.respondWithTokenPair(c, user)
} }
@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
return return
} }
// Delete the login session // Get the user (before session deletion so we can check backend mode)
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
// Get the user
user, err := h.userService.GetByID(c.Request.Context(), session.UserID) user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
// Backend mode: only admin can login (check BEFORE deleting session)
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return
}
// Delete the login session (only after all checks pass)
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
h.respondWithTokenPair(c, user) h.respondWithTokenPair(c, user)
} }
@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
return return
} }
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken) result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
// Backend mode: block non-admin token refresh
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return
}
response.Success(c, RefreshTokenResponse{ response.Success(c, RefreshTokenResponse{
AccessToken: tokenPair.AccessToken, AccessToken: result.AccessToken,
RefreshToken: tokenPair.RefreshToken, RefreshToken: result.RefreshToken,
ExpiresIn: tokenPair.ExpiresIn, ExpiresIn: result.ExpiresIn,
TokenType: "Bearer", TokenType: "Bearer",
}) })
} }

View File

@@ -264,8 +264,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
} }
} }
// 提取 API Key 账号配额限制(apikey 类型有效) // 提取账号配额限制apikey / bedrock 类型有效)
if a.Type == service.AccountTypeAPIKey { if a.IsAPIKeyOrBedrock() {
if limit := a.GetQuotaLimit(); limit > 0 { if limit := a.GetQuotaLimit(); limit > 0 {
out.QuotaLimit = &limit out.QuotaLimit = &limit
used := a.GetQuotaUsed() used := a.GetQuotaUsed()
@@ -281,6 +281,31 @@ func AccountFromServiceShallow(a *service.Account) *Account {
used := a.GetQuotaWeeklyUsed() used := a.GetQuotaWeeklyUsed()
out.QuotaWeeklyUsed = &used out.QuotaWeeklyUsed = &used
} }
// 固定时间重置配置
if mode := a.GetQuotaDailyResetMode(); mode == "fixed" {
out.QuotaDailyResetMode = &mode
hour := a.GetQuotaDailyResetHour()
out.QuotaDailyResetHour = &hour
}
if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" {
out.QuotaWeeklyResetMode = &mode
day := a.GetQuotaWeeklyResetDay()
out.QuotaWeeklyResetDay = &day
hour := a.GetQuotaWeeklyResetHour()
out.QuotaWeeklyResetHour = &hour
}
if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" {
tz := a.GetQuotaResetTimezone()
out.QuotaResetTimezone = &tz
}
if a.Extra != nil {
if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" {
out.QuotaDailyResetAt = &v
}
if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" {
out.QuotaWeeklyResetAt = &v
}
}
} }
return out return out

View File

@@ -81,6 +81,9 @@ type SystemSettings struct {
// 分组隔离 // 分组隔离
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"` AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
// Backend Mode
BackendModeEnabled bool `json:"backend_mode_enabled"`
} }
type DefaultSubscriptionSetting struct { type DefaultSubscriptionSetting struct {
@@ -111,6 +114,7 @@ type PublicSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
Version string `json:"version"` Version string `json:"version"`
} }

View File

@@ -203,6 +203,16 @@ type Account struct {
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"` QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"` QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
// 配额固定时间重置配置
QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"`
QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"`
QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"`
QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"`
QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"`
QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"`
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"` Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"`

View File

@@ -12,6 +12,7 @@ type AdminHandlers struct {
Account *admin.AccountHandler Account *admin.AccountHandler
Announcement *admin.AnnouncementHandler Announcement *admin.AnnouncementHandler
DataManagement *admin.DataManagementHandler DataManagement *admin.DataManagementHandler
Backup *admin.BackupHandler
OAuth *admin.OAuthHandler OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler GeminiOAuth *admin.GeminiOAuthHandler

View File

@@ -181,13 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now() forwardStart := time.Now()
defaultMappedModel := "" defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
if apiKey.Group != nil {
defaultMappedModel = apiKey.Group.DefaultMappedModel
}
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
defaultMappedModel = fallbackModel
}
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds() forwardDurationMs := time.Since(forwardStart).Milliseconds()

View File

@@ -655,14 +655,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now() forwardStart := time.Now()
defaultMappedModel := "" // 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时,
if apiKey.Group != nil { // 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。
defaultMappedModel = apiKey.Group.DefaultMappedModel defaultMappedModel := c.GetString("openai_messages_fallback_model")
}
// 如果使用了降级模型调度,强制使用降级模型
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
defaultMappedModel = fallbackModel
}
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds() forwardDurationMs := time.Since(forwardStart).Milliseconds()

View File

@@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
SoraClientEnabled: settings.SoraClientEnabled, SoraClientEnabled: settings.SoraClientEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
Version: h.version, Version: h.version,
}) })
} }

View File

@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
accountHandler *admin.AccountHandler, accountHandler *admin.AccountHandler,
announcementHandler *admin.AnnouncementHandler, announcementHandler *admin.AnnouncementHandler,
dataManagementHandler *admin.DataManagementHandler, dataManagementHandler *admin.DataManagementHandler,
backupHandler *admin.BackupHandler,
oauthHandler *admin.OAuthHandler, oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler, openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler, geminiOAuthHandler *admin.GeminiOAuthHandler,
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
Account: accountHandler, Account: accountHandler,
Announcement: announcementHandler, Announcement: announcementHandler,
DataManagement: dataManagementHandler, DataManagement: dataManagementHandler,
Backup: backupHandler,
OAuth: oauthHandler, OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler, OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler, GeminiOAuth: geminiOAuthHandler,
@@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet(
admin.NewAccountHandler, admin.NewAccountHandler,
admin.NewAnnouncementHandler, admin.NewAnnouncementHandler,
admin.NewDataManagementHandler, admin.NewDataManagementHandler,
admin.NewBackupHandler,
admin.NewOAuthHandler, admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler, admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler, admin.NewGeminiOAuthHandler,

View File

@@ -19,6 +19,16 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
) )
// ForbiddenError 表示上游返回 403 Forbidden
type ForbiddenError struct {
StatusCode int
Body string
}
func (e *ForbiddenError) Error() string {
return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body)
}
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求v1internal 端点) // NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求v1internal 端点)
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) { func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
// 构建 URL流式请求添加 ?alt=sse 参数 // 构建 URL流式请求添加 ?alt=sse 参数
@@ -514,7 +524,20 @@ type ModelQuotaInfo struct {
// ModelInfo 模型信息 // ModelInfo 模型信息
type ModelInfo struct { type ModelInfo struct {
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"` QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
DisplayName string `json:"displayName,omitempty"`
SupportsImages *bool `json:"supportsImages,omitempty"`
SupportsThinking *bool `json:"supportsThinking,omitempty"`
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
Recommended *bool `json:"recommended,omitempty"`
MaxTokens *int `json:"maxTokens,omitempty"`
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"`
}
// DeprecatedModelInfo 废弃模型转发信息
type DeprecatedModelInfo struct {
NewModelID string `json:"newModelId"`
} }
// FetchAvailableModelsRequest fetchAvailableModels 请求 // FetchAvailableModelsRequest fetchAvailableModels 请求
@@ -524,7 +547,8 @@ type FetchAvailableModelsRequest struct {
// FetchAvailableModelsResponse fetchAvailableModels 响应 // FetchAvailableModelsResponse fetchAvailableModels 响应
type FetchAvailableModelsResponse struct { type FetchAvailableModelsResponse struct {
Models map[string]ModelInfo `json:"models"` Models map[string]ModelInfo `json:"models"`
DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"`
} }
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON // FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
@@ -573,6 +597,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
continue continue
} }
if resp.StatusCode == http.StatusForbidden {
return nil, nil, &ForbiddenError{
StatusCode: resp.StatusCode,
Body: string(respBodyBytes),
}
}
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
} }

View File

@@ -105,6 +105,7 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
assert.Equal(t, "assistant", items[1].Role) assert.Equal(t, "assistant", items[1].Role)
assert.Equal(t, "function_call", items[2].Type) assert.Equal(t, "function_call", items[2].Type)
assert.Equal(t, "fc_call_1", items[2].CallID) assert.Equal(t, "fc_call_1", items[2].CallID)
assert.Empty(t, items[2].ID)
assert.Equal(t, "function_call_output", items[3].Type) assert.Equal(t, "function_call_output", items[3].Type)
assert.Equal(t, "fc_call_1", items[3].CallID) assert.Equal(t, "fc_call_1", items[3].CallID)
assert.Equal(t, "Sunny, 72°F", items[3].Output) assert.Equal(t, "Sunny, 72°F", items[3].Output)

View File

@@ -277,7 +277,6 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
CallID: fcID, CallID: fcID,
Name: b.Name, Name: b.Name,
Arguments: args, Arguments: args,
ID: fcID,
}) })
} }

View File

@@ -99,6 +99,7 @@ func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
// Check function_call item // Check function_call item
assert.Equal(t, "function_call", items[1].Type) assert.Equal(t, "function_call", items[1].Type)
assert.Equal(t, "call_1", items[1].CallID) assert.Equal(t, "call_1", items[1].CallID)
assert.Empty(t, items[1].ID)
assert.Equal(t, "ping", items[1].Name) assert.Equal(t, "ping", items[1].Name)
// Check function_call_output item // Check function_call_output item
@@ -252,6 +253,55 @@ func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T)
assert.Equal(t, "user", items[0].Role) assert.Equal(t, "user", items[0].Role)
assert.Equal(t, "assistant", items[1].Role) assert.Equal(t, "assistant", items[1].Role)
assert.Equal(t, "function_call", items[2].Type) assert.Equal(t, "function_call", items[2].Type)
assert.Empty(t, items[2].ID)
}
func TestChatCompletionsToResponses_AssistantArrayContentPreserved(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Hi"`)},
{Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"A"},{"type":"text","text":"B"}]`)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 2)
assert.Equal(t, "assistant", items[1].Role)
var parts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
require.Len(t, parts, 1)
assert.Equal(t, "output_text", parts[0].Type)
assert.Equal(t, "AB", parts[0].Text)
}
func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Hi"`)},
{Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"internal plan"},{"type":"text","text":"final answer"}]`)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 2)
var parts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
require.Len(t, parts, 1)
assert.Equal(t, "output_text", parts[0].Type)
assert.Contains(t, parts[0].Text, "<thinking>internal plan</thinking>")
assert.Contains(t, parts[0].Text, "final answer")
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -344,8 +394,8 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
var content string var content string
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content)) require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
// Reasoning summary is prepended to text assert.Equal(t, "The answer is 42.", content)
assert.Equal(t, "I thought about it.The answer is 42.", content) assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
} }
func TestResponsesToChatCompletions_Incomplete(t *testing.T) { func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
@@ -582,8 +632,35 @@ func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
Delta: "Thinking...", Delta: "Thinking...",
}, state) }, state)
require.Len(t, chunks, 1) require.Len(t, chunks, 1)
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.ReasoningContent)
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.reasoning_summary_text.done",
}, state)
require.Len(t, chunks, 0)
}
func TestResponsesEventToChatChunks_ReasoningThenTextAutoCloseTag(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.SentRole = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.reasoning_summary_text.delta",
Delta: "plan",
}, state)
require.Len(t, chunks, 1)
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
assert.Equal(t, "plan", *chunks[0].Choices[0].Delta.ReasoningContent)
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.output_text.delta",
Delta: "answer",
}, state)
require.Len(t, chunks, 1)
require.NotNil(t, chunks[0].Choices[0].Delta.Content) require.NotNil(t, chunks[0].Choices[0].Delta.Content)
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content) assert.Equal(t, "answer", *chunks[0].Choices[0].Delta.Content)
} }
func TestFinalizeResponsesChatStream(t *testing.T) { func TestFinalizeResponsesChatStream(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package apicompat
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
) )
// ChatCompletionsToResponses converts a Chat Completions request into a // ChatCompletionsToResponses converts a Chat Completions request into a
@@ -174,8 +175,11 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
// Emit assistant message with output_text if content is non-empty. // Emit assistant message with output_text if content is non-empty.
if len(m.Content) > 0 { if len(m.Content) > 0 {
var s string s, err := parseAssistantContent(m.Content)
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" { if err != nil {
return nil, err
}
if s != "" {
parts := []ResponsesContentPart{{Type: "output_text", Text: s}} parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
partsJSON, err := json.Marshal(parts) partsJSON, err := json.Marshal(parts)
if err != nil { if err != nil {
@@ -196,13 +200,82 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
CallID: tc.ID, CallID: tc.ID,
Name: tc.Function.Name, Name: tc.Function.Name,
Arguments: args, Arguments: args,
ID: tc.ID,
}) })
} }
return items, nil return items, nil
} }
// parseAssistantContent returns assistant content as plain text.
//
// Supported formats:
// - JSON string
// - JSON array of typed parts (e.g. [{"type":"text","text":"..."}])
//
// For structured thinking/reasoning parts, it preserves semantics by wrapping
// the text in explicit tags so downstream can still distinguish it from normal text.
func parseAssistantContent(raw json.RawMessage) (string, error) {
if len(raw) == 0 {
return "", nil
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return s, nil
}
var parts []map[string]any
if err := json.Unmarshal(raw, &parts); err != nil {
// Keep compatibility with prior behavior: unsupported assistant content
// formats are ignored instead of failing the whole request conversion.
return "", nil
}
var b strings.Builder
write := func(v string) error {
_, err := b.WriteString(v)
return err
}
for _, p := range parts {
typ, _ := p["type"].(string)
text, _ := p["text"].(string)
thinking, _ := p["thinking"].(string)
switch typ {
case "thinking", "reasoning":
if thinking != "" {
if err := write("<thinking>"); err != nil {
return "", err
}
if err := write(thinking); err != nil {
return "", err
}
if err := write("</thinking>"); err != nil {
return "", err
}
} else if text != "" {
if err := write("<thinking>"); err != nil {
return "", err
}
if err := write(text); err != nil {
return "", err
}
if err := write("</thinking>"); err != nil {
return "", err
}
}
default:
if text != "" {
if err := write(text); err != nil {
return "", err
}
}
}
}
return b.String(), nil
}
// chatToolToResponses converts a tool result message (role=tool) into a // chatToolToResponses converts a tool result message (role=tool) into a
// function_call_output item. // function_call_output item.
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) { func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {

View File

@@ -29,6 +29,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
} }
var contentText string var contentText string
var reasoningText string
var toolCalls []ChatToolCall var toolCalls []ChatToolCall
for _, item := range resp.Output { for _, item := range resp.Output {
@@ -51,7 +52,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
case "reasoning": case "reasoning":
for _, s := range item.Summary { for _, s := range item.Summary {
if s.Type == "summary_text" && s.Text != "" { if s.Type == "summary_text" && s.Text != "" {
contentText += s.Text reasoningText += s.Text
} }
} }
case "web_search_call": case "web_search_call":
@@ -67,6 +68,9 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
raw, _ := json.Marshal(contentText) raw, _ := json.Marshal(contentText)
msg.Content = raw msg.Content = raw
} }
if reasoningText != "" {
msg.ReasoningContent = reasoningText
}
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls) finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
@@ -153,6 +157,8 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
return resToChatHandleFuncArgsDelta(evt, state) return resToChatHandleFuncArgsDelta(evt, state)
case "response.reasoning_summary_text.delta": case "response.reasoning_summary_text.delta":
return resToChatHandleReasoningDelta(evt, state) return resToChatHandleReasoningDelta(evt, state)
case "response.reasoning_summary_text.done":
return nil
case "response.completed", "response.incomplete", "response.failed": case "response.completed", "response.incomplete", "response.failed":
return resToChatHandleCompleted(evt, state) return resToChatHandleCompleted(evt, state)
default: default:
@@ -276,8 +282,8 @@ func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEv
if evt.Delta == "" { if evt.Delta == "" {
return nil return nil
} }
content := evt.Delta reasoning := evt.Delta
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})} return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ReasoningContent: &reasoning})}
} }
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {

View File

@@ -361,11 +361,12 @@ type ChatStreamOptions struct {
// ChatMessage is a single message in the Chat Completions conversation. // ChatMessage is a single message in the Chat Completions conversation.
type ChatMessage struct { type ChatMessage struct {
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function" Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
Content json.RawMessage `json:"content,omitempty"` Content json.RawMessage `json:"content,omitempty"`
Name string `json:"name,omitempty"` ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` Name string `json:"name,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
// Legacy function calling // Legacy function calling
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"` FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
@@ -466,9 +467,10 @@ type ChatChunkChoice struct {
// ChatDelta carries incremental content in a streaming chunk. // ChatDelta carries incremental content in a streaming chunk.
type ChatDelta struct { type ChatDelta struct {
Role string `json:"role,omitempty"` Role string `json:"role,omitempty"`
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` ReasoningContent *string `json:"reasoning_content,omitempty"`
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------

View File

@@ -397,9 +397,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err) logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
} }
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable { // 普通账号编辑(如 model_mapping / credentials也需要立即刷新单账号快照
r.syncSchedulerAccountSnapshot(ctx, account.ID) // 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。
} r.syncSchedulerAccountSnapshot(ctx, account.ID)
return nil return nil
} }
@@ -1727,8 +1727,96 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string. // nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')` const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired.
// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes.
const dailyExpiredExpr = `(
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz)
ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
+ '24 hours'::interval <= NOW()
END
)`
// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired.
const weeklyExpiredExpr = `(
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz)
ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
+ '168 hours'::interval <= NOW()
END
)`
// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs.
// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour.
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
const nextDailyResetAtExpr = `(
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
THEN to_char((
-- Compute today's reset point in the configured timezone, then pick next future one
CASE WHEN NOW() >= (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
-- NOW() is at or past today's reset point → next reset is tomorrow
THEN (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
+ '1 day'::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
-- NOW() is before today's reset point → next reset is today
ELSE (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
END
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
ELSE NULL END
)`
// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs.
// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour.
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
const nextWeeklyResetAtExpr = `(
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
THEN to_char((
-- Compute this week's reset point in the configured timezone
-- Step 1: get today's date at reset hour in configured tz
-- Step 2: compute days forward to target weekday
-- Step 3: if same day but past reset hour, advance 7 days
CASE
WHEN (
-- days_forward = (target_day - current_day + 7) % 7
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
+ 7) % 7
) = 0 AND NOW() >= (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
-- Same weekday and past reset hour → next week
THEN (
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
+ '7 days'::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
ELSE (
-- Advance to target weekday this week (or next if days_forward > 0)
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
+ ((
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
+ 7) % 7
) || ' days')::interval
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
END
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
ELSE NULL END
)`
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度) // IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
// 日/周额度在周期过期时自动重置为 0 再递增。 // 日/周额度在周期过期时自动重置为 0 再递增。
// 支持滚动窗口rolling和固定时间fixed两种重置模式。
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error { func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
rows, err := r.sql.QueryContext(ctx, rows, err := r.sql.QueryContext(ctx,
`UPDATE accounts SET extra = ( `UPDATE accounts SET extra = (
@@ -1739,31 +1827,35 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN || CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object( jsonb_build_object(
'quota_daily_used', 'quota_daily_used',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) CASE WHEN `+dailyExpiredExpr+`
+ '24 hours'::interval <= NOW()
THEN $1 THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END, ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start', 'quota_daily_start',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) CASE WHEN `+dailyExpiredExpr+`
+ '24 hours'::interval <= NOW()
THEN `+nowUTC+` THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
) )
-- 固定模式重置时更新下次重置时间
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
ELSE '{}'::jsonb END
ELSE '{}'::jsonb END ELSE '{}'::jsonb END
-- 周额度:仅在 quota_weekly_limit > 0 时处理 -- 周额度:仅在 quota_weekly_limit > 0 时处理
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN || CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object( jsonb_build_object(
'quota_weekly_used', 'quota_weekly_used',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) CASE WHEN `+weeklyExpiredExpr+`
+ '168 hours'::interval <= NOW()
THEN $1 THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END, ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start', 'quota_weekly_start',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) CASE WHEN `+weeklyExpiredExpr+`
+ '168 hours'::interval <= NOW()
THEN `+nowUTC+` THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
) )
-- 固定模式重置时更新下次重置时间
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
ELSE '{}'::jsonb END
ELSE '{}'::jsonb END ELSE '{}'::jsonb END
), updated_at = NOW() ), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL WHERE id = $2 AND deleted_at IS NULL
@@ -1796,12 +1888,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
} }
// ResetQuotaUsed 重置账号所有维度的配额用量为 0 // ResetQuotaUsed 重置账号所有维度的配额用量为 0
// 保留固定重置模式的配置字段quota_daily_reset_mode 等),仅清零用量和窗口起始时间
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error { func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, _, err := r.sql.ExecContext(ctx,
`UPDATE accounts SET extra = ( `UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb) COALESCE(extra, '{}'::jsonb)
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb || '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW() ) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW()
WHERE id = $1 AND deleted_at IS NULL`, WHERE id = $1 AND deleted_at IS NULL`,
id) id)
if err != nil { if err != nil {

View File

@@ -142,6 +142,35 @@ func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status) s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
} }
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnCredentialsChange() {
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "sync-credentials-update",
Status: service.StatusActive,
Schedulable: true,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.1",
},
},
})
cacheRecorder := &schedulerCacheRecorder{}
s.repo.schedulerCache = cacheRecorder
account.Credentials = map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.2",
},
}
err := s.repo.Update(s.ctx, account)
s.Require().NoError(err, "Update")
s.Require().Len(cacheRecorder.setAccounts, 1)
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
mapping, ok := cacheRecorder.setAccounts[0].Credentials["model_mapping"].(map[string]any)
s.Require().True(ok)
s.Require().Equal("gpt-5.2", mapping["gpt-5"])
}
func (s *AccountRepoSuite) TestDelete() { func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})

View File

@@ -0,0 +1,98 @@
package repository
import (
"context"
"fmt"
"io"
"os/exec"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// PgDumper implements service.DBDumper using pg_dump/psql
type PgDumper struct {
cfg *config.DatabaseConfig
}
// NewPgDumper creates a new PgDumper
func NewPgDumper(cfg *config.Config) service.DBDumper {
return &PgDumper{cfg: &cfg.Database}
}
// Dump executes pg_dump and returns a streaming reader of the output
func (d *PgDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
args := []string{
"-h", d.cfg.Host,
"-p", fmt.Sprintf("%d", d.cfg.Port),
"-U", d.cfg.User,
"-d", d.cfg.DBName,
"--no-owner",
"--no-acl",
"--clean",
"--if-exists",
}
cmd := exec.CommandContext(ctx, "pg_dump", args...)
if d.cfg.Password != "" {
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
}
if d.cfg.SSLMode != "" {
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("create stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("start pg_dump: %w", err)
}
// 返回一个 ReadCloser读 stdout关闭时等待进程退出
return &cmdReadCloser{ReadCloser: stdout, cmd: cmd}, nil
}
// Restore executes psql to restore from a streaming reader
func (d *PgDumper) Restore(ctx context.Context, data io.Reader) error {
args := []string{
"-h", d.cfg.Host,
"-p", fmt.Sprintf("%d", d.cfg.Port),
"-U", d.cfg.User,
"-d", d.cfg.DBName,
"--single-transaction",
}
cmd := exec.CommandContext(ctx, "psql", args...)
if d.cfg.Password != "" {
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
}
if d.cfg.SSLMode != "" {
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
}
cmd.Stdin = data
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("%v: %s", err, string(output))
}
return nil
}
// cmdReadCloser wraps a command stdout pipe and waits for the process on Close
type cmdReadCloser struct {
io.ReadCloser
cmd *exec.Cmd
}
func (c *cmdReadCloser) Close() error {
// Close the pipe first
_ = c.ReadCloser.Close()
// Wait for the process to exit
if err := c.cmd.Wait(); err != nil {
return fmt.Errorf("pg_dump exited with error: %w", err)
}
return nil
}

View File

@@ -0,0 +1,116 @@
package repository
import (
"bytes"
"context"
"fmt"
"io"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage
type S3BackupStore struct {
client *s3.Client
bucket string
}
// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores
func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) {
region := cfg.Region
if region == "" {
region = "auto" // Cloudflare R2 默认 region
}
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
awsconfig.WithRegion(region),
awsconfig.WithCredentialsProvider(
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
),
)
if err != nil {
return nil, fmt.Errorf("load aws config: %w", err)
}
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
if cfg.Endpoint != "" {
o.BaseEndpoint = &cfg.Endpoint
}
if cfg.ForcePathStyle {
o.UsePathStyle = true
}
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
})
return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil
}
}
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
// 读取全部内容以获取大小S3 PutObject 需要知道内容长度)
data, err := io.ReadAll(body)
if err != nil {
return 0, fmt.Errorf("read body: %w", err)
}
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
Bucket: &s.bucket,
Key: &key,
Body: bytes.NewReader(data),
ContentType: &contentType,
})
if err != nil {
return 0, fmt.Errorf("S3 PutObject: %w", err)
}
return int64(len(data)), nil
}
func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) {
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
Bucket: &s.bucket,
Key: &key,
})
if err != nil {
return nil, fmt.Errorf("S3 GetObject: %w", err)
}
return result.Body, nil
}
func (s *S3BackupStore) Delete(ctx context.Context, key string) error {
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: &s.bucket,
Key: &key,
})
return err
}
func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) {
presignClient := s3.NewPresignClient(s.client)
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
Bucket: &s.bucket,
Key: &key,
}, s3.WithPresignExpires(expiry))
if err != nil {
return "", fmt.Errorf("presign url: %w", err)
}
return result.URL, nil
}
func (s *S3BackupStore) HeadBucket(ctx context.Context) error {
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
Bucket: &s.bucket,
})
if err != nil {
return fmt.Errorf("S3 HeadBucket failed: %w", err)
}
return nil
}

View File

@@ -100,6 +100,10 @@ var ProviderSet = wire.NewSet(
// Encryptors // Encryptors
NewAESEncryptor, NewAESEncryptor,
// Backup infrastructure
NewPgDumper,
NewS3BackupStoreFactory,
// HTTP service ports (DI Strategy A: return interface directly) // HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier, NewTurnstileVerifier,
ProvidePricingRemoteClient, ProvidePricingRemoteClient,

View File

@@ -537,6 +537,7 @@ func TestAPIContracts(t *testing.T) {
"purchase_subscription_url": "", "purchase_subscription_url": "",
"min_claude_code_version": "", "min_claude_code_version": "",
"allow_ungrouped_key_scheduling": false, "allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"custom_menu_items": [] "custom_menu_items": []
} }
}`, }`,

View File

@@ -0,0 +1,51 @@
package middleware
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// BackendModeUserGuard blocks non-admin users from accessing user routes when backend mode is enabled.
// Must be placed AFTER JWT auth middleware so that the user role is available in context.
func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFunc {
return func(c *gin.Context) {
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
c.Next()
return
}
role, _ := GetUserRoleFromContext(c)
if role == "admin" {
c.Next()
return
}
response.Forbidden(c, "Backend mode is active. User self-service is disabled.")
c.Abort()
}
}
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
// Allows: login, login/2fa, logout, refresh (admin needs these).
// Blocks: register, forgot-password, reset-password, OAuth, etc.
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
return func(c *gin.Context) {
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
c.Next()
return
}
path := c.Request.URL.Path
// Allow login, 2FA, logout, refresh, public settings
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
for _, suffix := range allowedSuffixes {
if strings.HasSuffix(path, suffix) {
c.Next()
return
}
}
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
c.Abort()
}
}

View File

@@ -0,0 +1,239 @@
//go:build unit
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type bmSettingRepo struct {
values map[string]string
}
func (r *bmSettingRepo) Get(_ context.Context, _ string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (r *bmSettingRepo) GetValue(_ context.Context, key string) (string, error) {
v, ok := r.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return v, nil
}
func (r *bmSettingRepo) Set(_ context.Context, _, _ string) error {
panic("unexpected Set call")
}
func (r *bmSettingRepo) GetMultiple(_ context.Context, _ []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (r *bmSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
if r.values == nil {
r.values = make(map[string]string, len(settings))
}
for key, value := range settings {
r.values[key] = value
}
return nil
}
func (r *bmSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (r *bmSettingRepo) Delete(_ context.Context, _ string) error {
panic("unexpected Delete call")
}
func newBackendModeSettingService(t *testing.T, enabled string) *service.SettingService {
t.Helper()
repo := &bmSettingRepo{
values: map[string]string{
service.SettingKeyBackendModeEnabled: enabled,
},
}
svc := service.NewSettingService(repo, &config.Config{})
require.NoError(t, svc.UpdateSettings(context.Background(), &service.SystemSettings{
BackendModeEnabled: enabled == "true",
}))
return svc
}
func stringPtr(v string) *string {
return &v
}
func TestBackendModeUserGuard(t *testing.T) {
tests := []struct {
name string
nilService bool
enabled string
role *string
wantStatus int
}{
{
name: "disabled_allows_all",
enabled: "false",
role: stringPtr("user"),
wantStatus: http.StatusOK,
},
{
name: "nil_service_allows_all",
nilService: true,
role: stringPtr("user"),
wantStatus: http.StatusOK,
},
{
name: "enabled_admin_allowed",
enabled: "true",
role: stringPtr("admin"),
wantStatus: http.StatusOK,
},
{
name: "enabled_user_blocked",
enabled: "true",
role: stringPtr("user"),
wantStatus: http.StatusForbidden,
},
{
name: "enabled_no_role_blocked",
enabled: "true",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_empty_role_blocked",
enabled: "true",
role: stringPtr(""),
wantStatus: http.StatusForbidden,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
if tc.role != nil {
role := *tc.role
r.Use(func(c *gin.Context) {
c.Set(string(ContextKeyUserRole), role)
c.Next()
})
}
var svc *service.SettingService
if !tc.nilService {
svc = newBackendModeSettingService(t, tc.enabled)
}
r.Use(BackendModeUserGuard(svc))
r.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
r.ServeHTTP(w, req)
require.Equal(t, tc.wantStatus, w.Code)
})
}
}
func TestBackendModeAuthGuard(t *testing.T) {
tests := []struct {
name string
nilService bool
enabled string
path string
wantStatus int
}{
{
name: "disabled_allows_all",
enabled: "false",
path: "/api/v1/auth/register",
wantStatus: http.StatusOK,
},
{
name: "nil_service_allows_all",
nilService: true,
path: "/api/v1/auth/register",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_login",
enabled: "true",
path: "/api/v1/auth/login",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_login_2fa",
enabled: "true",
path: "/api/v1/auth/login/2fa",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_logout",
enabled: "true",
path: "/api/v1/auth/logout",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_refresh",
enabled: "true",
path: "/api/v1/auth/refresh",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_register",
enabled: "true",
path: "/api/v1/auth/register",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_blocks_forgot_password",
enabled: "true",
path: "/api/v1/auth/forgot-password",
wantStatus: http.StatusForbidden,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
var svc *service.SettingService
if !tc.nilService {
svc = newBackendModeSettingService(t, tc.enabled)
}
r.Use(BackendModeAuthGuard(svc))
r.Any("/*path", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
r.ServeHTTP(w, req)
require.Equal(t, tc.wantStatus, w.Code)
})
}
}

View File

@@ -107,9 +107,9 @@ func registerRoutes(
v1 := r.Group("/api/v1") v1 := r.Group("/api/v1")
// 注册各模块路由 // 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient) routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
routes.RegisterUserRoutes(v1, h, jwtAuth) routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
routes.RegisterSoraClientRoutes(v1, h, jwtAuth) routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
} }

View File

@@ -58,6 +58,9 @@ func RegisterAdminRoutes(
// 数据管理 // 数据管理
registerDataManagementRoutes(admin, h) registerDataManagementRoutes(admin, h)
// 数据库备份恢复
registerBackupRoutes(admin, h)
// 运维监控Ops // 运维监控Ops
registerOpsRoutes(admin, h) registerOpsRoutes(admin, h)
@@ -440,6 +443,30 @@ func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
} }
} }
func registerBackupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
backup := admin.Group("/backups")
{
// S3 存储配置
backup.GET("/s3-config", h.Admin.Backup.GetS3Config)
backup.PUT("/s3-config", h.Admin.Backup.UpdateS3Config)
backup.POST("/s3-config/test", h.Admin.Backup.TestS3Connection)
// 定时备份配置
backup.GET("/schedule", h.Admin.Backup.GetSchedule)
backup.PUT("/schedule", h.Admin.Backup.UpdateSchedule)
// 备份操作
backup.POST("", h.Admin.Backup.CreateBackup)
backup.GET("", h.Admin.Backup.ListBackups)
backup.GET("/:id", h.Admin.Backup.GetBackup)
backup.DELETE("/:id", h.Admin.Backup.DeleteBackup)
backup.GET("/:id/download-url", h.Admin.Backup.GetDownloadURL)
// 恢复操作
backup.POST("/:id/restore", h.Admin.Backup.RestoreBackup)
}
}
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
system := admin.Group("/system") system := admin.Group("/system")
{ {

View File

@@ -6,6 +6,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/middleware" "github.com/Wei-Shaw/sub2api/internal/middleware"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
@@ -17,12 +18,14 @@ func RegisterAuthRoutes(
h *handler.Handlers, h *handler.Handlers,
jwtAuth servermiddleware.JWTAuthMiddleware, jwtAuth servermiddleware.JWTAuthMiddleware,
redisClient *redis.Client, redisClient *redis.Client,
settingService *service.SettingService,
) { ) {
// 创建速率限制器 // 创建速率限制器
rateLimiter := middleware.NewRateLimiter(redisClient) rateLimiter := middleware.NewRateLimiter(redisClient)
// 公开接口 // 公开接口
auth := v1.Group("/auth") auth := v1.Group("/auth")
auth.Use(servermiddleware.BackendModeAuthGuard(settingService))
{ {
// 注册/登录/2FA/验证码发送均属于高风险入口增加服务端兜底限流Redis 故障时 fail-close // 注册/登录/2FA/验证码发送均属于高风险入口增加服务端兜底限流Redis 故障时 fail-close
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{ auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
@@ -78,6 +81,7 @@ func RegisterAuthRoutes(
// 需要认证的当前用户信息 // 需要认证的当前用户信息
authenticated := v1.Group("") authenticated := v1.Group("")
authenticated.Use(gin.HandlerFunc(jwtAuth)) authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(servermiddleware.BackendModeUserGuard(settingService))
{ {
authenticated.GET("/auth/me", h.Auth.GetCurrentUser) authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 撤销所有会话(需要认证) // 撤销所有会话(需要认证)

View File

@@ -29,6 +29,7 @@ func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
c.Next() c.Next()
}), }),
redisClient, redisClient,
nil,
) )
return router return router

View File

@@ -3,6 +3,7 @@ package routes
import ( import (
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -12,6 +13,7 @@ func RegisterSoraClientRoutes(
v1 *gin.RouterGroup, v1 *gin.RouterGroup,
h *handler.Handlers, h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware, jwtAuth middleware.JWTAuthMiddleware,
settingService *service.SettingService,
) { ) {
if h.SoraClient == nil { if h.SoraClient == nil {
return return
@@ -19,6 +21,7 @@ func RegisterSoraClientRoutes(
authenticated := v1.Group("/sora") authenticated := v1.Group("/sora")
authenticated.Use(gin.HandlerFunc(jwtAuth)) authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(middleware.BackendModeUserGuard(settingService))
{ {
authenticated.POST("/generate", h.SoraClient.Generate) authenticated.POST("/generate", h.SoraClient.Generate)
authenticated.GET("/generations", h.SoraClient.ListGenerations) authenticated.GET("/generations", h.SoraClient.ListGenerations)

View File

@@ -3,6 +3,7 @@ package routes
import ( import (
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -12,9 +13,11 @@ func RegisterUserRoutes(
v1 *gin.RouterGroup, v1 *gin.RouterGroup,
h *handler.Handlers, h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware, jwtAuth middleware.JWTAuthMiddleware,
settingService *service.SettingService,
) { ) {
authenticated := v1.Group("") authenticated := v1.Group("")
authenticated.Use(gin.HandlerFunc(jwtAuth)) authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(middleware.BackendModeUserGuard(settingService))
{ {
// 用户接口 // 用户接口
user := authenticated.Group("/user") user := authenticated.Group("/user")

View File

@@ -3,6 +3,7 @@ package service
import ( import (
"encoding/json" "encoding/json"
"errors"
"hash/fnv" "hash/fnv"
"reflect" "reflect"
"sort" "sort"
@@ -522,16 +523,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配) // GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
// 如果未配置 mapping返回原始模型名 // 如果未配置 mapping返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string { func (a *Account) GetMappedModel(requestedModel string) string {
mappedModel, _ := a.ResolveMappedModel(requestedModel)
return mappedModel
}
// ResolveMappedModel 获取映射后的模型名,并返回是否命中了账号级映射。
// matched=true 表示命中了精确映射或通配符映射,即使映射结果与原模型名相同。
func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, matched bool) {
mapping := a.GetModelMapping() mapping := a.GetModelMapping()
if len(mapping) == 0 { if len(mapping) == 0 {
return requestedModel return requestedModel, false
} }
// 精确匹配优先 // 精确匹配优先
if mappedModel, exists := mapping[requestedModel]; exists { if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel return mappedModel, true
} }
// 通配符匹配(最长优先) // 通配符匹配(最长优先)
return matchWildcardMapping(mapping, requestedModel) return matchWildcardMappingResult(mapping, requestedModel)
} }
func (a *Account) GetBaseURL() string { func (a *Account) GetBaseURL() string {
@@ -605,9 +613,7 @@ func matchWildcard(pattern, str string) bool {
return matchAntigravityWildcard(pattern, str) return matchAntigravityWildcard(pattern, str)
} }
// matchWildcardMapping 通配符映射匹配(最长优先) func matchWildcardMappingResult(mapping map[string]string, requestedModel string) (string, bool) {
// 如果没有匹配,返回原始字符串
func matchWildcardMapping(mapping map[string]string, requestedModel string) string {
// 收集所有匹配的 pattern按长度降序排序最长优先 // 收集所有匹配的 pattern按长度降序排序最长优先
type patternMatch struct { type patternMatch struct {
pattern string pattern string
@@ -622,7 +628,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
} }
if len(matches) == 0 { if len(matches) == 0 {
return requestedModel // 无匹配,返回原始模型名 return requestedModel, false // 无匹配,返回原始模型名
} }
// 按 pattern 长度降序排序 // 按 pattern 长度降序排序
@@ -633,7 +639,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
return matches[i].pattern < matches[j].pattern return matches[i].pattern < matches[j].pattern
}) })
return matches[0].target return matches[0].target, true
} }
func (a *Account) IsCustomErrorCodesEnabled() bool { func (a *Account) IsCustomErrorCodesEnabled() bool {
@@ -651,7 +657,7 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
// IsPoolMode 检查 API Key 账号是否启用池模式。 // IsPoolMode 检查 API Key 账号是否启用池模式。
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。 // 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
func (a *Account) IsPoolMode() bool { func (a *Account) IsPoolMode() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil { if !a.IsAPIKeyOrBedrock() || a.Credentials == nil {
return false return false
} }
if v, ok := a.Credentials["pool_mode"]; ok { if v, ok := a.Credentials["pool_mode"]; ok {
@@ -766,11 +772,16 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
} }
func (a *Account) IsBedrock() bool { func (a *Account) IsBedrock() bool {
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeBedrock || a.Type == AccountTypeBedrockAPIKey) return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
} }
func (a *Account) IsBedrockAPIKey() bool { func (a *Account) IsBedrockAPIKey() bool {
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrockAPIKey return a.IsBedrock() && a.GetCredential("auth_mode") == "apikey"
}
// IsAPIKeyOrBedrock 返回账号类型是否支持配额和池模式等特性
func (a *Account) IsAPIKeyOrBedrock() bool {
return a.Type == AccountTypeAPIKey || a.Type == AccountTypeBedrock
} }
func (a *Account) IsOpenAI() bool { func (a *Account) IsOpenAI() bool {
@@ -1269,6 +1280,240 @@ func (a *Account) getExtraTime(key string) time.Time {
return time.Time{} return time.Time{}
} }
// getExtraString 从 Extra 中读取指定 key 的字符串值
func (a *Account) getExtraString(key string) string {
if a.Extra == nil {
return ""
}
if v, ok := a.Extra[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// getExtraInt 从 Extra 中读取指定 key 的 int 值
func (a *Account) getExtraInt(key string) int {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra[key]; ok {
return int(parseExtraFloat64(v))
}
return 0
}
// GetQuotaDailyResetMode 获取日额度重置模式:"rolling"(默认)或 "fixed"
func (a *Account) GetQuotaDailyResetMode() string {
if m := a.getExtraString("quota_daily_reset_mode"); m == "fixed" {
return "fixed"
}
return "rolling"
}
// GetQuotaDailyResetHour 获取固定重置的小时0-23默认 0
func (a *Account) GetQuotaDailyResetHour() int {
return a.getExtraInt("quota_daily_reset_hour")
}
// GetQuotaWeeklyResetMode 获取周额度重置模式:"rolling"(默认)或 "fixed"
func (a *Account) GetQuotaWeeklyResetMode() string {
if m := a.getExtraString("quota_weekly_reset_mode"); m == "fixed" {
return "fixed"
}
return "rolling"
}
// GetQuotaWeeklyResetDay 获取固定重置的星期几0=周日, 1=周一, ..., 6=周六),默认 1周一
func (a *Account) GetQuotaWeeklyResetDay() int {
if a.Extra == nil {
return 1
}
if _, ok := a.Extra["quota_weekly_reset_day"]; !ok {
return 1
}
return a.getExtraInt("quota_weekly_reset_day")
}
// GetQuotaWeeklyResetHour 获取周配额固定重置的小时0-23默认 0
func (a *Account) GetQuotaWeeklyResetHour() int {
return a.getExtraInt("quota_weekly_reset_hour")
}
// GetQuotaResetTimezone 获取固定重置的时区名IANA默认 "UTC"
func (a *Account) GetQuotaResetTimezone() string {
if tz := a.getExtraString("quota_reset_timezone"); tz != "" {
return tz
}
return "UTC"
}
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
t := after.In(tz)
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
if !after.Before(today) {
return today.AddDate(0, 0, 1)
}
return today
}
// lastFixedDailyReset 计算 now 之前最近一次的每日固定重置时间点
func lastFixedDailyReset(hour int, tz *time.Location, now time.Time) time.Time {
t := now.In(tz)
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
if now.Before(today) {
return today.AddDate(0, 0, -1)
}
return today
}
// nextFixedWeeklyReset 计算在 after 之后的下一个每周固定重置时间点
// day: 0=Sunday, 1=Monday, ..., 6=Saturday
func nextFixedWeeklyReset(day, hour int, tz *time.Location, after time.Time) time.Time {
t := after.In(tz)
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
currentDay := int(todayReset.Weekday())
daysForward := (day - currentDay + 7) % 7
if daysForward == 0 && !after.Before(todayReset) {
daysForward = 7
}
return todayReset.AddDate(0, 0, daysForward)
}
// lastFixedWeeklyReset 计算 now 之前最近一次的每周固定重置时间点
func lastFixedWeeklyReset(day, hour int, tz *time.Location, now time.Time) time.Time {
t := now.In(tz)
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
currentDay := int(todayReset.Weekday())
daysBack := (currentDay - day + 7) % 7
if daysBack == 0 && now.Before(todayReset) {
daysBack = 7
}
return todayReset.AddDate(0, 0, -daysBack)
}
// isFixedDailyPeriodExpired 检查日配额是否在固定时间模式下已过期
func (a *Account) isFixedDailyPeriodExpired(periodStart time.Time) bool {
if periodStart.IsZero() {
return true
}
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
if err != nil {
tz = time.UTC
}
lastReset := lastFixedDailyReset(a.GetQuotaDailyResetHour(), tz, time.Now())
return periodStart.Before(lastReset)
}
// isFixedWeeklyPeriodExpired 检查周配额是否在固定时间模式下已过期
func (a *Account) isFixedWeeklyPeriodExpired(periodStart time.Time) bool {
if periodStart.IsZero() {
return true
}
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
if err != nil {
tz = time.UTC
}
lastReset := lastFixedWeeklyReset(a.GetQuotaWeeklyResetDay(), a.GetQuotaWeeklyResetHour(), tz, time.Now())
return periodStart.Before(lastReset)
}
// ComputeQuotaResetAt 根据当前配置计算并填充 extra 中的 quota_daily_reset_at / quota_weekly_reset_at
// 在保存账号配置时调用
func ComputeQuotaResetAt(extra map[string]any) {
now := time.Now()
tzName, _ := extra["quota_reset_timezone"].(string)
if tzName == "" {
tzName = "UTC"
}
tz, err := time.LoadLocation(tzName)
if err != nil {
tz = time.UTC
}
// 日配额固定重置时间
if mode, _ := extra["quota_daily_reset_mode"].(string); mode == "fixed" {
hour := int(parseExtraFloat64(extra["quota_daily_reset_hour"]))
if hour < 0 || hour > 23 {
hour = 0
}
resetAt := nextFixedDailyReset(hour, tz, now)
extra["quota_daily_reset_at"] = resetAt.UTC().Format(time.RFC3339)
} else {
delete(extra, "quota_daily_reset_at")
}
// 周配额固定重置时间
if mode, _ := extra["quota_weekly_reset_mode"].(string); mode == "fixed" {
day := 1 // 默认周一
if d, ok := extra["quota_weekly_reset_day"]; ok {
day = int(parseExtraFloat64(d))
}
if day < 0 || day > 6 {
day = 1
}
hour := int(parseExtraFloat64(extra["quota_weekly_reset_hour"]))
if hour < 0 || hour > 23 {
hour = 0
}
resetAt := nextFixedWeeklyReset(day, hour, tz, now)
extra["quota_weekly_reset_at"] = resetAt.UTC().Format(time.RFC3339)
} else {
delete(extra, "quota_weekly_reset_at")
}
}
// ValidateQuotaResetConfig 校验配额固定重置时间配置的合法性
func ValidateQuotaResetConfig(extra map[string]any) error {
if extra == nil {
return nil
}
// 校验时区
if tz, ok := extra["quota_reset_timezone"].(string); ok && tz != "" {
if _, err := time.LoadLocation(tz); err != nil {
return errors.New("invalid quota_reset_timezone: must be a valid IANA timezone name")
}
}
// 日配额重置模式
if mode, ok := extra["quota_daily_reset_mode"].(string); ok {
if mode != "rolling" && mode != "fixed" {
return errors.New("quota_daily_reset_mode must be 'rolling' or 'fixed'")
}
}
// 日配额重置小时
if v, ok := extra["quota_daily_reset_hour"]; ok {
hour := int(parseExtraFloat64(v))
if hour < 0 || hour > 23 {
return errors.New("quota_daily_reset_hour must be between 0 and 23")
}
}
// 周配额重置模式
if mode, ok := extra["quota_weekly_reset_mode"].(string); ok {
if mode != "rolling" && mode != "fixed" {
return errors.New("quota_weekly_reset_mode must be 'rolling' or 'fixed'")
}
}
// 周配额重置星期几
if v, ok := extra["quota_weekly_reset_day"]; ok {
day := int(parseExtraFloat64(v))
if day < 0 || day > 6 {
return errors.New("quota_weekly_reset_day must be between 0 (Sunday) and 6 (Saturday)")
}
}
// 周配额重置小时
if v, ok := extra["quota_weekly_reset_hour"]; ok {
hour := int(parseExtraFloat64(v))
if hour < 0 || hour > 23 {
return errors.New("quota_weekly_reset_hour must be between 0 and 23")
}
}
return nil
}
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制 // HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
func (a *Account) HasAnyQuotaLimit() bool { func (a *Account) HasAnyQuotaLimit() bool {
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0 return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
@@ -1291,14 +1536,26 @@ func (a *Account) IsQuotaExceeded() bool {
// 日额度(周期过期视为未超限,下次 increment 会重置) // 日额度(周期过期视为未超限,下次 increment 会重置)
if limit := a.GetQuotaDailyLimit(); limit > 0 { if limit := a.GetQuotaDailyLimit(); limit > 0 {
start := a.getExtraTime("quota_daily_start") start := a.getExtraTime("quota_daily_start")
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit { var expired bool
if a.GetQuotaDailyResetMode() == "fixed" {
expired = a.isFixedDailyPeriodExpired(start)
} else {
expired = isPeriodExpired(start, 24*time.Hour)
}
if !expired && a.GetQuotaDailyUsed() >= limit {
return true return true
} }
} }
// 周额度 // 周额度
if limit := a.GetQuotaWeeklyLimit(); limit > 0 { if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
start := a.getExtraTime("quota_weekly_start") start := a.getExtraTime("quota_weekly_start")
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit { var expired bool
if a.GetQuotaWeeklyResetMode() == "fixed" {
expired = a.isFixedWeeklyPeriodExpired(start)
} else {
expired = isPeriodExpired(start, 7*24*time.Hour)
}
if !expired && a.GetQuotaWeeklyUsed() >= limit {
return true return true
} }
} }

View File

@@ -0,0 +1,516 @@
//go:build unit
package service
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// nextFixedDailyReset
// ---------------------------------------------------------------------------
func TestNextFixedDailyReset_BeforeResetHour(t *testing.T) {
tz := time.UTC
// 2026-03-14 06:00 UTC, reset hour = 9
after := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
got := nextFixedDailyReset(9, tz, after)
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_AtResetHour(t *testing.T) {
tz := time.UTC
// Exactly at reset hour → should return tomorrow
after := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
got := nextFixedDailyReset(9, tz, after)
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_AfterResetHour(t *testing.T) {
tz := time.UTC
// After reset hour → should return tomorrow
after := time.Date(2026, 3, 14, 15, 30, 0, 0, tz)
got := nextFixedDailyReset(9, tz, after)
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_MidnightReset(t *testing.T) {
tz := time.UTC
// Reset at hour 0 (midnight), currently 23:59
after := time.Date(2026, 3, 14, 23, 59, 0, 0, tz)
got := nextFixedDailyReset(0, tz, after)
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_NonUTCTimezone(t *testing.T) {
tz, err := time.LoadLocation("Asia/Shanghai")
require.NoError(t, err)
// 2026-03-14 07:00 UTC = 2026-03-14 15:00 CST, reset hour = 9 (CST)
after := time.Date(2026, 3, 14, 7, 0, 0, 0, time.UTC)
got := nextFixedDailyReset(9, tz, after)
// Already past 9:00 CST today → tomorrow 9:00 CST = 2026-03-15 01:00 UTC
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// lastFixedDailyReset
// ---------------------------------------------------------------------------
func TestLastFixedDailyReset_BeforeResetHour(t *testing.T) {
tz := time.UTC
now := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
got := lastFixedDailyReset(9, tz, now)
// Before today's 9:00 → yesterday 9:00
want := time.Date(2026, 3, 13, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedDailyReset_AtResetHour(t *testing.T) {
tz := time.UTC
now := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
got := lastFixedDailyReset(9, tz, now)
// At exactly 9:00 → today 9:00
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedDailyReset_AfterResetHour(t *testing.T) {
tz := time.UTC
now := time.Date(2026, 3, 14, 15, 0, 0, 0, tz)
got := lastFixedDailyReset(9, tz, now)
// After 9:00 → today 9:00
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// nextFixedWeeklyReset
// ---------------------------------------------------------------------------
func TestNextFixedWeeklyReset_TargetDayAhead(t *testing.T) {
tz := time.UTC
// 2026-03-14 is Saturday (day=6), target = Monday (day=1), hour = 9
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday = 2026-03-16
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayToday_BeforeHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, before 9:00
after := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Today at 9:00
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayToday_AtHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday, target = Monday, hour = 9, exactly at 9:00
after := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday at 9:00
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayToday_AfterHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday, target = Monday, hour = 9, after 9:00
after := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday at 9:00
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayPast(t *testing.T) {
tz := time.UTC
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
after := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday = 2026-03-23
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_Sunday(t *testing.T) {
tz := time.UTC
// 2026-03-14 is Saturday (day=6), target = Sunday (day=0)
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
got := nextFixedWeeklyReset(0, 0, tz, after)
// Next Sunday = 2026-03-15
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// lastFixedWeeklyReset
// ---------------------------------------------------------------------------
func TestLastFixedWeeklyReset_SameDay_AfterHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, now = 15:00
now := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
got := lastFixedWeeklyReset(1, 9, tz, now)
// Today at 9:00
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedWeeklyReset_SameDay_BeforeHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday, target = Monday, hour = 9, now = 06:00
now := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
got := lastFixedWeeklyReset(1, 9, tz, now)
// Last Monday at 9:00 = 2026-03-09
want := time.Date(2026, 3, 9, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedWeeklyReset_DifferentDay(t *testing.T) {
tz := time.UTC
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
now := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
got := lastFixedWeeklyReset(1, 9, tz, now)
// Last Monday = 2026-03-16
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// isFixedDailyPeriodExpired
// ---------------------------------------------------------------------------
func TestIsFixedDailyPeriodExpired_ZeroPeriodStart(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
assert.True(t, a.isFixedDailyPeriodExpired(time.Time{}))
}
func TestIsFixedDailyPeriodExpired_NotExpired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started after the most recent reset → not expired
// (This test uses a time very close to "now", which is after the last reset)
periodStart := time.Now().Add(-1 * time.Minute)
assert.False(t, a.isFixedDailyPeriodExpired(periodStart))
}
func TestIsFixedDailyPeriodExpired_Expired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started 3 days ago → definitely expired
periodStart := time.Now().Add(-72 * time.Hour)
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
}
func TestIsFixedDailyPeriodExpired_InvalidTimezone(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "Invalid/Timezone",
}}
// Invalid timezone falls back to UTC
periodStart := time.Now().Add(-72 * time.Hour)
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
}
// ---------------------------------------------------------------------------
// isFixedWeeklyPeriodExpired
// ---------------------------------------------------------------------------
func TestIsFixedWeeklyPeriodExpired_ZeroPeriodStart(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
assert.True(t, a.isFixedWeeklyPeriodExpired(time.Time{}))
}
func TestIsFixedWeeklyPeriodExpired_NotExpired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started 1 minute ago → not expired
periodStart := time.Now().Add(-1 * time.Minute)
assert.False(t, a.isFixedWeeklyPeriodExpired(periodStart))
}
func TestIsFixedWeeklyPeriodExpired_Expired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started 10 days ago → definitely expired
periodStart := time.Now().Add(-240 * time.Hour)
assert.True(t, a.isFixedWeeklyPeriodExpired(periodStart))
}
// ---------------------------------------------------------------------------
// ValidateQuotaResetConfig
// ---------------------------------------------------------------------------
func TestValidateQuotaResetConfig_NilExtra(t *testing.T) {
assert.NoError(t, ValidateQuotaResetConfig(nil))
}
func TestValidateQuotaResetConfig_EmptyExtra(t *testing.T) {
assert.NoError(t, ValidateQuotaResetConfig(map[string]any{}))
}
func TestValidateQuotaResetConfig_ValidFixed(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(0),
"quota_reset_timezone": "Asia/Shanghai",
}
assert.NoError(t, ValidateQuotaResetConfig(extra))
}
func TestValidateQuotaResetConfig_ValidRolling(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "rolling",
"quota_weekly_reset_mode": "rolling",
}
assert.NoError(t, ValidateQuotaResetConfig(extra))
}
func TestValidateQuotaResetConfig_InvalidTimezone(t *testing.T) {
extra := map[string]any{
"quota_reset_timezone": "Not/A/Timezone",
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_reset_timezone")
}
func TestValidateQuotaResetConfig_InvalidDailyMode(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "invalid",
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_daily_reset_mode")
}
func TestValidateQuotaResetConfig_InvalidDailyHour_TooHigh(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_hour": float64(24),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
}
func TestValidateQuotaResetConfig_InvalidDailyHour_Negative(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_hour": float64(-1),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
}
func TestValidateQuotaResetConfig_InvalidWeeklyMode(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_mode": "unknown",
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_mode")
}
func TestValidateQuotaResetConfig_InvalidWeeklyDay_TooHigh(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_day": float64(7),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
}
func TestValidateQuotaResetConfig_InvalidWeeklyDay_Negative(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_day": float64(-1),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
}
func TestValidateQuotaResetConfig_InvalidWeeklyHour(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_hour": float64(25),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_hour")
}
func TestValidateQuotaResetConfig_BoundaryValues(t *testing.T) {
// All boundary values should be valid
extra := map[string]any{
"quota_daily_reset_hour": float64(23),
"quota_weekly_reset_day": float64(0), // Sunday
"quota_weekly_reset_hour": float64(0),
"quota_reset_timezone": "UTC",
}
assert.NoError(t, ValidateQuotaResetConfig(extra))
extra2 := map[string]any{
"quota_daily_reset_hour": float64(0),
"quota_weekly_reset_day": float64(6), // Saturday
"quota_weekly_reset_hour": float64(23),
}
assert.NoError(t, ValidateQuotaResetConfig(extra2))
}
// ---------------------------------------------------------------------------
// ComputeQuotaResetAt
// ---------------------------------------------------------------------------
func TestComputeQuotaResetAt_RollingMode_NoResetAt(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "rolling",
"quota_weekly_reset_mode": "rolling",
}
ComputeQuotaResetAt(extra)
_, hasDailyResetAt := extra["quota_daily_reset_at"]
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
assert.False(t, hasDailyResetAt, "rolling mode should not set quota_daily_reset_at")
assert.False(t, hasWeeklyResetAt, "rolling mode should not set quota_weekly_reset_at")
}
func TestComputeQuotaResetAt_RollingMode_ClearsExistingResetAt(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "rolling",
"quota_weekly_reset_mode": "rolling",
"quota_daily_reset_at": "2026-03-14T09:00:00Z",
"quota_weekly_reset_at": "2026-03-16T09:00:00Z",
}
ComputeQuotaResetAt(extra)
_, hasDailyResetAt := extra["quota_daily_reset_at"]
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
assert.False(t, hasDailyResetAt, "rolling mode should remove quota_daily_reset_at")
assert.False(t, hasWeeklyResetAt, "rolling mode should remove quota_weekly_reset_at")
}
func TestComputeQuotaResetAt_FixedDaily_SetsResetAt(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok, "quota_daily_reset_at should be set")
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Reset time should be in the future
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
// Reset hour should be 9 UTC
assert.Equal(t, 9, resetAt.UTC().Hour())
}
func TestComputeQuotaResetAt_FixedWeekly_SetsResetAt(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1), // Monday
"quota_weekly_reset_hour": float64(0),
"quota_reset_timezone": "UTC",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_weekly_reset_at"].(string)
require.True(t, ok, "quota_weekly_reset_at should be set")
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Reset time should be in the future
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
// Reset day should be Monday
assert.Equal(t, time.Monday, resetAt.UTC().Weekday())
}
func TestComputeQuotaResetAt_FixedDaily_WithTimezone(t *testing.T) {
tz, err := time.LoadLocation("Asia/Shanghai")
require.NoError(t, err)
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "Asia/Shanghai",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok)
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// In Shanghai timezone, the hour should be 9
assert.Equal(t, 9, resetAt.In(tz).Hour())
}
func TestComputeQuotaResetAt_DefaultTimezone(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(12),
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok)
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Default timezone is UTC
assert.Equal(t, 12, resetAt.UTC().Hour())
}
func TestComputeQuotaResetAt_InvalidHour_ClampedToZero(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(99),
"quota_reset_timezone": "UTC",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok)
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Invalid hour → clamped to 0
assert.Equal(t, 0, resetAt.UTC().Hour())
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
"log/slog"
"math/rand/v2" "math/rand/v2"
"net/http" "net/http"
"strings" "strings"
@@ -100,6 +101,7 @@ type antigravityUsageCache struct {
const ( const (
apiCacheTTL = 3 * time.Minute apiCacheTTL = 3 * time.Minute
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL429 等错误缓存 1 分钟 apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL429 等错误缓存 1 分钟
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL可恢复错误
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟 apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
windowStatsCacheTTL = 1 * time.Minute windowStatsCacheTTL = 1 * time.Minute
openAIProbeCacheTTL = 10 * time.Minute openAIProbeCacheTTL = 10 * time.Minute
@@ -108,11 +110,12 @@ const (
// UsageCache 封装账户使用量相关的缓存 // UsageCache 封装账户使用量相关的缓存
type UsageCache struct { type UsageCache struct {
apiCache sync.Map // accountID -> *apiUsageCache apiCache sync.Map // accountID -> *apiUsageCache
windowStatsCache sync.Map // accountID -> *windowStatsCache windowStatsCache sync.Map // accountID -> *windowStatsCache
antigravityCache sync.Map // accountID -> *antigravityUsageCache antigravityCache sync.Map // accountID -> *antigravityUsageCache
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存 apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存Anthropic
openAIProbeCache sync.Map // accountID -> time.Time antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
openAIProbeCache sync.Map // accountID -> time.Time
} }
// NewUsageCache 创建 UsageCache 实例 // NewUsageCache 创建 UsageCache 实例
@@ -149,6 +152,18 @@ type AntigravityModelQuota struct {
ResetTime string `json:"reset_time"` // 重置时间 ISO8601 ResetTime string `json:"reset_time"` // 重置时间 ISO8601
} }
// AntigravityModelDetail Antigravity 单个模型的详细能力信息
type AntigravityModelDetail struct {
DisplayName string `json:"display_name,omitempty"`
SupportsImages *bool `json:"supports_images,omitempty"`
SupportsThinking *bool `json:"supports_thinking,omitempty"`
ThinkingBudget *int `json:"thinking_budget,omitempty"`
Recommended *bool `json:"recommended,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"`
}
// UsageInfo 账号使用量信息 // UsageInfo 账号使用量信息
type UsageInfo struct { type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间 UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
@@ -164,6 +179,33 @@ type UsageInfo struct {
// Antigravity 多模型配额 // Antigravity 多模型配额
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"` AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
// Antigravity 账号级信息
SubscriptionTier string `json:"subscription_tier,omitempty"` // 归一化订阅等级: FREE/PRO/ULTRA/UNKNOWN
SubscriptionTierRaw string `json:"subscription_tier_raw,omitempty"` // 上游原始订阅等级名称
// Antigravity 模型详细能力信息(与 antigravity_quota 同 key
AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"`
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
// Antigravity 账号是否被上游禁止 (HTTP 403)
IsForbidden bool `json:"is_forbidden,omitempty"`
ForbiddenReason string `json:"forbidden_reason,omitempty"`
ForbiddenType string `json:"forbidden_type,omitempty"` // "validation" / "violation" / "forbidden"
ValidationURL string `json:"validation_url,omitempty"` // 验证/申诉链接
// 状态标记(从 ForbiddenType / HTTP 错误码推导)
NeedsVerify bool `json:"needs_verify,omitempty"` // 需要人工验证forbidden_type=validation
IsBanned bool `json:"is_banned,omitempty"` // 账号被封forbidden_type=violation
NeedsReauth bool `json:"needs_reauth,omitempty"` // token 失效需重新授权401
// 错误码机器可读forbidden / unauthenticated / rate_limited / network_error
ErrorCode string `json:"error_code,omitempty"`
// 获取 usage 时的错误信息(降级返回,而非 500
Error string `json:"error,omitempty"`
} }
// ClaudeUsageResponse Anthropic API返回的usage结构 // ClaudeUsageResponse Anthropic API返回的usage结构
@@ -648,34 +690,157 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
return &UsageInfo{UpdatedAt: &now}, nil return &UsageInfo{UpdatedAt: &now}, nil
} }
// 1. 检查缓存10 分钟) // 1. 检查缓存
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok { if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { if cache, ok := cached.(*antigravityUsageCache); ok {
// 重新计算 RemainingSeconds ttl := antigravityCacheTTL(cache.usageInfo)
usage := cache.usageInfo if time.Since(cache.timestamp) < ttl {
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil { usage := cache.usageInfo
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds()) if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
}
return usage, nil
} }
return usage, nil
} }
} }
// 2. 获取代理 URL // 2. singleflight 防止并发击穿
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account) flightKey := fmt.Sprintf("ag-usage:%d", account.ID)
result, flightErr, _ := s.cache.antigravityFlight.Do(flightKey, func() (any, error) {
// 再次检查缓存(等待期间可能已被填充)
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
if cache, ok := cached.(*antigravityUsageCache); ok {
ttl := antigravityCacheTTL(cache.usageInfo)
if time.Since(cache.timestamp) < ttl {
usage := cache.usageInfo
// 重新计算 RemainingSeconds避免返回过时的剩余秒数
recalcAntigravityRemainingSeconds(usage)
return usage, nil
}
}
}
// 3. 调用 API 获取额度 // 使用独立 context避免调用方 cancel 导致所有共享 flight 的请求失败
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL) fetchCtx, fetchCancel := context.WithTimeout(context.Background(), 30*time.Second)
if err != nil { defer fetchCancel()
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
}
// 4. 缓存结果 proxyURL := s.antigravityQuotaFetcher.GetProxyURL(fetchCtx, account)
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{ fetchResult, err := s.antigravityQuotaFetcher.FetchQuota(fetchCtx, account, proxyURL)
usageInfo: result.UsageInfo, if err != nil {
timestamp: time.Now(), degraded := buildAntigravityDegradedUsage(err)
enrichUsageWithAccountError(degraded, account)
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
usageInfo: degraded,
timestamp: time.Now(),
})
return degraded, nil
}
enrichUsageWithAccountError(fetchResult.UsageInfo, account)
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
usageInfo: fetchResult.UsageInfo,
timestamp: time.Now(),
})
return fetchResult.UsageInfo, nil
}) })
return result.UsageInfo, nil if flightErr != nil {
return nil, flightErr
}
usage, ok := result.(*UsageInfo)
if !ok || usage == nil {
now := time.Now()
return &UsageInfo{UpdatedAt: &now}, nil
}
return usage, nil
}
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
func recalcAntigravityRemainingSeconds(info *UsageInfo) {
if info == nil {
return
}
if info.FiveHour != nil && info.FiveHour.ResetsAt != nil {
remaining := int(time.Until(*info.FiveHour.ResetsAt).Seconds())
if remaining < 0 {
remaining = 0
}
info.FiveHour.RemainingSeconds = remaining
}
}
// antigravityCacheTTL 根据 UsageInfo 内容决定缓存 TTL
// 403 forbidden 状态稳定缓存与成功相同3 分钟);
// 其他错误401/网络)可能快速恢复,缓存 1 分钟。
func antigravityCacheTTL(info *UsageInfo) time.Duration {
if info == nil {
return antigravityErrorTTL
}
if info.IsForbidden {
return apiCacheTTL // 封号/验证状态不会很快变
}
if info.ErrorCode != "" || info.Error != "" {
return antigravityErrorTTL
}
return apiCacheTTL
}
// buildAntigravityDegradedUsage 从 FetchQuota 错误构建降级 UsageInfo
func buildAntigravityDegradedUsage(err error) *UsageInfo {
now := time.Now()
errMsg := fmt.Sprintf("usage API error: %v", err)
slog.Warn("antigravity usage fetch failed, returning degraded response", "error", err)
info := &UsageInfo{
UpdatedAt: &now,
Error: errMsg,
}
// 从错误信息推断 error_code 和状态标记
// 错误格式来自 antigravity/client.go: "fetchAvailableModels 失败 (HTTP %d): ..."
errStr := err.Error()
switch {
case strings.Contains(errStr, "HTTP 401") ||
strings.Contains(errStr, "UNAUTHENTICATED") ||
strings.Contains(errStr, "invalid_grant"):
info.ErrorCode = errorCodeUnauthenticated
info.NeedsReauth = true
case strings.Contains(errStr, "HTTP 429"):
info.ErrorCode = errorCodeRateLimited
default:
info.ErrorCode = errorCodeNetworkError
}
return info
}
// enrichUsageWithAccountError 结合账号错误状态修正 UsageInfo
// 场景 1成功路径FetchAvailableModels 正常返回,但账号已因 403 被标记为 error
//
// 需要在正常 usage 数据上附加 forbidden/validation 信息。
//
// 场景 2降级路径被封号的账号 OAuth token 失效FetchAvailableModels 返回 401
//
// 降级逻辑设置了 needs_reauth但账号实际是 403 封号/需验证,需覆盖为正确状态。
func enrichUsageWithAccountError(info *UsageInfo, account *Account) {
if info == nil || account == nil || account.Status != StatusError {
return
}
msg := strings.ToLower(account.ErrorMessage)
if !strings.Contains(msg, "403") && !strings.Contains(msg, "forbidden") &&
!strings.Contains(msg, "violation") && !strings.Contains(msg, "validation") {
return
}
fbType := classifyForbiddenType(account.ErrorMessage)
info.IsForbidden = true
info.ForbiddenType = fbType
info.ForbiddenReason = account.ErrorMessage
info.NeedsVerify = fbType == forbiddenTypeValidation
info.IsBanned = fbType == forbiddenTypeViolation
info.ValidationURL = extractValidationURL(account.ErrorMessage)
info.ErrorCode = errorCodeForbidden
info.NeedsReauth = false
} }
// addWindowStats 为 usage 数据添加窗口期统计 // addWindowStats 为 usage 数据添加窗口期统计

View File

@@ -43,12 +43,13 @@ func TestMatchWildcard(t *testing.T) {
} }
} }
func TestMatchWildcardMapping(t *testing.T) { func TestMatchWildcardMappingResult(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
mapping map[string]string mapping map[string]string
requestedModel string requestedModel string
expected string expected string
matched bool
}{ }{
// 精确匹配优先于通配符 // 精确匹配优先于通配符
{ {
@@ -59,6 +60,7 @@ func TestMatchWildcardMapping(t *testing.T) {
}, },
requestedModel: "claude-sonnet-4-5", requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5-exact", expected: "claude-sonnet-4-5-exact",
matched: true,
}, },
// 最长通配符优先 // 最长通配符优先
@@ -71,6 +73,7 @@ func TestMatchWildcardMapping(t *testing.T) {
}, },
requestedModel: "claude-sonnet-4-5", requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-series", expected: "claude-sonnet-4-series",
matched: true,
}, },
// 单个通配符 // 单个通配符
@@ -81,6 +84,7 @@ func TestMatchWildcardMapping(t *testing.T) {
}, },
requestedModel: "claude-opus-4-5", requestedModel: "claude-opus-4-5",
expected: "claude-mapped", expected: "claude-mapped",
matched: true,
}, },
// 无匹配返回原始模型 // 无匹配返回原始模型
@@ -91,6 +95,7 @@ func TestMatchWildcardMapping(t *testing.T) {
}, },
requestedModel: "gemini-3-flash", requestedModel: "gemini-3-flash",
expected: "gemini-3-flash", expected: "gemini-3-flash",
matched: false,
}, },
// 空映射返回原始模型 // 空映射返回原始模型
@@ -99,6 +104,7 @@ func TestMatchWildcardMapping(t *testing.T) {
mapping: map[string]string{}, mapping: map[string]string{},
requestedModel: "claude-sonnet-4-5", requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5", expected: "claude-sonnet-4-5",
matched: false,
}, },
// Gemini 模型映射 // Gemini 模型映射
@@ -110,14 +116,15 @@ func TestMatchWildcardMapping(t *testing.T) {
}, },
requestedModel: "gemini-3-flash-preview", requestedModel: "gemini-3-flash-preview",
expected: "gemini-3-pro-high", expected: "gemini-3-pro-high",
matched: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := matchWildcardMapping(tt.mapping, tt.requestedModel) result, matched := matchWildcardMappingResult(tt.mapping, tt.requestedModel)
if result != tt.expected { if result != tt.expected || matched != tt.matched {
t.Errorf("matchWildcardMapping(%v, %q) = %q, want %q", tt.mapping, tt.requestedModel, result, tt.expected) t.Errorf("matchWildcardMappingResult(%v, %q) = (%q, %v), want (%q, %v)", tt.mapping, tt.requestedModel, result, matched, tt.expected, tt.matched)
} }
}) })
} }
@@ -268,6 +275,69 @@ func TestAccountGetMappedModel(t *testing.T) {
} }
} }
func TestAccountResolveMappedModel(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
requestedModel string
expectedModel string
expectedMatch bool
}{
{
name: "no mapping reports unmatched",
credentials: nil,
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
{
name: "exact passthrough mapping still counts as matched",
credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: true,
},
{
name: "wildcard passthrough mapping still counts as matched",
credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-*": "gpt-5.4",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: true,
},
{
name: "missing mapping reports unmatched",
credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.2": "gpt-5.2",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Credentials: tt.credentials,
}
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
if mappedModel != tt.expectedModel || matched != tt.expectedMatch {
t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch)
}
})
}
}
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) { func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
account := &Account{ account := &Account{
Platform: PlatformAntigravity, Platform: PlatformAntigravity,

View File

@@ -1462,6 +1462,13 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
Status: StatusActive, Status: StatusActive,
Schedulable: true, Schedulable: true,
} }
// 预计算固定时间重置的下次重置时间
if account.Extra != nil {
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
return nil, err
}
ComputeQuotaResetAt(account.Extra)
}
if input.ExpiresAt != nil && *input.ExpiresAt > 0 { if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
expiresAt := time.Unix(*input.ExpiresAt, 0) expiresAt := time.Unix(*input.ExpiresAt, 0)
account.ExpiresAt = &expiresAt account.ExpiresAt = &expiresAt
@@ -1535,6 +1542,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
} }
} }
account.Extra = input.Extra account.Extra = input.Extra
// 校验并预计算固定时间重置的下次重置时间
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
return nil, err
}
ComputeQuotaResetAt(account.Extra)
} }
if input.ProxyID != nil { if input.ProxyID != nil {
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图) // 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)

View File

@@ -2,12 +2,29 @@ package service
import ( import (
"context" "context"
"encoding/json"
"errors"
"fmt" "fmt"
"log/slog"
"regexp"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
) )
const (
forbiddenTypeValidation = "validation"
forbiddenTypeViolation = "violation"
forbiddenTypeForbidden = "forbidden"
// 机器可读的错误码
errorCodeForbidden = "forbidden"
errorCodeUnauthenticated = "unauthenticated"
errorCodeRateLimited = "rate_limited"
errorCodeNetworkError = "network_error"
)
// AntigravityQuotaFetcher 从 Antigravity API 获取额度 // AntigravityQuotaFetcher 从 Antigravity API 获取额度
type AntigravityQuotaFetcher struct { type AntigravityQuotaFetcher struct {
proxyRepo ProxyRepository proxyRepo ProxyRepository
@@ -40,11 +57,32 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
// 调用 API 获取配额 // 调用 API 获取配额
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID) modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
if err != nil { if err != nil {
// 403 Forbidden: 不报错,返回 is_forbidden 标记
var forbiddenErr *antigravity.ForbiddenError
if errors.As(err, &forbiddenErr) {
now := time.Now()
fbType := classifyForbiddenType(forbiddenErr.Body)
return &QuotaResult{
UsageInfo: &UsageInfo{
UpdatedAt: &now,
IsForbidden: true,
ForbiddenReason: forbiddenErr.Body,
ForbiddenType: fbType,
ValidationURL: extractValidationURL(forbiddenErr.Body),
NeedsVerify: fbType == forbiddenTypeValidation,
IsBanned: fbType == forbiddenTypeViolation,
ErrorCode: errorCodeForbidden,
},
}, nil
}
return nil, err return nil, err
} }
// 调用 LoadCodeAssist 获取订阅等级(非关键路径,失败不影响主流程)
tierRaw, tierNormalized := f.fetchSubscriptionTier(ctx, client, accessToken)
// 转换为 UsageInfo // 转换为 UsageInfo
usageInfo := f.buildUsageInfo(modelsResp) usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized)
return &QuotaResult{ return &QuotaResult{
UsageInfo: usageInfo, UsageInfo: usageInfo,
@@ -52,15 +90,52 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
}, nil }, nil
} }
// buildUsageInfo 将 API 响应转换为 UsageInfo // fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo { func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string) {
now := time.Now() loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
info := &UsageInfo{ if err != nil {
UpdatedAt: &now, slog.Warn("failed to fetch subscription tier", "error", err)
AntigravityQuota: make(map[string]*AntigravityModelQuota), return "", ""
}
if loadResp == nil {
return "", ""
} }
// 遍历所有模型,填充 AntigravityQuota raw = loadResp.GetTier() // 已有方法paidTier > currentTier
normalized = normalizeTier(raw)
return raw, normalized
}
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
func normalizeTier(raw string) string {
if raw == "" {
return ""
}
lower := strings.ToLower(raw)
switch {
case strings.Contains(lower, "ultra"):
return "ULTRA"
case strings.Contains(lower, "pro"):
return "PRO"
case strings.Contains(lower, "free"):
return "FREE"
default:
return "UNKNOWN"
}
}
// buildUsageInfo 将 API 响应转换为 UsageInfo
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string) *UsageInfo {
now := time.Now()
info := &UsageInfo{
UpdatedAt: &now,
AntigravityQuota: make(map[string]*AntigravityModelQuota),
AntigravityQuotaDetails: make(map[string]*AntigravityModelDetail),
SubscriptionTier: tierNormalized,
SubscriptionTierRaw: tierRaw,
}
// 遍历所有模型,填充 AntigravityQuota 和 AntigravityQuotaDetails
for modelName, modelInfo := range modelsResp.Models { for modelName, modelInfo := range modelsResp.Models {
if modelInfo.QuotaInfo == nil { if modelInfo.QuotaInfo == nil {
continue continue
@@ -73,6 +148,27 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv
Utilization: utilization, Utilization: utilization,
ResetTime: modelInfo.QuotaInfo.ResetTime, ResetTime: modelInfo.QuotaInfo.ResetTime,
} }
// 填充模型详细能力信息
detail := &AntigravityModelDetail{
DisplayName: modelInfo.DisplayName,
SupportsImages: modelInfo.SupportsImages,
SupportsThinking: modelInfo.SupportsThinking,
ThinkingBudget: modelInfo.ThinkingBudget,
Recommended: modelInfo.Recommended,
MaxTokens: modelInfo.MaxTokens,
MaxOutputTokens: modelInfo.MaxOutputTokens,
SupportedMimeTypes: modelInfo.SupportedMimeTypes,
}
info.AntigravityQuotaDetails[modelName] = detail
}
// 废弃模型转发规则
if len(modelsResp.DeprecatedModelIDs) > 0 {
info.ModelForwardingRules = make(map[string]string, len(modelsResp.DeprecatedModelIDs))
for oldID, deprecated := range modelsResp.DeprecatedModelIDs {
info.ModelForwardingRules[oldID] = deprecated.NewModelID
}
} }
// 同时设置 FiveHour 用于兼容展示(取主要模型) // 同时设置 FiveHour 用于兼容展示(取主要模型)
@@ -108,3 +204,58 @@ func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Acco
} }
return proxy.URL() return proxy.URL()
} }
// classifyForbiddenType 根据 403 响应体判断禁止类型
func classifyForbiddenType(body string) string {
lower := strings.ToLower(body)
switch {
case strings.Contains(lower, "validation_required") ||
strings.Contains(lower, "verify your account") ||
strings.Contains(lower, "validation_url"):
return forbiddenTypeValidation
case strings.Contains(lower, "terms of service") ||
strings.Contains(lower, "violation"):
return forbiddenTypeViolation
default:
return forbiddenTypeForbidden
}
}
// urlPattern 用于从 403 响应体中提取 URL降级方案
var urlPattern = regexp.MustCompile(`https://[^\s"'\\]+`)
// extractValidationURL 从 403 响应 JSON 中提取验证/申诉链接
func extractValidationURL(body string) string {
// 1. 尝试结构化 JSON 提取: /error/details[*]/metadata/validation_url 或 appeal_url
var parsed struct {
Error struct {
Details []struct {
Metadata map[string]string `json:"metadata"`
} `json:"details"`
} `json:"error"`
}
if json.Unmarshal([]byte(body), &parsed) == nil {
for _, detail := range parsed.Error.Details {
if u := detail.Metadata["validation_url"]; u != "" {
return u
}
if u := detail.Metadata["appeal_url"]; u != "" {
return u
}
}
}
// 2. 降级:正则匹配 URL
lower := strings.ToLower(body)
if !strings.Contains(lower, "validation") &&
!strings.Contains(lower, "verify") &&
!strings.Contains(lower, "appeal") {
return ""
}
// 先解码常见转义再匹配
normalized := strings.ReplaceAll(body, `\u0026`, "&")
if m := urlPattern.FindString(normalized); m != "" {
return m
}
return ""
}

View File

@@ -0,0 +1,497 @@
//go:build unit
package service
import (
"errors"
"testing"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// ---------------------------------------------------------------------------
// normalizeTier
// ---------------------------------------------------------------------------
func TestNormalizeTier(t *testing.T) {
tests := []struct {
name string
raw string
expected string
}{
{name: "empty string", raw: "", expected: ""},
{name: "free-tier", raw: "free-tier", expected: "FREE"},
{name: "g1-pro-tier", raw: "g1-pro-tier", expected: "PRO"},
{name: "g1-ultra-tier", raw: "g1-ultra-tier", expected: "ULTRA"},
{name: "unknown-something", raw: "unknown-something", expected: "UNKNOWN"},
{name: "Google AI Pro contains pro keyword", raw: "Google AI Pro", expected: "PRO"},
{name: "case insensitive FREE", raw: "FREE-TIER", expected: "FREE"},
{name: "case insensitive Ultra", raw: "Ultra Plan", expected: "ULTRA"},
{name: "arbitrary unrecognized string", raw: "enterprise-custom", expected: "UNKNOWN"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeTier(tt.raw)
require.Equal(t, tt.expected, got, "normalizeTier(%q)", tt.raw)
})
}
}
// ---------------------------------------------------------------------------
// buildUsageInfo
// ---------------------------------------------------------------------------
func aqfBoolPtr(v bool) *bool { return &v }
func aqfIntPtr(v int) *int { return &v }
func TestBuildUsageInfo_BasicModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.75,
ResetTime: "2026-03-08T12:00:00Z",
},
DisplayName: "Claude Sonnet 4",
SupportsImages: aqfBoolPtr(true),
SupportsThinking: aqfBoolPtr(false),
ThinkingBudget: aqfIntPtr(0),
Recommended: aqfBoolPtr(true),
MaxTokens: aqfIntPtr(200000),
MaxOutputTokens: aqfIntPtr(16384),
SupportedMimeTypes: map[string]bool{
"image/png": true,
"image/jpeg": true,
},
},
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.50,
ResetTime: "2026-03-08T15:00:00Z",
},
DisplayName: "Gemini 2.5 Pro",
MaxTokens: aqfIntPtr(1000000),
MaxOutputTokens: aqfIntPtr(65536),
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO")
// 基本字段
require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set")
require.Equal(t, "PRO", info.SubscriptionTier)
require.Equal(t, "g1-pro-tier", info.SubscriptionTierRaw)
// AntigravityQuota
require.Len(t, info.AntigravityQuota, 2)
sonnetQuota := info.AntigravityQuota["claude-sonnet-4-20250514"]
require.NotNil(t, sonnetQuota)
require.Equal(t, 25, sonnetQuota.Utilization) // (1 - 0.75) * 100 = 25
require.Equal(t, "2026-03-08T12:00:00Z", sonnetQuota.ResetTime)
geminiQuota := info.AntigravityQuota["gemini-2.5-pro"]
require.NotNil(t, geminiQuota)
require.Equal(t, 50, geminiQuota.Utilization) // (1 - 0.50) * 100 = 50
require.Equal(t, "2026-03-08T15:00:00Z", geminiQuota.ResetTime)
// AntigravityQuotaDetails
require.Len(t, info.AntigravityQuotaDetails, 2)
sonnetDetail := info.AntigravityQuotaDetails["claude-sonnet-4-20250514"]
require.NotNil(t, sonnetDetail)
require.Equal(t, "Claude Sonnet 4", sonnetDetail.DisplayName)
require.Equal(t, aqfBoolPtr(true), sonnetDetail.SupportsImages)
require.Equal(t, aqfBoolPtr(false), sonnetDetail.SupportsThinking)
require.Equal(t, aqfIntPtr(0), sonnetDetail.ThinkingBudget)
require.Equal(t, aqfBoolPtr(true), sonnetDetail.Recommended)
require.Equal(t, aqfIntPtr(200000), sonnetDetail.MaxTokens)
require.Equal(t, aqfIntPtr(16384), sonnetDetail.MaxOutputTokens)
require.Equal(t, map[string]bool{"image/png": true, "image/jpeg": true}, sonnetDetail.SupportedMimeTypes)
geminiDetail := info.AntigravityQuotaDetails["gemini-2.5-pro"]
require.NotNil(t, geminiDetail)
require.Equal(t, "Gemini 2.5 Pro", geminiDetail.DisplayName)
require.Nil(t, geminiDetail.SupportsImages)
require.Nil(t, geminiDetail.SupportsThinking)
require.Equal(t, aqfIntPtr(1000000), geminiDetail.MaxTokens)
require.Equal(t, aqfIntPtr(65536), geminiDetail.MaxOutputTokens)
}
func TestBuildUsageInfo_DeprecatedModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 1.0,
},
},
},
DeprecatedModelIDs: map[string]antigravity.DeprecatedModelInfo{
"claude-3-sonnet-20240229": {NewModelID: "claude-sonnet-4-20250514"},
"claude-3-haiku-20240307": {NewModelID: "claude-haiku-3.5-latest"},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.Len(t, info.ModelForwardingRules, 2)
require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"])
require.Equal(t, "claude-haiku-3.5-latest", info.ModelForwardingRules["claude-3-haiku-20240307"])
}
func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"some-model": {
QuotaInfo: &antigravity.ModelQuotaInfo{RemainingFraction: 0.9},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models")
}
func TestBuildUsageInfo_EmptyModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.NotNil(t, info)
require.NotNil(t, info.AntigravityQuota)
require.Empty(t, info.AntigravityQuota)
require.NotNil(t, info.AntigravityQuotaDetails)
require.Empty(t, info.AntigravityQuotaDetails)
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
}
func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"model-without-quota": {
DisplayName: "No Quota Model",
// QuotaInfo is nil
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.NotNil(t, info)
require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped")
require.Empty(t, info.AntigravityQuotaDetails, "models with nil QuotaInfo should be skipped from details too")
}
func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// priorityModels = ["claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"]
// When the first priority model exists, it should be used for FiveHour
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.40,
ResetTime: "2026-03-08T18:00:00Z",
},
},
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.80,
ResetTime: "2026-03-08T12:00:00Z",
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists")
// claude-sonnet-4-20250514 is first in priority list, so it should be used
expectedUtilization := (1.0 - 0.80) * 100 // 20
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
require.NotNil(t, info.FiveHour.ResetsAt, "ResetsAt should be parsed from ResetTime")
}
func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// Only claude-sonnet-4 exists (second in priority list), not claude-sonnet-4-20250514
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.60,
ResetTime: "2026-03-08T14:00:00Z",
},
},
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.30,
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.NotNil(t, info.FiveHour)
expectedUtilization := (1.0 - 0.60) * 100 // 40
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
}
func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// Only gemini-2.5-pro exists (third in priority list)
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.30,
},
},
"other-model": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.90,
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.NotNil(t, info.FiveHour)
expectedUtilization := (1.0 - 0.30) * 100 // 70
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
}
func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// None of the priority models exist
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"some-other-model": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.50,
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
}
func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.50,
ResetTime: "", // empty reset time
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
require.NotNil(t, info.FiveHour)
require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty")
require.Equal(t, 0, info.FiveHour.RemainingSeconds)
}
func TestBuildUsageInfo_FullUtilization(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.0, // fully used
ResetTime: "2026-03-08T12:00:00Z",
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
require.NotNil(t, quota)
require.Equal(t, 100, quota.Utilization)
}
func TestBuildUsageInfo_ZeroUtilization(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 1.0, // fully available
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "")
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
require.NotNil(t, quota)
require.Equal(t, 0, quota.Utilization)
}
func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) {
// 模拟 FetchQuota 遇到 403 时的行为:
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true
forbiddenErr := &antigravity.ForbiddenError{
StatusCode: 403,
Body: "Access denied",
}
// 验证 ForbiddenError 满足 errors.As
var target *antigravity.ForbiddenError
require.True(t, errors.As(forbiddenErr, &target))
require.Equal(t, 403, target.StatusCode)
require.Equal(t, "Access denied", target.Body)
require.Contains(t, forbiddenErr.Error(), "403")
}
// ---------------------------------------------------------------------------
// classifyForbiddenType
// ---------------------------------------------------------------------------
func TestClassifyForbiddenType(t *testing.T) {
tests := []struct {
name string
body string
expected string
}{
{
name: "VALIDATION_REQUIRED keyword",
body: `{"error":{"message":"VALIDATION_REQUIRED"}}`,
expected: "validation",
},
{
name: "verify your account",
body: `Please verify your account to continue`,
expected: "validation",
},
{
name: "contains validation_url field",
body: `{"error":{"details":[{"metadata":{"validation_url":"https://..."}}]}}`,
expected: "validation",
},
{
name: "terms of service violation",
body: `Your account has been suspended for Terms of Service violation`,
expected: "violation",
},
{
name: "violation keyword",
body: `Account suspended due to policy violation`,
expected: "violation",
},
{
name: "generic 403",
body: `Access denied`,
expected: "forbidden",
},
{
name: "empty body",
body: "",
expected: "forbidden",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := classifyForbiddenType(tt.body)
require.Equal(t, tt.expected, got)
})
}
}
// ---------------------------------------------------------------------------
// extractValidationURL
// ---------------------------------------------------------------------------
func TestExtractValidationURL(t *testing.T) {
tests := []struct {
name string
body string
expected string
}{
{
name: "structured validation_url",
body: `{"error":{"details":[{"metadata":{"validation_url":"https://accounts.google.com/verify?token=abc"}}]}}`,
expected: "https://accounts.google.com/verify?token=abc",
},
{
name: "structured appeal_url",
body: `{"error":{"details":[{"metadata":{"appeal_url":"https://support.google.com/appeal/123"}}]}}`,
expected: "https://support.google.com/appeal/123",
},
{
name: "validation_url takes priority over appeal_url",
body: `{"error":{"details":[{"metadata":{"validation_url":"https://v.com","appeal_url":"https://a.com"}}]}}`,
expected: "https://v.com",
},
{
name: "fallback regex with verify keyword",
body: `Please verify your account at https://accounts.google.com/verify`,
expected: "https://accounts.google.com/verify",
},
{
name: "no URL in generic forbidden",
body: `Access denied`,
expected: "",
},
{
name: "empty body",
body: "",
expected: "",
},
{
name: "URL present but no validation keywords",
body: `Error at https://example.com/something`,
expected: "",
},
{
name: "unicode escaped ampersand",
body: `validation required: https://accounts.google.com/verify?a=1\u0026b=2`,
expected: "https://accounts.google.com/verify?a=1&b=2",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractValidationURL(tt.body)
require.Equal(t, tt.expected, got)
})
}
}

View File

@@ -1087,6 +1087,12 @@ type TokenPair struct {
ExpiresIn int `json:"expires_in"` // Access Token有效期 ExpiresIn int `json:"expires_in"` // Access Token有效期
} }
// TokenPairWithUser extends TokenPair with user role for backend mode checks
type TokenPairWithUser struct {
TokenPair
UserRole string
}
// GenerateTokenPair 生成Access Token和Refresh Token对 // GenerateTokenPair 生成Access Token和Refresh Token对
// familyID: 可选的Token家族ID用于Token轮转时保持家族关系 // familyID: 可选的Token家族ID用于Token轮转时保持家族关系
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) { func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
@@ -1168,7 +1174,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
// RefreshTokenPair 使用Refresh Token刷新Token对 // RefreshTokenPair 使用Refresh Token刷新Token对
// 实现Token轮转每次刷新都会生成新的Refresh Token旧Token立即失效 // 实现Token轮转每次刷新都会生成新的Refresh Token旧Token立即失效
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) { func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPairWithUser, error) {
// 检查 refreshTokenCache 是否可用 // 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil { if s.refreshTokenCache == nil {
return nil, ErrRefreshTokenInvalid return nil, ErrRefreshTokenInvalid
@@ -1233,7 +1239,14 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
} }
// 生成新的Token对保持同一个家族ID // 生成新的Token对保持同一个家族ID
return s.GenerateTokenPair(ctx, user, data.FamilyID) pair, err := s.GenerateTokenPair(ctx, user, data.FamilyID)
if err != nil {
return nil, err
}
return &TokenPairWithUser{
TokenPair: *pair,
UserRole: user.Role,
}, nil
} }
// RevokeRefreshToken 撤销单个Refresh Token // RevokeRefreshToken 撤销单个Refresh Token

View File

@@ -0,0 +1,770 @@
package service
import (
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io"
"sort"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/robfig/cron/v3"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
const (
settingKeyBackupS3Config = "backup_s3_config"
settingKeyBackupSchedule = "backup_schedule"
settingKeyBackupRecords = "backup_records"
maxBackupRecords = 100
)
var (
ErrBackupS3NotConfigured = infraerrors.BadRequest("BACKUP_S3_NOT_CONFIGURED", "backup S3 storage is not configured")
ErrBackupNotFound = infraerrors.NotFound("BACKUP_NOT_FOUND", "backup record not found")
ErrBackupInProgress = infraerrors.Conflict("BACKUP_IN_PROGRESS", "a backup is already in progress")
ErrRestoreInProgress = infraerrors.Conflict("RESTORE_IN_PROGRESS", "a restore is already in progress")
ErrBackupRecordsCorrupt = infraerrors.InternalServer("BACKUP_RECORDS_CORRUPT", "backup records data is corrupted")
ErrBackupS3ConfigCorrupt = infraerrors.InternalServer("BACKUP_S3_CONFIG_CORRUPT", "backup S3 config data is corrupted")
)
// ─── 接口定义 ───
// DBDumper abstracts database dump/restore operations
type DBDumper interface {
Dump(ctx context.Context) (io.ReadCloser, error)
Restore(ctx context.Context, data io.Reader) error
}
// BackupObjectStore abstracts object storage for backup files
type BackupObjectStore interface {
Upload(ctx context.Context, key string, body io.Reader, contentType string) (sizeBytes int64, err error)
Download(ctx context.Context, key string) (io.ReadCloser, error)
Delete(ctx context.Context, key string) error
PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error)
HeadBucket(ctx context.Context) error
}
// BackupObjectStoreFactory creates an object store from S3 config
type BackupObjectStoreFactory func(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error)
// ─── 数据模型 ───
// BackupS3Config S3 兼容存储配置(支持 Cloudflare R2
type BackupS3Config struct {
Endpoint string `json:"endpoint"` // e.g. https://<account_id>.r2.cloudflarestorage.com
Region string `json:"region"` // R2 用 "auto"
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key,omitempty"` //nolint:revive // field name follows AWS convention
Prefix string `json:"prefix"` // S3 key 前缀,如 "backups/"
ForcePathStyle bool `json:"force_path_style"`
}
// IsConfigured 检查必要字段是否已配置
func (c *BackupS3Config) IsConfigured() bool {
return c.Bucket != "" && c.AccessKeyID != "" && c.SecretAccessKey != ""
}
// BackupScheduleConfig 定时备份配置
type BackupScheduleConfig struct {
Enabled bool `json:"enabled"`
CronExpr string `json:"cron_expr"` // cron 表达式,如 "0 2 * * *" 每天凌晨2点
RetainDays int `json:"retain_days"` // 备份文件过期天数默认140=不自动清理
RetainCount int `json:"retain_count"` // 最多保留份数0=不限制
}
// BackupRecord 备份记录
type BackupRecord struct {
ID string `json:"id"`
Status string `json:"status"` // pending, running, completed, failed
BackupType string `json:"backup_type"` // postgres
FileName string `json:"file_name"`
S3Key string `json:"s3_key"`
SizeBytes int64 `json:"size_bytes"`
TriggeredBy string `json:"triggered_by"` // manual, scheduled
ErrorMsg string `json:"error_message,omitempty"`
StartedAt string `json:"started_at"`
FinishedAt string `json:"finished_at,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
}
// BackupService 数据库备份恢复服务
type BackupService struct {
settingRepo SettingRepository
dbCfg *config.DatabaseConfig
encryptor SecretEncryptor
storeFactory BackupObjectStoreFactory
dumper DBDumper
mu sync.Mutex
store BackupObjectStore
s3Cfg *BackupS3Config
backingUp bool
restoring bool
recordsMu sync.Mutex // 保护 records 的 load/save 操作
cronMu sync.Mutex
cronSched *cron.Cron
cronEntryID cron.EntryID
}
func NewBackupService(
settingRepo SettingRepository,
cfg *config.Config,
encryptor SecretEncryptor,
storeFactory BackupObjectStoreFactory,
dumper DBDumper,
) *BackupService {
return &BackupService{
settingRepo: settingRepo,
dbCfg: &cfg.Database,
encryptor: encryptor,
storeFactory: storeFactory,
dumper: dumper,
}
}
// Start 启动定时备份调度器
func (s *BackupService) Start() {
s.cronSched = cron.New()
s.cronSched.Start()
// 加载已有的定时配置
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
schedule, err := s.GetSchedule(ctx)
if err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 加载定时备份配置失败: %v", err)
return
}
if schedule.Enabled && schedule.CronExpr != "" {
if err := s.applyCronSchedule(schedule); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 应用定时备份配置失败: %v", err)
}
}
}
// Stop 停止定时备份
func (s *BackupService) Stop() {
s.cronMu.Lock()
defer s.cronMu.Unlock()
if s.cronSched != nil {
s.cronSched.Stop()
}
}
// ─── S3 配置管理 ───
func (s *BackupService) GetS3Config(ctx context.Context) (*BackupS3Config, error) {
cfg, err := s.loadS3Config(ctx)
if err != nil {
return nil, err
}
if cfg == nil {
return &BackupS3Config{}, nil
}
// 脱敏返回
cfg.SecretAccessKey = ""
return cfg, nil
}
func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) (*BackupS3Config, error) {
// 如果没提供 secret保留原有值
if cfg.SecretAccessKey == "" {
old, _ := s.loadS3Config(ctx)
if old != nil {
cfg.SecretAccessKey = old.SecretAccessKey
}
} else {
// 加密 SecretAccessKey
encrypted, err := s.encryptor.Encrypt(cfg.SecretAccessKey)
if err != nil {
return nil, fmt.Errorf("encrypt secret: %w", err)
}
cfg.SecretAccessKey = encrypted
}
data, err := json.Marshal(cfg)
if err != nil {
return nil, fmt.Errorf("marshal s3 config: %w", err)
}
if err := s.settingRepo.Set(ctx, settingKeyBackupS3Config, string(data)); err != nil {
return nil, fmt.Errorf("save s3 config: %w", err)
}
// 清除缓存的 S3 客户端
s.mu.Lock()
s.store = nil
s.s3Cfg = nil
s.mu.Unlock()
cfg.SecretAccessKey = ""
return &cfg, nil
}
func (s *BackupService) TestS3Connection(ctx context.Context, cfg BackupS3Config) error {
// 如果没提供 secret用已保存的
if cfg.SecretAccessKey == "" {
old, _ := s.loadS3Config(ctx)
if old != nil {
cfg.SecretAccessKey = old.SecretAccessKey
}
}
if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
return fmt.Errorf("incomplete S3 config: bucket, access_key_id, secret_access_key are required")
}
store, err := s.storeFactory(ctx, &cfg)
if err != nil {
return err
}
return store.HeadBucket(ctx)
}
// ─── 定时备份管理 ───
func (s *BackupService) GetSchedule(ctx context.Context) (*BackupScheduleConfig, error) {
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupSchedule)
if err != nil || raw == "" {
return &BackupScheduleConfig{}, nil
}
var cfg BackupScheduleConfig
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
return &BackupScheduleConfig{}, nil
}
return &cfg, nil
}
func (s *BackupService) UpdateSchedule(ctx context.Context, cfg BackupScheduleConfig) (*BackupScheduleConfig, error) {
if cfg.Enabled && cfg.CronExpr == "" {
return nil, infraerrors.BadRequest("INVALID_CRON", "cron expression is required when schedule is enabled")
}
// 验证 cron 表达式
if cfg.CronExpr != "" {
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
if _, err := parser.Parse(cfg.CronExpr); err != nil {
return nil, infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("invalid cron expression: %v", err))
}
}
data, err := json.Marshal(cfg)
if err != nil {
return nil, fmt.Errorf("marshal schedule config: %w", err)
}
if err := s.settingRepo.Set(ctx, settingKeyBackupSchedule, string(data)); err != nil {
return nil, fmt.Errorf("save schedule config: %w", err)
}
// 应用或停止定时任务
if cfg.Enabled {
if err := s.applyCronSchedule(&cfg); err != nil {
return nil, err
}
} else {
s.removeCronSchedule()
}
return &cfg, nil
}
func (s *BackupService) applyCronSchedule(cfg *BackupScheduleConfig) error {
s.cronMu.Lock()
defer s.cronMu.Unlock()
if s.cronSched == nil {
return fmt.Errorf("cron scheduler not initialized")
}
// 移除旧任务
if s.cronEntryID != 0 {
s.cronSched.Remove(s.cronEntryID)
s.cronEntryID = 0
}
entryID, err := s.cronSched.AddFunc(cfg.CronExpr, func() {
s.runScheduledBackup()
})
if err != nil {
return infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("failed to schedule: %v", err))
}
s.cronEntryID = entryID
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已启用: %s", cfg.CronExpr)
return nil
}
func (s *BackupService) removeCronSchedule() {
s.cronMu.Lock()
defer s.cronMu.Unlock()
if s.cronSched != nil && s.cronEntryID != 0 {
s.cronSched.Remove(s.cronEntryID)
s.cronEntryID = 0
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已停用")
}
}
func (s *BackupService) runScheduledBackup() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
// 读取定时备份配置中的过期天数
schedule, _ := s.GetSchedule(ctx)
expireDays := 14 // 默认14天过期
if schedule != nil && schedule.RetainDays > 0 {
expireDays = schedule.RetainDays
}
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
if err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
return
}
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
// 清理过期备份(复用已加载的 schedule
if schedule == nil {
return
}
if err := s.cleanupOldBackups(ctx, schedule); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 清理过期备份失败: %v", err)
}
}
// ─── 备份/恢复核心 ───
// CreateBackup 创建全量数据库备份并上传到 S3流式处理
// expireDays: 备份过期天数0=永不过期默认14天
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
s.mu.Lock()
if s.backingUp {
s.mu.Unlock()
return nil, ErrBackupInProgress
}
s.backingUp = true
s.mu.Unlock()
defer func() {
s.mu.Lock()
s.backingUp = false
s.mu.Unlock()
}()
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return nil, err
}
if s3Cfg == nil || !s3Cfg.IsConfigured() {
return nil, ErrBackupS3NotConfigured
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return nil, fmt.Errorf("init object store: %w", err)
}
now := time.Now()
backupID := uuid.New().String()[:8]
fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405"))
s3Key := s.buildS3Key(s3Cfg, fileName)
var expiresAt string
if expireDays > 0 {
expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339)
}
record := &BackupRecord{
ID: backupID,
Status: "running",
BackupType: "postgres",
FileName: fileName,
S3Key: s3Key,
TriggeredBy: triggeredBy,
StartedAt: now.Format(time.RFC3339),
ExpiresAt: expiresAt,
}
// 流式执行: pg_dump -> gzip -> S3 upload
dumpReader, err := s.dumper.Dump(ctx)
if err != nil {
record.Status = "failed"
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(ctx, record)
return record, fmt.Errorf("pg_dump: %w", err)
}
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
pr, pw := io.Pipe()
var gzipErr error
go func() {
gzWriter := gzip.NewWriter(pw)
_, gzipErr = io.Copy(gzWriter, dumpReader)
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
gzipErr = closeErr
}
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
gzipErr = closeErr
}
if gzipErr != nil {
_ = pw.CloseWithError(gzipErr)
} else {
_ = pw.Close()
}
}()
contentType := "application/gzip"
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
if err != nil {
record.Status = "failed"
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
if gzipErr != nil {
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
}
record.ErrorMsg = errMsg
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(ctx, record)
return record, fmt.Errorf("backup upload: %w", err)
}
record.SizeBytes = sizeBytes
record.Status = "completed"
record.FinishedAt = time.Now().Format(time.RFC3339)
if err := s.saveRecord(ctx, record); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err)
}
return record, nil
}
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
s.mu.Lock()
if s.restoring {
s.mu.Unlock()
return ErrRestoreInProgress
}
s.restoring = true
s.mu.Unlock()
defer func() {
s.mu.Lock()
s.restoring = false
s.mu.Unlock()
}()
record, err := s.GetBackupRecord(ctx, backupID)
if err != nil {
return err
}
if record.Status != "completed" {
return infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup")
}
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return err
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return fmt.Errorf("init object store: %w", err)
}
// 从 S3 流式下载
body, err := objectStore.Download(ctx, record.S3Key)
if err != nil {
return fmt.Errorf("S3 download failed: %w", err)
}
defer func() { _ = body.Close() }()
// 流式解压 gzip -> psql不将全部数据加载到内存
gzReader, err := gzip.NewReader(body)
if err != nil {
return fmt.Errorf("gzip reader: %w", err)
}
defer func() { _ = gzReader.Close() }()
// 流式恢复
if err := s.dumper.Restore(ctx, gzReader); err != nil {
return fmt.Errorf("pg restore: %w", err)
}
return nil
}
// ─── 备份记录管理 ───
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
records, err := s.loadRecords(ctx)
if err != nil {
return nil, err
}
// 倒序返回(最新在前)
sort.Slice(records, func(i, j int) bool {
return records[i].StartedAt > records[j].StartedAt
})
return records, nil
}
func (s *BackupService) GetBackupRecord(ctx context.Context, backupID string) (*BackupRecord, error) {
records, err := s.loadRecords(ctx)
if err != nil {
return nil, err
}
for i := range records {
if records[i].ID == backupID {
return &records[i], nil
}
}
return nil, ErrBackupNotFound
}
func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error {
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
records, err := s.loadRecordsLocked(ctx)
if err != nil {
return err
}
var found *BackupRecord
var remaining []BackupRecord
for i := range records {
if records[i].ID == backupID {
found = &records[i]
} else {
remaining = append(remaining, records[i])
}
}
if found == nil {
return ErrBackupNotFound
}
// 从 S3 删除
if found.S3Key != "" && found.Status == "completed" {
s3Cfg, err := s.loadS3Config(ctx)
if err == nil && s3Cfg != nil && s3Cfg.IsConfigured() {
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err == nil {
_ = objectStore.Delete(ctx, found.S3Key)
}
}
}
return s.saveRecordsLocked(ctx, remaining)
}
// GetBackupDownloadURL 获取备份文件预签名下载 URL
func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID string) (string, error) {
record, err := s.GetBackupRecord(ctx, backupID)
if err != nil {
return "", err
}
if record.Status != "completed" {
return "", infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "backup is not completed")
}
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return "", err
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return "", err
}
url, err := objectStore.PresignURL(ctx, record.S3Key, 1*time.Hour)
if err != nil {
return "", fmt.Errorf("presign url: %w", err)
}
return url, nil
}
// ─── 内部方法 ───
func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, error) {
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config)
if err != nil || raw == "" {
return nil, nil //nolint:nilnil // no config is a valid state
}
var cfg BackupS3Config
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
return nil, ErrBackupS3ConfigCorrupt
}
// 解密 SecretAccessKey
if cfg.SecretAccessKey != "" {
decrypted, err := s.encryptor.Decrypt(cfg.SecretAccessKey)
if err != nil {
// 兼容未加密的旧数据:如果解密失败,保持原值
logger.LegacyPrintf("service.backup", "[Backup] S3 SecretAccessKey 解密失败(可能是旧的未加密数据): %v", err)
} else {
cfg.SecretAccessKey = decrypted
}
}
return &cfg, nil
}
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.store != nil && s.s3Cfg != nil {
return s.store, nil
}
if cfg == nil {
return nil, ErrBackupS3NotConfigured
}
store, err := s.storeFactory(ctx, cfg)
if err != nil {
return nil, err
}
s.store = store
s.s3Cfg = cfg
return store, nil
}
func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string {
prefix := strings.TrimRight(cfg.Prefix, "/")
if prefix == "" {
prefix = "backups"
}
return fmt.Sprintf("%s/%s/%s", prefix, time.Now().Format("2006/01/02"), fileName)
}
// loadRecords 加载备份记录,区分"无数据"和"数据损坏"
func (s *BackupService) loadRecords(ctx context.Context) ([]BackupRecord, error) {
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
return s.loadRecordsLocked(ctx)
}
// loadRecordsLocked 在已持有 recordsMu 锁的情况下加载记录
func (s *BackupService) loadRecordsLocked(ctx context.Context) ([]BackupRecord, error) {
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupRecords)
if err != nil || raw == "" {
return nil, nil //nolint:nilnil // no records is a valid state
}
var records []BackupRecord
if err := json.Unmarshal([]byte(raw), &records); err != nil {
return nil, ErrBackupRecordsCorrupt
}
return records, nil
}
// saveRecordsLocked 在已持有 recordsMu 锁的情况下保存记录
func (s *BackupService) saveRecordsLocked(ctx context.Context, records []BackupRecord) error {
data, err := json.Marshal(records)
if err != nil {
return err
}
return s.settingRepo.Set(ctx, settingKeyBackupRecords, string(data))
}
// saveRecord 保存单条记录(带互斥锁保护)
func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) error {
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
records, _ := s.loadRecordsLocked(ctx)
// 更新已有记录或追加
found := false
for i := range records {
if records[i].ID == record.ID {
records[i] = *record
found = true
break
}
}
if !found {
records = append(records, *record)
}
// 限制记录数量
if len(records) > maxBackupRecords {
records = records[len(records)-maxBackupRecords:]
}
return s.saveRecordsLocked(ctx, records)
}
func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupScheduleConfig) error {
if schedule == nil {
return nil
}
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
records, err := s.loadRecordsLocked(ctx)
if err != nil {
return err
}
// 按时间倒序
sort.Slice(records, func(i, j int) bool {
return records[i].StartedAt > records[j].StartedAt
})
var toDelete []BackupRecord
var toKeep []BackupRecord
for i, r := range records {
shouldDelete := false
// 按保留份数清理
if schedule.RetainCount > 0 && i >= schedule.RetainCount {
shouldDelete = true
}
// 按保留天数清理
if schedule.RetainDays > 0 && r.StartedAt != "" {
startedAt, err := time.Parse(time.RFC3339, r.StartedAt)
if err == nil && time.Since(startedAt) > time.Duration(schedule.RetainDays)*24*time.Hour {
shouldDelete = true
}
}
if shouldDelete && r.Status == "completed" {
toDelete = append(toDelete, r)
} else {
toKeep = append(toKeep, r)
}
}
// 删除 S3 上的文件
for _, r := range toDelete {
if r.S3Key != "" {
_ = s.deleteS3Object(ctx, r.S3Key)
}
}
if len(toDelete) > 0 {
logger.LegacyPrintf("service.backup", "[Backup] 自动清理了 %d 个过期备份", len(toDelete))
return s.saveRecordsLocked(ctx, toKeep)
}
return nil
}
func (s *BackupService) deleteS3Object(ctx context.Context, key string) error {
s3Cfg, err := s.loadS3Config(ctx)
if err != nil || s3Cfg == nil {
return nil
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return err
}
return objectStore.Delete(ctx, key)
}

View File

@@ -0,0 +1,528 @@
//go:build unit
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// ─── Mocks ───
type mockSettingRepo struct {
mu sync.Mutex
data map[string]string
}
func newMockSettingRepo() *mockSettingRepo {
return &mockSettingRepo{data: make(map[string]string)}
}
func (m *mockSettingRepo) Get(_ context.Context, key string) (*Setting, error) {
m.mu.Lock()
defer m.mu.Unlock()
v, ok := m.data[key]
if !ok {
return nil, ErrSettingNotFound
}
return &Setting{Key: key, Value: v}, nil
}
func (m *mockSettingRepo) GetValue(_ context.Context, key string) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
v, ok := m.data[key]
if !ok {
return "", nil
}
return v, nil
}
func (m *mockSettingRepo) Set(_ context.Context, key, value string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.data[key] = value
return nil
}
func (m *mockSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
result := make(map[string]string)
for _, k := range keys {
if v, ok := m.data[k]; ok {
result[k] = v
}
}
return result, nil
}
func (m *mockSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
m.mu.Lock()
defer m.mu.Unlock()
for k, v := range settings {
m.data[k] = v
}
return nil
}
func (m *mockSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
result := make(map[string]string, len(m.data))
for k, v := range m.data {
result[k] = v
}
return result, nil
}
func (m *mockSettingRepo) Delete(_ context.Context, key string) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.data, key)
return nil
}
// plainEncryptor 仅做 base64-like 包装,用于测试
type plainEncryptor struct{}
func (e *plainEncryptor) Encrypt(plaintext string) (string, error) {
return "ENC:" + plaintext, nil
}
func (e *plainEncryptor) Decrypt(ciphertext string) (string, error) {
if strings.HasPrefix(ciphertext, "ENC:") {
return strings.TrimPrefix(ciphertext, "ENC:"), nil
}
return ciphertext, fmt.Errorf("not encrypted")
}
type mockDumper struct {
dumpData []byte
dumpErr error
restored []byte
restErr error
}
func (m *mockDumper) Dump(_ context.Context) (io.ReadCloser, error) {
if m.dumpErr != nil {
return nil, m.dumpErr
}
return io.NopCloser(bytes.NewReader(m.dumpData)), nil
}
func (m *mockDumper) Restore(_ context.Context, data io.Reader) error {
if m.restErr != nil {
return m.restErr
}
d, err := io.ReadAll(data)
if err != nil {
return err
}
m.restored = d
return nil
}
type mockObjectStore struct {
objects map[string][]byte
mu sync.Mutex
}
func newMockObjectStore() *mockObjectStore {
return &mockObjectStore{objects: make(map[string][]byte)}
}
func (m *mockObjectStore) Upload(_ context.Context, key string, body io.Reader, _ string) (int64, error) {
data, err := io.ReadAll(body)
if err != nil {
return 0, err
}
m.mu.Lock()
m.objects[key] = data
m.mu.Unlock()
return int64(len(data)), nil
}
func (m *mockObjectStore) Download(_ context.Context, key string) (io.ReadCloser, error) {
m.mu.Lock()
data, ok := m.objects[key]
m.mu.Unlock()
if !ok {
return nil, fmt.Errorf("not found: %s", key)
}
return io.NopCloser(bytes.NewReader(data)), nil
}
func (m *mockObjectStore) Delete(_ context.Context, key string) error {
m.mu.Lock()
delete(m.objects, key)
m.mu.Unlock()
return nil
}
func (m *mockObjectStore) PresignURL(_ context.Context, key string, _ time.Duration) (string, error) {
return "https://presigned.example.com/" + key, nil
}
func (m *mockObjectStore) HeadBucket(_ context.Context) error {
return nil
}
func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService {
cfg := &config.Config{
Database: config.DatabaseConfig{
Host: "localhost",
Port: 5432,
User: "test",
DBName: "testdb",
},
}
factory := func(_ context.Context, _ *BackupS3Config) (BackupObjectStore, error) {
return store, nil
}
return NewBackupService(repo, cfg, &plainEncryptor{}, factory, dumper)
}
func seedS3Config(t *testing.T, repo *mockSettingRepo) {
t.Helper()
cfg := BackupS3Config{
Bucket: "test-bucket",
AccessKeyID: "AKID",
SecretAccessKey: "ENC:secret123",
Prefix: "backups",
}
data, _ := json.Marshal(cfg)
require.NoError(t, repo.Set(context.Background(), settingKeyBackupS3Config, string(data)))
}
// ─── Tests ───
func TestBackupService_S3ConfigEncryption(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
// 保存配置 -> SecretAccessKey 应被加密
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
Bucket: "my-bucket",
AccessKeyID: "AKID",
SecretAccessKey: "my-secret",
Prefix: "backups",
})
require.NoError(t, err)
// 直接读取数据库中存储的值,应该是加密后的
raw, _ := repo.GetValue(context.Background(), settingKeyBackupS3Config)
var stored BackupS3Config
require.NoError(t, json.Unmarshal([]byte(raw), &stored))
require.Equal(t, "ENC:my-secret", stored.SecretAccessKey)
// 通过 GetS3Config 获取应该脱敏
cfg, err := svc.GetS3Config(context.Background())
require.NoError(t, err)
require.Empty(t, cfg.SecretAccessKey)
require.Equal(t, "my-bucket", cfg.Bucket)
// loadS3Config 内部应解密
internal, err := svc.loadS3Config(context.Background())
require.NoError(t, err)
require.Equal(t, "my-secret", internal.SecretAccessKey)
}
func TestBackupService_S3ConfigKeepExistingSecret(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
// 先保存一个有 secret 的配置
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
Bucket: "my-bucket",
AccessKeyID: "AKID",
SecretAccessKey: "original-secret",
})
require.NoError(t, err)
// 再更新时不提供 secret应保留原值
_, err = svc.UpdateS3Config(context.Background(), BackupS3Config{
Bucket: "my-bucket",
AccessKeyID: "AKID-NEW",
})
require.NoError(t, err)
internal, err := svc.loadS3Config(context.Background())
require.NoError(t, err)
require.Equal(t, "original-secret", internal.SecretAccessKey)
require.Equal(t, "AKID-NEW", internal.AccessKeyID)
}
func TestBackupService_SaveRecordConcurrency(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
var wg sync.WaitGroup
n := 20
wg.Add(n)
for i := 0; i < n; i++ {
go func(idx int) {
defer wg.Done()
record := &BackupRecord{
ID: fmt.Sprintf("rec-%d", idx),
Status: "completed",
StartedAt: time.Now().Format(time.RFC3339),
}
_ = svc.saveRecord(context.Background(), record)
}(i)
}
wg.Wait()
records, err := svc.loadRecords(context.Background())
require.NoError(t, err)
require.Len(t, records, n)
}
func TestBackupService_LoadRecords_Empty(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
records, err := svc.loadRecords(context.Background())
require.NoError(t, err)
require.Nil(t, records) // 无数据时返回 nil
}
func TestBackupService_LoadRecords_Corrupted(t *testing.T) {
repo := newMockSettingRepo()
_ = repo.Set(context.Background(), settingKeyBackupRecords, "not valid json{{{")
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
records, err := svc.loadRecords(context.Background())
require.Error(t, err) // 损坏数据应返回错误
require.Nil(t, records)
}
func TestBackupService_CreateBackup_Streaming(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
dumper := &mockDumper{dumpData: []byte(dumpContent)}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
require.Equal(t, "completed", record.Status)
require.Greater(t, record.SizeBytes, int64(0))
require.NotEmpty(t, record.S3Key)
// 验证 S3 上确实有文件
store.mu.Lock()
require.Len(t, store.objects, 1)
store.mu.Unlock()
}
func TestBackupService_CreateBackup_DumpFailure(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &mockDumper{dumpErr: fmt.Errorf("pg_dump failed")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.Error(t, err)
require.Equal(t, "failed", record.Status)
require.Contains(t, record.ErrorMsg, "pg_dump")
}
func TestBackupService_CreateBackup_NoS3Config(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
_, err := svc.CreateBackup(context.Background(), "manual", 14)
require.ErrorIs(t, err, ErrBackupS3NotConfigured)
}
func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
// 使用一个慢速 dumper 来模拟正在进行的备份
dumper := &mockDumper{dumpData: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
// 手动设置 backingUp 标志
svc.mu.Lock()
svc.backingUp = true
svc.mu.Unlock()
_, err := svc.CreateBackup(context.Background(), "manual", 14)
require.ErrorIs(t, err, ErrBackupInProgress)
}
func TestBackupService_RestoreBackup_Streaming(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
dumper := &mockDumper{dumpData: []byte(dumpContent)}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
// 先创建一个备份
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// 恢复
err = svc.RestoreBackup(context.Background(), record.ID)
require.NoError(t, err)
// 验证 psql 收到的数据是否与原始 dump 内容一致
require.Equal(t, dumpContent, string(dumper.restored))
}
func TestBackupService_RestoreBackup_NotCompleted(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
// 手动插入一条 failed 记录
_ = svc.saveRecord(context.Background(), &BackupRecord{
ID: "fail-1",
Status: "failed",
})
err := svc.RestoreBackup(context.Background(), "fail-1")
require.Error(t, err)
}
func TestBackupService_DeleteBackup(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumpContent := "data"
dumper := &mockDumper{dumpData: []byte(dumpContent)}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// S3 中应有文件
store.mu.Lock()
require.Len(t, store.objects, 1)
store.mu.Unlock()
// 删除
err = svc.DeleteBackup(context.Background(), record.ID)
require.NoError(t, err)
// S3 中文件应被删除
store.mu.Lock()
require.Len(t, store.objects, 0)
store.mu.Unlock()
// 记录应不存在
_, err = svc.GetBackupRecord(context.Background(), record.ID)
require.ErrorIs(t, err, ErrBackupNotFound)
}
func TestBackupService_GetDownloadURL(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &mockDumper{dumpData: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
url, err := svc.GetBackupDownloadURL(context.Background(), record.ID)
require.NoError(t, err)
require.Contains(t, url, "https://presigned.example.com/")
}
func TestBackupService_ListBackups_Sorted(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
now := time.Now()
for i := 0; i < 3; i++ {
_ = svc.saveRecord(context.Background(), &BackupRecord{
ID: fmt.Sprintf("rec-%d", i),
Status: "completed",
StartedAt: now.Add(time.Duration(i) * time.Hour).Format(time.RFC3339),
})
}
records, err := svc.ListBackups(context.Background())
require.NoError(t, err)
require.Len(t, records, 3)
// 最新在前
require.Equal(t, "rec-2", records[0].ID)
require.Equal(t, "rec-0", records[2].ID)
}
func TestBackupService_TestS3Connection(t *testing.T) {
repo := newMockSettingRepo()
store := newMockObjectStore()
svc := newTestBackupService(repo, &mockDumper{}, store)
err := svc.TestS3Connection(context.Background(), BackupS3Config{
Bucket: "test",
AccessKeyID: "ak",
SecretAccessKey: "sk",
})
require.NoError(t, err)
}
func TestBackupService_TestS3Connection_Incomplete(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
err := svc.TestS3Connection(context.Background(), BackupS3Config{
Bucket: "test",
})
require.Error(t, err)
require.Contains(t, err.Error(), "incomplete")
}
func TestBackupService_Schedule_CronValidation(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
svc.cronSched = nil // 未初始化 cron
// 启用但 cron 为空
_, err := svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
Enabled: true,
CronExpr: "",
})
require.Error(t, err)
// 无效的 cron 表达式
_, err = svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
Enabled: true,
CronExpr: "invalid",
})
require.Error(t, err)
}
func TestBackupService_LoadS3Config_Corrupted(t *testing.T) {
repo := newMockSettingRepo()
_ = repo.Set(context.Background(), settingKeyBackupS3Config, "not json!!!!")
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
cfg, err := svc.loadS3Config(context.Background())
require.Error(t, err)
require.Nil(t, cfg)
}

View File

@@ -29,12 +29,11 @@ const (
// Account type constants // Account type constants
const ( const (
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号full scope: profile + inference AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号inference only scope AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号inference only scope
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock由 credentials.auth_mode 区分
AccountTypeBedrockAPIKey = domain.AccountTypeBedrockAPIKey // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock
) )
// Redeem type constants // Redeem type constants
@@ -221,6 +220,9 @@ const (
// SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false未分组 Key 返回 403 // SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false未分组 Key 返回 403
SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling" SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling"
// SettingKeyBackendModeEnabled Backend 模式:禁用用户注册和自助服务,仅管理员可登录
SettingKeyBackendModeEnabled = "backend_mode_enabled"
) )
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).

View File

@@ -110,7 +110,9 @@ func TestCheckErrorPolicy(t *testing.T) {
expected: ErrorPolicyTempUnscheduled, expected: ErrorPolicyTempUnscheduled,
}, },
{ {
name: "temp_unschedulable_401_second_hit_upgrades_to_none", // Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制),
// second hit 仍然返回 TempUnscheduled。
name: "temp_unschedulable_401_second_hit_antigravity_stays_temp",
account: &Account{ account: &Account{
ID: 15, ID: 15,
Type: AccountTypeOAuth, Type: AccountTypeOAuth,
@@ -129,7 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) {
}, },
statusCode: 401, statusCode: 401,
body: []byte(`unauthorized`), body: []byte(`unauthorized`),
expected: ErrorPolicyNone, expected: ErrorPolicyTempUnscheduled,
}, },
{ {
name: "temp_unschedulable_body_miss_returns_none", name: "temp_unschedulable_body_miss_returns_none",

View File

@@ -2173,10 +2173,10 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []
return context.WithValue(ctx, windowCostPrefetchContextKey, costs) return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
} }
// isAccountSchedulableForQuota 检查 API Key 账号是否在配额限制内 // isAccountSchedulableForQuota 检查账号是否在配额限制内
// 适用于配置了 quota_limit 的 apikey 类型账号 // 适用于配置了 quota_limit 的 apikey 和 bedrock 类型账号
func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool { func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool {
if account.Type != AccountTypeAPIKey { if !account.IsAPIKeyOrBedrock() {
return true return true
} }
return !account.IsQuotaExceeded() return !account.IsQuotaExceeded()
@@ -3532,9 +3532,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
} }
return apiKey, "apikey", nil return apiKey, "apikey", nil
case AccountTypeBedrock: case AccountTypeBedrock:
return "", "bedrock", nil // Bedrock 使用 SigV4 签名,不需要 token return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key由 forwardBedrock 处理
case AccountTypeBedrockAPIKey:
return "", "bedrock-apikey", nil // Bedrock API Key 使用 Bearer Token由 forwardBedrock 处理
default: default:
return "", "", fmt.Errorf("unsupported account type: %s", account.Type) return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
} }
@@ -5186,7 +5184,7 @@ func (s *GatewayService) forwardBedrock(
if account.IsBedrockAPIKey() { if account.IsBedrockAPIKey() {
bedrockAPIKey = account.GetCredential("api_key") bedrockAPIKey = account.GetCredential("api_key")
if bedrockAPIKey == "" { if bedrockAPIKey == "" {
return nil, fmt.Errorf("api_key not found in bedrock-apikey credentials") return nil, fmt.Errorf("api_key not found in bedrock credentials")
} }
} else { } else {
signer, err = NewBedrockSignerFromAccount(account) signer, err = NewBedrockSignerFromAccount(account)
@@ -5375,8 +5373,9 @@ func (s *GatewayService) handleBedrockUpstreamErrors(
Message: extractUpstreamErrorMessage(respBody), Message: extractUpstreamErrorMessage(respBody),
}) })
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
ResponseBody: respBody, ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
} }
} }
return s.handleRetryExhaustedError(ctx, resp, c, account) return s.handleRetryExhaustedError(ctx, resp, c, account)
@@ -5398,8 +5397,9 @@ func (s *GatewayService) handleBedrockUpstreamErrors(
Message: extractUpstreamErrorMessage(respBody), Message: extractUpstreamErrorMessage(respBody),
}) })
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
ResponseBody: respBody, ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
} }
} }
@@ -5808,9 +5808,10 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri
return betaPolicyResult{} return betaPolicyResult{}
} }
isOAuth := account.IsOAuth() isOAuth := account.IsOAuth()
isBedrock := account.IsBedrock()
var result betaPolicyResult var result betaPolicyResult
for _, rule := range settings.Rules { for _, rule := range settings.Rules {
if !betaPolicyScopeMatches(rule.Scope, isOAuth) { if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue continue
} }
switch rule.Action { switch rule.Action {
@@ -5870,14 +5871,16 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont
} }
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type. // betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
func betaPolicyScopeMatches(scope string, isOAuth bool) bool { func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool {
switch scope { switch scope {
case BetaPolicyScopeAll: case BetaPolicyScopeAll:
return true return true
case BetaPolicyScopeOAuth: case BetaPolicyScopeOAuth:
return isOAuth return isOAuth
case BetaPolicyScopeAPIKey: case BetaPolicyScopeAPIKey:
return !isOAuth return !isOAuth && !isBedrock
case BetaPolicyScopeBedrock:
return isBedrock
default: default:
return true // unknown scope → match all (fail-open) return true // unknown scope → match all (fail-open)
} }
@@ -5959,12 +5962,13 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke
return nil return nil
} }
isOAuth := account.IsOAuth() isOAuth := account.IsOAuth()
isBedrock := account.IsBedrock()
tokenSet := buildBetaTokenSet(tokens) tokenSet := buildBetaTokenSet(tokens)
for _, rule := range settings.Rules { for _, rule := range settings.Rules {
if rule.Action != BetaPolicyActionBlock { if rule.Action != BetaPolicyActionBlock {
continue continue
} }
if !betaPolicyScopeMatches(rule.Scope, isOAuth) { if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue continue
} }
if _, present := tokenSet[rule.BetaToken]; present { if _, present := tokenSet[rule.BetaToken]; present {
@@ -6125,6 +6129,29 @@ func extractUpstreamErrorMessage(body []byte) string {
return gjson.GetBytes(body, "message").String() return gjson.GetBytes(body, "message").String()
} }
func extractUpstreamErrorCode(body []byte) string {
if code := strings.TrimSpace(gjson.GetBytes(body, "error.code").String()); code != "" {
return code
}
inner := strings.TrimSpace(gjson.GetBytes(body, "error.message").String())
if !strings.HasPrefix(inner, "{") {
return ""
}
if code := strings.TrimSpace(gjson.Get(inner, "error.code").String()); code != "" {
return code
}
if lastBrace := strings.LastIndex(inner, "}"); lastBrace >= 0 {
if code := strings.TrimSpace(gjson.Get(inner[:lastBrace+1], "error.code").String()); code != "" {
return code
}
}
return ""
}
func isCountTokensUnsupported404(statusCode int, body []byte) bool { func isCountTokensUnsupported404(statusCode int, body []byte) bool {
if statusCode != http.StatusNotFound { if statusCode != http.StatusNotFound {
return false return false
@@ -7176,7 +7203,7 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
} }
// 4. 账号配额用量账号口径TotalCost × 账号计费倍率) // 4. 账号配额用量账号口径TotalCost × 账号计费倍率)
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() { if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
accountCost := cost.TotalCost * p.AccountRateMultiplier accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
@@ -7264,7 +7291,7 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
cmd.APIKeyRateLimitCost = p.Cost.ActualCost cmd.APIKeyRateLimitCost = p.Cost.ActualCost
} }
if p.Cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() { if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
} }

View File

@@ -339,8 +339,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
} }
typ, _ := m["type"].(string) typ, _ := m["type"].(string)
// 修复 OpenAI 上游的最新校验:"Expected an ID that begins with 'fc'" // 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id
fixIDPrefix := func(id string) string { // 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。
fixCallIDPrefix := func(id string) string {
if id == "" || strings.HasPrefix(id, "fc") { if id == "" || strings.HasPrefix(id, "fc") {
return id return id
} }
@@ -358,8 +359,8 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
for key, value := range m { for key, value := range m {
newItem[key] = value newItem[key] = value
} }
if id, ok := newItem["id"].(string); ok && id != "" { if id, ok := newItem["id"].(string); ok && strings.HasPrefix(id, "call_") {
newItem["id"] = fixIDPrefix(id) newItem["id"] = fixCallIDPrefix(id)
} }
filtered = append(filtered, newItem) filtered = append(filtered, newItem)
continue continue
@@ -390,7 +391,7 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
} }
if callID != "" { if callID != "" {
fixedCallID := fixIDPrefix(callID) fixedCallID := fixCallIDPrefix(callID)
if fixedCallID != callID { if fixedCallID != callID {
ensureCopy() ensureCopy()
newItem["call_id"] = fixedCallID newItem["call_id"] = fixedCallID
@@ -404,14 +405,6 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
if !isCodexToolCallItemType(typ) { if !isCodexToolCallItemType(typ) {
delete(newItem, "call_id") delete(newItem, "call_id")
} }
} else {
if id, ok := newItem["id"].(string); ok && id != "" {
fixedID := fixIDPrefix(id)
if fixedID != id {
ensureCopy()
newItem["id"] = fixedID
}
}
} }
filtered = append(filtered, newItem) filtered = append(filtered, newItem)

View File

@@ -33,12 +33,63 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
first, ok := input[0].(map[string]any) first, ok := input[0].(map[string]any)
require.True(t, ok) require.True(t, ok)
require.Equal(t, "item_reference", first["type"]) require.Equal(t, "item_reference", first["type"])
require.Equal(t, "fc_ref1", first["id"]) require.Equal(t, "ref1", first["id"])
// 校验 input[1] 为 map确保后续字段断言安全。 // 校验 input[1] 为 map确保后续字段断言安全。
second, ok := input[1].(map[string]any) second, ok := input[1].(map[string]any)
require.True(t, ok) require.True(t, ok)
require.Equal(t, "fc_o1", second["id"]) require.Equal(t, "o1", second["id"])
require.Equal(t, "fc1", second["call_id"])
}
func TestApplyCodexOAuthTransform_ToolContinuationPreservesNativeMessageAndReasoningIDs(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"},
map[string]any{"type": "item_reference", "id": "rs_123"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody, false, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "msg_0", first["id"])
second, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "rs_123", second["id"])
}
func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "item_reference", "id": "call_1"},
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody, false, false)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "fc1", first["id"])
second, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "fc1", second["call_id"])
} }
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {

View File

@@ -51,10 +51,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
} }
// 3. Model mapping // 3. Model mapping
mappedModel := account.GetMappedModel(originalModel) mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
if mappedModel == originalModel && defaultMappedModel != "" {
mappedModel = defaultMappedModel
}
responsesReq.Model = mappedModel responsesReq.Model = mappedModel
logger.L().Debug("openai chat_completions: model mapping applied", logger.L().Debug("openai chat_completions: model mapping applied",

View File

@@ -59,11 +59,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
} }
// 3. Model mapping // 3. Model mapping
mappedModel := account.GetMappedModel(originalModel) mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
// 分组级降级:账号未映射时使用分组默认映射模型
if mappedModel == originalModel && defaultMappedModel != "" {
mappedModel = defaultMappedModel
}
responsesReq.Model = mappedModel responsesReq.Model = mappedModel
logger.L().Debug("openai messages: model mapping applied", logger.L().Debug("openai messages: model mapping applied",

View File

@@ -480,6 +480,7 @@ func classifyOpenAIWSReconnectReason(err error) (string, bool) {
"upgrade_required", "upgrade_required",
"ws_unsupported", "ws_unsupported",
"auth_failed", "auth_failed",
"invalid_encrypted_content",
"previous_response_not_found": "previous_response_not_found":
return reason, false return reason, false
} }
@@ -530,6 +531,14 @@ func resolveOpenAIWSFallbackErrorResponse(err error) (statusCode int, errType st
} }
switch reason { switch reason {
case "invalid_encrypted_content":
if statusCode == 0 {
statusCode = http.StatusBadRequest
}
errType = "invalid_request_error"
if upstreamMessage == "" {
upstreamMessage = "encrypted content could not be verified"
}
case "previous_response_not_found": case "previous_response_not_found":
if statusCode == 0 { if statusCode == 0 {
statusCode = http.StatusBadRequest statusCode = http.StatusBadRequest
@@ -1924,6 +1933,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
var wsErr error var wsErr error
wsLastFailureReason := "" wsLastFailureReason := ""
wsPrevResponseRecoveryTried := false wsPrevResponseRecoveryTried := false
wsInvalidEncryptedContentRecoveryTried := false
recoverPrevResponseNotFound := func(attempt int) bool { recoverPrevResponseNotFound := func(attempt int) bool {
if wsPrevResponseRecoveryTried { if wsPrevResponseRecoveryTried {
return false return false
@@ -1956,6 +1966,37 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
) )
return true return true
} }
recoverInvalidEncryptedContent := func(attempt int) bool {
if wsInvalidEncryptedContentRecoveryTried {
return false
}
removedReasoningItems := trimOpenAIEncryptedReasoningItems(wsReqBody)
if !removedReasoningItems {
logOpenAIWSModeInfo(
"reconnect_invalid_encrypted_content_recovery_skip account_id=%d attempt=%d reason=missing_encrypted_reasoning_items",
account.ID,
attempt,
)
return false
}
previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id")
hasFunctionCallOutput := HasFunctionCallOutput(wsReqBody)
if previousResponseID != "" && !hasFunctionCallOutput {
delete(wsReqBody, "previous_response_id")
}
wsInvalidEncryptedContentRecoveryTried = true
logOpenAIWSModeInfo(
"reconnect_invalid_encrypted_content_recovery account_id=%d attempt=%d action=drop_encrypted_reasoning_items retry=1 previous_response_id_present=%v previous_response_id=%s previous_response_id_kind=%s has_function_call_output=%v dropped_previous_response_id=%v",
account.ID,
attempt,
previousResponseID != "",
truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)),
hasFunctionCallOutput,
previousResponseID != "" && !hasFunctionCallOutput,
)
return true
}
retryBudget := s.openAIWSRetryTotalBudget() retryBudget := s.openAIWSRetryTotalBudget()
retryStartedAt := time.Now() retryStartedAt := time.Now()
wsRetryLoop: wsRetryLoop:
@@ -1992,6 +2033,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
if reason == "previous_response_not_found" && recoverPrevResponseNotFound(attempt) { if reason == "previous_response_not_found" && recoverPrevResponseNotFound(attempt) {
continue continue
} }
if reason == "invalid_encrypted_content" && recoverInvalidEncryptedContent(attempt) {
continue
}
if retryable && attempt < maxAttempts { if retryable && attempt < maxAttempts {
backoff := s.openAIWSRetryBackoff(attempt) backoff := s.openAIWSRetryBackoff(attempt)
if retryBudget > 0 && time.Since(retryStartedAt)+backoff > retryBudget { if retryBudget > 0 && time.Since(retryStartedAt)+backoff > retryBudget {
@@ -2075,126 +2119,143 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
return nil, wsErr return nil, wsErr
} }
// Build upstream request httpInvalidEncryptedContentRetryTried := false
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) for {
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) // Build upstream request
releaseUpstreamCtx() upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
if err != nil { upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
return nil, err releaseUpstreamCtx()
} if err != nil {
return nil, err
}
// Get proxy URL // Get proxy URL
proxyURL := "" proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
// Send request // Send request
upstreamStart := time.Now() upstreamStart := time.Now()
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
if err != nil { if err != nil {
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
safeErr := sanitizeUpstreamErrorMessage(err.Error()) safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "") setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// Handle error response
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name, AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: 0,
UpstreamRequestID: resp.Header.Get("x-request-id"), Kind: "request_error",
Kind: "failover", Message: safeErr,
Message: upstreamMsg,
Detail: upstreamDetail,
}) })
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
s.handleFailoverSideEffects(ctx, resp, account) // Handle error response
return nil, &UpstreamFailoverError{ if resp.StatusCode >= 400 {
StatusCode: resp.StatusCode, respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
ResponseBody: respBody, _ = resp.Body.Close()
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamCode := extractUpstreamErrorCode(respBody)
if !httpInvalidEncryptedContentRetryTried && resp.StatusCode == http.StatusBadRequest && upstreamCode == "invalid_encrypted_content" {
if trimOpenAIEncryptedReasoningItems(reqBody) {
body, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("serialize invalid_encrypted_content retry body: %w", err)
}
setOpsUpstreamRequestBody(c, body)
httpInvalidEncryptedContentRetryTried = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Retrying non-WSv2 request once after invalid_encrypted_content (account: %s)", account.Name)
continue
}
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Skip non-WSv2 invalid_encrypted_content retry because encrypted reasoning items are missing (account: %s)", account.Name)
}
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleErrorResponse(ctx, resp, c, account, body)
}
defer func() { _ = resp.Body.Close() }()
// Handle normal response
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
if err != nil {
return nil, err
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
if err != nil {
return nil, err
} }
} }
return s.handleErrorResponse(ctx, resp, c, account, body)
}
// Handle normal response // Extract and save Codex usage snapshot from response headers (for OAuth accounts)
var usage *OpenAIUsage if account.Type == AccountTypeOAuth {
var firstTokenMs *int if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
if reqStream { s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel) }
if err != nil {
return nil, err
} }
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs if usage == nil {
} else { usage = &OpenAIUsage{}
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
if err != nil {
return nil, err
} }
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
serviceTier := extractOpenAIServiceTier(reqBody)
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
ServiceTier: serviceTier,
ReasoningEffort: reasoningEffort,
Stream: reqStream,
OpenAIWSMode: false,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
} }
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if account.Type == AccountTypeOAuth {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
if usage == nil {
usage = &OpenAIUsage{}
}
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
serviceTier := extractOpenAIServiceTier(reqBody)
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
ServiceTier: serviceTier,
ReasoningEffort: reasoningEffort,
Stream: reqStream,
OpenAIWSMode: false,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
} }
func (s *OpenAIGatewayService) forwardOpenAIPassthrough( func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
@@ -3756,6 +3817,109 @@ func buildOpenAIResponsesURL(base string) string {
return normalized + "/v1/responses" return normalized + "/v1/responses"
} }
func trimOpenAIEncryptedReasoningItems(reqBody map[string]any) bool {
if len(reqBody) == 0 {
return false
}
inputValue, has := reqBody["input"]
if !has {
return false
}
switch input := inputValue.(type) {
case []any:
filtered := input[:0]
changed := false
for _, item := range input {
nextItem, itemChanged, keep := sanitizeEncryptedReasoningInputItem(item)
if itemChanged {
changed = true
}
if !keep {
continue
}
filtered = append(filtered, nextItem)
}
if !changed {
return false
}
if len(filtered) == 0 {
delete(reqBody, "input")
return true
}
reqBody["input"] = filtered
return true
case []map[string]any:
filtered := input[:0]
changed := false
for _, item := range input {
nextItem, itemChanged, keep := sanitizeEncryptedReasoningInputItem(item)
if itemChanged {
changed = true
}
if !keep {
continue
}
nextMap, ok := nextItem.(map[string]any)
if !ok {
filtered = append(filtered, item)
continue
}
filtered = append(filtered, nextMap)
}
if !changed {
return false
}
if len(filtered) == 0 {
delete(reqBody, "input")
return true
}
reqBody["input"] = filtered
return true
case map[string]any:
nextItem, changed, keep := sanitizeEncryptedReasoningInputItem(input)
if !changed {
return false
}
if !keep {
delete(reqBody, "input")
return true
}
nextMap, ok := nextItem.(map[string]any)
if !ok {
return false
}
reqBody["input"] = nextMap
return true
default:
return false
}
}
func sanitizeEncryptedReasoningInputItem(item any) (next any, changed bool, keep bool) {
inputItem, ok := item.(map[string]any)
if !ok {
return item, false, true
}
itemType, _ := inputItem["type"].(string)
if strings.TrimSpace(itemType) != "reasoning" {
return item, false, true
}
_, hasEncryptedContent := inputItem["encrypted_content"]
if !hasEncryptedContent {
return item, false, true
}
delete(inputItem, "encrypted_content")
if len(inputItem) == 1 {
return nil, true, false
}
return inputItem, true, true
}
func IsOpenAIResponsesCompactPathForTest(c *gin.Context) bool { func IsOpenAIResponsesCompactPathForTest(c *gin.Context) bool {
return isOpenAIResponsesCompactPath(c) return isOpenAIResponsesCompactPath(c)
} }

View File

@@ -0,0 +1,19 @@
package service
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
// forwarding. Group-level default mapping only applies when the account itself
// did not match any explicit model_mapping rule.
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
if account == nil {
if defaultMappedModel != "" {
return defaultMappedModel
}
return requestedModel
}
mappedModel, matched := account.ResolveMappedModel(requestedModel)
if !matched && defaultMappedModel != "" {
return defaultMappedModel
}
return mappedModel
}

View File

@@ -0,0 +1,70 @@
package service
import "testing"
func TestResolveOpenAIForwardModel(t *testing.T) {
tests := []struct {
name string
account *Account
requestedModel string
defaultMappedModel string
expectedModel string
}{
{
name: "falls back to group default when account has no mapping",
account: &Account{
Credentials: map[string]any{},
},
requestedModel: "gpt-5.4",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-4o-mini",
},
{
name: "preserves exact passthrough mapping instead of group default",
account: &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
},
requestedModel: "gpt-5.4",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
{
name: "preserves wildcard passthrough mapping instead of group default",
account: &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-*": "gpt-5.4",
},
},
},
requestedModel: "gpt-5.4",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
{
name: "uses account remap when explicit target differs",
account: &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.4",
},
},
},
requestedModel: "gpt-5",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveOpenAIForwardModel(tt.account, tt.requestedModel, tt.defaultMappedModel); got != tt.expectedModel {
t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", got, tt.expectedModel)
}
})
}
}

View File

@@ -3922,6 +3922,8 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri
return "ws_unsupported", true return "ws_unsupported", true
case "websocket_connection_limit_reached": case "websocket_connection_limit_reached":
return "ws_connection_limit_reached", true return "ws_connection_limit_reached", true
case "invalid_encrypted_content":
return "invalid_encrypted_content", true
case "previous_response_not_found": case "previous_response_not_found":
return "previous_response_not_found", true return "previous_response_not_found", true
} }
@@ -3940,6 +3942,10 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri
if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") { if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") {
return "ws_connection_limit_reached", true return "ws_connection_limit_reached", true
} }
if strings.Contains(msg, "invalid_encrypted_content") ||
(strings.Contains(msg, "encrypted content") && strings.Contains(msg, "could not be verified")) {
return "invalid_encrypted_content", true
}
if strings.Contains(msg, "previous_response_not_found") || if strings.Contains(msg, "previous_response_not_found") ||
(strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) { (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) {
return "previous_response_not_found", true return "previous_response_not_found", true
@@ -3964,6 +3970,7 @@ func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
case strings.Contains(errType, "invalid_request"), case strings.Contains(errType, "invalid_request"),
strings.Contains(code, "invalid_request"), strings.Contains(code, "invalid_request"),
strings.Contains(code, "bad_request"), strings.Contains(code, "bad_request"),
code == "invalid_encrypted_content",
code == "previous_response_not_found": code == "previous_response_not_found":
return http.StatusBadRequest return http.StatusBadRequest
case strings.Contains(errType, "authentication"), case strings.Contains(errType, "authentication"),

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"io" "io"
@@ -19,6 +20,47 @@ import (
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
type httpUpstreamSequenceRecorder struct {
mu sync.Mutex
bodies [][]byte
reqs []*http.Request
responses []*http.Response
errs []error
callCount int
}
func (u *httpUpstreamSequenceRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
u.mu.Lock()
defer u.mu.Unlock()
idx := u.callCount
u.callCount++
u.reqs = append(u.reqs, req)
if req != nil && req.Body != nil {
b, _ := io.ReadAll(req.Body)
u.bodies = append(u.bodies, b)
_ = req.Body.Close()
req.Body = io.NopCloser(bytes.NewReader(b))
} else {
u.bodies = append(u.bodies, nil)
}
if idx < len(u.errs) && u.errs[idx] != nil {
return nil, u.errs[idx]
}
if idx < len(u.responses) {
return u.responses[idx], nil
}
if len(u.responses) == 0 {
return nil, nil
}
return u.responses[len(u.responses)-1], nil
}
func (u *httpUpstreamSequenceRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return u.Do(req, proxyURL, accountID, accountConcurrency)
}
func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t *testing.T) { func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -143,6 +185,176 @@ func TestOpenAIGatewayService_Forward_HTTPIngressStaysHTTPWhenWSEnabled(t *testi
require.Equal(t, "client_protocol_http", reason) require.Equal(t, "client_protocol_http", reason)
} }
func TestOpenAIGatewayService_Forward_HTTPIngressRetriesInvalidEncryptedContentOnce(t *testing.T) {
gin.SetMode(gin.TestMode)
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer wsFallbackServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
upstream := &httpUpstreamSequenceRecorder{
responses: []*http.Response{
{
StatusCode: http.StatusBadRequest,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"error":{"code":"invalid_encrypted_content","type":"invalid_request_error","message":"The encrypted content could not be verified."}}`,
)),
},
{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"id":"resp_http_retry_ok","usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
)),
},
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 102,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsFallbackServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_retry","input":[{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me"}]},{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发")
require.Equal(t, 2, upstream.callCount, "命中 invalid_encrypted_content 后应只在 HTTP 路径重试一次")
require.Len(t, upstream.bodies, 2)
firstBody := upstream.bodies[0]
secondBody := upstream.bodies[1]
require.False(t, gjson.GetBytes(firstBody, "previous_response_id").Exists(), "HTTP 首次请求仍应沿用原逻辑移除 previous_response_id")
require.True(t, gjson.GetBytes(firstBody, "input.0.encrypted_content").Exists(), "首次请求不应做发送前预清理")
require.Equal(t, "keep me", gjson.GetBytes(firstBody, "input.0.summary.0.text").String())
require.False(t, gjson.GetBytes(secondBody, "previous_response_id").Exists(), "HTTP 精确重试不应重新带回 previous_response_id")
require.False(t, gjson.GetBytes(secondBody, "input.0.encrypted_content").Exists(), "精确重试应移除 reasoning.encrypted_content")
require.Equal(t, "keep me", gjson.GetBytes(secondBody, "input.0.summary.0.text").String(), "精确重试应保留有效 reasoning summary")
require.Equal(t, "input_text", gjson.GetBytes(secondBody, "input.1.type").String(), "非 reasoning input 应保持原样")
decision, _ := c.Get("openai_ws_transport_decision")
reason, _ := c.Get("openai_ws_transport_reason")
require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision)
require.Equal(t, "client_protocol_http", reason)
}
func TestOpenAIGatewayService_Forward_HTTPIngressRetriesWrappedInvalidEncryptedContentOnce(t *testing.T) {
gin.SetMode(gin.TestMode)
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer wsFallbackServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
upstream := &httpUpstreamSequenceRecorder{
responses: []*http.Response{
{
StatusCode: http.StatusBadRequest,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"error":{"code":null,"message":"{\"error\":{\"message\":\"The encrypted content could not be verified.\",\"type\":\"invalid_request_error\",\"param\":null,\"code\":\"invalid_encrypted_content\"}}traceid: fb7ad1dbc7699c18f8a02f258f1af5ab","param":null,"type":"invalid_request_error"}}`,
)),
},
{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"x-request-id": []string{"req_http_retry_wrapped_ok"},
},
Body: io.NopCloser(strings.NewReader(
`{"id":"resp_http_retry_wrapped_ok","usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
)),
},
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 103,
Name: "openai-apikey-wrapped",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsFallbackServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_retry_wrapped","input":[{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me too"}]},{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发")
require.Equal(t, 2, upstream.callCount, "wrapped invalid_encrypted_content 也应只在 HTTP 路径重试一次")
require.Len(t, upstream.bodies, 2)
firstBody := upstream.bodies[0]
secondBody := upstream.bodies[1]
require.True(t, gjson.GetBytes(firstBody, "input.0.encrypted_content").Exists(), "首次请求不应做发送前预清理")
require.False(t, gjson.GetBytes(secondBody, "input.0.encrypted_content").Exists(), "wrapped exact retry 应移除 reasoning.encrypted_content")
require.Equal(t, "keep me too", gjson.GetBytes(secondBody, "input.0.summary.0.text").String(), "wrapped exact retry 应保留有效 reasoning summary")
decision, _ := c.Get("openai_ws_transport_decision")
reason, _ := c.Get("openai_ws_transport_reason")
require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision)
require.Equal(t, "client_protocol_http", reason)
}
func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *testing.T) { func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -1218,3 +1430,460 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOn
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应包含 previous_response_id") require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应包含 previous_response_id")
require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
} }
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentRecoversOnce(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempt := wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
if attempt == 1 {
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "invalid_encrypted_content",
"type": "invalid_request_error",
"message": "The encrypted content could not be verified.",
},
})
return
}
_ = conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": "resp_ws_invalid_encrypted_content_recover_ok",
"model": "gpt-5.3-codex",
"usage": map[string]any{
"input_tokens": 1,
"output_tokens": 1,
"input_tokens_details": map[string]any{
"cached_tokens": 0,
},
},
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 95,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_encrypted","input":[{"type":"reasoning","encrypted_content":"gAAA"},{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_ws_invalid_encrypted_content_recover_ok", result.RequestID)
require.Nil(t, upstream.lastReq, "invalid_encrypted_content 不应回退 HTTP")
require.Equal(t, int32(2), wsAttempts.Load(), "invalid_encrypted_content 应触发一次清洗后重试")
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "resp_ws_invalid_encrypted_content_recover_ok", gjson.Get(rec.Body.String(), "id").String())
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 2)
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id")
require.True(t, gjson.GetBytes(requests[0], `input.0.encrypted_content`).Exists(), "首轮请求应保留 encrypted reasoning")
require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
require.False(t, gjson.GetBytes(requests[1], `input.0.encrypted_content`).Exists(), "恢复重试应移除 encrypted reasoning item")
require.Equal(t, "input_text", gjson.GetBytes(requests[1], `input.0.type`).String())
}
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentSkipsRecoveryWithoutReasoningItem(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "invalid_encrypted_content",
"type": "invalid_request_error",
"message": "The encrypted content could not be verified.",
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 96,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_encrypted","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "invalid_encrypted_content 不应回退 HTTP")
require.Equal(t, int32(1), wsAttempts.Load(), "缺少 reasoning encrypted item 时应跳过自动恢复重试")
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, strings.ToLower(rec.Body.String()), "encrypted content")
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 1)
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists())
require.False(t, gjson.GetBytes(requests[0], `input.0.encrypted_content`).Exists())
}
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentRecoversSingleObjectInputAndKeepsSummary(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempt := wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
if attempt == 1 {
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "invalid_encrypted_content",
"type": "invalid_request_error",
"message": "The encrypted content could not be verified.",
},
})
return
}
_ = conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": "resp_ws_invalid_encrypted_content_object_ok",
"model": "gpt-5.3-codex",
"usage": map[string]any{
"input_tokens": 1,
"output_tokens": 1,
"input_tokens_details": map[string]any{
"cached_tokens": 0,
},
},
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 97,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_encrypted","input":{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me"}]}}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_ws_invalid_encrypted_content_object_ok", result.RequestID)
require.Nil(t, upstream.lastReq, "invalid_encrypted_content 单对象 input 不应回退 HTTP")
require.Equal(t, int32(2), wsAttempts.Load(), "单对象 reasoning input 也应触发一次清洗后重试")
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 2)
require.True(t, gjson.GetBytes(requests[0], `input.encrypted_content`).Exists(), "首轮单对象应保留 encrypted_content")
require.True(t, gjson.GetBytes(requests[1], `input.summary.0.text`).Exists(), "恢复重试应保留 reasoning summary")
require.False(t, gjson.GetBytes(requests[1], `input.encrypted_content`).Exists(), "恢复重试只应移除 encrypted_content")
require.Equal(t, "reasoning", gjson.GetBytes(requests[1], `input.type`).String())
require.False(t, gjson.GetBytes(requests[1], `previous_response_id`).Exists(), "恢复重试应移除 previous_response_id")
}
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentKeepsPreviousResponseIDForFunctionCallOutput(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempt := wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
if attempt == 1 {
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "invalid_encrypted_content",
"type": "invalid_request_error",
"message": "The encrypted content could not be verified.",
},
})
return
}
_ = conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": "resp_ws_invalid_encrypted_content_function_call_output_ok",
"model": "gpt-5.3-codex",
"usage": map[string]any{
"input_tokens": 1,
"output_tokens": 1,
"input_tokens_details": map[string]any{
"cached_tokens": 0,
},
},
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 98,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_function_call","input":[{"type":"reasoning","encrypted_content":"gAAA"},{"type":"function_call_output","call_id":"call_123","output":"ok"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_ws_invalid_encrypted_content_function_call_output_ok", result.RequestID)
require.Nil(t, upstream.lastReq, "function_call_output + invalid_encrypted_content 不应回退 HTTP")
require.Equal(t, int32(2), wsAttempts.Load(), "应只做一次保锚点的清洗后重试")
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 2)
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id")
require.True(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "function_call_output 恢复重试不应移除 previous_response_id")
require.False(t, gjson.GetBytes(requests[1], `input.0.encrypted_content`).Exists(), "恢复重试应移除 reasoning encrypted_content")
require.Equal(t, "function_call_output", gjson.GetBytes(requests[1], `input.0.type`).String(), "清洗后应保留 function_call_output 作为首个输入项")
require.Equal(t, "call_123", gjson.GetBytes(requests[1], `input.0.call_id`).String())
require.Equal(t, "ok", gjson.GetBytes(requests[1], `input.0.output`).String())
require.Equal(t, "resp_prev_function_call", gjson.GetBytes(requests[1], "previous_response_id").String())
}

View File

@@ -23,7 +23,7 @@ const (
opsAggDailyInterval = 1 * time.Hour opsAggDailyInterval = 1 * time.Hour
// Keep in sync with ops retention target (vNext default 30d). // Keep in sync with ops retention target (vNext default 30d).
opsAggBackfillWindow = 30 * 24 * time.Hour opsAggBackfillWindow = 1 * time.Hour
// Recompute overlap to absorb late-arriving rows near boundaries. // Recompute overlap to absorb late-arriving rows near boundaries.
opsAggHourlyOverlap = 2 * time.Hour opsAggHourlyOverlap = 2 * time.Hour
@@ -36,7 +36,7 @@ const (
// that may still receive late inserts. // that may still receive late inserts.
opsAggSafeDelay = 5 * time.Minute opsAggSafeDelay = 5 * time.Minute
opsAggMaxQueryTimeout = 3 * time.Second opsAggMaxQueryTimeout = 5 * time.Second
opsAggHourlyTimeout = 5 * time.Minute opsAggHourlyTimeout = 5 * time.Minute
opsAggDailyTimeout = 2 * time.Minute opsAggDailyTimeout = 2 * time.Minute

View File

@@ -371,6 +371,8 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略 IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
DisplayOpenAITokenStats: false,
DisplayAlertEvents: true,
AutoRefreshEnabled: false, AutoRefreshEnabled: false,
AutoRefreshIntervalSec: 30, AutoRefreshIntervalSec: 30,
} }
@@ -438,7 +440,7 @@ func (s *OpsService) GetOpsAdvancedSettings(ctx context.Context) (*OpsAdvancedSe
return nil, err return nil, err
} }
cfg := &OpsAdvancedSettings{} cfg := defaultOpsAdvancedSettings()
if err := json.Unmarshal([]byte(raw), cfg); err != nil { if err := json.Unmarshal([]byte(raw), cfg); err != nil {
return defaultCfg, nil return defaultCfg, nil
} }

View File

@@ -0,0 +1,97 @@
package service
import (
"context"
"encoding/json"
"testing"
)
func TestGetOpsAdvancedSettings_DefaultHidesOpenAITokenStats(t *testing.T) {
repo := newRuntimeSettingRepoStub()
svc := &OpsService{settingRepo: repo}
cfg, err := svc.GetOpsAdvancedSettings(context.Background())
if err != nil {
t.Fatalf("GetOpsAdvancedSettings() error = %v", err)
}
if cfg.DisplayOpenAITokenStats {
t.Fatalf("DisplayOpenAITokenStats = true, want false by default")
}
if !cfg.DisplayAlertEvents {
t.Fatalf("DisplayAlertEvents = false, want true by default")
}
if repo.setCalls != 1 {
t.Fatalf("expected defaults to be persisted once, got %d", repo.setCalls)
}
}
func TestUpdateOpsAdvancedSettings_PersistsOpenAITokenStatsVisibility(t *testing.T) {
repo := newRuntimeSettingRepoStub()
svc := &OpsService{settingRepo: repo}
cfg := defaultOpsAdvancedSettings()
cfg.DisplayOpenAITokenStats = true
cfg.DisplayAlertEvents = false
updated, err := svc.UpdateOpsAdvancedSettings(context.Background(), cfg)
if err != nil {
t.Fatalf("UpdateOpsAdvancedSettings() error = %v", err)
}
if !updated.DisplayOpenAITokenStats {
t.Fatalf("DisplayOpenAITokenStats = false, want true")
}
if updated.DisplayAlertEvents {
t.Fatalf("DisplayAlertEvents = true, want false")
}
reloaded, err := svc.GetOpsAdvancedSettings(context.Background())
if err != nil {
t.Fatalf("GetOpsAdvancedSettings() after update error = %v", err)
}
if !reloaded.DisplayOpenAITokenStats {
t.Fatalf("reloaded DisplayOpenAITokenStats = false, want true")
}
if reloaded.DisplayAlertEvents {
t.Fatalf("reloaded DisplayAlertEvents = true, want false")
}
}
func TestGetOpsAdvancedSettings_BackfillsNewDisplayFlagsFromDefaults(t *testing.T) {
repo := newRuntimeSettingRepoStub()
svc := &OpsService{settingRepo: repo}
legacyCfg := map[string]any{
"data_retention": map[string]any{
"cleanup_enabled": false,
"cleanup_schedule": "0 2 * * *",
"error_log_retention_days": 30,
"minute_metrics_retention_days": 30,
"hourly_metrics_retention_days": 30,
},
"aggregation": map[string]any{
"aggregation_enabled": false,
},
"ignore_count_tokens_errors": true,
"ignore_context_canceled": true,
"ignore_no_available_accounts": false,
"ignore_invalid_api_key_errors": false,
"auto_refresh_enabled": false,
"auto_refresh_interval_seconds": 30,
}
raw, err := json.Marshal(legacyCfg)
if err != nil {
t.Fatalf("marshal legacy config: %v", err)
}
repo.values[SettingKeyOpsAdvancedSettings] = string(raw)
cfg, err := svc.GetOpsAdvancedSettings(context.Background())
if err != nil {
t.Fatalf("GetOpsAdvancedSettings() error = %v", err)
}
if cfg.DisplayOpenAITokenStats {
t.Fatalf("DisplayOpenAITokenStats = true, want false default backfill")
}
if !cfg.DisplayAlertEvents {
t.Fatalf("DisplayAlertEvents = false, want true default backfill")
}
}

View File

@@ -98,6 +98,8 @@ type OpsAdvancedSettings struct {
IgnoreContextCanceled bool `json:"ignore_context_canceled"` IgnoreContextCanceled bool `json:"ignore_context_canceled"`
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"` IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"` IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
DisplayOpenAITokenStats bool `json:"display_openai_token_stats"`
DisplayAlertEvents bool `json:"display_alert_events"`
AutoRefreshEnabled bool `json:"auto_refresh_enabled"` AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"` AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
} }

View File

@@ -149,8 +149,9 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
} }
// 其他 400 错误(如参数问题)不处理,不禁用账号 // 其他 400 错误(如参数问题)不处理,不禁用账号
case 401: case 401:
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新 // OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为。
if account.Type == AccountTypeOAuth { // Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制。
if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity {
// 1. 失效缓存 // 1. 失效缓存
if s.tokenCacheInvalidator != nil { if s.tokenCacheInvalidator != nil {
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil { if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
@@ -182,7 +183,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
} }
shouldDisable = true shouldDisable = true
} else { } else {
// 非 OAuth 账号APIKey:保持原有 SetError 行为 // 非 OAuth / Antigravity OAuth:保持 SetError 行为
msg := "Authentication failed (401): invalid or expired credentials" msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" { if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg msg = "Authentication failed (401): " + upstreamMsg
@@ -199,11 +200,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
s.handleAuthError(ctx, account, msg) s.handleAuthError(ctx, account, msg)
shouldDisable = true shouldDisable = true
case 403: case 403:
// 禁止访问:停止调度,记录错误
msg := "Access forbidden (403): account may be suspended or lack permissions"
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
}
logger.LegacyPrintf( logger.LegacyPrintf(
"service.ratelimit", "service.ratelimit",
"[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s", "[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s",
@@ -215,8 +211,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
upstreamMsg, upstreamMsg,
truncateForLog(responseBody, 1024), truncateForLog(responseBody, 1024),
) )
s.handleAuthError(ctx, account, msg) shouldDisable = s.handle403(ctx, account, upstreamMsg, responseBody)
shouldDisable = true
case 429: case 429:
s.handle429(ctx, account, headers, responseBody) s.handle429(ctx, account, headers, responseBody)
shouldDisable = false shouldDisable = false
@@ -621,6 +616,62 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg) slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
} }
// handle403 处理 403 Forbidden 错误
// Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用;
// 其他平台保持原有 SetError 行为。
func (s *RateLimitService) handle403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
if account.Platform == PlatformAntigravity {
return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody)
}
// 非 Antigravity 平台:保持原有行为
msg := "Access forbidden (403): account may be suspended or lack permissions"
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
return true
}
// handleAntigravity403 处理 Antigravity 平台的 403 错误
// validation需要验证→ 永久 SetError需人工去 Google 验证后恢复)
// violation违规封号→ 永久 SetError需人工处理
// generic通用禁止→ 永久 SetError
func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
fbType := classifyForbiddenType(string(responseBody))
switch fbType {
case forbiddenTypeValidation:
// VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复
msg := "Validation required (403): account needs Google verification"
if upstreamMsg != "" {
msg = "Validation required (403): " + upstreamMsg
}
if validationURL := extractValidationURL(string(responseBody)); validationURL != "" {
msg += " | validation_url: " + validationURL
}
s.handleAuthError(ctx, account, msg)
return true
case forbiddenTypeViolation:
// 违规封号: 永久禁用,需人工处理
msg := "Account violation (403): terms of service violation"
if upstreamMsg != "" {
msg = "Account violation (403): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
return true
default:
// 通用 403: 保持原有行为
msg := "Access forbidden (403): account may be suspended or lack permissions"
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
return true
}
}
// handleCustomErrorCode 处理自定义错误码,停止账号调度 // handleCustomErrorCode 处理自定义错误码,停止账号调度
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) { func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
@@ -1213,7 +1264,8 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac
} }
// 401 首次命中可临时不可调度(给 token 刷新窗口); // 401 首次命中可临时不可调度(给 token 刷新窗口);
// 若历史上已因 401 进入过临时不可调度,则本次应升级为 error返回 false 交由默认错误逻辑处理)。 // 若历史上已因 401 进入过临时不可调度,则本次应升级为 error返回 false 交由默认错误逻辑处理)。
if statusCode == http.StatusUnauthorized { // Antigravity 跳过:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制,无需升级逻辑。
if statusCode == http.StatusUnauthorized && account.Platform != PlatformAntigravity {
reason := account.TempUnschedulableReason reason := account.TempUnschedulableReason
// 缓存可能没有 reason从 DB 回退读取 // 缓存可能没有 reason从 DB 回退读取
if reason == "" { if reason == "" {

View File

@@ -27,34 +27,68 @@ func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, e
func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) { func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
// Scenario: cache account has empty TempUnschedulableReason (cache miss), // Scenario: cache account has empty TempUnschedulableReason (cache miss),
// but DB account has a previous 401 record → should escalate to ErrorPolicyNone. // but DB account has a previous 401 record.
repo := &dbFallbackRepoStub{ // Non-Antigravity: should escalate to ErrorPolicyNone (second 401 = permanent error).
dbAccount: &Account{ // Antigravity: skips escalation logic (401 handled by applyErrorPolicy rules).
ID: 20, t.Run("gemini_escalates", func(t *testing.T) {
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, repo := &dbFallbackRepoStub{
}, dbAccount: &Account{
} ID: 20,
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
},
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{ account := &Account{
ID: 20, ID: 20,
Type: AccountTypeOAuth, Type: AccountTypeOAuth,
Platform: PlatformAntigravity, Platform: PlatformGemini,
TempUnschedulableReason: "", // cache miss — reason is empty TempUnschedulableReason: "",
Credentials: map[string]any{ Credentials: map[string]any{
"temp_unschedulable_enabled": true, "temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{ "temp_unschedulable_rules": []any{
map[string]any{ map[string]any{
"error_code": float64(401), "error_code": float64(401),
"keywords": []any{"unauthorized"}, "keywords": []any{"unauthorized"},
"duration_minutes": float64(10), "duration_minutes": float64(10),
},
}, },
}, },
}, }
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone") require.Equal(t, ErrorPolicyNone, result, "gemini 401 with DB fallback showing previous 401 should escalate")
})
t.Run("antigravity_stays_temp", func(t *testing.T) {
repo := &dbFallbackRepoStub{
dbAccount: &Account{
ID: 20,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
},
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 20,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: "",
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyTempUnscheduled, result, "antigravity 401 skips escalation, stays temp-unscheduled")
})
} }
func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) { func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) {

View File

@@ -42,45 +42,56 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc
} }
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) { func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
tests := []struct { t.Run("gemini", func(t *testing.T) {
name string repo := &rateLimitAccountRepoStub{}
platform string invalidator := &tokenCacheInvalidatorRecorder{}
}{ service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
{name: "gemini", platform: PlatformGemini}, service.SetTokenCacheInvalidator(invalidator)
{name: "antigravity", platform: PlatformAntigravity}, account := &Account{
} ID: 100,
Platform: PlatformGemini,
for _, tt := range tests { Type: AccountTypeOAuth,
t.Run(tt.name, func(t *testing.T) { Credentials: map[string]any{
repo := &rateLimitAccountRepoStub{} "temp_unschedulable_enabled": true,
invalidator := &tokenCacheInvalidatorRecorder{} "temp_unschedulable_rules": []any{
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) map[string]any{
service.SetTokenCacheInvalidator(invalidator) "error_code": 401,
account := &Account{ "keywords": []any{"unauthorized"},
ID: 100, "duration_minutes": 30,
Platform: tt.platform, "description": "custom rule",
Type: AccountTypeOAuth,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": 401,
"keywords": []any{"unauthorized"},
"duration_minutes": 30,
"description": "custom rule",
},
}, },
}, },
} },
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable) require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls) require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls) require.Equal(t, 1, repo.tempCalls)
require.Len(t, invalidator.accounts, 1) require.Len(t, invalidator.accounts, 1)
}) })
}
t.Run("antigravity_401_uses_SetError", func(t *testing.T) {
// Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制,
// HandleUpstreamError 中走 SetError 路径。
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Empty(t, invalidator.accounts)
})
} }
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) { func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {

View File

@@ -65,6 +65,19 @@ const minVersionErrorTTL = 5 * time.Second
// minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context // minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context
const minVersionDBTimeout = 5 * time.Second const minVersionDBTimeout = 5 * time.Second
// cachedBackendMode Backend Mode cache (in-process, 60s TTL)
type cachedBackendMode struct {
value bool
expiresAt int64 // unix nano
}
var backendModeCache atomic.Value // *cachedBackendMode
var backendModeSF singleflight.Group
const backendModeCacheTTL = 60 * time.Second
const backendModeErrorTTL = 5 * time.Second
const backendModeDBTimeout = 5 * time.Second
// DefaultSubscriptionGroupReader validates group references used by default subscriptions. // DefaultSubscriptionGroupReader validates group references used by default subscriptions.
type DefaultSubscriptionGroupReader interface { type DefaultSubscriptionGroupReader interface {
GetByID(ctx context.Context, id int64) (*Group, error) GetByID(ctx context.Context, id int64) (*Group, error)
@@ -128,6 +141,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeySoraClientEnabled, SettingKeySoraClientEnabled,
SettingKeyCustomMenuItems, SettingKeyCustomMenuItems,
SettingKeyLinuxDoConnectEnabled, SettingKeyLinuxDoConnectEnabled,
SettingKeyBackendModeEnabled,
} }
settings, err := s.settingRepo.GetMultiple(ctx, keys) settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -172,6 +186,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomMenuItems: settings[SettingKeyCustomMenuItems],
LinuxDoOAuthEnabled: linuxDoEnabled, LinuxDoOAuthEnabled: linuxDoEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
}, nil }, nil
} }
@@ -223,6 +238,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems json.RawMessage `json:"custom_menu_items"` CustomMenuItems json.RawMessage `json:"custom_menu_items"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
}{ }{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
@@ -247,6 +263,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
SoraClientEnabled: settings.SoraClientEnabled, SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
Version: s.version, Version: s.version,
}, nil }, nil
} }
@@ -461,6 +478,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// 分组隔离 // 分组隔离
updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling) updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling)
// Backend Mode
updates[SettingKeyBackendModeEnabled] = strconv.FormatBool(settings.BackendModeEnabled)
err = s.settingRepo.SetMultiple(ctx, updates) err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil { if err == nil {
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
@@ -469,6 +489,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
value: settings.MinClaudeCodeVersion, value: settings.MinClaudeCodeVersion,
expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(), expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(),
}) })
backendModeSF.Forget("backend_mode")
backendModeCache.Store(&cachedBackendMode{
value: settings.BackendModeEnabled,
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
})
if s.onUpdate != nil { if s.onUpdate != nil {
s.onUpdate() // Invalidate cache after settings update s.onUpdate() // Invalidate cache after settings update
} }
@@ -525,6 +550,52 @@ func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
return value == "true" return value == "true"
} }
// IsBackendModeEnabled checks if backend mode is enabled
// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path
func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool {
if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.value
}
}
result, _, _ := backendModeSF.Do("backend_mode", func() (any, error) {
if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.value, nil
}
}
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), backendModeDBTimeout)
defer cancel()
value, err := s.settingRepo.GetValue(dbCtx, SettingKeyBackendModeEnabled)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
// Setting not yet created (fresh install) - default to disabled with full TTL
backendModeCache.Store(&cachedBackendMode{
value: false,
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
})
return false, nil
}
slog.Warn("failed to get backend_mode_enabled setting", "error", err)
backendModeCache.Store(&cachedBackendMode{
value: false,
expiresAt: time.Now().Add(backendModeErrorTTL).UnixNano(),
})
return false, nil
}
enabled := value == "true"
backendModeCache.Store(&cachedBackendMode{
value: enabled,
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
})
return enabled, nil
})
if val, ok := result.(bool); ok {
return val
}
return false
}
// IsEmailVerifyEnabled 检查是否开启邮件验证 // IsEmailVerifyEnabled 检查是否开启邮件验证
func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool { func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled) value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled)
@@ -719,6 +790,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomMenuItems: settings[SettingKeyCustomMenuItems],
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
} }
// 解析整数类型 // 解析整数类型
@@ -1278,7 +1350,7 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true, BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
} }
validScopes := map[string]bool{ validScopes := map[string]bool{
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true,
} }
for i, rule := range settings.Rules { for i, rule := range settings.Rules {

View File

@@ -0,0 +1,199 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type bmRepoStub struct {
getValueFn func(ctx context.Context, key string) (string, error)
calls int
}
func (s *bmRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *bmRepoStub) GetValue(ctx context.Context, key string) (string, error) {
s.calls++
if s.getValueFn == nil {
panic("unexpected GetValue call")
}
return s.getValueFn(ctx, key)
}
func (s *bmRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *bmRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *bmRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *bmRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *bmRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
type bmUpdateRepoStub struct {
updates map[string]string
getValueFn func(ctx context.Context, key string) (string, error)
}
func (s *bmUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *bmUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) {
if s.getValueFn == nil {
panic("unexpected GetValue call")
}
return s.getValueFn(ctx, key)
}
func (s *bmUpdateRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *bmUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *bmUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
s.updates = make(map[string]string, len(settings))
for k, v := range settings {
s.updates[k] = v
}
return nil
}
func (s *bmUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *bmUpdateRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func resetBackendModeTestCache(t *testing.T) {
t.Helper()
backendModeCache.Store((*cachedBackendMode)(nil))
t.Cleanup(func() {
backendModeCache.Store((*cachedBackendMode)(nil))
})
}
func TestIsBackendModeEnabled_ReturnsTrue(t *testing.T) {
resetBackendModeTestCache(t)
repo := &bmRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "true", nil
},
}
svc := NewSettingService(repo, &config.Config{})
require.True(t, svc.IsBackendModeEnabled(context.Background()))
require.Equal(t, 1, repo.calls)
}
func TestIsBackendModeEnabled_ReturnsFalse(t *testing.T) {
resetBackendModeTestCache(t)
repo := &bmRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "false", nil
},
}
svc := NewSettingService(repo, &config.Config{})
require.False(t, svc.IsBackendModeEnabled(context.Background()))
require.Equal(t, 1, repo.calls)
}
func TestIsBackendModeEnabled_ReturnsFalseOnNotFound(t *testing.T) {
resetBackendModeTestCache(t)
repo := &bmRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "", ErrSettingNotFound
},
}
svc := NewSettingService(repo, &config.Config{})
require.False(t, svc.IsBackendModeEnabled(context.Background()))
require.Equal(t, 1, repo.calls)
}
func TestIsBackendModeEnabled_ReturnsFalseOnDBError(t *testing.T) {
resetBackendModeTestCache(t)
repo := &bmRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "", errors.New("db down")
},
}
svc := NewSettingService(repo, &config.Config{})
require.False(t, svc.IsBackendModeEnabled(context.Background()))
require.Equal(t, 1, repo.calls)
}
func TestIsBackendModeEnabled_CachesResult(t *testing.T) {
resetBackendModeTestCache(t)
repo := &bmRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "true", nil
},
}
svc := NewSettingService(repo, &config.Config{})
require.True(t, svc.IsBackendModeEnabled(context.Background()))
require.True(t, svc.IsBackendModeEnabled(context.Background()))
require.Equal(t, 1, repo.calls)
}
func TestUpdateSettings_InvalidatesBackendModeCache(t *testing.T) {
resetBackendModeTestCache(t)
backendModeCache.Store(&cachedBackendMode{
value: true,
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
})
repo := &bmUpdateRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "true", nil
},
}
svc := NewSettingService(repo, &config.Config{})
err := svc.UpdateSettings(context.Background(), &SystemSettings{
BackendModeEnabled: false,
})
require.NoError(t, err)
require.Equal(t, "false", repo.updates[SettingKeyBackendModeEnabled])
require.False(t, svc.IsBackendModeEnabled(context.Background()))
}

View File

@@ -69,6 +69,9 @@ type SystemSettings struct {
// 分组隔离:允许未分组 Key 调度(默认 false → 403 // 分组隔离:允许未分组 Key 调度(默认 false → 403
AllowUngroupedKeyScheduling bool AllowUngroupedKeyScheduling bool
// Backend 模式:禁用用户注册和自助服务,仅管理员可登录
BackendModeEnabled bool
} }
type DefaultSubscriptionSetting struct { type DefaultSubscriptionSetting struct {
@@ -101,6 +104,7 @@ type PublicSettings struct {
CustomMenuItems string // JSON array of custom menu items CustomMenuItems string // JSON array of custom menu items
LinuxDoOAuthEnabled bool LinuxDoOAuthEnabled bool
BackendModeEnabled bool
Version string Version string
} }
@@ -198,16 +202,17 @@ const (
BetaPolicyActionFilter = "filter" // 过滤,从 beta header 中移除该 token BetaPolicyActionFilter = "filter" // 过滤,从 beta header 中移除该 token
BetaPolicyActionBlock = "block" // 拦截,直接返回错误 BetaPolicyActionBlock = "block" // 拦截,直接返回错误
BetaPolicyScopeAll = "all" // 所有账号类型 BetaPolicyScopeAll = "all" // 所有账号类型
BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号 BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号
BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号 BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号
BetaPolicyScopeBedrock = "bedrock" // 仅 AWS Bedrock 账号
) )
// BetaPolicyRule 单条 Beta 策略规则 // BetaPolicyRule 单条 Beta 策略规则
type BetaPolicyRule struct { type BetaPolicyRule struct {
BetaToken string `json:"beta_token"` // beta token 值 BetaToken string `json:"beta_token"` // beta token 值
Action string `json:"action"` // "pass" | "filter" | "block" Action string `json:"action"` // "pass" | "filter" | "block"
Scope string `json:"scope"` // "all" | "oauth" | "apikey" Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock"
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效) ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
} }

View File

@@ -322,6 +322,19 @@ func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthC
return apiKeyService return apiKeyService
} }
// ProvideBackupService creates and starts BackupService
func ProvideBackupService(
settingRepo SettingRepository,
cfg *config.Config,
encryptor SecretEncryptor,
storeFactory BackupObjectStoreFactory,
dumper DBDumper,
) *BackupService {
svc := NewBackupService(settingRepo, cfg, encryptor, storeFactory, dumper)
svc.Start()
return svc
}
// ProvideSettingService wires SettingService with group reader for default subscription validation. // ProvideSettingService wires SettingService with group reader for default subscription validation.
func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService { func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService {
svc := NewSettingService(settingRepo, cfg) svc := NewSettingService(settingRepo, cfg)
@@ -373,6 +386,7 @@ var ProviderSet = wire.NewSet(
NewAccountTestService, NewAccountTestService,
ProvideSettingService, ProvideSettingService,
NewDataManagementService, NewDataManagementService,
ProvideBackupService,
ProvideOpsSystemLogSink, ProvideOpsSystemLogSink,
NewOpsService, NewOpsService,
ProvideOpsMetricsCollector, ProvideOpsMetricsCollector,

View File

@@ -0,0 +1,105 @@
# =============================================================================
# Sub2API Docker Compose - Local Development Build
# =============================================================================
# Build from local source code for testing changes.
#
# Usage:
# cd deploy
# docker compose -f docker-compose.dev.yml up --build
# =============================================================================
services:
sub2api:
build:
context: ..
dockerfile: Dockerfile
container_name: sub2api-dev
restart: unless-stopped
ports:
- "${BIND_HOST:-127.0.0.1}:${SERVER_PORT:-8080}:8080"
volumes:
- ./data:/app/data
environment:
- AUTO_SETUP=true
- SERVER_HOST=0.0.0.0
- SERVER_PORT=8080
- SERVER_MODE=debug
- RUN_MODE=${RUN_MODE:-standard}
- DATABASE_HOST=postgres
- DATABASE_PORT=5432
- DATABASE_USER=${POSTGRES_USER:-sub2api}
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
- DATABASE_SSLMODE=disable
- REDIS_HOST=redis
- REDIS_PORT=6379
- REDIS_PASSWORD=${REDIS_PASSWORD:-}
- REDIS_DB=${REDIS_DB:-0}
- ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local}
- ADMIN_PASSWORD=${ADMIN_PASSWORD:-}
- JWT_SECRET=${JWT_SECRET:-}
- TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-}
- TZ=${TZ:-Asia/Shanghai}
depends_on:
postgres:
condition: service_healthy
redis:
condition: service_healthy
networks:
- sub2api-network
healthcheck:
test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 30s
postgres:
image: postgres:18-alpine
container_name: sub2api-postgres-dev
restart: unless-stopped
volumes:
- ./postgres_data:/var/lib/postgresql/data
environment:
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
- PGDATA=/var/lib/postgresql/data
- TZ=${TZ:-Asia/Shanghai}
networks:
- sub2api-network
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}"]
interval: 10s
timeout: 5s
retries: 5
start_period: 10s
redis:
image: redis:8-alpine
container_name: sub2api-redis-dev
restart: unless-stopped
volumes:
- ./redis_data:/data
command: >
sh -c '
redis-server
--save 60 1
--appendonly yes
--appendfsync everysec
${REDIS_PASSWORD:+--requirepass "$REDIS_PASSWORD"}'
environment:
- TZ=${TZ:-Asia/Shanghai}
- REDISCLI_AUTH=${REDIS_PASSWORD:-}
networks:
- sub2api-network
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
start_period: 5s
networks:
sub2api-network:
driver: bridge

View File

@@ -0,0 +1,114 @@
import { apiClient } from '../client'
export interface BackupS3Config {
endpoint: string
region: string
bucket: string
access_key_id: string
secret_access_key?: string
prefix: string
force_path_style: boolean
}
export interface BackupScheduleConfig {
enabled: boolean
cron_expr: string
retain_days: number
retain_count: number
}
export interface BackupRecord {
id: string
status: 'pending' | 'running' | 'completed' | 'failed'
backup_type: string
file_name: string
s3_key: string
size_bytes: number
triggered_by: string
error_message?: string
started_at: string
finished_at?: string
expires_at?: string
}
export interface CreateBackupRequest {
expire_days?: number
}
export interface TestS3Response {
ok: boolean
message: string
}
// S3 Config
export async function getS3Config(): Promise<BackupS3Config> {
const { data } = await apiClient.get<BackupS3Config>('/admin/backups/s3-config')
return data
}
export async function updateS3Config(config: BackupS3Config): Promise<BackupS3Config> {
const { data } = await apiClient.put<BackupS3Config>('/admin/backups/s3-config', config)
return data
}
export async function testS3Connection(config: BackupS3Config): Promise<TestS3Response> {
const { data } = await apiClient.post<TestS3Response>('/admin/backups/s3-config/test', config)
return data
}
// Schedule
export async function getSchedule(): Promise<BackupScheduleConfig> {
const { data } = await apiClient.get<BackupScheduleConfig>('/admin/backups/schedule')
return data
}
export async function updateSchedule(config: BackupScheduleConfig): Promise<BackupScheduleConfig> {
const { data } = await apiClient.put<BackupScheduleConfig>('/admin/backups/schedule', config)
return data
}
// Backup operations
export async function createBackup(req?: CreateBackupRequest): Promise<BackupRecord> {
const { data } = await apiClient.post<BackupRecord>('/admin/backups', req || {}, { timeout: 600000 })
return data
}
export async function listBackups(): Promise<{ items: BackupRecord[] }> {
const { data } = await apiClient.get<{ items: BackupRecord[] }>('/admin/backups')
return data
}
export async function getBackup(id: string): Promise<BackupRecord> {
const { data } = await apiClient.get<BackupRecord>(`/admin/backups/${id}`)
return data
}
export async function deleteBackup(id: string): Promise<void> {
await apiClient.delete(`/admin/backups/${id}`)
}
export async function getDownloadURL(id: string): Promise<{ url: string }> {
const { data } = await apiClient.get<{ url: string }>(`/admin/backups/${id}/download-url`)
return data
}
// Restore
export async function restoreBackup(id: string, password: string): Promise<void> {
await apiClient.post(`/admin/backups/${id}/restore`, { password }, { timeout: 600000 })
}
export const backupAPI = {
getS3Config,
updateS3Config,
testS3Connection,
getSchedule,
updateSchedule,
createBackup,
listBackups,
getBackup,
deleteBackup,
getDownloadURL,
restoreBackup,
}
export default backupAPI

View File

@@ -23,6 +23,7 @@ import errorPassthroughAPI from './errorPassthrough'
import dataManagementAPI from './dataManagement' import dataManagementAPI from './dataManagement'
import apiKeysAPI from './apiKeys' import apiKeysAPI from './apiKeys'
import scheduledTestsAPI from './scheduledTests' import scheduledTestsAPI from './scheduledTests'
import backupAPI from './backup'
/** /**
* Unified admin API object for convenient access * Unified admin API object for convenient access
@@ -47,7 +48,8 @@ export const adminAPI = {
errorPassthrough: errorPassthroughAPI, errorPassthrough: errorPassthroughAPI,
dataManagement: dataManagementAPI, dataManagement: dataManagementAPI,
apiKeys: apiKeysAPI, apiKeys: apiKeysAPI,
scheduledTests: scheduledTestsAPI scheduledTests: scheduledTestsAPI,
backup: backupAPI
} }
export { export {
@@ -70,7 +72,8 @@ export {
errorPassthroughAPI, errorPassthroughAPI,
dataManagementAPI, dataManagementAPI,
apiKeysAPI, apiKeysAPI,
scheduledTestsAPI scheduledTestsAPI,
backupAPI
} }
export default adminAPI export default adminAPI

View File

@@ -841,6 +841,8 @@ export interface OpsAdvancedSettings {
ignore_context_canceled: boolean ignore_context_canceled: boolean
ignore_no_available_accounts: boolean ignore_no_available_accounts: boolean
ignore_invalid_api_key_errors: boolean ignore_invalid_api_key_errors: boolean
display_openai_token_stats: boolean
display_alert_events: boolean
auto_refresh_enabled: boolean auto_refresh_enabled: boolean
auto_refresh_interval_seconds: number auto_refresh_interval_seconds: number
} }

View File

@@ -40,6 +40,7 @@ export interface SystemSettings {
purchase_subscription_enabled: boolean purchase_subscription_enabled: boolean
purchase_subscription_url: string purchase_subscription_url: string
sora_client_enabled: boolean sora_client_enabled: boolean
backend_mode_enabled: boolean
custom_menu_items: CustomMenuItem[] custom_menu_items: CustomMenuItem[]
// SMTP settings // SMTP settings
smtp_host: string smtp_host: string
@@ -106,6 +107,7 @@ export interface UpdateSettingsRequest {
purchase_subscription_enabled?: boolean purchase_subscription_enabled?: boolean
purchase_subscription_url?: string purchase_subscription_url?: string
sora_client_enabled?: boolean sora_client_enabled?: boolean
backend_mode_enabled?: boolean
custom_menu_items?: CustomMenuItem[] custom_menu_items?: CustomMenuItem[]
smtp_host?: string smtp_host?: string
smtp_port?: number smtp_port?: number
@@ -316,7 +318,7 @@ export async function updateRectifierSettings(
export interface BetaPolicyRule { export interface BetaPolicyRule {
beta_token: string beta_token: string
action: 'pass' | 'filter' | 'block' action: 'pass' | 'filter' | 'block'
scope: 'all' | 'oauth' | 'apikey' scope: 'all' | 'oauth' | 'apikey' | 'bedrock'
error_message?: string error_message?: string
} }

View File

@@ -292,17 +292,19 @@ const rpmTooltip = computed(() => {
} }
}) })
// 是否显示各维度配额(apikey 类型) // 是否显示各维度配额apikey / bedrock 类型)
const isQuotaEligible = computed(() => props.account.type === 'apikey' || props.account.type === 'bedrock')
const showDailyQuota = computed(() => { const showDailyQuota = computed(() => {
return props.account.type === 'apikey' && (props.account.quota_daily_limit ?? 0) > 0 return isQuotaEligible.value && (props.account.quota_daily_limit ?? 0) > 0
}) })
const showWeeklyQuota = computed(() => { const showWeeklyQuota = computed(() => {
return props.account.type === 'apikey' && (props.account.quota_weekly_limit ?? 0) > 0 return isQuotaEligible.value && (props.account.quota_weekly_limit ?? 0) > 0
}) })
const showTotalQuota = computed(() => { const showTotalQuota = computed(() => {
return props.account.type === 'apikey' && (props.account.quota_limit ?? 0) > 0 return isQuotaEligible.value && (props.account.quota_limit ?? 0) > 0
}) })
// 格式化费用显示 // 格式化费用显示

View File

@@ -36,6 +36,10 @@
<!-- Usage data --> <!-- Usage data -->
<div v-else-if="usageInfo" class="space-y-1"> <div v-else-if="usageInfo" class="space-y-1">
<!-- API error (degraded response) -->
<div v-if="usageInfo.error" class="text-xs text-amber-600 dark:text-amber-400 truncate max-w-[200px]" :title="usageInfo.error">
{{ usageInfo.error }}
</div>
<!-- 5h Window --> <!-- 5h Window -->
<UsageProgressBar <UsageProgressBar
v-if="usageInfo.five_hour" v-if="usageInfo.five_hour"
@@ -189,8 +193,53 @@
</span> </span>
</div> </div>
<!-- Forbidden state (403) -->
<div v-if="isForbidden" class="space-y-1">
<span
:class="[
'inline-block rounded px-1.5 py-0.5 text-[10px] font-medium',
forbiddenBadgeClass
]"
>
{{ forbiddenLabel }}
</span>
<div v-if="validationURL" class="flex items-center gap-1">
<a
:href="validationURL"
target="_blank"
rel="noopener noreferrer"
class="text-[10px] text-blue-600 hover:text-blue-800 hover:underline dark:text-blue-400 dark:hover:text-blue-300"
:title="t('admin.accounts.openVerification')"
>
{{ t('admin.accounts.openVerification') }}
</a>
<button
type="button"
class="text-[10px] text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200"
:title="t('admin.accounts.copyLink')"
@click="copyValidationURL"
>
{{ linkCopied ? t('admin.accounts.linkCopied') : t('admin.accounts.copyLink') }}
</button>
</div>
</div>
<!-- Needs reauth (401) -->
<div v-else-if="needsReauth" class="space-y-1">
<span class="inline-block rounded px-1.5 py-0.5 text-[10px] font-medium bg-orange-100 text-orange-700 dark:bg-orange-900/40 dark:text-orange-300">
{{ t('admin.accounts.needsReauth') }}
</span>
</div>
<!-- Degraded error (non-403, non-401) -->
<div v-else-if="usageInfo?.error" class="space-y-1">
<span class="inline-block rounded px-1.5 py-0.5 text-[10px] font-medium bg-amber-100 text-amber-700 dark:bg-amber-900/40 dark:text-amber-300">
{{ usageErrorLabel }}
</span>
</div>
<!-- Loading state --> <!-- Loading state -->
<div v-if="loading" class="space-y-1.5"> <div v-else-if="loading" class="space-y-1.5">
<div class="flex items-center gap-1"> <div class="flex items-center gap-1">
<div class="h-3 w-[32px] animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div> <div class="h-3 w-[32px] animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
<div class="h-1.5 w-8 animate-pulse rounded-full bg-gray-200 dark:bg-gray-700"></div> <div class="h-1.5 w-8 animate-pulse rounded-full bg-gray-200 dark:bg-gray-700"></div>
@@ -816,6 +865,51 @@ const hasIneligibleTiers = computed(() => {
return Array.isArray(ineligibleTiers) && ineligibleTiers.length > 0 return Array.isArray(ineligibleTiers) && ineligibleTiers.length > 0
}) })
// Antigravity 403 forbidden 状态
const isForbidden = computed(() => !!usageInfo.value?.is_forbidden)
const forbiddenType = computed(() => usageInfo.value?.forbidden_type || 'forbidden')
const validationURL = computed(() => usageInfo.value?.validation_url || '')
// 需要重新授权401
const needsReauth = computed(() => !!usageInfo.value?.needs_reauth)
// 降级错误标签rate_limited / network_error
const usageErrorLabel = computed(() => {
const code = usageInfo.value?.error_code
if (code === 'rate_limited') return t('admin.accounts.rateLimited')
return t('admin.accounts.usageError')
})
const forbiddenLabel = computed(() => {
switch (forbiddenType.value) {
case 'validation':
return t('admin.accounts.forbiddenValidation')
case 'violation':
return t('admin.accounts.forbiddenViolation')
default:
return t('admin.accounts.forbidden')
}
})
const forbiddenBadgeClass = computed(() => {
if (forbiddenType.value === 'validation') {
return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/40 dark:text-yellow-300'
}
return 'bg-red-100 text-red-700 dark:bg-red-900/40 dark:text-red-300'
})
const linkCopied = ref(false)
const copyValidationURL = async () => {
if (!validationURL.value) return
try {
await navigator.clipboard.writeText(validationURL.value)
linkCopied.value = true
setTimeout(() => { linkCopied.value = false }, 2000)
} catch {
// fallback: ignore
}
}
const loadUsage = async () => { const loadUsage = async () => {
if (!shouldFetchUsage.value) return if (!shouldFetchUsage.value) return
@@ -848,18 +942,30 @@ const makeQuotaBar = (
let resetsAt: string | null = null let resetsAt: string | null = null
if (startKey) { if (startKey) {
const extra = props.account.extra as Record<string, unknown> | undefined const extra = props.account.extra as Record<string, unknown> | undefined
const startStr = extra?.[startKey] as string | undefined const isDaily = startKey.includes('daily')
if (startStr) { const mode = isDaily
const startDate = new Date(startStr) ? (extra?.quota_daily_reset_mode as string) || 'rolling'
const periodMs = startKey.includes('daily') ? 24 * 60 * 60 * 1000 : 7 * 24 * 60 * 60 * 1000 : (extra?.quota_weekly_reset_mode as string) || 'rolling'
resetsAt = new Date(startDate.getTime() + periodMs).toISOString()
if (mode === 'fixed') {
// Use pre-computed next reset time for fixed mode
const resetAtKey = isDaily ? 'quota_daily_reset_at' : 'quota_weekly_reset_at'
resetsAt = (extra?.[resetAtKey] as string) || null
} else {
// Rolling mode: compute from start + period
const startStr = extra?.[startKey] as string | undefined
if (startStr) {
const startDate = new Date(startStr)
const periodMs = isDaily ? 24 * 60 * 60 * 1000 : 7 * 24 * 60 * 60 * 1000
resetsAt = new Date(startDate.getTime() + periodMs).toISOString()
}
} }
} }
return { utilization, resetsAt } return { utilization, resetsAt }
} }
const hasApiKeyQuota = computed(() => { const hasApiKeyQuota = computed(() => {
if (props.account.type !== 'apikey') return false if (props.account.type !== 'apikey' && props.account.type !== 'bedrock') return false
return ( return (
(props.account.quota_daily_limit ?? 0) > 0 || (props.account.quota_daily_limit ?? 0) > 0 ||
(props.account.quota_weekly_limit ?? 0) > 0 || (props.account.quota_weekly_limit ?? 0) > 0 ||

View File

@@ -323,35 +323,6 @@
</div> </div>
</button> </button>
<button
type="button"
@click="accountCategory = 'bedrock-apikey'"
:class="[
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
accountCategory === 'bedrock-apikey'
? 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
: 'border-gray-200 hover:border-amber-300 dark:border-dark-600 dark:hover:border-amber-700'
]"
>
<div
:class="[
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
accountCategory === 'bedrock-apikey'
? 'bg-amber-500 text-white'
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
]"
>
<Icon name="key" size="sm" />
</div>
<div>
<span class="block text-sm font-medium text-gray-900 dark:text-white">{{
t('admin.accounts.bedrockApiKeyLabel')
}}</span>
<span class="text-xs text-gray-500 dark:text-gray-400">{{
t('admin.accounts.bedrockApiKeyDesc')
}}</span>
</div>
</button>
</div> </div>
</div> </div>
@@ -956,7 +927,7 @@
</div> </div>
<!-- API Key input (only for apikey type, excluding Antigravity which has its own fields) --> <!-- API Key input (only for apikey type, excluding Antigravity which has its own fields) -->
<div v-if="form.type === 'apikey' && form.platform !== 'antigravity' && accountCategory !== 'bedrock-apikey'" class="space-y-4"> <div v-if="form.type === 'apikey' && form.platform !== 'antigravity'" class="space-y-4">
<div> <div>
<label class="input-label">{{ t('admin.accounts.baseUrl') }}</label> <label class="input-label">{{ t('admin.accounts.baseUrl') }}</label>
<input <input
@@ -1341,34 +1312,75 @@
<!-- Bedrock credentials (only for Anthropic Bedrock type) --> <!-- Bedrock credentials (only for Anthropic Bedrock type) -->
<div v-if="form.platform === 'anthropic' && accountCategory === 'bedrock'" class="space-y-4"> <div v-if="form.platform === 'anthropic' && accountCategory === 'bedrock'" class="space-y-4">
<!-- Auth Mode Radio -->
<div> <div>
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label> <label class="input-label">{{ t('admin.accounts.bedrockAuthMode') }}</label>
<input <div class="mt-2 flex gap-4">
v-model="bedrockAccessKeyId" <label class="flex cursor-pointer items-center">
type="text" <input
required v-model="bedrockAuthMode"
class="input font-mono" type="radio"
placeholder="AKIA..." value="sigv4"
/> class="mr-2 text-primary-600 focus:ring-primary-500"
/>
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockAuthModeSigv4') }}</span>
</label>
<label class="flex cursor-pointer items-center">
<input
v-model="bedrockAuthMode"
type="radio"
value="apikey"
class="mr-2 text-primary-600 focus:ring-primary-500"
/>
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockAuthModeApikey') }}</span>
</label>
</div>
</div> </div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label> <!-- SigV4 fields -->
<template v-if="bedrockAuthMode === 'sigv4'">
<div>
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label>
<input
v-model="bedrockAccessKeyId"
type="text"
required
class="input font-mono"
placeholder="AKIA..."
/>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label>
<input
v-model="bedrockSecretAccessKey"
type="password"
required
class="input font-mono"
/>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label>
<input
v-model="bedrockSessionToken"
type="password"
class="input font-mono"
/>
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
</div>
</template>
<!-- API Key field -->
<div v-if="bedrockAuthMode === 'apikey'">
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
<input <input
v-model="bedrockSecretAccessKey" v-model="bedrockApiKeyValue"
type="password" type="password"
required required
class="input font-mono" class="input font-mono"
/> />
</div> </div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label> <!-- Shared: Region -->
<input
v-model="bedrockSessionToken"
type="password"
class="input font-mono"
/>
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
</div>
<div> <div>
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label> <label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
<select v-model="bedrockRegion" class="input"> <select v-model="bedrockRegion" class="input">
@@ -1408,6 +1420,8 @@
</select> </select>
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p> <p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
</div> </div>
<!-- Shared: Force Global -->
<div> <div>
<label class="flex items-center gap-2 cursor-pointer"> <label class="flex items-center gap-2 cursor-pointer">
<input <input
@@ -1488,142 +1502,62 @@
</div> </div>
</div> </div>
</div> </div>
</div>
<!-- Bedrock API Key credentials (only for Anthropic Bedrock API Key type) --> <!-- Pool Mode Section for Bedrock -->
<div v-if="form.platform === 'anthropic' && accountCategory === 'bedrock-apikey'" class="space-y-4">
<div>
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
<input
v-model="bedrockApiKeyValue"
type="password"
required
class="input font-mono"
/>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
<select v-model="bedrockApiKeyRegion" class="input">
<optgroup label="US">
<option value="us-east-1">us-east-1 (N. Virginia)</option>
<option value="us-east-2">us-east-2 (Ohio)</option>
<option value="us-west-1">us-west-1 (N. California)</option>
<option value="us-west-2">us-west-2 (Oregon)</option>
<option value="us-gov-east-1">us-gov-east-1 (GovCloud US-East)</option>
<option value="us-gov-west-1">us-gov-west-1 (GovCloud US-West)</option>
</optgroup>
<optgroup label="Europe">
<option value="eu-west-1">eu-west-1 (Ireland)</option>
<option value="eu-west-2">eu-west-2 (London)</option>
<option value="eu-west-3">eu-west-3 (Paris)</option>
<option value="eu-central-1">eu-central-1 (Frankfurt)</option>
<option value="eu-central-2">eu-central-2 (Zurich)</option>
<option value="eu-south-1">eu-south-1 (Milan)</option>
<option value="eu-south-2">eu-south-2 (Spain)</option>
<option value="eu-north-1">eu-north-1 (Stockholm)</option>
</optgroup>
<optgroup label="Asia Pacific">
<option value="ap-northeast-1">ap-northeast-1 (Tokyo)</option>
<option value="ap-northeast-2">ap-northeast-2 (Seoul)</option>
<option value="ap-northeast-3">ap-northeast-3 (Osaka)</option>
<option value="ap-south-1">ap-south-1 (Mumbai)</option>
<option value="ap-south-2">ap-south-2 (Hyderabad)</option>
<option value="ap-southeast-1">ap-southeast-1 (Singapore)</option>
<option value="ap-southeast-2">ap-southeast-2 (Sydney)</option>
</optgroup>
<optgroup label="Canada">
<option value="ca-central-1">ca-central-1 (Canada)</option>
</optgroup>
<optgroup label="South America">
<option value="sa-east-1">sa-east-1 (São Paulo)</option>
</optgroup>
</select>
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
</div>
<div>
<label class="flex items-center gap-2 cursor-pointer">
<input
v-model="bedrockApiKeyForceGlobal"
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
/>
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockForceGlobal') }}</span>
</label>
<p class="input-hint mt-1">{{ t('admin.accounts.bedrockForceGlobalHint') }}</p>
</div>
<!-- Model Restriction Section for Bedrock API Key -->
<div class="border-t border-gray-200 pt-4 dark:border-dark-600"> <div class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label> <div class="mb-3 flex items-center justify-between">
<div>
<!-- Mode Toggle --> <label class="input-label mb-0">{{ t('admin.accounts.poolMode') }}</label>
<div class="mb-4 flex gap-2"> <p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.poolModeHint') }}
</p>
</div>
<button <button
type="button" type="button"
@click="modelRestrictionMode = 'whitelist'" @click="poolModeEnabled = !poolModeEnabled"
:class="[ :class="[
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all', 'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
modelRestrictionMode === 'whitelist' poolModeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
]" ]"
> >
{{ t('admin.accounts.modelWhitelist') }} <span
</button> :class="[
<button 'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
type="button" poolModeEnabled ? 'translate-x-5' : 'translate-x-0'
@click="modelRestrictionMode = 'mapping'" ]"
:class="[ />
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
modelRestrictionMode === 'mapping'
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
]"
>
{{ t('admin.accounts.modelMapping') }}
</button> </button>
</div> </div>
<div v-if="poolModeEnabled" class="rounded-lg bg-blue-50 p-3 dark:bg-blue-900/20">
<!-- Whitelist Mode --> <p class="text-xs text-blue-700 dark:text-blue-400">
<div v-if="modelRestrictionMode === 'whitelist'"> <Icon name="exclamationCircle" size="sm" class="mr-1 inline" :stroke-width="2" />
<ModelWhitelistSelector v-model="allowedModels" platform="anthropic" /> {{ t('admin.accounts.poolModeInfo') }}
<p class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
<span v-if="allowedModels.length === 0">{{ t('admin.accounts.supportsAllModels') }}</span>
</p> </p>
</div> </div>
<div v-if="poolModeEnabled" class="mt-3">
<!-- Mapping Mode --> <label class="input-label">{{ t('admin.accounts.poolModeRetryCount') }}</label>
<div v-else class="space-y-3"> <input
<div v-for="(mapping, index) in modelMappings" :key="index" class="flex items-center gap-2"> v-model.number="poolModeRetryCount"
<input v-model="mapping.from" type="text" class="input flex-1" :placeholder="t('admin.accounts.fromModel')" /> type="number"
<span class="text-gray-400"></span> min="0"
<input v-model="mapping.to" type="text" class="input flex-1" :placeholder="t('admin.accounts.toModel')" /> :max="MAX_POOL_MODE_RETRY_COUNT"
<button type="button" @click="modelMappings.splice(index, 1)" class="text-red-500 hover:text-red-700"> step="1"
<Icon name="trash" size="sm" /> class="input"
</button> />
</div> <p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
<button type="button" @click="modelMappings.push({ from: '', to: '' })" class="btn btn-secondary text-sm"> {{
+ {{ t('admin.accounts.addMapping') }} t('admin.accounts.poolModeRetryCountHint', {
</button> default: DEFAULT_POOL_MODE_RETRY_COUNT,
<!-- Bedrock Preset Mappings --> max: MAX_POOL_MODE_RETRY_COUNT
<div class="flex flex-wrap gap-2"> })
<button }}
v-for="preset in bedrockPresets" </p>
:key="preset.from"
type="button"
@click="addPresetMapping(preset.from, preset.to)"
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
>
+ {{ preset.label }}
</button>
</div>
</div> </div>
</div> </div>
</div> </div>
<!-- API Key 账号配额限制 --> <!-- API Key / Bedrock 账号配额限制 -->
<div v-if="form.type === 'apikey'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4"> <div v-if="form.type === 'apikey' || form.type === 'bedrock'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
<div class="mb-3"> <div class="mb-3">
<h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaLimit') }}</h3> <h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaLimit') }}</h3>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400"> <p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
@@ -1634,9 +1568,21 @@
:totalLimit="editQuotaLimit" :totalLimit="editQuotaLimit"
:dailyLimit="editQuotaDailyLimit" :dailyLimit="editQuotaDailyLimit"
:weeklyLimit="editQuotaWeeklyLimit" :weeklyLimit="editQuotaWeeklyLimit"
:dailyResetMode="editDailyResetMode"
:dailyResetHour="editDailyResetHour"
:weeklyResetMode="editWeeklyResetMode"
:weeklyResetDay="editWeeklyResetDay"
:weeklyResetHour="editWeeklyResetHour"
:resetTimezone="editResetTimezone"
@update:totalLimit="editQuotaLimit = $event" @update:totalLimit="editQuotaLimit = $event"
@update:dailyLimit="editQuotaDailyLimit = $event" @update:dailyLimit="editQuotaDailyLimit = $event"
@update:weeklyLimit="editQuotaWeeklyLimit = $event" @update:weeklyLimit="editQuotaWeeklyLimit = $event"
@update:dailyResetMode="editDailyResetMode = $event"
@update:dailyResetHour="editDailyResetHour = $event"
@update:weeklyResetMode="editWeeklyResetMode = $event"
@update:weeklyResetDay="editWeeklyResetDay = $event"
@update:weeklyResetHour="editWeeklyResetHour = $event"
@update:resetTimezone="editResetTimezone = $event"
/> />
</div> </div>
@@ -3014,13 +2960,19 @@ interface TempUnschedRuleForm {
// State // State
const step = ref(1) const step = ref(1)
const submitting = ref(false) const submitting = ref(false)
const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'bedrock-apikey'>('oauth-based') // UI selection for account category const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock'>('oauth-based') // UI selection for account category
const addMethod = ref<AddMethod>('oauth') // For oauth-based: 'oauth' or 'setup-token' const addMethod = ref<AddMethod>('oauth') // For oauth-based: 'oauth' or 'setup-token'
const apiKeyBaseUrl = ref('https://api.anthropic.com') const apiKeyBaseUrl = ref('https://api.anthropic.com')
const apiKeyValue = ref('') const apiKeyValue = ref('')
const editQuotaLimit = ref<number | null>(null) const editQuotaLimit = ref<number | null>(null)
const editQuotaDailyLimit = ref<number | null>(null) const editQuotaDailyLimit = ref<number | null>(null)
const editQuotaWeeklyLimit = ref<number | null>(null) const editQuotaWeeklyLimit = ref<number | null>(null)
const editDailyResetMode = ref<'rolling' | 'fixed' | null>(null)
const editDailyResetHour = ref<number | null>(null)
const editWeeklyResetMode = ref<'rolling' | 'fixed' | null>(null)
const editWeeklyResetDay = ref<number | null>(null)
const editWeeklyResetHour = ref<number | null>(null)
const editResetTimezone = ref<string | null>(null)
const modelMappings = ref<ModelMapping[]>([]) const modelMappings = ref<ModelMapping[]>([])
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
const allowedModels = ref<string[]>([]) const allowedModels = ref<string[]>([])
@@ -3050,16 +3002,13 @@ const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('an
const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock')) const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock'))
// Bedrock credentials // Bedrock credentials
const bedrockAuthMode = ref<'sigv4' | 'apikey'>('sigv4')
const bedrockAccessKeyId = ref('') const bedrockAccessKeyId = ref('')
const bedrockSecretAccessKey = ref('') const bedrockSecretAccessKey = ref('')
const bedrockSessionToken = ref('') const bedrockSessionToken = ref('')
const bedrockRegion = ref('us-east-1') const bedrockRegion = ref('us-east-1')
const bedrockForceGlobal = ref(false) const bedrockForceGlobal = ref(false)
// Bedrock API Key credentials
const bedrockApiKeyValue = ref('') const bedrockApiKeyValue = ref('')
const bedrockApiKeyRegion = ref('us-east-1')
const bedrockApiKeyForceGlobal = ref(false)
const tempUnschedEnabled = ref(false) const tempUnschedEnabled = ref(false)
const tempUnschedRules = ref<TempUnschedRuleForm[]>([]) const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping') const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping')
@@ -3343,7 +3292,8 @@ watch(
bedrockSessionToken.value = '' bedrockSessionToken.value = ''
bedrockRegion.value = 'us-east-1' bedrockRegion.value = 'us-east-1'
bedrockForceGlobal.value = false bedrockForceGlobal.value = false
bedrockApiKeyForceGlobal.value = false bedrockAuthMode.value = 'sigv4'
bedrockApiKeyValue.value = ''
// Reset Anthropic/Antigravity-specific settings when switching to other platforms // Reset Anthropic/Antigravity-specific settings when switching to other platforms
if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') { if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') {
interceptWarmupRequests.value = false interceptWarmupRequests.value = false
@@ -3719,6 +3669,12 @@ const resetForm = () => {
editQuotaLimit.value = null editQuotaLimit.value = null
editQuotaDailyLimit.value = null editQuotaDailyLimit.value = null
editQuotaWeeklyLimit.value = null editQuotaWeeklyLimit.value = null
editDailyResetMode.value = null
editDailyResetHour.value = null
editWeeklyResetMode.value = null
editWeeklyResetDay.value = null
editWeeklyResetHour.value = null
editResetTimezone.value = null
modelMappings.value = [] modelMappings.value = []
modelRestrictionMode.value = 'whitelist' modelRestrictionMode.value = 'whitelist'
allowedModels.value = [...claudeModels] // Default fill related models allowedModels.value = [...claudeModels] // Default fill related models
@@ -3919,27 +3875,34 @@ const handleSubmit = async () => {
appStore.showError(t('admin.accounts.pleaseEnterAccountName')) appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
return return
} }
if (!bedrockAccessKeyId.value.trim()) {
appStore.showError(t('admin.accounts.bedrockAccessKeyIdRequired'))
return
}
if (!bedrockSecretAccessKey.value.trim()) {
appStore.showError(t('admin.accounts.bedrockSecretAccessKeyRequired'))
return
}
if (!bedrockRegion.value.trim()) {
appStore.showError(t('admin.accounts.bedrockRegionRequired'))
return
}
const credentials: Record<string, unknown> = { const credentials: Record<string, unknown> = {
aws_access_key_id: bedrockAccessKeyId.value.trim(), auth_mode: bedrockAuthMode.value,
aws_secret_access_key: bedrockSecretAccessKey.value.trim(), aws_region: bedrockRegion.value.trim() || 'us-east-1',
aws_region: bedrockRegion.value.trim(),
} }
if (bedrockSessionToken.value.trim()) {
credentials.aws_session_token = bedrockSessionToken.value.trim() if (bedrockAuthMode.value === 'sigv4') {
if (!bedrockAccessKeyId.value.trim()) {
appStore.showError(t('admin.accounts.bedrockAccessKeyIdRequired'))
return
}
if (!bedrockSecretAccessKey.value.trim()) {
appStore.showError(t('admin.accounts.bedrockSecretAccessKeyRequired'))
return
}
credentials.aws_access_key_id = bedrockAccessKeyId.value.trim()
credentials.aws_secret_access_key = bedrockSecretAccessKey.value.trim()
if (bedrockSessionToken.value.trim()) {
credentials.aws_session_token = bedrockSessionToken.value.trim()
}
} else {
if (!bedrockApiKeyValue.value.trim()) {
appStore.showError(t('admin.accounts.bedrockApiKeyRequired'))
return
}
credentials.api_key = bedrockApiKeyValue.value.trim()
} }
if (bedrockForceGlobal.value) { if (bedrockForceGlobal.value) {
credentials.aws_force_global = 'true' credentials.aws_force_global = 'true'
} }
@@ -3952,45 +3915,18 @@ const handleSubmit = async () => {
credentials.model_mapping = modelMapping credentials.model_mapping = modelMapping
} }
// Pool mode
if (poolModeEnabled.value) {
credentials.pool_mode = true
credentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value)
}
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
await createAccountAndFinish('anthropic', 'bedrock' as AccountType, credentials) await createAccountAndFinish('anthropic', 'bedrock' as AccountType, credentials)
return return
} }
// For Bedrock API Key type, create directly
if (form.platform === 'anthropic' && accountCategory.value === 'bedrock-apikey') {
if (!form.name.trim()) {
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
return
}
if (!bedrockApiKeyValue.value.trim()) {
appStore.showError(t('admin.accounts.bedrockApiKeyRequired'))
return
}
const credentials: Record<string, unknown> = {
api_key: bedrockApiKeyValue.value.trim(),
aws_region: bedrockApiKeyRegion.value.trim() || 'us-east-1',
}
if (bedrockApiKeyForceGlobal.value) {
credentials.aws_force_global = 'true'
}
// Model mapping
const modelMapping = buildModelMappingObject(
modelRestrictionMode.value, allowedModels.value, modelMappings.value
)
if (modelMapping) {
credentials.model_mapping = modelMapping
}
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
await createAccountAndFinish('anthropic', 'bedrock-apikey' as AccountType, credentials)
return
}
// For Antigravity upstream type, create directly // For Antigravity upstream type, create directly
if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') { if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') {
if (!form.name.trim()) { if (!form.name.trim()) {
@@ -4233,9 +4169,9 @@ const createAccountAndFinish = async (
if (!applyTempUnschedConfig(credentials)) { if (!applyTempUnschedConfig(credentials)) {
return return
} }
// Inject quota limits for apikey accounts // Inject quota limits for apikey/bedrock accounts
let finalExtra = extra let finalExtra = extra
if (type === 'apikey') { if (type === 'apikey' || type === 'bedrock') {
const quotaExtra: Record<string, unknown> = { ...(extra || {}) } const quotaExtra: Record<string, unknown> = { ...(extra || {}) }
if (editQuotaLimit.value != null && editQuotaLimit.value > 0) { if (editQuotaLimit.value != null && editQuotaLimit.value > 0) {
quotaExtra.quota_limit = editQuotaLimit.value quotaExtra.quota_limit = editQuotaLimit.value
@@ -4246,6 +4182,19 @@ const createAccountAndFinish = async (
if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) { if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) {
quotaExtra.quota_weekly_limit = editQuotaWeeklyLimit.value quotaExtra.quota_weekly_limit = editQuotaWeeklyLimit.value
} }
// Quota reset mode config
if (editDailyResetMode.value === 'fixed') {
quotaExtra.quota_daily_reset_mode = 'fixed'
quotaExtra.quota_daily_reset_hour = editDailyResetHour.value ?? 0
}
if (editWeeklyResetMode.value === 'fixed') {
quotaExtra.quota_weekly_reset_mode = 'fixed'
quotaExtra.quota_weekly_reset_day = editWeeklyResetDay.value ?? 1
quotaExtra.quota_weekly_reset_hour = editWeeklyResetHour.value ?? 0
}
if (editDailyResetMode.value === 'fixed' || editWeeklyResetMode.value === 'fixed') {
quotaExtra.quota_reset_timezone = editResetTimezone.value || 'UTC'
}
if (Object.keys(quotaExtra).length > 0) { if (Object.keys(quotaExtra).length > 0) {
finalExtra = quotaExtra finalExtra = quotaExtra
} }

View File

@@ -563,37 +563,54 @@
</div> </div>
</div> </div>
<!-- Bedrock fields (only for bedrock type) --> <!-- Bedrock fields (for bedrock type, both SigV4 and API Key modes) -->
<div v-if="account.type === 'bedrock'" class="space-y-4"> <div v-if="account.type === 'bedrock'" class="space-y-4">
<div> <!-- SigV4 fields -->
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label> <template v-if="!isBedrockAPIKeyMode">
<div>
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label>
<input
v-model="editBedrockAccessKeyId"
type="text"
class="input font-mono"
placeholder="AKIA..."
/>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label>
<input
v-model="editBedrockSecretAccessKey"
type="password"
class="input font-mono"
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')"
/>
<p class="input-hint">{{ t('admin.accounts.bedrockSecretKeyLeaveEmpty') }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label>
<input
v-model="editBedrockSessionToken"
type="password"
class="input font-mono"
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')"
/>
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
</div>
</template>
<!-- API Key field -->
<div v-if="isBedrockAPIKeyMode">
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
<input <input
v-model="editBedrockAccessKeyId" v-model="editBedrockApiKeyValue"
type="text"
class="input font-mono"
placeholder="AKIA..."
/>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label>
<input
v-model="editBedrockSecretAccessKey"
type="password" type="password"
class="input font-mono" class="input font-mono"
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')" :placeholder="t('admin.accounts.bedrockApiKeyLeaveEmpty')"
/> />
<p class="input-hint">{{ t('admin.accounts.bedrockSecretKeyLeaveEmpty') }}</p> <p class="input-hint">{{ t('admin.accounts.bedrockApiKeyLeaveEmpty') }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label>
<input
v-model="editBedrockSessionToken"
type="password"
class="input font-mono"
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')"
/>
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
</div> </div>
<!-- Shared: Region -->
<div> <div>
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label> <label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
<input <input
@@ -604,6 +621,8 @@
/> />
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p> <p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
</div> </div>
<!-- Shared: Force Global -->
<div> <div>
<label class="flex items-center gap-2 cursor-pointer"> <label class="flex items-center gap-2 cursor-pointer">
<input <input
@@ -684,108 +703,56 @@
</div> </div>
</div> </div>
</div> </div>
</div>
<!-- Bedrock API Key fields (only for bedrock-apikey type) --> <!-- Pool Mode Section for Bedrock -->
<div v-if="account.type === 'bedrock-apikey'" class="space-y-4">
<div>
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
<input
v-model="editBedrockApiKeyValue"
type="password"
class="input font-mono"
:placeholder="t('admin.accounts.bedrockApiKeyLeaveEmpty')"
/>
<p class="input-hint">{{ t('admin.accounts.bedrockApiKeyLeaveEmpty') }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
<input
v-model="editBedrockApiKeyRegion"
type="text"
class="input"
placeholder="us-east-1"
/>
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
</div>
<div>
<label class="flex items-center gap-2 cursor-pointer">
<input
v-model="editBedrockApiKeyForceGlobal"
type="checkbox"
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
/>
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockForceGlobal') }}</span>
</label>
<p class="input-hint mt-1">{{ t('admin.accounts.bedrockForceGlobalHint') }}</p>
</div>
<!-- Model Restriction for Bedrock API Key -->
<div class="border-t border-gray-200 pt-4 dark:border-dark-600"> <div class="border-t border-gray-200 pt-4 dark:border-dark-600">
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label> <div class="mb-3 flex items-center justify-between">
<div>
<!-- Mode Toggle --> <label class="input-label mb-0">{{ t('admin.accounts.poolMode') }}</label>
<div class="mb-4 flex gap-2"> <p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.poolModeHint') }}
</p>
</div>
<button <button
type="button" type="button"
@click="modelRestrictionMode = 'whitelist'" @click="poolModeEnabled = !poolModeEnabled"
:class="[ :class="[
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all', 'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
modelRestrictionMode === 'whitelist' poolModeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
]" ]"
> >
{{ t('admin.accounts.modelWhitelist') }} <span
</button> :class="[
<button 'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
type="button" poolModeEnabled ? 'translate-x-5' : 'translate-x-0'
@click="modelRestrictionMode = 'mapping'" ]"
:class="[ />
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
modelRestrictionMode === 'mapping'
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
]"
>
{{ t('admin.accounts.modelMapping') }}
</button> </button>
</div> </div>
<div v-if="poolModeEnabled" class="rounded-lg bg-blue-50 p-3 dark:bg-blue-900/20">
<!-- Whitelist Mode --> <p class="text-xs text-blue-700 dark:text-blue-400">
<div v-if="modelRestrictionMode === 'whitelist'"> <Icon name="exclamationCircle" size="sm" class="mr-1 inline" :stroke-width="2" />
<ModelWhitelistSelector v-model="allowedModels" platform="anthropic" /> {{ t('admin.accounts.poolModeInfo') }}
<p class="text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
<span v-if="allowedModels.length === 0">{{ t('admin.accounts.supportsAllModels') }}</span>
</p> </p>
</div> </div>
<div v-if="poolModeEnabled" class="mt-3">
<!-- Mapping Mode --> <label class="input-label">{{ t('admin.accounts.poolModeRetryCount') }}</label>
<div v-else class="space-y-3"> <input
<div v-for="(mapping, index) in modelMappings" :key="getModelMappingKey(mapping)" class="flex items-center gap-2"> v-model.number="poolModeRetryCount"
<input v-model="mapping.from" type="text" class="input flex-1" :placeholder="t('admin.accounts.fromModel')" /> type="number"
<span class="text-gray-400"></span> min="0"
<input v-model="mapping.to" type="text" class="input flex-1" :placeholder="t('admin.accounts.toModel')" /> :max="MAX_POOL_MODE_RETRY_COUNT"
<button type="button" @click="modelMappings.splice(index, 1)" class="text-red-500 hover:text-red-700"> step="1"
<Icon name="trash" size="sm" /> class="input"
</button> />
</div> <p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
<button type="button" @click="modelMappings.push({ from: '', to: '' })" class="btn btn-secondary text-sm"> {{
+ {{ t('admin.accounts.addMapping') }} t('admin.accounts.poolModeRetryCountHint', {
</button> default: DEFAULT_POOL_MODE_RETRY_COUNT,
<!-- Bedrock Preset Mappings --> max: MAX_POOL_MODE_RETRY_COUNT
<div class="flex flex-wrap gap-2"> })
<button }}
v-for="preset in bedrockPresets" </p>
:key="preset.from"
type="button"
@click="modelMappings.push({ from: preset.from, to: preset.to })"
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
>
+ {{ preset.label }}
</button>
</div>
</div> </div>
</div> </div>
</div> </div>
@@ -1182,8 +1149,8 @@
</div> </div>
</div> </div>
<!-- API Key 账号配额限制 --> <!-- API Key / Bedrock 账号配额限制 -->
<div v-if="account?.type === 'apikey'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4"> <div v-if="account?.type === 'apikey' || account?.type === 'bedrock'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
<div class="mb-3"> <div class="mb-3">
<h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaLimit') }}</h3> <h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaLimit') }}</h3>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400"> <p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
@@ -1194,9 +1161,21 @@
:totalLimit="editQuotaLimit" :totalLimit="editQuotaLimit"
:dailyLimit="editQuotaDailyLimit" :dailyLimit="editQuotaDailyLimit"
:weeklyLimit="editQuotaWeeklyLimit" :weeklyLimit="editQuotaWeeklyLimit"
:dailyResetMode="editDailyResetMode"
:dailyResetHour="editDailyResetHour"
:weeklyResetMode="editWeeklyResetMode"
:weeklyResetDay="editWeeklyResetDay"
:weeklyResetHour="editWeeklyResetHour"
:resetTimezone="editResetTimezone"
@update:totalLimit="editQuotaLimit = $event" @update:totalLimit="editQuotaLimit = $event"
@update:dailyLimit="editQuotaDailyLimit = $event" @update:dailyLimit="editQuotaDailyLimit = $event"
@update:weeklyLimit="editQuotaWeeklyLimit = $event" @update:weeklyLimit="editQuotaWeeklyLimit = $event"
@update:dailyResetMode="editDailyResetMode = $event"
@update:dailyResetHour="editDailyResetHour = $event"
@update:weeklyResetMode="editWeeklyResetMode = $event"
@update:weeklyResetDay="editWeeklyResetDay = $event"
@update:weeklyResetHour="editWeeklyResetHour = $event"
@update:resetTimezone="editResetTimezone = $event"
/> />
</div> </div>
@@ -1781,11 +1760,11 @@ const editBedrockSecretAccessKey = ref('')
const editBedrockSessionToken = ref('') const editBedrockSessionToken = ref('')
const editBedrockRegion = ref('') const editBedrockRegion = ref('')
const editBedrockForceGlobal = ref(false) const editBedrockForceGlobal = ref(false)
// Bedrock API Key credentials
const editBedrockApiKeyValue = ref('') const editBedrockApiKeyValue = ref('')
const editBedrockApiKeyRegion = ref('') const isBedrockAPIKeyMode = computed(() =>
const editBedrockApiKeyForceGlobal = ref(false) props.account?.type === 'bedrock' &&
(props.account?.credentials as Record<string, unknown>)?.auth_mode === 'apikey'
)
const modelMappings = ref<ModelMapping[]>([]) const modelMappings = ref<ModelMapping[]>([])
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
const allowedModels = ref<string[]>([]) const allowedModels = ref<string[]>([])
@@ -1847,6 +1826,12 @@ const anthropicPassthroughEnabled = ref(false)
const editQuotaLimit = ref<number | null>(null) const editQuotaLimit = ref<number | null>(null)
const editQuotaDailyLimit = ref<number | null>(null) const editQuotaDailyLimit = ref<number | null>(null)
const editQuotaWeeklyLimit = ref<number | null>(null) const editQuotaWeeklyLimit = ref<number | null>(null)
const editDailyResetMode = ref<'rolling' | 'fixed' | null>(null)
const editDailyResetHour = ref<number | null>(null)
const editWeeklyResetMode = ref<'rolling' | 'fixed' | null>(null)
const editWeeklyResetDay = ref<number | null>(null)
const editWeeklyResetHour = ref<number | null>(null)
const editResetTimezone = ref<string | null>(null)
const openAIWSModeOptions = computed(() => [ const openAIWSModeOptions = computed(() => [
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
// TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复 // TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复
@@ -2026,18 +2011,31 @@ watch(
anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true
} }
// Load quota limit for apikey accounts // Load quota limit for apikey/bedrock accounts (bedrock quota is also loaded in its own branch above)
if (newAccount.type === 'apikey') { if (newAccount.type === 'apikey' || newAccount.type === 'bedrock') {
const quotaVal = extra?.quota_limit as number | undefined const quotaVal = extra?.quota_limit as number | undefined
editQuotaLimit.value = (quotaVal && quotaVal > 0) ? quotaVal : null editQuotaLimit.value = (quotaVal && quotaVal > 0) ? quotaVal : null
const dailyVal = extra?.quota_daily_limit as number | undefined const dailyVal = extra?.quota_daily_limit as number | undefined
editQuotaDailyLimit.value = (dailyVal && dailyVal > 0) ? dailyVal : null editQuotaDailyLimit.value = (dailyVal && dailyVal > 0) ? dailyVal : null
const weeklyVal = extra?.quota_weekly_limit as number | undefined const weeklyVal = extra?.quota_weekly_limit as number | undefined
editQuotaWeeklyLimit.value = (weeklyVal && weeklyVal > 0) ? weeklyVal : null editQuotaWeeklyLimit.value = (weeklyVal && weeklyVal > 0) ? weeklyVal : null
// Load quota reset mode config
editDailyResetMode.value = (extra?.quota_daily_reset_mode as 'rolling' | 'fixed') || null
editDailyResetHour.value = (extra?.quota_daily_reset_hour as number) ?? null
editWeeklyResetMode.value = (extra?.quota_weekly_reset_mode as 'rolling' | 'fixed') || null
editWeeklyResetDay.value = (extra?.quota_weekly_reset_day as number) ?? null
editWeeklyResetHour.value = (extra?.quota_weekly_reset_hour as number) ?? null
editResetTimezone.value = (extra?.quota_reset_timezone as string) || null
} else { } else {
editQuotaLimit.value = null editQuotaLimit.value = null
editQuotaDailyLimit.value = null editQuotaDailyLimit.value = null
editQuotaWeeklyLimit.value = null editQuotaWeeklyLimit.value = null
editDailyResetMode.value = null
editDailyResetHour.value = null
editWeeklyResetMode.value = null
editWeeklyResetDay.value = null
editWeeklyResetHour.value = null
editResetTimezone.value = null
} }
// Load antigravity model mapping (Antigravity 只支持映射模式) // Load antigravity model mapping (Antigravity 只支持映射模式)
@@ -2130,11 +2128,28 @@ watch(
} }
} else if (newAccount.type === 'bedrock' && newAccount.credentials) { } else if (newAccount.type === 'bedrock' && newAccount.credentials) {
const bedrockCreds = newAccount.credentials as Record<string, unknown> const bedrockCreds = newAccount.credentials as Record<string, unknown>
editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || '' const authMode = (bedrockCreds.auth_mode as string) || 'sigv4'
editBedrockRegion.value = (bedrockCreds.aws_region as string) || '' editBedrockRegion.value = (bedrockCreds.aws_region as string) || ''
editBedrockForceGlobal.value = (bedrockCreds.aws_force_global as string) === 'true' editBedrockForceGlobal.value = (bedrockCreds.aws_force_global as string) === 'true'
editBedrockSecretAccessKey.value = ''
editBedrockSessionToken.value = '' if (authMode === 'apikey') {
editBedrockApiKeyValue.value = ''
} else {
editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || ''
editBedrockSecretAccessKey.value = ''
editBedrockSessionToken.value = ''
}
// Load pool mode for bedrock
poolModeEnabled.value = bedrockCreds.pool_mode === true
const retryCount = bedrockCreds.pool_mode_retry_count
poolModeRetryCount.value = (typeof retryCount === 'number' && retryCount >= 0) ? retryCount : DEFAULT_POOL_MODE_RETRY_COUNT
// Load quota limits for bedrock
const bedrockExtra = (newAccount.extra as Record<string, unknown>) || {}
editQuotaLimit.value = typeof bedrockExtra.quota_limit === 'number' ? bedrockExtra.quota_limit : null
editQuotaDailyLimit.value = typeof bedrockExtra.quota_daily_limit === 'number' ? bedrockExtra.quota_daily_limit : null
editQuotaWeeklyLimit.value = typeof bedrockExtra.quota_weekly_limit === 'number' ? bedrockExtra.quota_weekly_limit : null
// Load model mappings for bedrock // Load model mappings for bedrock
const existingMappings = bedrockCreds.model_mapping as Record<string, string> | undefined const existingMappings = bedrockCreds.model_mapping as Record<string, string> | undefined
@@ -2155,31 +2170,6 @@ watch(
modelMappings.value = [] modelMappings.value = []
allowedModels.value = [] allowedModels.value = []
} }
} else if (newAccount.type === 'bedrock-apikey' && newAccount.credentials) {
const bedrockApiKeyCreds = newAccount.credentials as Record<string, unknown>
editBedrockApiKeyRegion.value = (bedrockApiKeyCreds.aws_region as string) || 'us-east-1'
editBedrockApiKeyForceGlobal.value = (bedrockApiKeyCreds.aws_force_global as string) === 'true'
editBedrockApiKeyValue.value = ''
// Load model mappings for bedrock-apikey
const existingMappings = bedrockApiKeyCreds.model_mapping as Record<string, string> | undefined
if (existingMappings && typeof existingMappings === 'object') {
const entries = Object.entries(existingMappings)
const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
if (isWhitelistMode) {
modelRestrictionMode.value = 'whitelist'
allowedModels.value = entries.map(([from]) => from)
modelMappings.value = []
} else {
modelRestrictionMode.value = 'mapping'
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
allowedModels.value = []
}
} else {
modelRestrictionMode.value = 'whitelist'
modelMappings.value = []
allowedModels.value = []
}
} else if (newAccount.type === 'upstream' && newAccount.credentials) { } else if (newAccount.type === 'upstream' && newAccount.credentials) {
const credentials = newAccount.credentials as Record<string, unknown> const credentials = newAccount.credentials as Record<string, unknown>
editBaseUrl.value = (credentials.base_url as string) || '' editBaseUrl.value = (credentials.base_url as string) || ''
@@ -2727,7 +2717,6 @@ const handleSubmit = async () => {
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {} const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
const newCredentials: Record<string, unknown> = { ...currentCredentials } const newCredentials: Record<string, unknown> = { ...currentCredentials }
newCredentials.aws_access_key_id = editBedrockAccessKeyId.value.trim()
newCredentials.aws_region = editBedrockRegion.value.trim() newCredentials.aws_region = editBedrockRegion.value.trim()
if (editBedrockForceGlobal.value) { if (editBedrockForceGlobal.value) {
newCredentials.aws_force_global = 'true' newCredentials.aws_force_global = 'true'
@@ -2735,42 +2724,29 @@ const handleSubmit = async () => {
delete newCredentials.aws_force_global delete newCredentials.aws_force_global
} }
// Only update secrets if user provided new values if (isBedrockAPIKeyMode.value) {
if (editBedrockSecretAccessKey.value.trim()) { // API Key mode: only update api_key if user provided new value
newCredentials.aws_secret_access_key = editBedrockSecretAccessKey.value.trim() if (editBedrockApiKeyValue.value.trim()) {
} newCredentials.api_key = editBedrockApiKeyValue.value.trim()
if (editBedrockSessionToken.value.trim()) { }
newCredentials.aws_session_token = editBedrockSessionToken.value.trim()
}
// Model mapping
const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
if (modelMapping) {
newCredentials.model_mapping = modelMapping
} else { } else {
delete newCredentials.model_mapping // SigV4 mode
newCredentials.aws_access_key_id = editBedrockAccessKeyId.value.trim()
if (editBedrockSecretAccessKey.value.trim()) {
newCredentials.aws_secret_access_key = editBedrockSecretAccessKey.value.trim()
}
if (editBedrockSessionToken.value.trim()) {
newCredentials.aws_session_token = editBedrockSessionToken.value.trim()
}
} }
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') // Pool mode
if (!applyTempUnschedConfig(newCredentials)) { if (poolModeEnabled.value) {
return newCredentials.pool_mode = true
} newCredentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value)
updatePayload.credentials = newCredentials
} else if (props.account.type === 'bedrock-apikey') {
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
const newCredentials: Record<string, unknown> = { ...currentCredentials }
newCredentials.aws_region = editBedrockApiKeyRegion.value.trim() || 'us-east-1'
if (editBedrockApiKeyForceGlobal.value) {
newCredentials.aws_force_global = 'true'
} else { } else {
delete newCredentials.aws_force_global delete newCredentials.pool_mode
} delete newCredentials.pool_mode_retry_count
// Only update API key if user provided new value
if (editBedrockApiKeyValue.value.trim()) {
newCredentials.api_key = editBedrockApiKeyValue.value.trim()
} }
// Model mapping // Model mapping
@@ -2980,8 +2956,8 @@ const handleSubmit = async () => {
updatePayload.extra = newExtra updatePayload.extra = newExtra
} }
// For apikey accounts, handle quota_limit in extra // For apikey/bedrock accounts, handle quota_limit in extra
if (props.account.type === 'apikey') { if (props.account.type === 'apikey' || props.account.type === 'bedrock') {
const currentExtra = (updatePayload.extra as Record<string, unknown>) || const currentExtra = (updatePayload.extra as Record<string, unknown>) ||
(props.account.extra as Record<string, unknown>) || {} (props.account.extra as Record<string, unknown>) || {}
const newExtra: Record<string, unknown> = { ...currentExtra } const newExtra: Record<string, unknown> = { ...currentExtra }
@@ -3000,6 +2976,28 @@ const handleSubmit = async () => {
} else { } else {
delete newExtra.quota_weekly_limit delete newExtra.quota_weekly_limit
} }
// Quota reset mode config
if (editDailyResetMode.value === 'fixed') {
newExtra.quota_daily_reset_mode = 'fixed'
newExtra.quota_daily_reset_hour = editDailyResetHour.value ?? 0
} else {
delete newExtra.quota_daily_reset_mode
delete newExtra.quota_daily_reset_hour
}
if (editWeeklyResetMode.value === 'fixed') {
newExtra.quota_weekly_reset_mode = 'fixed'
newExtra.quota_weekly_reset_day = editWeeklyResetDay.value ?? 1
newExtra.quota_weekly_reset_hour = editWeeklyResetHour.value ?? 0
} else {
delete newExtra.quota_weekly_reset_mode
delete newExtra.quota_weekly_reset_day
delete newExtra.quota_weekly_reset_hour
}
if (editDailyResetMode.value === 'fixed' || editWeeklyResetMode.value === 'fixed') {
newExtra.quota_reset_timezone = editResetTimezone.value || 'UTC'
} else {
delete newExtra.quota_reset_timezone
}
updatePayload.extra = newExtra updatePayload.extra = newExtra
} }

View File

@@ -8,12 +8,24 @@ const props = defineProps<{
totalLimit: number | null totalLimit: number | null
dailyLimit: number | null dailyLimit: number | null
weeklyLimit: number | null weeklyLimit: number | null
dailyResetMode: 'rolling' | 'fixed' | null
dailyResetHour: number | null
weeklyResetMode: 'rolling' | 'fixed' | null
weeklyResetDay: number | null
weeklyResetHour: number | null
resetTimezone: string | null
}>() }>()
const emit = defineEmits<{ const emit = defineEmits<{
'update:totalLimit': [value: number | null] 'update:totalLimit': [value: number | null]
'update:dailyLimit': [value: number | null] 'update:dailyLimit': [value: number | null]
'update:weeklyLimit': [value: number | null] 'update:weeklyLimit': [value: number | null]
'update:dailyResetMode': [value: 'rolling' | 'fixed' | null]
'update:dailyResetHour': [value: number | null]
'update:weeklyResetMode': [value: 'rolling' | 'fixed' | null]
'update:weeklyResetDay': [value: number | null]
'update:weeklyResetHour': [value: number | null]
'update:resetTimezone': [value: string | null]
}>() }>()
const enabled = computed(() => const enabled = computed(() =>
@@ -35,9 +47,56 @@ watch(localEnabled, (val) => {
emit('update:totalLimit', null) emit('update:totalLimit', null)
emit('update:dailyLimit', null) emit('update:dailyLimit', null)
emit('update:weeklyLimit', null) emit('update:weeklyLimit', null)
emit('update:dailyResetMode', null)
emit('update:dailyResetHour', null)
emit('update:weeklyResetMode', null)
emit('update:weeklyResetDay', null)
emit('update:weeklyResetHour', null)
emit('update:resetTimezone', null)
} }
}) })
// Whether any fixed mode is active (to show timezone selector)
const hasFixedMode = computed(() =>
props.dailyResetMode === 'fixed' || props.weeklyResetMode === 'fixed'
)
// Common timezone options
const timezoneOptions = [
'UTC',
'Asia/Shanghai',
'Asia/Tokyo',
'Asia/Seoul',
'Asia/Singapore',
'Asia/Kolkata',
'Asia/Dubai',
'Europe/London',
'Europe/Paris',
'Europe/Berlin',
'Europe/Moscow',
'America/New_York',
'America/Chicago',
'America/Denver',
'America/Los_Angeles',
'America/Sao_Paulo',
'Australia/Sydney',
'Pacific/Auckland',
]
// Hours for dropdown (0-23)
const hourOptions = Array.from({ length: 24 }, (_, i) => i)
// Day of week options
const dayOptions = [
{ value: 1, key: 'monday' },
{ value: 2, key: 'tuesday' },
{ value: 3, key: 'wednesday' },
{ value: 4, key: 'thursday' },
{ value: 5, key: 'friday' },
{ value: 6, key: 'saturday' },
{ value: 0, key: 'sunday' },
]
const onTotalInput = (e: Event) => { const onTotalInput = (e: Event) => {
const raw = (e.target as HTMLInputElement).valueAsNumber const raw = (e.target as HTMLInputElement).valueAsNumber
emit('update:totalLimit', Number.isNaN(raw) ? null : raw) emit('update:totalLimit', Number.isNaN(raw) ? null : raw)
@@ -50,6 +109,25 @@ const onWeeklyInput = (e: Event) => {
const raw = (e.target as HTMLInputElement).valueAsNumber const raw = (e.target as HTMLInputElement).valueAsNumber
emit('update:weeklyLimit', Number.isNaN(raw) ? null : raw) emit('update:weeklyLimit', Number.isNaN(raw) ? null : raw)
} }
const onDailyModeChange = (e: Event) => {
const val = (e.target as HTMLSelectElement).value as 'rolling' | 'fixed'
emit('update:dailyResetMode', val)
if (val === 'fixed') {
if (props.dailyResetHour == null) emit('update:dailyResetHour', 0)
if (!props.resetTimezone) emit('update:resetTimezone', 'UTC')
}
}
const onWeeklyModeChange = (e: Event) => {
const val = (e.target as HTMLSelectElement).value as 'rolling' | 'fixed'
emit('update:weeklyResetMode', val)
if (val === 'fixed') {
if (props.weeklyResetDay == null) emit('update:weeklyResetDay', 1)
if (props.weeklyResetHour == null) emit('update:weeklyResetHour', 0)
if (!props.resetTimezone) emit('update:resetTimezone', 'UTC')
}
}
</script> </script>
<template> <template>
@@ -94,7 +172,37 @@ const onWeeklyInput = (e: Event) => {
:placeholder="t('admin.accounts.quotaLimitPlaceholder')" :placeholder="t('admin.accounts.quotaLimitPlaceholder')"
/> />
</div> </div>
<p class="input-hint">{{ t('admin.accounts.quotaDailyLimitHint') }}</p> <!-- 日配额重置模式 -->
<div class="mt-2 flex items-center gap-2">
<label class="text-xs text-gray-500 dark:text-gray-400 whitespace-nowrap">{{ t('admin.accounts.quotaResetMode') }}</label>
<select
:value="dailyResetMode || 'rolling'"
@change="onDailyModeChange"
class="input py-1 text-xs"
>
<option value="rolling">{{ t('admin.accounts.quotaResetModeRolling') }}</option>
<option value="fixed">{{ t('admin.accounts.quotaResetModeFixed') }}</option>
</select>
</div>
<!-- 固定模式小时选择 -->
<div v-if="dailyResetMode === 'fixed'" class="mt-2 flex items-center gap-2">
<label class="text-xs text-gray-500 dark:text-gray-400 whitespace-nowrap">{{ t('admin.accounts.quotaResetHour') }}</label>
<select
:value="dailyResetHour ?? 0"
@change="emit('update:dailyResetHour', Number(($event.target as HTMLSelectElement).value))"
class="input py-1 text-xs w-24"
>
<option v-for="h in hourOptions" :key="h" :value="h">{{ String(h).padStart(2, '0') }}:00</option>
</select>
</div>
<p class="input-hint">
<template v-if="dailyResetMode === 'fixed'">
{{ t('admin.accounts.quotaDailyLimitHintFixed', { hour: String(dailyResetHour ?? 0).padStart(2, '0'), timezone: resetTimezone || 'UTC' }) }}
</template>
<template v-else>
{{ t('admin.accounts.quotaDailyLimitHint') }}
</template>
</p>
</div> </div>
<!-- 周配额 --> <!-- 周配额 -->
@@ -112,7 +220,57 @@ const onWeeklyInput = (e: Event) => {
:placeholder="t('admin.accounts.quotaLimitPlaceholder')" :placeholder="t('admin.accounts.quotaLimitPlaceholder')"
/> />
</div> </div>
<p class="input-hint">{{ t('admin.accounts.quotaWeeklyLimitHint') }}</p> <!-- 周配额重置模式 -->
<div class="mt-2 flex items-center gap-2">
<label class="text-xs text-gray-500 dark:text-gray-400 whitespace-nowrap">{{ t('admin.accounts.quotaResetMode') }}</label>
<select
:value="weeklyResetMode || 'rolling'"
@change="onWeeklyModeChange"
class="input py-1 text-xs"
>
<option value="rolling">{{ t('admin.accounts.quotaResetModeRolling') }}</option>
<option value="fixed">{{ t('admin.accounts.quotaResetModeFixed') }}</option>
</select>
</div>
<!-- 固定模式星期几 + 小时 -->
<div v-if="weeklyResetMode === 'fixed'" class="mt-2 flex items-center gap-2 flex-wrap">
<label class="text-xs text-gray-500 dark:text-gray-400 whitespace-nowrap">{{ t('admin.accounts.quotaWeeklyResetDay') }}</label>
<select
:value="weeklyResetDay ?? 1"
@change="emit('update:weeklyResetDay', Number(($event.target as HTMLSelectElement).value))"
class="input py-1 text-xs w-28"
>
<option v-for="d in dayOptions" :key="d.value" :value="d.value">{{ t('admin.accounts.dayOfWeek.' + d.key) }}</option>
</select>
<label class="text-xs text-gray-500 dark:text-gray-400 whitespace-nowrap">{{ t('admin.accounts.quotaResetHour') }}</label>
<select
:value="weeklyResetHour ?? 0"
@change="emit('update:weeklyResetHour', Number(($event.target as HTMLSelectElement).value))"
class="input py-1 text-xs w-24"
>
<option v-for="h in hourOptions" :key="h" :value="h">{{ String(h).padStart(2, '0') }}:00</option>
</select>
</div>
<p class="input-hint">
<template v-if="weeklyResetMode === 'fixed'">
{{ t('admin.accounts.quotaWeeklyLimitHintFixed', { day: t('admin.accounts.dayOfWeek.' + (dayOptions.find(d => d.value === (weeklyResetDay ?? 1))?.key || 'monday')), hour: String(weeklyResetHour ?? 0).padStart(2, '0'), timezone: resetTimezone || 'UTC' }) }}
</template>
<template v-else>
{{ t('admin.accounts.quotaWeeklyLimitHint') }}
</template>
</p>
</div>
<!-- 时区选择当任一维度使用固定模式时显示 -->
<div v-if="hasFixedMode">
<label class="input-label">{{ t('admin.accounts.quotaResetTimezone') }}</label>
<select
:value="resetTimezone || 'UTC'"
@change="emit('update:resetTimezone', ($event.target as HTMLSelectElement).value)"
class="input text-sm"
>
<option v-for="tz in timezoneOptions" :key="tz" :value="tz">{{ tz }}</option>
</select>
</div> </div>
<!-- 总配额 --> <!-- 总配额 -->

View File

@@ -76,7 +76,7 @@ const hasRecoverableState = computed(() => {
return props.account?.status === 'error' || Boolean(isRateLimited.value) || Boolean(isOverloaded.value) || Boolean(isTempUnschedulable.value) return props.account?.status === 'error' || Boolean(isRateLimited.value) || Boolean(isOverloaded.value) || Boolean(isTempUnschedulable.value)
}) })
const hasQuotaLimit = computed(() => { const hasQuotaLimit = computed(() => {
return props.account?.type === 'apikey' && ( return (props.account?.type === 'apikey' || props.account?.type === 'bedrock') && (
(props.account?.quota_limit ?? 0) > 0 || (props.account?.quota_limit ?? 0) > 0 ||
(props.account?.quota_daily_limit ?? 0) > 0 || (props.account?.quota_daily_limit ?? 0) > 0 ||
(props.account?.quota_weekly_limit ?? 0) > 0 (props.account?.quota_weekly_limit ?? 0) > 0

View File

@@ -83,7 +83,7 @@ const typeLabel = computed(() => {
case 'apikey': case 'apikey':
return 'Key' return 'Key'
case 'bedrock': case 'bedrock':
return 'Bedrock' return 'AWS'
default: default:
return props.type return props.type
} }

View File

@@ -82,7 +82,7 @@
</template> </template>
<!-- Regular User View --> <!-- Regular User View -->
<template v-else> <template v-else-if="!appStore.backendModeEnabled">
<div class="sidebar-section"> <div class="sidebar-section">
<router-link <router-link
v-for="item in userNavItems" v-for="item in userNavItems"
@@ -357,36 +357,6 @@ const ServerIcon = {
) )
} }
const DatabaseIcon = {
render: () =>
h(
'svg',
{ fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' },
[
h('path', {
'stroke-linecap': 'round',
'stroke-linejoin': 'round',
d: 'M3.75 5.25C3.75 4.007 7.443 3 12 3s8.25 1.007 8.25 2.25S16.557 7.5 12 7.5 3.75 6.493 3.75 5.25z'
}),
h('path', {
'stroke-linecap': 'round',
'stroke-linejoin': 'round',
d: 'M3.75 5.25v4.5C3.75 10.993 7.443 12 12 12s8.25-1.007 8.25-2.25v-4.5'
}),
h('path', {
'stroke-linecap': 'round',
'stroke-linejoin': 'round',
d: 'M3.75 9.75v4.5c0 1.243 3.693 2.25 8.25 2.25s8.25-1.007 8.25-2.25v-4.5'
}),
h('path', {
'stroke-linecap': 'round',
'stroke-linejoin': 'round',
d: 'M3.75 14.25v4.5C3.75 19.993 7.443 21 12 21s8.25-1.007 8.25-2.25v-4.5'
})
]
)
}
const BellIcon = { const BellIcon = {
render: () => render: () =>
h( h(
@@ -611,7 +581,6 @@ const adminNavItems = computed((): NavItem[] => {
if (authStore.isSimpleMode) { if (authStore.isSimpleMode) {
const filtered = baseItems.filter(item => !item.hideInSimpleMode) const filtered = baseItems.filter(item => !item.hideInSimpleMode)
filtered.push({ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon }) filtered.push({ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon })
filtered.push({ path: '/admin/data-management', label: t('nav.dataManagement'), icon: DatabaseIcon })
filtered.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon }) filtered.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon })
// Add admin custom menu items after settings // Add admin custom menu items after settings
for (const cm of customMenuItemsForAdmin.value) { for (const cm of customMenuItemsForAdmin.value) {
@@ -620,7 +589,6 @@ const adminNavItems = computed((): NavItem[] => {
return filtered return filtered
} }
baseItems.push({ path: '/admin/data-management', label: t('nav.dataManagement'), icon: DatabaseIcon })
baseItems.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon }) baseItems.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon })
// Add admin custom menu items after settings // Add admin custom menu items after settings
for (const cm of customMenuItemsForAdmin.value) { for (const cm of customMenuItemsForAdmin.value) {

View File

@@ -84,9 +84,7 @@ onUnmounted(() => {
} }
.table-scroll-container :deep(th) { .table-scroll-container :deep(th) {
/* 表头高度和文字加粗优化 */ @apply px-5 py-4 text-left text-sm font-medium text-gray-600 dark:text-dark-300 border-b border-gray-200 dark:border-dark-700;
@apply px-5 py-4 text-left text-sm font-bold text-gray-900 dark:text-white border-b border-gray-200 dark:border-dark-700;
@apply uppercase tracking-wider; /* 让表头更有设计感 */
} }
.table-scroll-container :deep(td) { .table-scroll-container :deep(td) {

View File

@@ -412,7 +412,7 @@ export function getPresetMappingsByPlatform(platform: string) {
if (platform === 'gemini') return geminiPresetMappings if (platform === 'gemini') return geminiPresetMappings
if (platform === 'sora') return soraPresetMappings if (platform === 'sora') return soraPresetMappings
if (platform === 'antigravity') return antigravityPresetMappings if (platform === 'antigravity') return antigravityPresetMappings
if (platform === 'bedrock' || platform === 'bedrock-apikey') return bedrockPresetMappings if (platform === 'bedrock') return bedrockPresetMappings
return anthropicPresetMappings return anthropicPresetMappings
} }

View File

@@ -340,7 +340,6 @@ export default {
redeemCodes: 'Redeem Codes', redeemCodes: 'Redeem Codes',
ops: 'Ops', ops: 'Ops',
promoCodes: 'Promo Codes', promoCodes: 'Promo Codes',
dataManagement: 'Data Management',
settings: 'Settings', settings: 'Settings',
myAccount: 'My Account', myAccount: 'My Account',
lightMode: 'Light Mode', lightMode: 'Light Mode',
@@ -978,6 +977,111 @@ export default {
failedToLoad: 'Failed to load dashboard statistics' failedToLoad: 'Failed to load dashboard statistics'
}, },
backup: {
title: 'Database Backup',
description: 'Full database backup to S3-compatible storage with scheduled backup and restore',
s3: {
title: 'S3 Storage Configuration',
description: 'Configure S3-compatible storage (supports Cloudflare R2)',
descriptionPrefix: 'Configure S3-compatible storage (supports',
descriptionSuffix: ')',
enabled: 'Enable S3 Storage',
endpoint: 'Endpoint',
region: 'Region',
bucket: 'Bucket',
prefix: 'Key Prefix',
accessKeyId: 'Access Key ID',
secretAccessKey: 'Secret Access Key',
secretConfigured: 'Already configured, leave empty to keep',
forcePathStyle: 'Force Path Style',
testConnection: 'Test Connection',
testSuccess: 'S3 connection test successful',
testFailed: 'S3 connection test failed',
saved: 'S3 configuration saved'
},
schedule: {
title: 'Scheduled Backup',
description: 'Configure automatic scheduled backups',
enabled: 'Enable Scheduled Backup',
cronExpr: 'Cron Expression',
cronHint: 'e.g. "0 2 * * *" means every day at 2:00 AM',
retainDays: 'Backup Expire Days',
retainDaysHint: 'Backup files auto-delete after this many days, 0 = never expire',
retainCount: 'Max Retain Count',
retainCountHint: 'Maximum number of backups to keep, 0 = unlimited',
saved: 'Schedule configuration saved'
},
operations: {
title: 'Backup Records',
description: 'Create manual backups and manage existing backup records',
createBackup: 'Create Backup',
backing: 'Backing up...',
backupCreated: 'Backup created successfully',
expireDays: 'Expire Days'
},
columns: {
status: 'Status',
fileName: 'File Name',
size: 'Size',
expiresAt: 'Expires At',
triggeredBy: 'Triggered By',
startedAt: 'Started At',
actions: 'Actions'
},
status: {
pending: 'Pending',
running: 'Running',
completed: 'Completed',
failed: 'Failed'
},
trigger: {
manual: 'Manual',
scheduled: 'Scheduled'
},
neverExpire: 'Never',
empty: 'No backup records',
actions: {
download: 'Download',
restore: 'Restore',
restoreConfirm: 'Are you sure you want to restore from this backup? This will overwrite the current database!',
restorePasswordPrompt: 'Please enter your admin password to confirm the restore operation',
restoreSuccess: 'Database restored successfully',
deleteConfirm: 'Are you sure you want to delete this backup?',
deleted: 'Backup deleted'
},
r2Guide: {
title: 'Cloudflare R2 Setup Guide',
intro: 'Cloudflare R2 provides S3-compatible object storage with a free tier of 10GB storage + 1M Class A requests/month, ideal for database backups.',
step1: {
title: 'Create an R2 Bucket',
line1: 'Log in to the Cloudflare Dashboard (dash.cloudflare.com), select "R2 Object Storage" from the sidebar',
line2: 'Click "Create bucket", enter a name (e.g. sub2api-backups), choose a region',
line3: 'Click create to finish'
},
step2: {
title: 'Create an API Token',
line1: 'On the R2 page, click "Manage R2 API Tokens" in the top right',
line2: 'Click "Create API token", set permission to "Object Read & Write"',
line3: 'Recommended: restrict to specific bucket for better security',
line4: 'After creation, you will see the Access Key ID and Secret Access Key',
warning: 'The Secret Access Key is only shown once — copy and save it immediately!'
},
step3: {
title: 'Get the S3 Endpoint',
desc: 'Find your Account ID on the R2 overview page (in the URL or the right panel). The endpoint format is:',
accountId: 'your_account_id'
},
step4: {
title: 'Fill in the Configuration',
checkEnabled: 'Checked',
bucketValue: 'Your bucket name',
fromStep2: 'Value from Step 2',
unchecked: 'Unchecked'
},
freeTier: 'R2 Free Tier: 10GB storage + 1M Class A requests + 10M Class B requests per month — more than enough for database backups.'
}
},
dataManagement: { dataManagement: {
title: 'Data Management', title: 'Data Management',
description: 'Manage data management agent status, object storage settings, and backup jobs in one place', description: 'Manage data management agent status, object storage settings, and backup jobs in one place',
@@ -1866,6 +1970,23 @@ export default {
quotaWeeklyLimitHint: 'Automatically resets every 7 days from first usage.', quotaWeeklyLimitHint: 'Automatically resets every 7 days from first usage.',
quotaTotalLimit: 'Total Limit', quotaTotalLimit: 'Total Limit',
quotaTotalLimitHint: 'Cumulative spending limit. Does not auto-reset — use "Reset Quota" to clear.', quotaTotalLimitHint: 'Cumulative spending limit. Does not auto-reset — use "Reset Quota" to clear.',
quotaResetMode: 'Reset Mode',
quotaResetModeRolling: 'Rolling Window',
quotaResetModeFixed: 'Fixed Time',
quotaResetHour: 'Reset Hour',
quotaWeeklyResetDay: 'Reset Day',
quotaResetTimezone: 'Reset Timezone',
quotaDailyLimitHintFixed: 'Resets daily at {hour}:00 ({timezone}).',
quotaWeeklyLimitHintFixed: 'Resets every {day} at {hour}:00 ({timezone}).',
dayOfWeek: {
monday: 'Monday',
tuesday: 'Tuesday',
wednesday: 'Wednesday',
thursday: 'Thursday',
friday: 'Friday',
saturday: 'Saturday',
sunday: 'Sunday',
},
quotaLimitAmount: 'Total Limit', quotaLimitAmount: 'Total Limit',
quotaLimitAmountHint: 'Cumulative spending limit. Does not auto-reset.', quotaLimitAmountHint: 'Cumulative spending limit. Does not auto-reset.',
testConnection: 'Test Connection', testConnection: 'Test Connection',
@@ -1934,7 +2055,7 @@ export default {
claudeCode: 'Claude Code', claudeCode: 'Claude Code',
claudeConsole: 'Claude Console', claudeConsole: 'Claude Console',
bedrockLabel: 'AWS Bedrock', bedrockLabel: 'AWS Bedrock',
bedrockDesc: 'SigV4 Signing', bedrockDesc: 'SigV4 / API Key',
oauthSetupToken: 'OAuth / Setup Token', oauthSetupToken: 'OAuth / Setup Token',
addMethod: 'Add Method', addMethod: 'Add Method',
setupTokenLongLived: 'Setup Token (Long-lived)', setupTokenLongLived: 'Setup Token (Long-lived)',
@@ -2136,6 +2257,9 @@ export default {
bedrockRegionRequired: 'Please select AWS Region', bedrockRegionRequired: 'Please select AWS Region',
bedrockSessionTokenHint: 'Optional, for temporary credentials', bedrockSessionTokenHint: 'Optional, for temporary credentials',
bedrockSecretKeyLeaveEmpty: 'Leave empty to keep current key', bedrockSecretKeyLeaveEmpty: 'Leave empty to keep current key',
bedrockAuthMode: 'Authentication Mode',
bedrockAuthModeSigv4: 'SigV4 Signing',
bedrockAuthModeApikey: 'Bedrock API Key',
bedrockApiKeyLabel: 'Bedrock API Key', bedrockApiKeyLabel: 'Bedrock API Key',
bedrockApiKeyDesc: 'Bearer Token', bedrockApiKeyDesc: 'Bearer Token',
bedrockApiKeyInput: 'API Key', bedrockApiKeyInput: 'API Key',
@@ -2555,7 +2679,16 @@ export default {
unlimited: 'Unlimited' unlimited: 'Unlimited'
}, },
ineligibleWarning: ineligibleWarning:
'This account is not eligible for Antigravity, but API forwarding still works. Use at your own risk.' 'This account is not eligible for Antigravity, but API forwarding still works. Use at your own risk.',
forbidden: 'Forbidden',
forbiddenValidation: 'Verification Required',
forbiddenViolation: 'Violation Ban',
openVerification: 'Open Verification Link',
copyLink: 'Copy Link',
linkCopied: 'Link Copied',
needsReauth: 'Re-auth Required',
rateLimited: 'Rate Limited',
usageError: 'Fetch Error'
}, },
// Scheduled Tests // Scheduled Tests
@@ -3709,6 +3842,11 @@ export default {
refreshInterval15s: '15 seconds', refreshInterval15s: '15 seconds',
refreshInterval30s: '30 seconds', refreshInterval30s: '30 seconds',
refreshInterval60s: '60 seconds', refreshInterval60s: '60 seconds',
dashboardCards: 'Dashboard Cards',
displayAlertEvents: 'Display alert events',
displayAlertEventsHint: 'Show or hide the recent alert events card on the ops dashboard. Enabled by default.',
displayOpenAITokenStats: 'Display OpenAI token request stats',
displayOpenAITokenStatsHint: 'Show or hide the OpenAI token request stats card on the ops dashboard. Hidden by default.',
autoRefreshCountdown: 'Auto refresh: {seconds}s', autoRefreshCountdown: 'Auto refresh: {seconds}s',
validation: { validation: {
title: 'Please fix the following issues', title: 'Please fix the following issues',
@@ -3798,6 +3936,8 @@ export default {
users: 'Users', users: 'Users',
gateway: 'Gateway', gateway: 'Gateway',
email: 'Email', email: 'Email',
backup: 'Backup',
data: 'Sora Storage',
}, },
emailTabDisabledTitle: 'Email Verification Not Enabled', emailTabDisabledTitle: 'Email Verification Not Enabled',
emailTabDisabledHint: 'Enable email verification in the Security tab to configure SMTP settings.', emailTabDisabledHint: 'Enable email verification in the Security tab to configure SMTP settings.',
@@ -3888,6 +4028,9 @@ export default {
site: { site: {
title: 'Site Settings', title: 'Site Settings',
description: 'Customize site branding', description: 'Customize site branding',
backendMode: 'Backend Mode',
backendModeDescription:
'Disables user registration, public site, and self-service features. Only admin can log in and manage the platform.',
siteName: 'Site Name', siteName: 'Site Name',
siteNamePlaceholder: 'Sub2API', siteNamePlaceholder: 'Sub2API',
siteNameHint: 'Displayed in emails and page titles', siteNameHint: 'Displayed in emails and page titles',
@@ -4127,6 +4270,7 @@ export default {
scopeAll: 'All accounts', scopeAll: 'All accounts',
scopeOAuth: 'OAuth only', scopeOAuth: 'OAuth only',
scopeAPIKey: 'API Key only', scopeAPIKey: 'API Key only',
scopeBedrock: 'Bedrock only',
errorMessage: 'Error message', errorMessage: 'Error message',
errorMessagePlaceholder: 'Custom error message when blocked', errorMessagePlaceholder: 'Custom error message when blocked',
errorMessageHint: 'Leave empty for default message', errorMessageHint: 'Leave empty for default message',

View File

@@ -340,7 +340,6 @@ export default {
redeemCodes: '兑换码', redeemCodes: '兑换码',
ops: '运维监控', ops: '运维监控',
promoCodes: '优惠码', promoCodes: '优惠码',
dataManagement: '数据管理',
settings: '系统设置', settings: '系统设置',
myAccount: '我的账户', myAccount: '我的账户',
lightMode: '浅色模式', lightMode: '浅色模式',
@@ -1000,6 +999,111 @@ export default {
failedToLoad: '加载仪表盘数据失败' failedToLoad: '加载仪表盘数据失败'
}, },
backup: {
title: '数据库备份',
description: '全量数据库备份到 S3 兼容存储,支持定时备份与恢复',
s3: {
title: 'S3 存储配置',
description: '配置 S3 兼容存储(支持 Cloudflare R2',
descriptionPrefix: '配置 S3 兼容存储(支持',
descriptionSuffix: '',
enabled: '启用 S3 存储',
endpoint: '端点地址',
region: '区域',
bucket: '存储桶',
prefix: 'Key 前缀',
accessKeyId: 'Access Key ID',
secretAccessKey: 'Secret Access Key',
secretConfigured: '已配置,留空保持不变',
forcePathStyle: '强制路径风格',
testConnection: '测试连接',
testSuccess: 'S3 连接测试成功',
testFailed: 'S3 连接测试失败',
saved: 'S3 配置已保存'
},
schedule: {
title: '定时备份',
description: '配置自动定时备份',
enabled: '启用定时备份',
cronExpr: 'Cron 表达式',
cronHint: '例如 "0 2 * * *" 表示每天凌晨 2 点',
retainDays: '备份过期天数',
retainDaysHint: '备份文件超过此天数后自动删除0 = 永不过期',
retainCount: '最大保留份数',
retainCountHint: '最多保留的备份数量0 = 不限制',
saved: '定时备份配置已保存'
},
operations: {
title: '备份记录',
description: '创建手动备份和管理已有备份记录',
createBackup: '创建备份',
backing: '备份中...',
backupCreated: '备份创建成功',
expireDays: '过期天数'
},
columns: {
status: '状态',
fileName: '文件名',
size: '大小',
expiresAt: '过期时间',
triggeredBy: '触发方式',
startedAt: '开始时间',
actions: '操作'
},
status: {
pending: '等待中',
running: '执行中',
completed: '已完成',
failed: '失败'
},
trigger: {
manual: '手动',
scheduled: '定时'
},
neverExpire: '永不过期',
empty: '暂无备份记录',
actions: {
download: '下载',
restore: '恢复',
restoreConfirm: '确定要从此备份恢复吗?这将覆盖当前数据库!',
restorePasswordPrompt: '请输入管理员密码以确认恢复操作',
restoreSuccess: '数据库恢复成功',
deleteConfirm: '确定要删除此备份吗?',
deleted: '备份已删除'
},
r2Guide: {
title: 'Cloudflare R2 配置教程',
intro: 'Cloudflare R2 提供 S3 兼容的对象存储,免费额度为 10GB 存储 + 每月 100 万次 A 类请求,非常适合数据库备份。',
step1: {
title: '创建 R2 存储桶',
line1: '登录 Cloudflare Dashboard (dash.cloudflare.com)左侧菜单选择「R2 对象存储」',
line2: '点击「创建存储桶」,输入名称(如 sub2api-backups选择区域',
line3: '点击创建完成'
},
step2: {
title: '创建 API 令牌',
line1: '在 R2 页面,点击右上角「管理 R2 API 令牌」',
line2: '点击「创建 API 令牌」,权限选择「对象读和写」',
line3: '建议指定存储桶范围(仅允许访问备份桶,更安全)',
line4: '创建后会显示 Access Key ID 和 Secret Access Key',
warning: 'Secret Access Key 只会显示一次,请立即复制保存!'
},
step3: {
title: '获取 S3 端点地址',
desc: '在 R2 概览页面找到你的账户 ID在 URL 或右侧面板中),端点格式为:',
accountId: '你的账户 ID'
},
step4: {
title: '填写以下配置',
checkEnabled: '勾选',
bucketValue: '你创建的存储桶名称',
fromStep2: '第 2 步获取的值',
unchecked: '不勾选'
},
freeTier: 'R2 免费额度10GB 存储 + 每月 100 万次 A 类请求 + 1000 万次 B 类请求,对数据库备份完全够用。'
}
},
dataManagement: { dataManagement: {
title: '数据管理', title: '数据管理',
description: '统一管理数据管理代理状态、对象存储配置和备份任务', description: '统一管理数据管理代理状态、对象存储配置和备份任务',
@@ -1872,6 +1976,23 @@ export default {
quotaWeeklyLimitHint: '从首次使用起每 7 天自动重置。', quotaWeeklyLimitHint: '从首次使用起每 7 天自动重置。',
quotaTotalLimit: '总限额', quotaTotalLimit: '总限额',
quotaTotalLimitHint: '累计消费上限,不会自动重置 — 使用「重置配额」手动清零。', quotaTotalLimitHint: '累计消费上限,不会自动重置 — 使用「重置配额」手动清零。',
quotaResetMode: '重置方式',
quotaResetModeRolling: '滚动窗口',
quotaResetModeFixed: '固定时间',
quotaResetHour: '重置时间',
quotaWeeklyResetDay: '重置日',
quotaResetTimezone: '重置时区',
quotaDailyLimitHintFixed: '每天 {hour}:00{timezone})重置。',
quotaWeeklyLimitHintFixed: '每{day} {hour}:00{timezone})重置。',
dayOfWeek: {
monday: '周一',
tuesday: '周二',
wednesday: '周三',
thursday: '周四',
friday: '周五',
saturday: '周六',
sunday: '周日',
},
quotaLimitAmount: '总限额', quotaLimitAmount: '总限额',
quotaLimitAmountHint: '累计消费上限,不会自动重置。', quotaLimitAmountHint: '累计消费上限,不会自动重置。',
testConnection: '测试连接', testConnection: '测试连接',
@@ -1992,6 +2113,15 @@ export default {
}, },
ineligibleWarning: ineligibleWarning:
'该账号无 Antigravity 使用权限,但仍能进行 API 转发。继续使用请自行承担风险。', '该账号无 Antigravity 使用权限,但仍能进行 API 转发。继续使用请自行承担风险。',
forbidden: '已封禁',
forbiddenValidation: '需要验证',
forbiddenViolation: '违规封禁',
openVerification: '打开验证链接',
copyLink: '复制链接',
linkCopied: '链接已复制',
needsReauth: '需要重新授权',
rateLimited: '限流中',
usageError: '获取失败',
form: { form: {
nameLabel: '账号名称', nameLabel: '账号名称',
namePlaceholder: '请输入账号名称', namePlaceholder: '请输入账号名称',
@@ -2082,7 +2212,7 @@ export default {
claudeCode: 'Claude Code', claudeCode: 'Claude Code',
claudeConsole: 'Claude Console', claudeConsole: 'Claude Console',
bedrockLabel: 'AWS Bedrock', bedrockLabel: 'AWS Bedrock',
bedrockDesc: 'SigV4 签名', bedrockDesc: 'SigV4 / API Key',
oauthSetupToken: 'OAuth / Setup Token', oauthSetupToken: 'OAuth / Setup Token',
addMethod: '添加方式', addMethod: '添加方式',
setupTokenLongLived: 'Setup Token长期有效', setupTokenLongLived: 'Setup Token长期有效',
@@ -2277,6 +2407,9 @@ export default {
bedrockRegionRequired: '请选择 AWS Region', bedrockRegionRequired: '请选择 AWS Region',
bedrockSessionTokenHint: '可选,用于临时凭证', bedrockSessionTokenHint: '可选,用于临时凭证',
bedrockSecretKeyLeaveEmpty: '留空以保持当前密钥', bedrockSecretKeyLeaveEmpty: '留空以保持当前密钥',
bedrockAuthMode: '认证方式',
bedrockAuthModeSigv4: 'SigV4 签名',
bedrockAuthModeApikey: 'Bedrock API Key',
bedrockApiKeyLabel: 'Bedrock API Key', bedrockApiKeyLabel: 'Bedrock API Key',
bedrockApiKeyDesc: 'Bearer Token 认证', bedrockApiKeyDesc: 'Bearer Token 认证',
bedrockApiKeyInput: 'API Key', bedrockApiKeyInput: 'API Key',
@@ -3883,6 +4016,11 @@ export default {
refreshInterval15s: '15 秒', refreshInterval15s: '15 秒',
refreshInterval30s: '30 秒', refreshInterval30s: '30 秒',
refreshInterval60s: '60 秒', refreshInterval60s: '60 秒',
dashboardCards: '仪表盘卡片',
displayAlertEvents: '展示告警事件',
displayAlertEventsHint: '控制运维监控仪表盘中告警事件卡片是否显示,默认开启。',
displayOpenAITokenStats: '展示 OpenAI Token 请求统计',
displayOpenAITokenStatsHint: '控制运维监控仪表盘中 OpenAI Token 请求统计卡片是否显示,默认关闭。',
autoRefreshCountdown: '自动刷新:{seconds}s', autoRefreshCountdown: '自动刷新:{seconds}s',
validation: { validation: {
title: '请先修正以下问题', title: '请先修正以下问题',
@@ -3972,6 +4110,8 @@ export default {
users: '用户默认值', users: '用户默认值',
gateway: '网关服务', gateway: '网关服务',
email: '邮件设置', email: '邮件设置',
backup: '数据备份',
data: 'Sora 存储',
}, },
emailTabDisabledTitle: '邮箱验证未启用', emailTabDisabledTitle: '邮箱验证未启用',
emailTabDisabledHint: '请在「安全与认证」选项卡中启用邮箱验证后,再配置 SMTP 设置。', emailTabDisabledHint: '请在「安全与认证」选项卡中启用邮箱验证后,再配置 SMTP 设置。',
@@ -4060,6 +4200,9 @@ export default {
site: { site: {
title: '站点设置', title: '站点设置',
description: '自定义站点品牌', description: '自定义站点品牌',
backendMode: 'Backend 模式',
backendModeDescription:
'禁用用户注册、公开页面和自助服务功能。仅管理员可以登录和管理平台。',
siteName: '站点名称', siteName: '站点名称',
siteNameHint: '显示在邮件和页面标题中', siteNameHint: '显示在邮件和页面标题中',
siteNamePlaceholder: 'Sub2API', siteNamePlaceholder: 'Sub2API',
@@ -4300,6 +4443,7 @@ export default {
scopeAll: '全部账号', scopeAll: '全部账号',
scopeOAuth: '仅 OAuth 账号', scopeOAuth: '仅 OAuth 账号',
scopeAPIKey: '仅 API Key 账号', scopeAPIKey: '仅 API Key 账号',
scopeBedrock: '仅 Bedrock 账号',
errorMessage: '错误消息', errorMessage: '错误消息',
errorMessagePlaceholder: '拦截时返回的自定义错误消息', errorMessagePlaceholder: '拦截时返回的自定义错误消息',
errorMessageHint: '留空则使用默认错误消息', errorMessageHint: '留空则使用默认错误消息',

View File

@@ -51,6 +51,7 @@ interface MockAuthState {
isAuthenticated: boolean isAuthenticated: boolean
isAdmin: boolean isAdmin: boolean
isSimpleMode: boolean isSimpleMode: boolean
backendModeEnabled: boolean
} }
/** /**
@@ -70,8 +71,17 @@ function simulateGuard(
authState.isAuthenticated && authState.isAuthenticated &&
(toPath === '/login' || toPath === '/register') (toPath === '/login' || toPath === '/register')
) { ) {
if (authState.backendModeEnabled && !authState.isAdmin) {
return null
}
return authState.isAdmin ? '/admin/dashboard' : '/dashboard' return authState.isAdmin ? '/admin/dashboard' : '/dashboard'
} }
if (authState.backendModeEnabled && !authState.isAuthenticated) {
const allowed = ['/login', '/key-usage', '/setup']
if (!allowed.some((path) => toPath === path || toPath.startsWith(path))) {
return '/login'
}
}
return null // 允许通过 return null // 允许通过
} }
@@ -99,6 +109,17 @@ function simulateGuard(
} }
} }
// Backend mode: admin gets full access, non-admin blocked
if (authState.backendModeEnabled) {
if (authState.isAuthenticated && authState.isAdmin) {
return null
}
const allowed = ['/login', '/key-usage', '/setup']
if (!allowed.some((path) => toPath === path || toPath.startsWith(path))) {
return '/login'
}
}
return null // 允许通过 return null // 允许通过
} }
@@ -114,6 +135,7 @@ describe('路由守卫逻辑', () => {
isAuthenticated: false, isAuthenticated: false,
isAdmin: false, isAdmin: false,
isSimpleMode: false, isSimpleMode: false,
backendModeEnabled: false,
} }
it('访问需要认证的页面重定向到 /login', () => { it('访问需要认证的页面重定向到 /login', () => {
@@ -144,6 +166,7 @@ describe('路由守卫逻辑', () => {
isAuthenticated: true, isAuthenticated: true,
isAdmin: false, isAdmin: false,
isSimpleMode: false, isSimpleMode: false,
backendModeEnabled: false,
} }
it('访问 /login 重定向到 /dashboard', () => { it('访问 /login 重定向到 /dashboard', () => {
@@ -179,6 +202,7 @@ describe('路由守卫逻辑', () => {
isAuthenticated: true, isAuthenticated: true,
isAdmin: true, isAdmin: true,
isSimpleMode: false, isSimpleMode: false,
backendModeEnabled: false,
} }
it('访问 /login 重定向到 /admin/dashboard', () => { it('访问 /login 重定向到 /admin/dashboard', () => {
@@ -205,6 +229,7 @@ describe('路由守卫逻辑', () => {
isAuthenticated: true, isAuthenticated: true,
isAdmin: false, isAdmin: false,
isSimpleMode: true, isSimpleMode: true,
backendModeEnabled: false,
} }
const redirect = simulateGuard('/subscriptions', {}, authState) const redirect = simulateGuard('/subscriptions', {}, authState)
expect(redirect).toBe('/dashboard') expect(redirect).toBe('/dashboard')
@@ -215,6 +240,7 @@ describe('路由守卫逻辑', () => {
isAuthenticated: true, isAuthenticated: true,
isAdmin: false, isAdmin: false,
isSimpleMode: true, isSimpleMode: true,
backendModeEnabled: false,
} }
const redirect = simulateGuard('/redeem', {}, authState) const redirect = simulateGuard('/redeem', {}, authState)
expect(redirect).toBe('/dashboard') expect(redirect).toBe('/dashboard')
@@ -225,6 +251,7 @@ describe('路由守卫逻辑', () => {
isAuthenticated: true, isAuthenticated: true,
isAdmin: true, isAdmin: true,
isSimpleMode: true, isSimpleMode: true,
backendModeEnabled: false,
} }
const redirect = simulateGuard('/admin/groups', { requiresAdmin: true }, authState) const redirect = simulateGuard('/admin/groups', { requiresAdmin: true }, authState)
expect(redirect).toBe('/admin/dashboard') expect(redirect).toBe('/admin/dashboard')
@@ -235,6 +262,7 @@ describe('路由守卫逻辑', () => {
isAuthenticated: true, isAuthenticated: true,
isAdmin: true, isAdmin: true,
isSimpleMode: true, isSimpleMode: true,
backendModeEnabled: false,
} }
const redirect = simulateGuard( const redirect = simulateGuard(
'/admin/subscriptions', '/admin/subscriptions',
@@ -249,6 +277,7 @@ describe('路由守卫逻辑', () => {
isAuthenticated: true, isAuthenticated: true,
isAdmin: false, isAdmin: false,
isSimpleMode: true, isSimpleMode: true,
backendModeEnabled: false,
} }
const redirect = simulateGuard('/dashboard', {}, authState) const redirect = simulateGuard('/dashboard', {}, authState)
expect(redirect).toBeNull() expect(redirect).toBeNull()
@@ -259,9 +288,111 @@ describe('路由守卫逻辑', () => {
isAuthenticated: true, isAuthenticated: true,
isAdmin: false, isAdmin: false,
isSimpleMode: true, isSimpleMode: true,
backendModeEnabled: false,
} }
const redirect = simulateGuard('/keys', {}, authState) const redirect = simulateGuard('/keys', {}, authState)
expect(redirect).toBeNull() expect(redirect).toBeNull()
}) })
}) })
describe('Backend Mode', () => {
it('unauthenticated: /home redirects to /login', () => {
const authState: MockAuthState = {
isAuthenticated: false,
isAdmin: false,
isSimpleMode: false,
backendModeEnabled: true,
}
const redirect = simulateGuard('/home', { requiresAuth: false }, authState)
expect(redirect).toBe('/login')
})
it('unauthenticated: /login is allowed', () => {
const authState: MockAuthState = {
isAuthenticated: false,
isAdmin: false,
isSimpleMode: false,
backendModeEnabled: true,
}
const redirect = simulateGuard('/login', { requiresAuth: false }, authState)
expect(redirect).toBeNull()
})
it('unauthenticated: /key-usage is allowed', () => {
const authState: MockAuthState = {
isAuthenticated: false,
isAdmin: false,
isSimpleMode: false,
backendModeEnabled: true,
}
const redirect = simulateGuard('/key-usage', { requiresAuth: false }, authState)
expect(redirect).toBeNull()
})
it('unauthenticated: /setup is allowed', () => {
const authState: MockAuthState = {
isAuthenticated: false,
isAdmin: false,
isSimpleMode: false,
backendModeEnabled: true,
}
const redirect = simulateGuard('/setup', { requiresAuth: false }, authState)
expect(redirect).toBeNull()
})
it('admin: /admin/dashboard is allowed', () => {
const authState: MockAuthState = {
isAuthenticated: true,
isAdmin: true,
isSimpleMode: false,
backendModeEnabled: true,
}
const redirect = simulateGuard('/admin/dashboard', { requiresAdmin: true }, authState)
expect(redirect).toBeNull()
})
it('admin: /login redirects to /admin/dashboard', () => {
const authState: MockAuthState = {
isAuthenticated: true,
isAdmin: true,
isSimpleMode: false,
backendModeEnabled: true,
}
const redirect = simulateGuard('/login', { requiresAuth: false }, authState)
expect(redirect).toBe('/admin/dashboard')
})
it('non-admin authenticated: /dashboard redirects to /login', () => {
const authState: MockAuthState = {
isAuthenticated: true,
isAdmin: false,
isSimpleMode: false,
backendModeEnabled: true,
}
const redirect = simulateGuard('/dashboard', {}, authState)
expect(redirect).toBe('/login')
})
it('non-admin authenticated: /login is allowed (no redirect loop)', () => {
const authState: MockAuthState = {
isAuthenticated: true,
isAdmin: false,
isSimpleMode: false,
backendModeEnabled: true,
}
const redirect = simulateGuard('/login', { requiresAuth: false }, authState)
expect(redirect).toBeNull()
})
it('non-admin authenticated: /key-usage is allowed', () => {
const authState: MockAuthState = {
isAuthenticated: true,
isAdmin: false,
isSimpleMode: false,
backendModeEnabled: true,
}
const redirect = simulateGuard('/key-usage', { requiresAuth: false }, authState)
expect(redirect).toBeNull()
})
})
}) })

View File

@@ -350,18 +350,6 @@ const routes: RouteRecordRaw[] = [
descriptionKey: 'admin.promo.description' descriptionKey: 'admin.promo.description'
} }
}, },
{
path: '/admin/data-management',
name: 'AdminDataManagement',
component: () => import('@/views/admin/DataManagementView.vue'),
meta: {
requiresAuth: true,
requiresAdmin: true,
title: 'Data Management',
titleKey: 'admin.dataManagement.title',
descriptionKey: 'admin.dataManagement.description'
}
},
{ {
path: '/admin/settings', path: '/admin/settings',
name: 'AdminSettings', name: 'AdminSettings',
@@ -423,6 +411,7 @@ let authInitialized = false
const navigationLoading = useNavigationLoadingState() const navigationLoading = useNavigationLoadingState()
// 延迟初始化预加载,传入 router 实例 // 延迟初始化预加载,传入 router 实例
let routePrefetch: ReturnType<typeof useRoutePrefetch> | null = null let routePrefetch: ReturnType<typeof useRoutePrefetch> | null = null
const BACKEND_MODE_ALLOWED_PATHS = ['/login', '/key-usage', '/setup']
router.beforeEach((to, _from, next) => { router.beforeEach((to, _from, next) => {
// 开始导航加载状态 // 开始导航加载状态
@@ -463,10 +452,24 @@ router.beforeEach((to, _from, next) => {
if (!requiresAuth) { if (!requiresAuth) {
// If already authenticated and trying to access login/register, redirect to appropriate dashboard // If already authenticated and trying to access login/register, redirect to appropriate dashboard
if (authStore.isAuthenticated && (to.path === '/login' || to.path === '/register')) { if (authStore.isAuthenticated && (to.path === '/login' || to.path === '/register')) {
// In backend mode, non-admin users should NOT be redirected away from login
// (they are blocked from all protected routes, so redirecting would cause a loop)
if (appStore.backendModeEnabled && !authStore.isAdmin) {
next()
return
}
// Admin users go to admin dashboard, regular users go to user dashboard // Admin users go to admin dashboard, regular users go to user dashboard
next(authStore.isAdmin ? '/admin/dashboard' : '/dashboard') next(authStore.isAdmin ? '/admin/dashboard' : '/dashboard')
return return
} }
// Backend mode: block public pages for unauthenticated users (except login, key-usage, setup)
if (appStore.backendModeEnabled && !authStore.isAuthenticated) {
const isAllowed = BACKEND_MODE_ALLOWED_PATHS.some((p) => to.path === p || to.path.startsWith(p))
if (!isAllowed) {
next('/login')
return
}
}
next() next()
return return
} }
@@ -505,6 +508,19 @@ router.beforeEach((to, _from, next) => {
} }
} }
// Backend mode: admin gets full access, non-admin blocked
if (appStore.backendModeEnabled) {
if (authStore.isAuthenticated && authStore.isAdmin) {
next()
return
}
const isAllowed = BACKEND_MODE_ALLOWED_PATHS.some((p) => to.path === p || to.path.startsWith(p))
if (!isAllowed) {
next('/login')
return
}
}
// All checks passed, allow navigation // All checks passed, allow navigation
next() next()
}) })

View File

@@ -47,6 +47,7 @@ export const useAppStore = defineStore('app', () => {
// ==================== Computed ==================== // ==================== Computed ====================
const hasActiveToasts = computed(() => toasts.value.length > 0) const hasActiveToasts = computed(() => toasts.value.length > 0)
const backendModeEnabled = computed(() => cachedPublicSettings.value?.backend_mode_enabled ?? false)
const loadingCount = ref<number>(0) const loadingCount = ref<number>(0)
@@ -331,6 +332,7 @@ export const useAppStore = defineStore('app', () => {
custom_menu_items: [], custom_menu_items: [],
linuxdo_oauth_enabled: false, linuxdo_oauth_enabled: false,
sora_client_enabled: false, sora_client_enabled: false,
backend_mode_enabled: false,
version: siteVersion.value version: siteVersion.value
} }
} }
@@ -404,6 +406,7 @@ export const useAppStore = defineStore('app', () => {
// Computed // Computed
hasActiveToasts, hasActiveToasts,
backendModeEnabled,
// Actions // Actions
toggleSidebar, toggleSidebar,

View File

@@ -106,6 +106,7 @@ export interface PublicSettings {
custom_menu_items: CustomMenuItem[] custom_menu_items: CustomMenuItem[]
linuxdo_oauth_enabled: boolean linuxdo_oauth_enabled: boolean
sora_client_enabled: boolean sora_client_enabled: boolean
backend_mode_enabled: boolean
version: string version: string
} }
@@ -531,7 +532,7 @@ export interface UpdateGroupRequest {
// ==================== Account & Proxy Types ==================== // ==================== Account & Proxy Types ====================
export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' | 'sora' export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' | 'sora'
export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream' | 'bedrock' | 'bedrock-apikey' export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream' | 'bedrock'
export type OAuthAddMethod = 'oauth' | 'setup-token' export type OAuthAddMethod = 'oauth' | 'setup-token'
export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h' export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h'
@@ -727,6 +728,16 @@ export interface Account {
quota_weekly_limit?: number | null quota_weekly_limit?: number | null
quota_weekly_used?: number | null quota_weekly_used?: number | null
// 配额固定时间重置配置
quota_daily_reset_mode?: 'rolling' | 'fixed' | null
quota_daily_reset_hour?: number | null
quota_weekly_reset_mode?: 'rolling' | 'fixed' | null
quota_weekly_reset_day?: number | null
quota_weekly_reset_hour?: number | null
quota_reset_timezone?: string | null
quota_daily_reset_at?: string | null
quota_weekly_reset_at?: string | null
// 运行时状态(仅当启用对应限制时返回) // 运行时状态(仅当启用对应限制时返回)
current_window_cost?: number | null // 当前窗口费用 current_window_cost?: number | null // 当前窗口费用
active_sessions?: number | null // 当前活跃会话数 active_sessions?: number | null // 当前活跃会话数
@@ -769,6 +780,21 @@ export interface AccountUsageInfo {
gemini_pro_minute?: UsageProgress | null gemini_pro_minute?: UsageProgress | null
gemini_flash_minute?: UsageProgress | null gemini_flash_minute?: UsageProgress | null
antigravity_quota?: Record<string, AntigravityModelQuota> | null antigravity_quota?: Record<string, AntigravityModelQuota> | null
// Antigravity 403 forbidden 状态
is_forbidden?: boolean
forbidden_reason?: string
forbidden_type?: string // "validation" | "violation" | "forbidden"
validation_url?: string // 验证/申诉链接
// 状态标记(后端自动推导)
needs_verify?: boolean // 需要人工验证forbidden_type=validation
is_banned?: boolean // 账号被封forbidden_type=violation
needs_reauth?: boolean // token 失效需重新授权401
// 机器可读错误码forbidden / unauthenticated / rate_limited / network_error
error_code?: string
error?: string // usage 获取失败时的错误信息
} }
// OpenAI Codex usage snapshot (from response headers) // OpenAI Codex usage snapshot (from response headers)

View File

@@ -171,7 +171,15 @@
<span v-else class="text-sm text-gray-400 dark:text-dark-500">-</span> <span v-else class="text-sm text-gray-400 dark:text-dark-500">-</span>
</template> </template>
<template #cell-platform_type="{ row }"> <template #cell-platform_type="{ row }">
<PlatformTypeBadge :platform="row.platform" :type="row.type" :plan-type="row.credentials?.plan_type" :privacy-mode="row.extra?.privacy_mode" /> <div class="flex flex-wrap items-center gap-1">
<PlatformTypeBadge :platform="row.platform" :type="row.type" :plan-type="row.credentials?.plan_type" :privacy-mode="row.extra?.privacy_mode" />
<span
v-if="getAntigravityTierLabel(row)"
:class="['inline-block rounded px-1.5 py-0.5 text-[10px] font-medium', getAntigravityTierClass(row)]"
>
{{ getAntigravityTierLabel(row) }}
</span>
</div>
</template> </template>
<template #cell-capacity="{ row }"> <template #cell-capacity="{ row }">
<AccountCapacityCell :account="row" /> <AccountCapacityCell :account="row" />
@@ -794,6 +802,40 @@ const { pause: pauseAutoRefresh, resume: resumeAutoRefresh } = useIntervalFn(
{ immediate: false } { immediate: false }
) )
// Antigravity 订阅等级辅助函数
function getAntigravityTierFromRow(row: any): string | null {
if (row.platform !== 'antigravity') return null
const extra = row.extra as Record<string, unknown> | undefined
if (!extra) return null
const lca = extra.load_code_assist as Record<string, unknown> | undefined
if (!lca) return null
const paid = lca.paidTier as Record<string, unknown> | undefined
if (paid && typeof paid.id === 'string') return paid.id
const current = lca.currentTier as Record<string, unknown> | undefined
if (current && typeof current.id === 'string') return current.id
return null
}
function getAntigravityTierLabel(row: any): string | null {
const tier = getAntigravityTierFromRow(row)
switch (tier) {
case 'free-tier': return t('admin.accounts.tier.free')
case 'g1-pro-tier': return t('admin.accounts.tier.pro')
case 'g1-ultra-tier': return t('admin.accounts.tier.ultra')
default: return null
}
}
function getAntigravityTierClass(row: any): string {
const tier = getAntigravityTierFromRow(row)
switch (tier) {
case 'free-tier': return 'bg-gray-100 text-gray-600 dark:bg-gray-700 dark:text-gray-300'
case 'g1-pro-tier': return 'bg-blue-100 text-blue-600 dark:bg-blue-900/40 dark:text-blue-300'
case 'g1-ultra-tier': return 'bg-purple-100 text-purple-600 dark:bg-purple-900/40 dark:text-purple-300'
default: return ''
}
}
// All available columns // All available columns
const allColumns = computed(() => { const allColumns = computed(() => {
const c = [ const c = [

View File

@@ -0,0 +1,505 @@
<template>
<div class="space-y-6">
<!-- S3 Storage Config -->
<div class="card p-6">
<div class="mb-4 flex flex-wrap items-center justify-between gap-3">
<div>
<h3 class="text-base font-semibold text-gray-900 dark:text-white">
{{ t('admin.backup.s3.title') }}
</h3>
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.backup.s3.descriptionPrefix') }}
<button type="button" class="text-primary-600 underline hover:text-primary-700 dark:text-primary-400 dark:hover:text-primary-300" @click="showR2Guide = true">Cloudflare R2</button>
{{ t('admin.backup.s3.descriptionSuffix') }}
</p>
</div>
</div>
<div class="grid grid-cols-1 gap-3 md:grid-cols-2">
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.s3.endpoint') }}</label>
<input v-model="s3Form.endpoint" class="input w-full" placeholder="https://<account_id>.r2.cloudflarestorage.com" />
</div>
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.s3.region') }}</label>
<input v-model="s3Form.region" class="input w-full" placeholder="auto" />
</div>
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.s3.bucket') }}</label>
<input v-model="s3Form.bucket" class="input w-full" />
</div>
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.s3.prefix') }}</label>
<input v-model="s3Form.prefix" class="input w-full" placeholder="backups/" />
</div>
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.s3.accessKeyId') }}</label>
<input v-model="s3Form.access_key_id" class="input w-full" />
</div>
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.s3.secretAccessKey') }}</label>
<input v-model="s3Form.secret_access_key" type="password" class="input w-full" :placeholder="s3SecretConfigured ? t('admin.backup.s3.secretConfigured') : ''" />
</div>
<label class="inline-flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300 md:col-span-2">
<input v-model="s3Form.force_path_style" type="checkbox" />
<span>{{ t('admin.backup.s3.forcePathStyle') }}</span>
</label>
</div>
<div class="mt-4 flex flex-wrap gap-2">
<button type="button" class="btn btn-secondary btn-sm" :disabled="testingS3" @click="testS3">
{{ testingS3 ? t('common.loading') : t('admin.backup.s3.testConnection') }}
</button>
<button type="button" class="btn btn-primary btn-sm" :disabled="savingS3" @click="saveS3Config">
{{ savingS3 ? t('common.loading') : t('common.save') }}
</button>
</div>
</div>
<!-- Schedule Config -->
<div class="card p-6">
<div class="mb-4">
<h3 class="text-base font-semibold text-gray-900 dark:text-white">
{{ t('admin.backup.schedule.title') }}
</h3>
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.backup.schedule.description') }}
</p>
</div>
<div class="grid grid-cols-1 gap-3 md:grid-cols-2">
<label class="inline-flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300 md:col-span-2">
<input v-model="scheduleForm.enabled" type="checkbox" />
<span>{{ t('admin.backup.schedule.enabled') }}</span>
</label>
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.schedule.cronExpr') }}</label>
<input v-model="scheduleForm.cron_expr" class="input w-full" placeholder="0 2 * * *" />
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">{{ t('admin.backup.schedule.cronHint') }}</p>
</div>
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.schedule.retainDays') }}</label>
<input v-model.number="scheduleForm.retain_days" type="number" min="0" class="input w-full" />
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">{{ t('admin.backup.schedule.retainDaysHint') }}</p>
</div>
<div>
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.schedule.retainCount') }}</label>
<input v-model.number="scheduleForm.retain_count" type="number" min="0" class="input w-full" />
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">{{ t('admin.backup.schedule.retainCountHint') }}</p>
</div>
</div>
<div class="mt-4">
<button type="button" class="btn btn-primary btn-sm" :disabled="savingSchedule" @click="saveSchedule">
{{ savingSchedule ? t('common.loading') : t('common.save') }}
</button>
</div>
</div>
<!-- Backup Operations -->
<div class="card p-6">
<div class="mb-4 flex flex-wrap items-center justify-between gap-3">
<div>
<h3 class="text-base font-semibold text-gray-900 dark:text-white">
{{ t('admin.backup.operations.title') }}
</h3>
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.backup.operations.description') }}
</p>
</div>
<div class="flex flex-wrap items-center gap-2">
<div class="flex items-center gap-1">
<label class="text-xs text-gray-600 dark:text-gray-400">{{ t('admin.backup.operations.expireDays') }}</label>
<input v-model.number="manualExpireDays" type="number" min="0" class="input w-20 text-xs" />
</div>
<button type="button" class="btn btn-primary btn-sm" :disabled="creatingBackup" @click="createBackup">
{{ creatingBackup ? t('admin.backup.operations.backing') : t('admin.backup.operations.createBackup') }}
</button>
<button type="button" class="btn btn-secondary btn-sm" :disabled="loadingBackups" @click="loadBackups">
{{ loadingBackups ? t('common.loading') : t('common.refresh') }}
</button>
</div>
</div>
<div class="overflow-x-auto">
<table class="w-full min-w-[800px] text-sm">
<thead>
<tr class="border-b border-gray-200 text-left text-xs uppercase tracking-wide text-gray-500 dark:border-dark-700 dark:text-gray-400">
<th class="py-2 pr-4">ID</th>
<th class="py-2 pr-4">{{ t('admin.backup.columns.status') }}</th>
<th class="py-2 pr-4">{{ t('admin.backup.columns.fileName') }}</th>
<th class="py-2 pr-4">{{ t('admin.backup.columns.size') }}</th>
<th class="py-2 pr-4">{{ t('admin.backup.columns.expiresAt') }}</th>
<th class="py-2 pr-4">{{ t('admin.backup.columns.triggeredBy') }}</th>
<th class="py-2 pr-4">{{ t('admin.backup.columns.startedAt') }}</th>
<th class="py-2">{{ t('admin.backup.columns.actions') }}</th>
</tr>
</thead>
<tbody>
<tr v-for="record in backups" :key="record.id" class="border-b border-gray-100 align-top dark:border-dark-800">
<td class="py-3 pr-4 font-mono text-xs">{{ record.id }}</td>
<td class="py-3 pr-4">
<span
class="rounded px-2 py-0.5 text-xs"
:class="statusClass(record.status)"
>
{{ t(`admin.backup.status.${record.status}`) }}
</span>
</td>
<td class="py-3 pr-4 text-xs">{{ record.file_name }}</td>
<td class="py-3 pr-4 text-xs">{{ formatSize(record.size_bytes) }}</td>
<td class="py-3 pr-4 text-xs">
{{ record.expires_at ? formatDate(record.expires_at) : t('admin.backup.neverExpire') }}
</td>
<td class="py-3 pr-4 text-xs">
{{ record.triggered_by === 'scheduled' ? t('admin.backup.trigger.scheduled') : t('admin.backup.trigger.manual') }}
</td>
<td class="py-3 pr-4 text-xs">{{ formatDate(record.started_at) }}</td>
<td class="py-3 text-xs">
<div class="flex flex-wrap gap-1">
<button
v-if="record.status === 'completed'"
type="button"
class="btn btn-secondary btn-xs"
@click="downloadBackup(record.id)"
>
{{ t('admin.backup.actions.download') }}
</button>
<button
v-if="record.status === 'completed'"
type="button"
class="btn btn-secondary btn-xs"
:disabled="restoringId === record.id"
@click="restoreBackup(record.id)"
>
{{ restoringId === record.id ? t('common.loading') : t('admin.backup.actions.restore') }}
</button>
<button
type="button"
class="btn btn-danger btn-xs"
@click="removeBackup(record.id)"
>
{{ t('common.delete') }}
</button>
</div>
</td>
</tr>
<tr v-if="backups.length === 0">
<td colspan="8" class="py-6 text-center text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.backup.empty') }}
</td>
</tr>
</tbody>
</table>
</div>
</div>
</div>
<!-- Cloudflare R2 Setup Guide Modal -->
<teleport to="body">
<transition name="modal">
<div v-if="showR2Guide" class="fixed inset-0 z-50 flex items-center justify-center p-4" @mousedown.self="showR2Guide = false">
<div class="fixed inset-0 bg-black/50" @click="showR2Guide = false"></div>
<div class="relative max-h-[85vh] w-full max-w-2xl overflow-y-auto rounded-xl bg-white p-6 shadow-2xl dark:bg-dark-800">
<button type="button" class="absolute right-4 top-4 text-gray-400 hover:text-gray-600 dark:hover:text-gray-200" @click="showR2Guide = false">
<svg class="h-5 w-5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2"><path stroke-linecap="round" stroke-linejoin="round" d="M6 18L18 6M6 6l12 12" /></svg>
</button>
<h2 class="mb-4 text-lg font-bold text-gray-900 dark:text-white">{{ t('admin.backup.r2Guide.title') }}</h2>
<p class="mb-4 text-sm text-gray-500 dark:text-gray-400">{{ t('admin.backup.r2Guide.intro') }}</p>
<!-- Step 1 -->
<div class="mb-5">
<h3 class="mb-2 flex items-center gap-2 text-sm font-semibold text-gray-900 dark:text-white">
<span class="flex h-6 w-6 items-center justify-center rounded-full bg-primary-100 text-xs font-bold text-primary-700 dark:bg-primary-900/40 dark:text-primary-300">1</span>
{{ t('admin.backup.r2Guide.step1.title') }}
</h3>
<ol class="ml-8 list-decimal space-y-1 text-sm text-gray-600 dark:text-gray-300">
<li>{{ t('admin.backup.r2Guide.step1.line1') }}</li>
<li>{{ t('admin.backup.r2Guide.step1.line2') }}</li>
<li>{{ t('admin.backup.r2Guide.step1.line3') }}</li>
</ol>
</div>
<!-- Step 2 -->
<div class="mb-5">
<h3 class="mb-2 flex items-center gap-2 text-sm font-semibold text-gray-900 dark:text-white">
<span class="flex h-6 w-6 items-center justify-center rounded-full bg-primary-100 text-xs font-bold text-primary-700 dark:bg-primary-900/40 dark:text-primary-300">2</span>
{{ t('admin.backup.r2Guide.step2.title') }}
</h3>
<ol class="ml-8 list-decimal space-y-1 text-sm text-gray-600 dark:text-gray-300">
<li>{{ t('admin.backup.r2Guide.step2.line1') }}</li>
<li>{{ t('admin.backup.r2Guide.step2.line2') }}</li>
<li>{{ t('admin.backup.r2Guide.step2.line3') }}</li>
<li>{{ t('admin.backup.r2Guide.step2.line4') }}</li>
</ol>
<div class="mt-2 rounded-lg bg-amber-50 p-3 text-xs text-amber-700 dark:bg-amber-900/20 dark:text-amber-300">
{{ t('admin.backup.r2Guide.step2.warning') }}
</div>
</div>
<!-- Step 3 -->
<div class="mb-5">
<h3 class="mb-2 flex items-center gap-2 text-sm font-semibold text-gray-900 dark:text-white">
<span class="flex h-6 w-6 items-center justify-center rounded-full bg-primary-100 text-xs font-bold text-primary-700 dark:bg-primary-900/40 dark:text-primary-300">3</span>
{{ t('admin.backup.r2Guide.step3.title') }}
</h3>
<p class="ml-8 text-sm text-gray-600 dark:text-gray-300">{{ t('admin.backup.r2Guide.step3.desc') }}</p>
<code class="ml-8 mt-1 block rounded bg-gray-100 px-3 py-2 text-xs text-gray-800 dark:bg-dark-700 dark:text-gray-200">https://&lt;{{ t('admin.backup.r2Guide.step3.accountId') }}&gt;.r2.cloudflarestorage.com</code>
</div>
<!-- Step 4: Fill form -->
<div class="mb-5">
<h3 class="mb-2 flex items-center gap-2 text-sm font-semibold text-gray-900 dark:text-white">
<span class="flex h-6 w-6 items-center justify-center rounded-full bg-primary-100 text-xs font-bold text-primary-700 dark:bg-primary-900/40 dark:text-primary-300">4</span>
{{ t('admin.backup.r2Guide.step4.title') }}
</h3>
<div class="ml-8 overflow-hidden rounded-lg border border-gray-200 dark:border-dark-600">
<table class="w-full text-sm">
<tbody>
<tr v-for="(row, i) in r2ConfigRows" :key="i" class="border-b border-gray-100 dark:border-dark-700 last:border-0">
<td class="whitespace-nowrap bg-gray-50 px-3 py-2 font-medium text-gray-700 dark:bg-dark-700 dark:text-gray-300">{{ row.field }}</td>
<td class="px-3 py-2 text-gray-600 dark:text-gray-400"><code class="text-xs">{{ row.value }}</code></td>
</tr>
</tbody>
</table>
</div>
</div>
<!-- Free tier note -->
<div class="rounded-lg bg-green-50 p-3 text-xs text-green-700 dark:bg-green-900/20 dark:text-green-300">
{{ t('admin.backup.r2Guide.freeTier') }}
</div>
<div class="mt-4 text-right">
<button type="button" class="btn btn-primary btn-sm" @click="showR2Guide = false">{{ t('common.close') }}</button>
</div>
</div>
</div>
</transition>
</teleport>
</template>
<script setup lang="ts">
import { computed, onMounted, ref } from 'vue'
import { useI18n } from 'vue-i18n'
import { adminAPI } from '@/api'
import { useAppStore } from '@/stores'
import type { BackupS3Config, BackupScheduleConfig, BackupRecord } from '@/api/admin/backup'
const { t } = useI18n()
const appStore = useAppStore()
// S3 config
const s3Form = ref<BackupS3Config>({
endpoint: '',
region: 'auto',
bucket: '',
access_key_id: '',
secret_access_key: '',
prefix: 'backups/',
force_path_style: false,
})
const s3SecretConfigured = ref(false)
const savingS3 = ref(false)
const testingS3 = ref(false)
// Schedule config
const scheduleForm = ref<BackupScheduleConfig>({
enabled: false,
cron_expr: '0 2 * * *',
retain_days: 14,
retain_count: 10,
})
const savingSchedule = ref(false)
// Backups
const backups = ref<BackupRecord[]>([])
const loadingBackups = ref(false)
const creatingBackup = ref(false)
const restoringId = ref('')
const manualExpireDays = ref(14)
// R2 guide
const showR2Guide = ref(false)
const r2ConfigRows = computed(() => [
{ field: t('admin.backup.s3.endpoint'), value: 'https://<account_id>.r2.cloudflarestorage.com' },
{ field: t('admin.backup.s3.region'), value: 'auto' },
{ field: t('admin.backup.s3.bucket'), value: t('admin.backup.r2Guide.step4.bucketValue') },
{ field: t('admin.backup.s3.prefix'), value: 'backups/' },
{ field: 'Access Key ID', value: t('admin.backup.r2Guide.step4.fromStep2') },
{ field: 'Secret Access Key', value: t('admin.backup.r2Guide.step4.fromStep2') },
{ field: t('admin.backup.s3.forcePathStyle'), value: t('admin.backup.r2Guide.step4.unchecked') },
])
async function loadS3Config() {
try {
const cfg = await adminAPI.backup.getS3Config()
s3Form.value = {
endpoint: cfg.endpoint || '',
region: cfg.region || 'auto',
bucket: cfg.bucket || '',
access_key_id: cfg.access_key_id || '',
secret_access_key: '',
prefix: cfg.prefix || 'backups/',
force_path_style: cfg.force_path_style,
}
s3SecretConfigured.value = Boolean(cfg.access_key_id)
} catch (error) {
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
}
}
async function saveS3Config() {
savingS3.value = true
try {
await adminAPI.backup.updateS3Config(s3Form.value)
appStore.showSuccess(t('admin.backup.s3.saved'))
await loadS3Config()
} catch (error) {
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
} finally {
savingS3.value = false
}
}
async function testS3() {
testingS3.value = true
try {
const result = await adminAPI.backup.testS3Connection(s3Form.value)
if (result.ok) {
appStore.showSuccess(result.message || t('admin.backup.s3.testSuccess'))
} else {
appStore.showError(result.message || t('admin.backup.s3.testFailed'))
}
} catch (error) {
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
} finally {
testingS3.value = false
}
}
async function loadSchedule() {
try {
const cfg = await adminAPI.backup.getSchedule()
scheduleForm.value = {
enabled: cfg.enabled,
cron_expr: cfg.cron_expr || '0 2 * * *',
retain_days: cfg.retain_days || 14,
retain_count: cfg.retain_count || 10,
}
} catch (error) {
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
}
}
async function saveSchedule() {
savingSchedule.value = true
try {
await adminAPI.backup.updateSchedule(scheduleForm.value)
appStore.showSuccess(t('admin.backup.schedule.saved'))
} catch (error) {
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
} finally {
savingSchedule.value = false
}
}
async function loadBackups() {
loadingBackups.value = true
try {
const result = await adminAPI.backup.listBackups()
backups.value = result.items || []
} catch (error) {
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
} finally {
loadingBackups.value = false
}
}
async function createBackup() {
creatingBackup.value = true
try {
await adminAPI.backup.createBackup({ expire_days: manualExpireDays.value })
appStore.showSuccess(t('admin.backup.operations.backupCreated'))
await loadBackups()
} catch (error) {
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
} finally {
creatingBackup.value = false
}
}
async function downloadBackup(id: string) {
try {
const result = await adminAPI.backup.getDownloadURL(id)
window.open(result.url, '_blank')
} catch (error) {
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
}
}
async function restoreBackup(id: string) {
if (!window.confirm(t('admin.backup.actions.restoreConfirm'))) return
const password = window.prompt(t('admin.backup.actions.restorePasswordPrompt'))
if (!password) return
restoringId.value = id
try {
await adminAPI.backup.restoreBackup(id, password)
appStore.showSuccess(t('admin.backup.actions.restoreSuccess'))
} catch (error) {
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
} finally {
restoringId.value = ''
}
}
async function removeBackup(id: string) {
if (!window.confirm(t('admin.backup.actions.deleteConfirm'))) return
try {
await adminAPI.backup.deleteBackup(id)
appStore.showSuccess(t('admin.backup.actions.deleted'))
await loadBackups()
} catch (error) {
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
}
}
function statusClass(status: string): string {
switch (status) {
case 'completed':
return 'bg-green-100 text-green-700 dark:bg-green-900/30 dark:text-green-300'
case 'running':
return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300'
case 'failed':
return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-300'
default:
return 'bg-gray-100 text-gray-700 dark:bg-dark-800 dark:text-gray-300'
}
}
function formatSize(bytes: number): string {
if (!bytes || bytes <= 0) return '-'
if (bytes < 1024) return `${bytes} B`
if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(1)} KB`
return `${(bytes / (1024 * 1024)).toFixed(1)} MB`
}
function formatDate(value?: string): string {
if (!value) return '-'
const date = new Date(value)
if (Number.isNaN(date.getTime())) return value
return date.toLocaleString()
}
onMounted(async () => {
await Promise.all([loadS3Config(), loadSchedule(), loadBackups()])
})
</script>
<style scoped>
.modal-enter-active,
.modal-leave-active {
transition: opacity 0.2s ease;
}
.modal-enter-from,
.modal-leave-to {
opacity: 0;
}
</style>

View File

@@ -1,5 +1,4 @@
<template> <template>
<AppLayout>
<div class="space-y-6"> <div class="space-y-6">
<div class="card p-6"> <div class="card p-6">
<div class="mb-4 flex flex-wrap items-center justify-between gap-3"> <div class="mb-4 flex flex-wrap items-center justify-between gap-3">
@@ -183,13 +182,11 @@
</div> </div>
</Transition> </Transition>
</Teleport> </Teleport>
</AppLayout>
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { onMounted, ref } from 'vue' import { onMounted, ref } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import AppLayout from '@/components/layout/AppLayout.vue'
import type { SoraS3Profile } from '@/api/admin/settings' import type { SoraS3Profile } from '@/api/admin/settings'
import { adminAPI } from '@/api' import { adminAPI } from '@/api'
import { useAppStore } from '@/stores' import { useAppStore } from '@/stores'

View File

@@ -9,7 +9,7 @@
<!-- Settings Form --> <!-- Settings Form -->
<form v-else @submit.prevent="saveSettings" class="space-y-6"> <form v-else @submit.prevent="saveSettings" class="space-y-6">
<!-- Tab Navigation --> <!-- Tab Navigation -->
<div class="sticky top-0 z-10 overflow-x-auto scrollbar-hide"> <div class="sticky top-0 z-10 overflow-x-auto settings-tabs-scroll">
<nav class="settings-tabs"> <nav class="settings-tabs">
<button <button
v-for="tab in settingsTabs" v-for="tab in settingsTabs"
@@ -1070,6 +1070,21 @@
</p> </p>
</div> </div>
<div class="space-y-6 p-6"> <div class="space-y-6 p-6">
<!-- Backend Mode -->
<div
class="flex items-center justify-between rounded-lg border border-amber-200 bg-amber-50 p-4 dark:border-amber-800 dark:bg-amber-900/20"
>
<div>
<h3 class="text-sm font-medium text-gray-900 dark:text-white">
{{ t('admin.settings.site.backendMode') }}
</h3>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.settings.site.backendModeDescription') }}
</p>
</div>
<Toggle v-model="form.backend_mode_enabled" />
</div>
<div class="grid grid-cols-1 gap-6 md:grid-cols-2"> <div class="grid grid-cols-1 gap-6 md:grid-cols-2">
<div> <div>
<label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300"> <label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
@@ -1634,8 +1649,18 @@
</div> </div>
</div><!-- /Tab: Email --> </div><!-- /Tab: Email -->
<!-- Tab: Backup -->
<div v-show="activeTab === 'backup'">
<BackupSettings />
</div>
<!-- Tab: Data Management -->
<div v-show="activeTab === 'data'">
<DataManagementSettings />
</div>
<!-- Save Button --> <!-- Save Button -->
<div class="flex justify-end"> <div v-show="activeTab !== 'backup' && activeTab !== 'data'" class="flex justify-end">
<button type="submit" :disabled="saving" class="btn btn-primary"> <button type="submit" :disabled="saving" class="btn btn-primary">
<svg v-if="saving" class="h-4 w-4 animate-spin" fill="none" viewBox="0 0 24 24"> <svg v-if="saving" class="h-4 w-4 animate-spin" fill="none" viewBox="0 0 24 24">
<circle <circle
@@ -1677,6 +1702,8 @@ import GroupBadge from '@/components/common/GroupBadge.vue'
import GroupOptionItem from '@/components/common/GroupOptionItem.vue' import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
import Toggle from '@/components/common/Toggle.vue' import Toggle from '@/components/common/Toggle.vue'
import ImageUpload from '@/components/common/ImageUpload.vue' import ImageUpload from '@/components/common/ImageUpload.vue'
import BackupSettings from '@/views/admin/BackupView.vue'
import DataManagementSettings from '@/views/admin/DataManagementView.vue'
import { useClipboard } from '@/composables/useClipboard' import { useClipboard } from '@/composables/useClipboard'
import { useAppStore } from '@/stores' import { useAppStore } from '@/stores'
import { useAdminSettingsStore } from '@/stores/adminSettings' import { useAdminSettingsStore } from '@/stores/adminSettings'
@@ -1691,7 +1718,7 @@ const { t } = useI18n()
const appStore = useAppStore() const appStore = useAppStore()
const adminSettingsStore = useAdminSettingsStore() const adminSettingsStore = useAdminSettingsStore()
type SettingsTab = 'general' | 'security' | 'users' | 'gateway' | 'email' type SettingsTab = 'general' | 'security' | 'users' | 'gateway' | 'email' | 'backup' | 'data'
const activeTab = ref<SettingsTab>('general') const activeTab = ref<SettingsTab>('general')
const settingsTabs = [ const settingsTabs = [
{ key: 'general' as SettingsTab, icon: 'home' as const }, { key: 'general' as SettingsTab, icon: 'home' as const },
@@ -1699,6 +1726,8 @@ const settingsTabs = [
{ key: 'users' as SettingsTab, icon: 'user' as const }, { key: 'users' as SettingsTab, icon: 'user' as const },
{ key: 'gateway' as SettingsTab, icon: 'server' as const }, { key: 'gateway' as SettingsTab, icon: 'server' as const },
{ key: 'email' as SettingsTab, icon: 'mail' as const }, { key: 'email' as SettingsTab, icon: 'mail' as const },
{ key: 'backup' as SettingsTab, icon: 'database' as const },
{ key: 'data' as SettingsTab, icon: 'cube' as const },
] ]
const { copyToClipboard } = useClipboard() const { copyToClipboard } = useClipboard()
@@ -1745,7 +1774,7 @@ const betaPolicyForm = reactive({
rules: [] as Array<{ rules: [] as Array<{
beta_token: string beta_token: string
action: 'pass' | 'filter' | 'block' action: 'pass' | 'filter' | 'block'
scope: 'all' | 'oauth' | 'apikey' scope: 'all' | 'oauth' | 'apikey' | 'bedrock'
error_message?: string error_message?: string
}> }>
}) })
@@ -1785,6 +1814,7 @@ const form = reactive<SettingsForm>({
contact_info: '', contact_info: '',
doc_url: '', doc_url: '',
home_content: '', home_content: '',
backend_mode_enabled: false,
hide_ccs_import_button: false, hide_ccs_import_button: false,
purchase_subscription_enabled: false, purchase_subscription_enabled: false,
purchase_subscription_url: '', purchase_subscription_url: '',
@@ -1962,6 +1992,7 @@ async function loadSettings() {
try { try {
const settings = await adminAPI.settings.getSettings() const settings = await adminAPI.settings.getSettings()
Object.assign(form, settings) Object.assign(form, settings)
form.backend_mode_enabled = settings.backend_mode_enabled
form.default_subscriptions = Array.isArray(settings.default_subscriptions) form.default_subscriptions = Array.isArray(settings.default_subscriptions)
? settings.default_subscriptions ? settings.default_subscriptions
.filter((item) => item.group_id > 0 && item.validity_days > 0) .filter((item) => item.group_id > 0 && item.validity_days > 0)
@@ -2060,6 +2091,7 @@ async function saveSettings() {
contact_info: form.contact_info, contact_info: form.contact_info,
doc_url: form.doc_url, doc_url: form.doc_url,
home_content: form.home_content, home_content: form.home_content,
backend_mode_enabled: form.backend_mode_enabled,
hide_ccs_import_button: form.hide_ccs_import_button, hide_ccs_import_button: form.hide_ccs_import_button,
purchase_subscription_enabled: form.purchase_subscription_enabled, purchase_subscription_enabled: form.purchase_subscription_enabled,
purchase_subscription_url: form.purchase_subscription_url, purchase_subscription_url: form.purchase_subscription_url,
@@ -2297,7 +2329,8 @@ const betaPolicyActionOptions = computed(() => [
const betaPolicyScopeOptions = computed(() => [ const betaPolicyScopeOptions = computed(() => [
{ value: 'all', label: t('admin.settings.betaPolicy.scopeAll') }, { value: 'all', label: t('admin.settings.betaPolicy.scopeAll') },
{ value: 'oauth', label: t('admin.settings.betaPolicy.scopeOAuth') }, { value: 'oauth', label: t('admin.settings.betaPolicy.scopeOAuth') },
{ value: 'apikey', label: t('admin.settings.betaPolicy.scopeAPIKey') } { value: 'apikey', label: t('admin.settings.betaPolicy.scopeAPIKey') },
{ value: 'bedrock', label: t('admin.settings.betaPolicy.scopeBedrock') }
]) ])
// Beta Policy 方法 // Beta Policy 方法
@@ -2359,9 +2392,38 @@ onMounted(() => {
} }
/* ============ Settings Tab Navigation ============ */ /* ============ Settings Tab Navigation ============ */
/* Scroll container: thin scrollbar on PC, auto-hide on mobile */
.settings-tabs-scroll {
scrollbar-width: thin;
scrollbar-color: transparent transparent;
}
.settings-tabs-scroll:hover {
scrollbar-color: rgb(0 0 0 / 0.15) transparent;
}
:root.dark .settings-tabs-scroll:hover {
scrollbar-color: rgb(255 255 255 / 0.2) transparent;
}
.settings-tabs-scroll::-webkit-scrollbar {
height: 3px;
}
.settings-tabs-scroll::-webkit-scrollbar-track {
background: transparent;
}
.settings-tabs-scroll::-webkit-scrollbar-thumb {
background: transparent;
border-radius: 3px;
}
.settings-tabs-scroll:hover::-webkit-scrollbar-thumb {
background: rgb(0 0 0 / 0.15);
}
:root.dark .settings-tabs-scroll:hover::-webkit-scrollbar-thumb {
background: rgb(255 255 255 / 0.2);
}
.settings-tabs { .settings-tabs {
@apply inline-flex min-w-full gap-1 rounded-2xl @apply inline-flex min-w-full gap-0.5 rounded-2xl
border border-gray-100 bg-white/80 p-1.5 backdrop-blur-sm border border-gray-100 bg-white/80 p-1 backdrop-blur-sm
dark:border-dark-700/50 dark:bg-dark-800/80; dark:border-dark-700/50 dark:bg-dark-800/80;
box-shadow: 0 1px 3px rgb(0 0 0 / 0.04), 0 1px 2px rgb(0 0 0 / 0.02); box-shadow: 0 1px 3px rgb(0 0 0 / 0.04), 0 1px 2px rgb(0 0 0 / 0.02);
} }
@@ -2373,8 +2435,8 @@ onMounted(() => {
} }
.settings-tab { .settings-tab {
@apply relative flex flex-1 items-center justify-center gap-2 @apply relative flex flex-1 items-center justify-center gap-1.5
whitespace-nowrap rounded-xl px-4 py-2.5 whitespace-nowrap rounded-xl px-2.5 py-2
text-sm font-medium text-sm font-medium
text-gray-500 dark:text-dark-400 text-gray-500 dark:text-dark-400
transition-all duration-200 ease-out; transition-all duration-200 ease-out;
@@ -2401,7 +2463,7 @@ onMounted(() => {
} }
.settings-tab-icon { .settings-tab-icon {
@apply flex h-7 w-7 items-center justify-center rounded-lg @apply flex h-6 w-6 items-center justify-center rounded-lg
transition-all duration-200; transition-all duration-200;
} }

View File

@@ -85,7 +85,7 @@
</div> </div>
<!-- Row: OpenAI Token Stats --> <!-- Row: OpenAI Token Stats -->
<div v-if="opsEnabled && !(loading && !hasLoadedOnce)" class="grid grid-cols-1 gap-6"> <div v-if="opsEnabled && showOpenAITokenStats && !(loading && !hasLoadedOnce)" class="grid grid-cols-1 gap-6">
<OpsOpenAITokenStatsCard <OpsOpenAITokenStatsCard
:platform-filter="platform" :platform-filter="platform"
:group-id-filter="groupId" :group-id-filter="groupId"
@@ -94,7 +94,7 @@
</div> </div>
<!-- Alert Events --> <!-- Alert Events -->
<OpsAlertEventsCard v-if="opsEnabled && !(loading && !hasLoadedOnce)" /> <OpsAlertEventsCard v-if="opsEnabled && showAlertEvents && !(loading && !hasLoadedOnce)" />
<!-- System Logs --> <!-- System Logs -->
<OpsSystemLogTable <OpsSystemLogTable
@@ -381,6 +381,8 @@ const showSettingsDialog = ref(false)
const showAlertRulesCard = ref(false) const showAlertRulesCard = ref(false)
// Auto refresh settings // Auto refresh settings
const showAlertEvents = ref(true)
const showOpenAITokenStats = ref(false)
const autoRefreshEnabled = ref(false) const autoRefreshEnabled = ref(false)
const autoRefreshIntervalMs = ref(30000) // default 30 seconds const autoRefreshIntervalMs = ref(30000) // default 30 seconds
const autoRefreshCountdown = ref(0) const autoRefreshCountdown = ref(0)
@@ -408,15 +410,22 @@ const { pause: pauseCountdown, resume: resumeCountdown } = useIntervalFn(
{ immediate: false } { immediate: false }
) )
// Load auto refresh settings from backend // Load ops dashboard presentation settings from backend.
async function loadAutoRefreshSettings() { async function loadDashboardAdvancedSettings() {
try { try {
const settings = await opsAPI.getAdvancedSettings() const settings = await opsAPI.getAdvancedSettings()
showAlertEvents.value = settings.display_alert_events
showOpenAITokenStats.value = settings.display_openai_token_stats
autoRefreshEnabled.value = settings.auto_refresh_enabled autoRefreshEnabled.value = settings.auto_refresh_enabled
autoRefreshIntervalMs.value = settings.auto_refresh_interval_seconds * 1000 autoRefreshIntervalMs.value = settings.auto_refresh_interval_seconds * 1000
autoRefreshCountdown.value = settings.auto_refresh_interval_seconds autoRefreshCountdown.value = settings.auto_refresh_interval_seconds
} catch (err) { } catch (err) {
console.error('[OpsDashboard] Failed to load auto refresh settings', err) console.error('[OpsDashboard] Failed to load dashboard advanced settings', err)
showAlertEvents.value = true
showOpenAITokenStats.value = false
autoRefreshEnabled.value = false
autoRefreshIntervalMs.value = 30000
autoRefreshCountdown.value = 0
} }
} }
@@ -464,7 +473,8 @@ function onCustomTimeRangeChange(startTime: string, endTime: string) {
customEndTime.value = endTime customEndTime.value = endTime
} }
function onSettingsSaved() { async function onSettingsSaved() {
await loadDashboardAdvancedSettings()
loadThresholds() loadThresholds()
fetchData() fetchData()
} }
@@ -774,7 +784,7 @@ onMounted(async () => {
loadThresholds() loadThresholds()
// Load auto refresh settings // Load auto refresh settings
await loadAutoRefreshSettings() await loadDashboardAdvancedSettings()
if (opsEnabled.value) { if (opsEnabled.value) {
await fetchData() await fetchData()
@@ -816,7 +826,7 @@ watch(autoRefreshEnabled, (enabled) => {
// Reload auto refresh settings after settings dialog is closed // Reload auto refresh settings after settings dialog is closed
watch(showSettingsDialog, async (show) => { watch(showSettingsDialog, async (show) => {
if (!show) { if (!show) {
await loadAutoRefreshSettings() await loadDashboardAdvancedSettings()
} }
}) })
</script> </script>

View File

@@ -208,35 +208,39 @@ function onNextPage() {
:description="t('admin.ops.openaiTokenStats.empty')" :description="t('admin.ops.openaiTokenStats.empty')"
/> />
<div v-else class="overflow-x-auto"> <div v-else class="space-y-3">
<table class="min-w-full text-left text-xs md:text-sm"> <div class="overflow-hidden rounded-xl border border-gray-200 dark:border-dark-700">
<thead> <div class="max-h-[420px] overflow-auto">
<tr class="border-b border-gray-200 text-gray-500 dark:border-dark-700 dark:text-gray-400"> <table class="min-w-full text-left text-xs md:text-sm">
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.model') }}</th> <thead class="sticky top-0 z-10 bg-white dark:bg-dark-800">
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.requestCount') }}</th> <tr class="border-b border-gray-200 text-gray-500 dark:border-dark-700 dark:text-gray-400">
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.avgTokensPerSec') }}</th> <th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.model') }}</th>
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.avgFirstTokenMs') }}</th> <th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.requestCount') }}</th>
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.totalOutputTokens') }}</th> <th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.avgTokensPerSec') }}</th>
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.avgDurationMs') }}</th> <th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.avgFirstTokenMs') }}</th>
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.requestsWithFirstToken') }}</th> <th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.totalOutputTokens') }}</th>
</tr> <th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.avgDurationMs') }}</th>
</thead> <th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.requestsWithFirstToken') }}</th>
<tbody> </tr>
<tr </thead>
v-for="row in items" <tbody>
:key="row.model" <tr
class="border-b border-gray-100 text-gray-700 dark:border-dark-800 dark:text-gray-200" v-for="row in items"
> :key="row.model"
<td class="px-2 py-2 font-medium">{{ row.model }}</td> class="border-b border-gray-100 text-gray-700 last:border-b-0 dark:border-dark-800 dark:text-gray-200"
<td class="px-2 py-2">{{ formatInt(row.request_count) }}</td> >
<td class="px-2 py-2">{{ formatRate(row.avg_tokens_per_sec) }}</td> <td class="px-2 py-2 font-medium">{{ row.model }}</td>
<td class="px-2 py-2">{{ formatRate(row.avg_first_token_ms) }}</td> <td class="px-2 py-2">{{ formatInt(row.request_count) }}</td>
<td class="px-2 py-2">{{ formatInt(row.total_output_tokens) }}</td> <td class="px-2 py-2">{{ formatRate(row.avg_tokens_per_sec) }}</td>
<td class="px-2 py-2">{{ formatInt(row.avg_duration_ms) }}</td> <td class="px-2 py-2">{{ formatRate(row.avg_first_token_ms) }}</td>
<td class="px-2 py-2">{{ formatInt(row.requests_with_first_token) }}</td> <td class="px-2 py-2">{{ formatInt(row.total_output_tokens) }}</td>
</tr> <td class="px-2 py-2">{{ formatInt(row.avg_duration_ms) }}</td>
</tbody> <td class="px-2 py-2">{{ formatInt(row.requests_with_first_token) }}</td>
</table> </tr>
</tbody>
</table>
</div>
</div>
<div v-if="viewMode === 'topn'" class="mt-3 text-xs text-gray-500 dark:text-gray-400"> <div v-if="viewMode === 'topn'" class="mt-3 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.ops.openaiTokenStats.totalModels', { total }) }} {{ t('admin.ops.openaiTokenStats.totalModels', { total }) }}
</div> </div>

View File

@@ -131,15 +131,7 @@ const validation = computed(() => {
} }
} }
// 验证邮件配置 // 邮件配置: 启用但无收件人时不阻断保存, 保存时会自动禁用
if (emailConfig.value) {
if (emailConfig.value.alert.enabled && emailConfig.value.alert.recipients.length === 0) {
errors.push(t('admin.ops.email.validation.alertRecipientsRequired'))
}
if (emailConfig.value.report.enabled && emailConfig.value.report.recipients.length === 0) {
errors.push(t('admin.ops.email.validation.reportRecipientsRequired'))
}
}
// 验证高级设置 // 验证高级设置
if (advancedSettings.value) { if (advancedSettings.value) {
@@ -181,6 +173,15 @@ async function saveAllSettings() {
saving.value = true saving.value = true
try { try {
// 无收件人时自动禁用邮件通知
if (emailConfig.value) {
if (emailConfig.value.alert.enabled && emailConfig.value.alert.recipients.length === 0) {
emailConfig.value.alert.enabled = false
}
if (emailConfig.value.report.enabled && emailConfig.value.report.recipients.length === 0) {
emailConfig.value.report.enabled = false
}
}
await Promise.all([ await Promise.all([
runtimeSettings.value ? opsAPI.updateAlertRuntimeSettings(runtimeSettings.value) : Promise.resolve(), runtimeSettings.value ? opsAPI.updateAlertRuntimeSettings(runtimeSettings.value) : Promise.resolve(),
emailConfig.value ? opsAPI.updateEmailNotificationConfig(emailConfig.value) : Promise.resolve(), emailConfig.value ? opsAPI.updateEmailNotificationConfig(emailConfig.value) : Promise.resolve(),
@@ -543,6 +544,31 @@ async function saveAllSettings() {
/> />
</div> </div>
</div> </div>
<!-- Dashboard Cards -->
<div class="space-y-3">
<h5 class="text-xs font-semibold text-gray-700 dark:text-gray-300">{{ t('admin.ops.settings.dashboardCards') }}</h5>
<div class="flex items-center justify-between">
<div>
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">{{ t('admin.ops.settings.displayAlertEvents') }}</label>
<p class="mt-1 text-xs text-gray-500">
{{ t('admin.ops.settings.displayAlertEventsHint') }}
</p>
</div>
<Toggle v-model="advancedSettings.display_alert_events" />
</div>
<div class="flex items-center justify-between">
<div>
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">{{ t('admin.ops.settings.displayOpenAITokenStats') }}</label>
<p class="mt-1 text-xs text-gray-500">
{{ t('admin.ops.settings.displayOpenAITokenStatsHint') }}
</p>
</div>
<Toggle v-model="advancedSettings.display_openai_token_stats" />
</div>
</div>
</div> </div>
</details> </details>
</div> </div>

Some files were not shown because too many files have changed in this diff Show More