mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-06 08:20:23 +08:00
Compare commits
42 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6da5fa01b9 | ||
|
|
616930f9d3 | ||
|
|
b9c31fa7c4 | ||
|
|
17b339972c | ||
|
|
39f8bd91b9 | ||
|
|
aa4e37d085 | ||
|
|
f59b66b7d4 | ||
|
|
8f0ea7a02d | ||
|
|
a1dc00890e | ||
|
|
dfbcc363d1 | ||
|
|
1047f973d5 | ||
|
|
e32977dd73 | ||
|
|
b5f78ec1e8 | ||
|
|
e0f290fdc8 | ||
|
|
fc00a4e3b2 | ||
|
|
db1f6ded88 | ||
|
|
4644af2ccc | ||
|
|
2e3e8687e1 | ||
|
|
ca42a45802 | ||
|
|
9350ecb62b | ||
|
|
a4a026e8da | ||
|
|
342fd03e72 | ||
|
|
e3f1fd9b63 | ||
|
|
e4a4dfd038 | ||
|
|
a377e99088 | ||
|
|
1d3d7a3033 | ||
|
|
e7086cb3a3 | ||
|
|
01ef7340aa | ||
|
|
1c960d22c1 | ||
|
|
ece0606fed | ||
|
|
2666422b99 | ||
|
|
4e8615f276 | ||
|
|
45456fa24c | ||
|
|
6344fa2a86 | ||
|
|
29b0e4a8a5 | ||
|
|
f7177be3b6 | ||
|
|
875b417fde | ||
|
|
2573107b32 | ||
|
|
5b85005945 | ||
|
|
53ad1645cf | ||
|
|
af9c4a7dd0 | ||
|
|
6826149a8f |
20
Dockerfile
20
Dockerfile
@@ -9,6 +9,7 @@
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
|
||||
@@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \
|
||||
./cmd/server
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 3: Final Runtime Image
|
||||
# Stage 3: PostgreSQL Client (version-matched with docker-compose)
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 4: Final Runtime Image
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${ALPINE_IMAGE}
|
||||
|
||||
@@ -86,8 +92,20 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||
RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Copy pg_dump and psql from the same postgres image used in docker-compose
|
||||
# This ensures version consistency between backup tools and the database server
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
|
||||
@@ -94,6 +94,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -230,6 +231,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -146,6 +146,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||
dataManagementService := service.NewDataManagementService()
|
||||
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
||||
backupObjectStoreFactory := repository.NewS3BackupStoreFactory()
|
||||
dbDumper := repository.NewPgDumper(configConfig)
|
||||
backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper)
|
||||
backupHandler := admin.NewBackupHandler(backupService, userService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||
@@ -201,7 +205,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
@@ -232,7 +236,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -285,6 +289,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -420,6 +425,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -75,6 +75,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
antigravityOAuthSvc,
|
||||
nil, // openAIGateway
|
||||
nil, // scheduledTestRunner
|
||||
nil, // backupSvc
|
||||
)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
|
||||
@@ -27,12 +27,11 @@ const (
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock)
|
||||
AccountTypeBedrockAPIKey = "bedrock-apikey" // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock)
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
|
||||
@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -1718,13 +1718,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
// Handle OpenAI accounts
|
||||
if account.IsOpenAI() {
|
||||
// For OAuth accounts: return default OpenAI models
|
||||
if account.IsOAuth() {
|
||||
// OpenAI 自动透传会绕过常规模型改写,测试/模型列表也应回落到默认模型集。
|
||||
if account.IsOpenAIPassthroughEnabled() {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
// For API Key accounts: check model_mapping
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
response.Success(c, openai.DefaultModels)
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type availableModelsAdminService struct {
|
||||
*stubAdminService
|
||||
account service.Account
|
||||
}
|
||||
|
||||
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
|
||||
if s.account.ID == id {
|
||||
acc := s.account
|
||||
return &acc, nil
|
||||
}
|
||||
return s.stubAdminService.GetAccount(context.Background(), id)
|
||||
}
|
||||
|
||||
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 42,
|
||||
Name: "openai-oauth",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Len(t, resp.Data, 1)
|
||||
require.Equal(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 43,
|
||||
Name: "openai-oauth-passthrough",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"openai_passthrough": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupAvailableModelsRouter(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
204
backend/internal/handler/admin/backup_handler.go
Normal file
204
backend/internal/handler/admin/backup_handler.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type BackupHandler struct {
|
||||
backupService *service.BackupService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
|
||||
return &BackupHandler{
|
||||
backupService: backupService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// ─── S3 配置 ───
|
||||
|
||||
func (h *BackupHandler) GetS3Config(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetS3Config(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
err := h.backupService.TestS3Connection(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.Success(c, gin.H{"ok": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
|
||||
}
|
||||
|
||||
// ─── 定时备份 ───
|
||||
|
||||
func (h *BackupHandler) GetSchedule(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetSchedule(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
|
||||
var req service.BackupScheduleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
// ─── 备份操作 ───
|
||||
|
||||
type CreateBackupRequest struct {
|
||||
ExpireDays *int `json:"expire_days"` // nil=使用默认值14,0=永不过期
|
||||
}
|
||||
|
||||
func (h *BackupHandler) CreateBackup(c *gin.Context) {
|
||||
var req CreateBackupRequest
|
||||
_ = c.ShouldBindJSON(&req) // 允许空 body
|
||||
|
||||
expireDays := 14 // 默认14天过期
|
||||
if req.ExpireDays != nil {
|
||||
expireDays = *req.ExpireDays
|
||||
}
|
||||
|
||||
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
||||
records, err := h.backupService.ListBackups(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if records == nil {
|
||||
records = []service.BackupRecord{}
|
||||
}
|
||||
response.Success(c, gin.H{"items": records})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"url": url})
|
||||
}
|
||||
|
||||
// ─── 恢复操作(需要重新输入管理员密码) ───
|
||||
|
||||
type RestoreBackupRequest struct {
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req RestoreBackupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "password is required for restore operation")
|
||||
return
|
||||
}
|
||||
|
||||
// 从上下文获取当前管理员用户 ID
|
||||
sub, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取管理员用户并验证密码
|
||||
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if !user.CheckPassword(req.Password) {
|
||||
response.BadRequest(c, "incorrect admin password")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"restored": true})
|
||||
}
|
||||
@@ -125,6 +125,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -199,6 +200,9 @@ type UpdateSettingsRequest struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -473,6 +477,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: req.BackendModeEnabled,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -571,6 +576,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: updatedSettings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -725,6 +731,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
||||
changed = append(changed, "allow_ungrouped_key_scheduling")
|
||||
}
|
||||
if before.BackendModeEnabled != after.BackendModeEnabled {
|
||||
changed = append(changed, "backend_mode_enabled")
|
||||
}
|
||||
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
||||
changed = append(changed, "purchase_subscription_enabled")
|
||||
}
|
||||
|
||||
@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
// Get the user
|
||||
// Get the user (before session deletion so we can check backend mode)
|
||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login (check BEFORE deleting session)
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session (only after all checks pass)
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: block non-admin token refresh
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, RefreshTokenResponse{
|
||||
AccessToken: tokenPair.AccessToken,
|
||||
RefreshToken: tokenPair.RefreshToken,
|
||||
ExpiresIn: tokenPair.ExpiresIn,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresIn: result.ExpiresIn,
|
||||
TokenType: "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -264,8 +264,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
// 提取账号配额限制(apikey / bedrock 类型有效)
|
||||
if a.IsAPIKeyOrBedrock() {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
used := a.GetQuotaUsed()
|
||||
@@ -281,6 +281,31 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
used := a.GetQuotaWeeklyUsed()
|
||||
out.QuotaWeeklyUsed = &used
|
||||
}
|
||||
// 固定时间重置配置
|
||||
if mode := a.GetQuotaDailyResetMode(); mode == "fixed" {
|
||||
out.QuotaDailyResetMode = &mode
|
||||
hour := a.GetQuotaDailyResetHour()
|
||||
out.QuotaDailyResetHour = &hour
|
||||
}
|
||||
if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" {
|
||||
out.QuotaWeeklyResetMode = &mode
|
||||
day := a.GetQuotaWeeklyResetDay()
|
||||
out.QuotaWeeklyResetDay = &day
|
||||
hour := a.GetQuotaWeeklyResetHour()
|
||||
out.QuotaWeeklyResetHour = &hour
|
||||
}
|
||||
if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||
tz := a.GetQuotaResetTimezone()
|
||||
out.QuotaResetTimezone = &tz
|
||||
}
|
||||
if a.Extra != nil {
|
||||
if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaDailyResetAt = &v
|
||||
}
|
||||
if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaWeeklyResetAt = &v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
@@ -81,6 +81,9 @@ type SystemSettings struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -111,6 +114,7 @@ type PublicSettings struct {
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
|
||||
@@ -203,6 +203,16 @@ type Account struct {
|
||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||
|
||||
// 配额固定时间重置配置
|
||||
QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"`
|
||||
QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"`
|
||||
QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"`
|
||||
QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"`
|
||||
QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"`
|
||||
QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"`
|
||||
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
|
||||
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ type AdminHandlers struct {
|
||||
Account *admin.AccountHandler
|
||||
Announcement *admin.AnnouncementHandler
|
||||
DataManagement *admin.DataManagementHandler
|
||||
Backup *admin.BackupHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
|
||||
@@ -181,13 +181,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
|
||||
@@ -655,14 +655,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultMappedModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
// 如果使用了降级模型调度,强制使用降级模型
|
||||
if fallbackModel := c.GetString("openai_messages_fallback_model"); fallbackModel != "" {
|
||||
defaultMappedModel = fallbackModel
|
||||
}
|
||||
// 仅在调度时实际触发了降级(原模型无可用账号、改用默认模型重试成功)时,
|
||||
// 才将降级模型传给 Forward 层做模型替换;否则保持用户请求的原始模型。
|
||||
defaultMappedModel := c.GetString("openai_messages_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
|
||||
@@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
|
||||
accountHandler *admin.AccountHandler,
|
||||
announcementHandler *admin.AnnouncementHandler,
|
||||
dataManagementHandler *admin.DataManagementHandler,
|
||||
backupHandler *admin.BackupHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
|
||||
Account: accountHandler,
|
||||
Announcement: announcementHandler,
|
||||
DataManagement: dataManagementHandler,
|
||||
Backup: backupHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
@@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewAccountHandler,
|
||||
admin.NewAnnouncementHandler,
|
||||
admin.NewDataManagementHandler,
|
||||
admin.NewBackupHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
|
||||
@@ -19,6 +19,16 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
)
|
||||
|
||||
// ForbiddenError 表示上游返回 403 Forbidden
|
||||
type ForbiddenError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *ForbiddenError) Error() string {
|
||||
return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||
@@ -514,7 +524,20 @@ type ModelQuotaInfo struct {
|
||||
|
||||
// ModelInfo 模型信息
|
||||
type ModelInfo struct {
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
SupportsImages *bool `json:"supportsImages,omitempty"`
|
||||
SupportsThinking *bool `json:"supportsThinking,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
Recommended *bool `json:"recommended,omitempty"`
|
||||
MaxTokens *int `json:"maxTokens,omitempty"`
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"`
|
||||
}
|
||||
|
||||
// DeprecatedModelInfo 废弃模型转发信息
|
||||
type DeprecatedModelInfo struct {
|
||||
NewModelID string `json:"newModelId"`
|
||||
}
|
||||
|
||||
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
||||
@@ -524,7 +547,8 @@ type FetchAvailableModelsRequest struct {
|
||||
|
||||
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
||||
type FetchAvailableModelsResponse struct {
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"`
|
||||
}
|
||||
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
@@ -573,6 +597,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
return nil, nil, &ForbiddenError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBodyBytes),
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
@@ -105,6 +105,7 @@ func TestAnthropicToResponses_ToolUse(t *testing.T) {
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Equal(t, "fc_call_1", items[2].CallID)
|
||||
assert.Empty(t, items[2].ID)
|
||||
assert.Equal(t, "function_call_output", items[3].Type)
|
||||
assert.Equal(t, "fc_call_1", items[3].CallID)
|
||||
assert.Equal(t, "Sunny, 72°F", items[3].Output)
|
||||
|
||||
@@ -277,7 +277,6 @@ func anthropicAssistantToResponses(raw json.RawMessage) ([]ResponsesInputItem, e
|
||||
CallID: fcID,
|
||||
Name: b.Name,
|
||||
Arguments: args,
|
||||
ID: fcID,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -99,6 +99,7 @@ func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
|
||||
// Check function_call item
|
||||
assert.Equal(t, "function_call", items[1].Type)
|
||||
assert.Equal(t, "call_1", items[1].CallID)
|
||||
assert.Empty(t, items[1].ID)
|
||||
assert.Equal(t, "ping", items[1].Name)
|
||||
|
||||
// Check function_call_output item
|
||||
@@ -252,6 +253,55 @@ func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T)
|
||||
assert.Equal(t, "user", items[0].Role)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
assert.Equal(t, "function_call", items[2].Type)
|
||||
assert.Empty(t, items[2].ID)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantArrayContentPreserved(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"text","text":"A"},{"type":"text","text":"B"}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
assert.Equal(t, "assistant", items[1].Role)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Equal(t, "AB", parts[0].Text)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
{Role: "assistant", Content: json.RawMessage(`[{"type":"thinking","thinking":"internal plan"},{"type":"text","text":"final answer"}]`)},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Contains(t, parts[0].Text, "<thinking>internal plan</thinking>")
|
||||
assert.Contains(t, parts[0].Text, "final answer")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -344,8 +394,8 @@ func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
|
||||
|
||||
var content string
|
||||
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
|
||||
// Reasoning summary is prepended to text
|
||||
assert.Equal(t, "I thought about it.The answer is 42.", content)
|
||||
assert.Equal(t, "The answer is 42.", content)
|
||||
assert.Equal(t, "I thought about it.", chat.Choices[0].Message.ReasoningContent)
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
|
||||
@@ -582,8 +632,35 @@ func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
|
||||
Delta: "Thinking...",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.done",
|
||||
}, state)
|
||||
require.Len(t, chunks, 0)
|
||||
}
|
||||
|
||||
func TestResponsesEventToChatChunks_ReasoningThenTextAutoCloseTag(t *testing.T) {
|
||||
state := NewResponsesEventToChatState()
|
||||
state.Model = "gpt-4o"
|
||||
state.SentRole = true
|
||||
|
||||
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.delta",
|
||||
Delta: "plan",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
assert.Equal(t, "plan", *chunks[0].Choices[0].Delta.ReasoningContent)
|
||||
|
||||
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
|
||||
Type: "response.output_text.delta",
|
||||
Delta: "answer",
|
||||
}, state)
|
||||
require.Len(t, chunks, 1)
|
||||
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content)
|
||||
assert.Equal(t, "answer", *chunks[0].Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
func TestFinalizeResponsesChatStream(t *testing.T) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package apicompat
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ChatCompletionsToResponses converts a Chat Completions request into a
|
||||
@@ -174,8 +175,11 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
|
||||
// Emit assistant message with output_text if content is non-empty.
|
||||
if len(m.Content) > 0 {
|
||||
var s string
|
||||
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
|
||||
s, err := parseAssistantContent(m.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s != "" {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
@@ -196,13 +200,82 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
CallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: args,
|
||||
ID: tc.ID,
|
||||
})
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// parseAssistantContent returns assistant content as plain text.
|
||||
//
|
||||
// Supported formats:
|
||||
// - JSON string
|
||||
// - JSON array of typed parts (e.g. [{"type":"text","text":"..."}])
|
||||
//
|
||||
// For structured thinking/reasoning parts, it preserves semantics by wrapping
|
||||
// the text in explicit tags so downstream can still distinguish it from normal text.
|
||||
func parseAssistantContent(raw json.RawMessage) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var parts []map[string]any
|
||||
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||
// Keep compatibility with prior behavior: unsupported assistant content
|
||||
// formats are ignored instead of failing the whole request conversion.
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
write := func(v string) error {
|
||||
_, err := b.WriteString(v)
|
||||
return err
|
||||
}
|
||||
for _, p := range parts {
|
||||
typ, _ := p["type"].(string)
|
||||
text, _ := p["text"].(string)
|
||||
thinking, _ := p["thinking"].(string)
|
||||
|
||||
switch typ {
|
||||
case "thinking", "reasoning":
|
||||
if thinking != "" {
|
||||
if err := write("<thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write(thinking); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write("</thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else if text != "" {
|
||||
if err := write("<thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write(text); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := write("</thinking>"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
default:
|
||||
if text != "" {
|
||||
if err := write(text); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
// chatToolToResponses converts a tool result message (role=tool) into a
|
||||
// function_call_output item.
|
||||
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
|
||||
@@ -29,6 +29,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
||||
}
|
||||
|
||||
var contentText string
|
||||
var reasoningText string
|
||||
var toolCalls []ChatToolCall
|
||||
|
||||
for _, item := range resp.Output {
|
||||
@@ -51,7 +52,7 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
||||
case "reasoning":
|
||||
for _, s := range item.Summary {
|
||||
if s.Type == "summary_text" && s.Text != "" {
|
||||
contentText += s.Text
|
||||
reasoningText += s.Text
|
||||
}
|
||||
}
|
||||
case "web_search_call":
|
||||
@@ -67,6 +68,9 @@ func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatComp
|
||||
raw, _ := json.Marshal(contentText)
|
||||
msg.Content = raw
|
||||
}
|
||||
if reasoningText != "" {
|
||||
msg.ReasoningContent = reasoningText
|
||||
}
|
||||
|
||||
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
|
||||
|
||||
@@ -153,6 +157,8 @@ func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEvent
|
||||
return resToChatHandleFuncArgsDelta(evt, state)
|
||||
case "response.reasoning_summary_text.delta":
|
||||
return resToChatHandleReasoningDelta(evt, state)
|
||||
case "response.reasoning_summary_text.done":
|
||||
return nil
|
||||
case "response.completed", "response.incomplete", "response.failed":
|
||||
return resToChatHandleCompleted(evt, state)
|
||||
default:
|
||||
@@ -276,8 +282,8 @@ func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEv
|
||||
if evt.Delta == "" {
|
||||
return nil
|
||||
}
|
||||
content := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
|
||||
reasoning := evt.Delta
|
||||
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ReasoningContent: &reasoning})}
|
||||
}
|
||||
|
||||
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
|
||||
|
||||
@@ -361,11 +361,12 @@ type ChatStreamOptions struct {
|
||||
|
||||
// ChatMessage is a single message in the Chat Completions conversation.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
|
||||
// Legacy function calling
|
||||
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
|
||||
@@ -466,9 +467,10 @@ type ChatChunkChoice struct {
|
||||
|
||||
// ChatDelta carries incremental content in a streaming chunk.
|
||||
type ChatDelta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -397,9 +397,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
}
|
||||
// 普通账号编辑(如 model_mapping / credentials)也需要立即刷新单账号快照,
|
||||
// 否则网关在 outbox worker 延迟或异常时仍可能读到旧配置。
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1727,8 +1727,96 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
||||
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
||||
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
||||
|
||||
// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired.
|
||||
// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes.
|
||||
const dailyExpiredExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||
THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||
ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
END
|
||||
)`
|
||||
|
||||
// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired.
|
||||
const weeklyExpiredExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||
THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||
ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
END
|
||||
)`
|
||||
|
||||
// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs.
|
||||
// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour.
|
||||
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||
const nextDailyResetAtExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||
THEN to_char((
|
||||
-- Compute today's reset point in the configured timezone, then pick next future one
|
||||
CASE WHEN NOW() >= (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- NOW() is at or past today's reset point → next reset is tomorrow
|
||||
THEN (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ '1 day'::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- NOW() is before today's reset point → next reset is today
|
||||
ELSE (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
END
|
||||
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||
ELSE NULL END
|
||||
)`
|
||||
|
||||
// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs.
|
||||
// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour.
|
||||
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||
const nextWeeklyResetAtExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||
THEN to_char((
|
||||
-- Compute this week's reset point in the configured timezone
|
||||
-- Step 1: get today's date at reset hour in configured tz
|
||||
-- Step 2: compute days forward to target weekday
|
||||
-- Step 3: if same day but past reset hour, advance 7 days
|
||||
CASE
|
||||
WHEN (
|
||||
-- days_forward = (target_day - current_day + 7) % 7
|
||||
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||
+ 7) % 7
|
||||
) = 0 AND NOW() >= (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- Same weekday and past reset hour → next week
|
||||
THEN (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ '7 days'::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
ELSE (
|
||||
-- Advance to target weekday this week (or next if days_forward > 0)
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ ((
|
||||
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||
+ 7) % 7
|
||||
) || ' days')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
END
|
||||
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||
ELSE NULL END
|
||||
)`
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
||||
// 日/周额度在周期过期时自动重置为 0 再递增。
|
||||
// 支持滚动窗口(rolling)和固定时间(fixed)两种重置模式。
|
||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
rows, err := r.sql.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
@@ -1739,31 +1827,35 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
CASE WHEN `+dailyExpiredExpr+`
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
CASE WHEN `+dailyExpiredExpr+`
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
-- 固定模式重置时更新下次重置时间
|
||||
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
|
||||
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
|
||||
ELSE '{}'::jsonb END
|
||||
ELSE '{}'::jsonb END
|
||||
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
CASE WHEN `+weeklyExpiredExpr+`
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
CASE WHEN `+weeklyExpiredExpr+`
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
-- 固定模式重置时更新下次重置时间
|
||||
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
|
||||
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
|
||||
ELSE '{}'::jsonb END
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
@@ -1796,12 +1888,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
}
|
||||
|
||||
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
||||
// 保留固定重置模式的配置字段(quota_daily_reset_mode 等),仅清零用量和窗口起始时间
|
||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
||||
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
|
||||
) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
|
||||
@@ -142,6 +142,35 @@ func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnCredentialsChange() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "sync-credentials-update",
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
account.Credentials = map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.2",
|
||||
},
|
||||
}
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
mapping, ok := cacheRecorder.setAccounts[0].Credentials["model_mapping"].(map[string]any)
|
||||
s.Require().True(ok)
|
||||
s.Require().Equal("gpt-5.2", mapping["gpt-5"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||
|
||||
|
||||
98
backend/internal/repository/backup_pg_dumper.go
Normal file
98
backend/internal/repository/backup_pg_dumper.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// PgDumper implements service.DBDumper using pg_dump/psql
|
||||
type PgDumper struct {
|
||||
cfg *config.DatabaseConfig
|
||||
}
|
||||
|
||||
// NewPgDumper creates a new PgDumper
|
||||
func NewPgDumper(cfg *config.Config) service.DBDumper {
|
||||
return &PgDumper{cfg: &cfg.Database}
|
||||
}
|
||||
|
||||
// Dump executes pg_dump and returns a streaming reader of the output
|
||||
func (d *PgDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--no-owner",
|
||||
"--no-acl",
|
||||
"--clean",
|
||||
"--if-exists",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_dump", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("start pg_dump: %w", err)
|
||||
}
|
||||
|
||||
// 返回一个 ReadCloser:读 stdout,关闭时等待进程退出
|
||||
return &cmdReadCloser{ReadCloser: stdout, cmd: cmd}, nil
|
||||
}
|
||||
|
||||
// Restore executes psql to restore from a streaming reader
|
||||
func (d *PgDumper) Restore(ctx context.Context, data io.Reader) error {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--single-transaction",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
cmd.Stdin = data
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%v: %s", err, string(output))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cmdReadCloser wraps a command stdout pipe and waits for the process on Close
|
||||
type cmdReadCloser struct {
|
||||
io.ReadCloser
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func (c *cmdReadCloser) Close() error {
|
||||
// Close the pipe first
|
||||
_ = c.ReadCloser.Close()
|
||||
// Wait for the process to exit
|
||||
if err := c.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("pg_dump exited with error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
116
backend/internal/repository/backup_s3_store.go
Normal file
116
backend/internal/repository/backup_s3_store.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage
|
||||
type S3BackupStore struct {
|
||||
client *s3.Client
|
||||
bucket string
|
||||
}
|
||||
|
||||
// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores
|
||||
func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
|
||||
return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) {
|
||||
region := cfg.Region
|
||||
if region == "" {
|
||||
region = "auto" // Cloudflare R2 默认 region
|
||||
}
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||||
awsconfig.WithRegion(region),
|
||||
awsconfig.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load aws config: %w", err)
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||
if cfg.Endpoint != "" {
|
||||
o.BaseEndpoint = &cfg.Endpoint
|
||||
}
|
||||
if cfg.ForcePathStyle {
|
||||
o.UsePathStyle = true
|
||||
}
|
||||
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
|
||||
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||
})
|
||||
|
||||
return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
|
||||
// 读取全部内容以获取大小(S3 PutObject 需要知道内容长度)
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
Body: bytes.NewReader(data),
|
||||
ContentType: &contentType,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S3 PutObject: %w", err)
|
||||
}
|
||||
return int64(len(data)), nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("S3 GetObject: %w", err)
|
||||
}
|
||||
return result.Body, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Delete(ctx context.Context, key string) error {
|
||||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) {
|
||||
presignClient := s3.NewPresignClient(s.client)
|
||||
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
}, s3.WithPresignExpires(expiry))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) HeadBucket(ctx context.Context) error {
|
||||
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: &s.bucket,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("S3 HeadBucket failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -100,6 +100,10 @@ var ProviderSet = wire.NewSet(
|
||||
// Encryptors
|
||||
NewAESEncryptor,
|
||||
|
||||
// Backup infrastructure
|
||||
NewPgDumper,
|
||||
NewS3BackupStoreFactory,
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
ProvidePricingRemoteClient,
|
||||
|
||||
@@ -537,6 +537,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"purchase_subscription_url": "",
|
||||
"min_claude_code_version": "",
|
||||
"allow_ungrouped_key_scheduling": false,
|
||||
"backend_mode_enabled": false,
|
||||
"custom_menu_items": []
|
||||
}
|
||||
}`,
|
||||
|
||||
51
backend/internal/server/middleware/backend_mode_guard.go
Normal file
51
backend/internal/server/middleware/backend_mode_guard.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// BackendModeUserGuard blocks non-admin users from accessing user routes when backend mode is enabled.
|
||||
// Must be placed AFTER JWT auth middleware so that the user role is available in context.
|
||||
func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
role, _ := GetUserRoleFromContext(c)
|
||||
if role == "admin" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
response.Forbidden(c, "Backend mode is active. User self-service is disabled.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
|
||||
// Allows: login, login/2fa, logout, refresh (admin needs these).
|
||||
// Blocks: register, forgot-password, reset-password, OAuth, etc.
|
||||
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
path := c.Request.URL.Path
|
||||
// Allow login, 2FA, logout, refresh, public settings
|
||||
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
|
||||
for _, suffix := range allowedSuffixes {
|
||||
if strings.HasSuffix(path, suffix) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
239
backend/internal/server/middleware/backend_mode_guard_test.go
Normal file
239
backend/internal/server/middleware/backend_mode_guard_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type bmSettingRepo struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Get(_ context.Context, _ string) (*service.Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||
v, ok := r.values[key]
|
||||
if !ok {
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Set(_ context.Context, _, _ string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetMultiple(_ context.Context, _ []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
if r.values == nil {
|
||||
r.values = make(map[string]string, len(settings))
|
||||
}
|
||||
for key, value := range settings {
|
||||
r.values[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Delete(_ context.Context, _ string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func newBackendModeSettingService(t *testing.T, enabled string) *service.SettingService {
|
||||
t.Helper()
|
||||
|
||||
repo := &bmSettingRepo{
|
||||
values: map[string]string{
|
||||
service.SettingKeyBackendModeEnabled: enabled,
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{})
|
||||
require.NoError(t, svc.UpdateSettings(context.Background(), &service.SystemSettings{
|
||||
BackendModeEnabled: enabled == "true",
|
||||
}))
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
func stringPtr(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestBackendModeUserGuard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nilService bool
|
||||
enabled string
|
||||
role *string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "disabled_allows_all",
|
||||
enabled: "false",
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "nil_service_allows_all",
|
||||
nilService: true,
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_admin_allowed",
|
||||
enabled: "true",
|
||||
role: stringPtr("admin"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_user_blocked",
|
||||
enabled: "true",
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_no_role_blocked",
|
||||
enabled: "true",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_empty_role_blocked",
|
||||
enabled: "true",
|
||||
role: stringPtr(""),
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
if tc.role != nil {
|
||||
role := *tc.role
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(string(ContextKeyUserRole), role)
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
|
||||
var svc *service.SettingService
|
||||
if !tc.nilService {
|
||||
svc = newBackendModeSettingService(t, tc.enabled)
|
||||
}
|
||||
|
||||
r.Use(BackendModeUserGuard(svc))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tc.wantStatus, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendModeAuthGuard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nilService bool
|
||||
enabled string
|
||||
path string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "disabled_allows_all",
|
||||
enabled: "false",
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "nil_service_allows_all",
|
||||
nilService: true,
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_login",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/login",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_login_2fa",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/login/2fa",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_logout",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/logout",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_refresh",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/refresh",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_register",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_forgot_password",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/forgot-password",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
|
||||
var svc *service.SettingService
|
||||
if !tc.nilService {
|
||||
svc = newBackendModeSettingService(t, tc.enabled)
|
||||
}
|
||||
|
||||
r.Use(BackendModeAuthGuard(svc))
|
||||
r.Any("/*path", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tc.wantStatus, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -107,9 +107,9 @@ func registerRoutes(
|
||||
v1 := r.Group("/api/v1")
|
||||
|
||||
// 注册各模块路由
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
||||
}
|
||||
|
||||
@@ -58,6 +58,9 @@ func RegisterAdminRoutes(
|
||||
// 数据管理
|
||||
registerDataManagementRoutes(admin, h)
|
||||
|
||||
// 数据库备份恢复
|
||||
registerBackupRoutes(admin, h)
|
||||
|
||||
// 运维监控(Ops)
|
||||
registerOpsRoutes(admin, h)
|
||||
|
||||
@@ -440,6 +443,30 @@ func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerBackupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
backup := admin.Group("/backups")
|
||||
{
|
||||
// S3 存储配置
|
||||
backup.GET("/s3-config", h.Admin.Backup.GetS3Config)
|
||||
backup.PUT("/s3-config", h.Admin.Backup.UpdateS3Config)
|
||||
backup.POST("/s3-config/test", h.Admin.Backup.TestS3Connection)
|
||||
|
||||
// 定时备份配置
|
||||
backup.GET("/schedule", h.Admin.Backup.GetSchedule)
|
||||
backup.PUT("/schedule", h.Admin.Backup.UpdateSchedule)
|
||||
|
||||
// 备份操作
|
||||
backup.POST("", h.Admin.Backup.CreateBackup)
|
||||
backup.GET("", h.Admin.Backup.ListBackups)
|
||||
backup.GET("/:id", h.Admin.Backup.GetBackup)
|
||||
backup.DELETE("/:id", h.Admin.Backup.DeleteBackup)
|
||||
backup.GET("/:id/download-url", h.Admin.Backup.GetDownloadURL)
|
||||
|
||||
// 恢复操作
|
||||
backup.POST("/:id/restore", h.Admin.Backup.RestoreBackup)
|
||||
}
|
||||
}
|
||||
|
||||
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
system := admin.Group("/system")
|
||||
{
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -17,12 +18,14 @@ func RegisterAuthRoutes(
|
||||
h *handler.Handlers,
|
||||
jwtAuth servermiddleware.JWTAuthMiddleware,
|
||||
redisClient *redis.Client,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
// 创建速率限制器
|
||||
rateLimiter := middleware.NewRateLimiter(redisClient)
|
||||
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
auth.Use(servermiddleware.BackendModeAuthGuard(settingService))
|
||||
{
|
||||
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
|
||||
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
|
||||
@@ -78,6 +81,7 @@ func RegisterAuthRoutes(
|
||||
// 需要认证的当前用户信息
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(servermiddleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
||||
// 撤销所有会话(需要认证)
|
||||
|
||||
@@ -29,6 +29,7 @@ func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
|
||||
c.Next()
|
||||
}),
|
||||
redisClient,
|
||||
nil,
|
||||
)
|
||||
|
||||
return router
|
||||
|
||||
@@ -3,6 +3,7 @@ package routes
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -12,6 +13,7 @@ func RegisterSoraClientRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
if h.SoraClient == nil {
|
||||
return
|
||||
@@ -19,6 +21,7 @@ func RegisterSoraClientRoutes(
|
||||
|
||||
authenticated := v1.Group("/sora")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
authenticated.POST("/generate", h.SoraClient.Generate)
|
||||
authenticated.GET("/generations", h.SoraClient.ListGenerations)
|
||||
|
||||
@@ -3,6 +3,7 @@ package routes
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -12,9 +13,11 @@ func RegisterUserRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
// 用户接口
|
||||
user := authenticated.Group("/user")
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"hash/fnv"
|
||||
"reflect"
|
||||
"sort"
|
||||
@@ -522,16 +523,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
|
||||
// 如果未配置 mapping,返回原始模型名
|
||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
mappedModel, _ := a.ResolveMappedModel(requestedModel)
|
||||
return mappedModel
|
||||
}
|
||||
|
||||
// ResolveMappedModel 获取映射后的模型名,并返回是否命中了账号级映射。
|
||||
// matched=true 表示命中了精确映射或通配符映射,即使映射结果与原模型名相同。
|
||||
func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, matched bool) {
|
||||
mapping := a.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
return requestedModel
|
||||
return requestedModel, false
|
||||
}
|
||||
// 精确匹配优先
|
||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||
return mappedModel
|
||||
return mappedModel, true
|
||||
}
|
||||
// 通配符匹配(最长优先)
|
||||
return matchWildcardMapping(mapping, requestedModel)
|
||||
return matchWildcardMappingResult(mapping, requestedModel)
|
||||
}
|
||||
|
||||
func (a *Account) GetBaseURL() string {
|
||||
@@ -605,9 +613,7 @@ func matchWildcard(pattern, str string) bool {
|
||||
return matchAntigravityWildcard(pattern, str)
|
||||
}
|
||||
|
||||
// matchWildcardMapping 通配符映射匹配(最长优先)
|
||||
// 如果没有匹配,返回原始字符串
|
||||
func matchWildcardMapping(mapping map[string]string, requestedModel string) string {
|
||||
func matchWildcardMappingResult(mapping map[string]string, requestedModel string) (string, bool) {
|
||||
// 收集所有匹配的 pattern,按长度降序排序(最长优先)
|
||||
type patternMatch struct {
|
||||
pattern string
|
||||
@@ -622,7 +628,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return requestedModel // 无匹配,返回原始模型名
|
||||
return requestedModel, false // 无匹配,返回原始模型名
|
||||
}
|
||||
|
||||
// 按 pattern 长度降序排序
|
||||
@@ -633,7 +639,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
|
||||
return matches[i].pattern < matches[j].pattern
|
||||
})
|
||||
|
||||
return matches[0].target
|
||||
return matches[0].target, true
|
||||
}
|
||||
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
@@ -651,7 +657,7 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
// IsPoolMode 检查 API Key 账号是否启用池模式。
|
||||
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
|
||||
func (a *Account) IsPoolMode() bool {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
if !a.IsAPIKeyOrBedrock() || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["pool_mode"]; ok {
|
||||
@@ -766,11 +772,16 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||
}
|
||||
|
||||
func (a *Account) IsBedrock() bool {
|
||||
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeBedrock || a.Type == AccountTypeBedrockAPIKey)
|
||||
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
|
||||
}
|
||||
|
||||
func (a *Account) IsBedrockAPIKey() bool {
|
||||
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrockAPIKey
|
||||
return a.IsBedrock() && a.GetCredential("auth_mode") == "apikey"
|
||||
}
|
||||
|
||||
// IsAPIKeyOrBedrock 返回账号类型是否支持配额和池模式等特性
|
||||
func (a *Account) IsAPIKeyOrBedrock() bool {
|
||||
return a.Type == AccountTypeAPIKey || a.Type == AccountTypeBedrock
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAI() bool {
|
||||
@@ -1269,6 +1280,240 @@ func (a *Account) getExtraTime(key string) time.Time {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// getExtraString 从 Extra 中读取指定 key 的字符串值
|
||||
func (a *Account) getExtraString(key string) string {
|
||||
if a.Extra == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getExtraInt 从 Extra 中读取指定 key 的 int 值
|
||||
func (a *Account) getExtraInt(key string) int {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
return int(parseExtraFloat64(v))
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetQuotaDailyResetMode 获取日额度重置模式:"rolling"(默认)或 "fixed"
|
||||
func (a *Account) GetQuotaDailyResetMode() string {
|
||||
if m := a.getExtraString("quota_daily_reset_mode"); m == "fixed" {
|
||||
return "fixed"
|
||||
}
|
||||
return "rolling"
|
||||
}
|
||||
|
||||
// GetQuotaDailyResetHour 获取固定重置的小时(0-23),默认 0
|
||||
func (a *Account) GetQuotaDailyResetHour() int {
|
||||
return a.getExtraInt("quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyResetMode 获取周额度重置模式:"rolling"(默认)或 "fixed"
|
||||
func (a *Account) GetQuotaWeeklyResetMode() string {
|
||||
if m := a.getExtraString("quota_weekly_reset_mode"); m == "fixed" {
|
||||
return "fixed"
|
||||
}
|
||||
return "rolling"
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyResetDay 获取固定重置的星期几(0=周日, 1=周一, ..., 6=周六),默认 1(周一)
|
||||
func (a *Account) GetQuotaWeeklyResetDay() int {
|
||||
if a.Extra == nil {
|
||||
return 1
|
||||
}
|
||||
if _, ok := a.Extra["quota_weekly_reset_day"]; !ok {
|
||||
return 1
|
||||
}
|
||||
return a.getExtraInt("quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyResetHour 获取周配额固定重置的小时(0-23),默认 0
|
||||
func (a *Account) GetQuotaWeeklyResetHour() int {
|
||||
return a.getExtraInt("quota_weekly_reset_hour")
|
||||
}
|
||||
|
||||
// GetQuotaResetTimezone 获取固定重置的时区名(IANA),默认 "UTC"
|
||||
func (a *Account) GetQuotaResetTimezone() string {
|
||||
if tz := a.getExtraString("quota_reset_timezone"); tz != "" {
|
||||
return tz
|
||||
}
|
||||
return "UTC"
|
||||
}
|
||||
|
||||
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
|
||||
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
|
||||
t := after.In(tz)
|
||||
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
if !after.Before(today) {
|
||||
return today.AddDate(0, 0, 1)
|
||||
}
|
||||
return today
|
||||
}
|
||||
|
||||
// lastFixedDailyReset 计算 now 之前最近一次的每日固定重置时间点
|
||||
func lastFixedDailyReset(hour int, tz *time.Location, now time.Time) time.Time {
|
||||
t := now.In(tz)
|
||||
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
if now.Before(today) {
|
||||
return today.AddDate(0, 0, -1)
|
||||
}
|
||||
return today
|
||||
}
|
||||
|
||||
// nextFixedWeeklyReset 计算在 after 之后的下一个每周固定重置时间点
|
||||
// day: 0=Sunday, 1=Monday, ..., 6=Saturday
|
||||
func nextFixedWeeklyReset(day, hour int, tz *time.Location, after time.Time) time.Time {
|
||||
t := after.In(tz)
|
||||
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
currentDay := int(todayReset.Weekday())
|
||||
|
||||
daysForward := (day - currentDay + 7) % 7
|
||||
if daysForward == 0 && !after.Before(todayReset) {
|
||||
daysForward = 7
|
||||
}
|
||||
return todayReset.AddDate(0, 0, daysForward)
|
||||
}
|
||||
|
||||
// lastFixedWeeklyReset 计算 now 之前最近一次的每周固定重置时间点
|
||||
func lastFixedWeeklyReset(day, hour int, tz *time.Location, now time.Time) time.Time {
|
||||
t := now.In(tz)
|
||||
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
currentDay := int(todayReset.Weekday())
|
||||
|
||||
daysBack := (currentDay - day + 7) % 7
|
||||
if daysBack == 0 && now.Before(todayReset) {
|
||||
daysBack = 7
|
||||
}
|
||||
return todayReset.AddDate(0, 0, -daysBack)
|
||||
}
|
||||
|
||||
// isFixedDailyPeriodExpired 检查日配额是否在固定时间模式下已过期
|
||||
func (a *Account) isFixedDailyPeriodExpired(periodStart time.Time) bool {
|
||||
if periodStart.IsZero() {
|
||||
return true
|
||||
}
|
||||
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
|
||||
if err != nil {
|
||||
tz = time.UTC
|
||||
}
|
||||
lastReset := lastFixedDailyReset(a.GetQuotaDailyResetHour(), tz, time.Now())
|
||||
return periodStart.Before(lastReset)
|
||||
}
|
||||
|
||||
// isFixedWeeklyPeriodExpired 检查周配额是否在固定时间模式下已过期
|
||||
func (a *Account) isFixedWeeklyPeriodExpired(periodStart time.Time) bool {
|
||||
if periodStart.IsZero() {
|
||||
return true
|
||||
}
|
||||
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
|
||||
if err != nil {
|
||||
tz = time.UTC
|
||||
}
|
||||
lastReset := lastFixedWeeklyReset(a.GetQuotaWeeklyResetDay(), a.GetQuotaWeeklyResetHour(), tz, time.Now())
|
||||
return periodStart.Before(lastReset)
|
||||
}
|
||||
|
||||
// ComputeQuotaResetAt 根据当前配置计算并填充 extra 中的 quota_daily_reset_at / quota_weekly_reset_at
|
||||
// 在保存账号配置时调用
|
||||
func ComputeQuotaResetAt(extra map[string]any) {
|
||||
now := time.Now()
|
||||
tzName, _ := extra["quota_reset_timezone"].(string)
|
||||
if tzName == "" {
|
||||
tzName = "UTC"
|
||||
}
|
||||
tz, err := time.LoadLocation(tzName)
|
||||
if err != nil {
|
||||
tz = time.UTC
|
||||
}
|
||||
|
||||
// 日配额固定重置时间
|
||||
if mode, _ := extra["quota_daily_reset_mode"].(string); mode == "fixed" {
|
||||
hour := int(parseExtraFloat64(extra["quota_daily_reset_hour"]))
|
||||
if hour < 0 || hour > 23 {
|
||||
hour = 0
|
||||
}
|
||||
resetAt := nextFixedDailyReset(hour, tz, now)
|
||||
extra["quota_daily_reset_at"] = resetAt.UTC().Format(time.RFC3339)
|
||||
} else {
|
||||
delete(extra, "quota_daily_reset_at")
|
||||
}
|
||||
|
||||
// 周配额固定重置时间
|
||||
if mode, _ := extra["quota_weekly_reset_mode"].(string); mode == "fixed" {
|
||||
day := 1 // 默认周一
|
||||
if d, ok := extra["quota_weekly_reset_day"]; ok {
|
||||
day = int(parseExtraFloat64(d))
|
||||
}
|
||||
if day < 0 || day > 6 {
|
||||
day = 1
|
||||
}
|
||||
hour := int(parseExtraFloat64(extra["quota_weekly_reset_hour"]))
|
||||
if hour < 0 || hour > 23 {
|
||||
hour = 0
|
||||
}
|
||||
resetAt := nextFixedWeeklyReset(day, hour, tz, now)
|
||||
extra["quota_weekly_reset_at"] = resetAt.UTC().Format(time.RFC3339)
|
||||
} else {
|
||||
delete(extra, "quota_weekly_reset_at")
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateQuotaResetConfig 校验配额固定重置时间配置的合法性
|
||||
func ValidateQuotaResetConfig(extra map[string]any) error {
|
||||
if extra == nil {
|
||||
return nil
|
||||
}
|
||||
// 校验时区
|
||||
if tz, ok := extra["quota_reset_timezone"].(string); ok && tz != "" {
|
||||
if _, err := time.LoadLocation(tz); err != nil {
|
||||
return errors.New("invalid quota_reset_timezone: must be a valid IANA timezone name")
|
||||
}
|
||||
}
|
||||
// 日配额重置模式
|
||||
if mode, ok := extra["quota_daily_reset_mode"].(string); ok {
|
||||
if mode != "rolling" && mode != "fixed" {
|
||||
return errors.New("quota_daily_reset_mode must be 'rolling' or 'fixed'")
|
||||
}
|
||||
}
|
||||
// 日配额重置小时
|
||||
if v, ok := extra["quota_daily_reset_hour"]; ok {
|
||||
hour := int(parseExtraFloat64(v))
|
||||
if hour < 0 || hour > 23 {
|
||||
return errors.New("quota_daily_reset_hour must be between 0 and 23")
|
||||
}
|
||||
}
|
||||
// 周配额重置模式
|
||||
if mode, ok := extra["quota_weekly_reset_mode"].(string); ok {
|
||||
if mode != "rolling" && mode != "fixed" {
|
||||
return errors.New("quota_weekly_reset_mode must be 'rolling' or 'fixed'")
|
||||
}
|
||||
}
|
||||
// 周配额重置星期几
|
||||
if v, ok := extra["quota_weekly_reset_day"]; ok {
|
||||
day := int(parseExtraFloat64(v))
|
||||
if day < 0 || day > 6 {
|
||||
return errors.New("quota_weekly_reset_day must be between 0 (Sunday) and 6 (Saturday)")
|
||||
}
|
||||
}
|
||||
// 周配额重置小时
|
||||
if v, ok := extra["quota_weekly_reset_hour"]; ok {
|
||||
hour := int(parseExtraFloat64(v))
|
||||
if hour < 0 || hour > 23 {
|
||||
return errors.New("quota_weekly_reset_hour must be between 0 and 23")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
|
||||
func (a *Account) HasAnyQuotaLimit() bool {
|
||||
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
|
||||
@@ -1291,14 +1536,26 @@ func (a *Account) IsQuotaExceeded() bool {
|
||||
// 日额度(周期过期视为未超限,下次 increment 会重置)
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_daily_start")
|
||||
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit {
|
||||
var expired bool
|
||||
if a.GetQuotaDailyResetMode() == "fixed" {
|
||||
expired = a.isFixedDailyPeriodExpired(start)
|
||||
} else {
|
||||
expired = isPeriodExpired(start, 24*time.Hour)
|
||||
}
|
||||
if !expired && a.GetQuotaDailyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// 周额度
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_weekly_start")
|
||||
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit {
|
||||
var expired bool
|
||||
if a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||
expired = a.isFixedWeeklyPeriodExpired(start)
|
||||
} else {
|
||||
expired = isPeriodExpired(start, 7*24*time.Hour)
|
||||
}
|
||||
if !expired && a.GetQuotaWeeklyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
516
backend/internal/service/account_quota_reset_test.go
Normal file
516
backend/internal/service/account_quota_reset_test.go
Normal file
@@ -0,0 +1,516 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// nextFixedDailyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNextFixedDailyReset_BeforeResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 06:00 UTC, reset hour = 9
|
||||
after := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_AtResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// Exactly at reset hour → should return tomorrow
|
||||
after := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_AfterResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// After reset hour → should return tomorrow
|
||||
after := time.Date(2026, 3, 14, 15, 30, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_MidnightReset(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// Reset at hour 0 (midnight), currently 23:59
|
||||
after := time.Date(2026, 3, 14, 23, 59, 0, 0, tz)
|
||||
got := nextFixedDailyReset(0, tz, after)
|
||||
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_NonUTCTimezone(t *testing.T) {
|
||||
tz, err := time.LoadLocation("Asia/Shanghai")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2026-03-14 07:00 UTC = 2026-03-14 15:00 CST, reset hour = 9 (CST)
|
||||
after := time.Date(2026, 3, 14, 7, 0, 0, 0, time.UTC)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
// Already past 9:00 CST today → tomorrow 9:00 CST = 2026-03-15 01:00 UTC
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// lastFixedDailyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLastFixedDailyReset_BeforeResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// Before today's 9:00 → yesterday 9:00
|
||||
want := time.Date(2026, 3, 13, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedDailyReset_AtResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// At exactly 9:00 → today 9:00
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedDailyReset_AfterResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 15, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// After 9:00 → today 9:00
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// nextFixedWeeklyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayAhead(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 is Saturday (day=6), target = Monday (day=1), hour = 9
|
||||
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday = 2026-03-16
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_BeforeHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, before 9:00
|
||||
after := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Today at 9:00
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_AtHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, exactly at 9:00
|
||||
after := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday at 9:00
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_AfterHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, after 9:00
|
||||
after := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday at 9:00
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayPast(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
|
||||
after := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday = 2026-03-23
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_Sunday(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 is Saturday (day=6), target = Sunday (day=0)
|
||||
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(0, 0, tz, after)
|
||||
// Next Sunday = 2026-03-15
|
||||
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// lastFixedWeeklyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLastFixedWeeklyReset_SameDay_AfterHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, now = 15:00
|
||||
now := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Today at 9:00
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedWeeklyReset_SameDay_BeforeHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, now = 06:00
|
||||
now := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Last Monday at 9:00 = 2026-03-09
|
||||
want := time.Date(2026, 3, 9, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedWeeklyReset_DifferentDay(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
|
||||
now := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Last Monday = 2026-03-16
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isFixedDailyPeriodExpired
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_ZeroPeriodStart(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(time.Time{}))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_NotExpired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started after the most recent reset → not expired
|
||||
// (This test uses a time very close to "now", which is after the last reset)
|
||||
periodStart := time.Now().Add(-1 * time.Minute)
|
||||
assert.False(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_Expired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 3 days ago → definitely expired
|
||||
periodStart := time.Now().Add(-72 * time.Hour)
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_InvalidTimezone(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "Invalid/Timezone",
|
||||
}}
|
||||
// Invalid timezone falls back to UTC
|
||||
periodStart := time.Now().Add(-72 * time.Hour)
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isFixedWeeklyPeriodExpired
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_ZeroPeriodStart(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
assert.True(t, a.isFixedWeeklyPeriodExpired(time.Time{}))
|
||||
}
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_NotExpired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 1 minute ago → not expired
|
||||
periodStart := time.Now().Add(-1 * time.Minute)
|
||||
assert.False(t, a.isFixedWeeklyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_Expired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 10 days ago → definitely expired
|
||||
periodStart := time.Now().Add(-240 * time.Hour)
|
||||
assert.True(t, a.isFixedWeeklyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ValidateQuotaResetConfig
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestValidateQuotaResetConfig_NilExtra(t *testing.T) {
|
||||
assert.NoError(t, ValidateQuotaResetConfig(nil))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_EmptyExtra(t *testing.T) {
|
||||
assert.NoError(t, ValidateQuotaResetConfig(map[string]any{}))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_ValidFixed(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "Asia/Shanghai",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_ValidRolling(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidTimezone(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_reset_timezone": "Not/A/Timezone",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_reset_timezone")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyMode(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "invalid",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_mode")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyHour_TooHigh(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(24),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyHour_Negative(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(-1),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyMode(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_mode": "unknown",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_mode")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyDay_TooHigh(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_day": float64(7),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyDay_Negative(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_day": float64(-1),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyHour(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_hour": float64(25),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_BoundaryValues(t *testing.T) {
|
||||
// All boundary values should be valid
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(23),
|
||||
"quota_weekly_reset_day": float64(0), // Sunday
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
|
||||
extra2 := map[string]any{
|
||||
"quota_daily_reset_hour": float64(0),
|
||||
"quota_weekly_reset_day": float64(6), // Saturday
|
||||
"quota_weekly_reset_hour": float64(23),
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra2))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ComputeQuotaResetAt
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestComputeQuotaResetAt_RollingMode_NoResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
_, hasDailyResetAt := extra["quota_daily_reset_at"]
|
||||
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
|
||||
assert.False(t, hasDailyResetAt, "rolling mode should not set quota_daily_reset_at")
|
||||
assert.False(t, hasWeeklyResetAt, "rolling mode should not set quota_weekly_reset_at")
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_RollingMode_ClearsExistingResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
"quota_daily_reset_at": "2026-03-14T09:00:00Z",
|
||||
"quota_weekly_reset_at": "2026-03-16T09:00:00Z",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
_, hasDailyResetAt := extra["quota_daily_reset_at"]
|
||||
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
|
||||
assert.False(t, hasDailyResetAt, "rolling mode should remove quota_daily_reset_at")
|
||||
assert.False(t, hasWeeklyResetAt, "rolling mode should remove quota_weekly_reset_at")
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedDaily_SetsResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok, "quota_daily_reset_at should be set")
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Reset time should be in the future
|
||||
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
|
||||
// Reset hour should be 9 UTC
|
||||
assert.Equal(t, 9, resetAt.UTC().Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedWeekly_SetsResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1), // Monday
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_weekly_reset_at"].(string)
|
||||
require.True(t, ok, "quota_weekly_reset_at should be set")
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Reset time should be in the future
|
||||
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
|
||||
// Reset day should be Monday
|
||||
assert.Equal(t, time.Monday, resetAt.UTC().Weekday())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedDaily_WithTimezone(t *testing.T) {
|
||||
tz, err := time.LoadLocation("Asia/Shanghai")
|
||||
require.NoError(t, err)
|
||||
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "Asia/Shanghai",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// In Shanghai timezone, the hour should be 9
|
||||
assert.Equal(t, 9, resetAt.In(tz).Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_DefaultTimezone(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(12),
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Default timezone is UTC
|
||||
assert.Equal(t, 12, resetAt.UTC().Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_InvalidHour_ClampedToZero(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(99),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Invalid hour → clamped to 0
|
||||
assert.Equal(t, 0, resetAt.UTC().Hour())
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -100,6 +101,7 @@ type antigravityUsageCache struct {
|
||||
const (
|
||||
apiCacheTTL = 3 * time.Minute
|
||||
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
|
||||
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误)
|
||||
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
openAIProbeCacheTTL = 10 * time.Minute
|
||||
@@ -108,11 +110,12 @@ const (
|
||||
|
||||
// UsageCache 封装账户使用量相关的缓存
|
||||
type UsageCache struct {
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存
|
||||
openAIProbeCache sync.Map // accountID -> time.Time
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic)
|
||||
antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
|
||||
openAIProbeCache sync.Map // accountID -> time.Time
|
||||
}
|
||||
|
||||
// NewUsageCache 创建 UsageCache 实例
|
||||
@@ -149,6 +152,18 @@ type AntigravityModelQuota struct {
|
||||
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
|
||||
}
|
||||
|
||||
// AntigravityModelDetail Antigravity 单个模型的详细能力信息
|
||||
type AntigravityModelDetail struct {
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
SupportsImages *bool `json:"supports_images,omitempty"`
|
||||
SupportsThinking *bool `json:"supports_thinking,omitempty"`
|
||||
ThinkingBudget *int `json:"thinking_budget,omitempty"`
|
||||
Recommended *bool `json:"recommended,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"`
|
||||
}
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
type UsageInfo struct {
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
@@ -164,6 +179,33 @@ type UsageInfo struct {
|
||||
|
||||
// Antigravity 多模型配额
|
||||
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
|
||||
|
||||
// Antigravity 账号级信息
|
||||
SubscriptionTier string `json:"subscription_tier,omitempty"` // 归一化订阅等级: FREE/PRO/ULTRA/UNKNOWN
|
||||
SubscriptionTierRaw string `json:"subscription_tier_raw,omitempty"` // 上游原始订阅等级名称
|
||||
|
||||
// Antigravity 模型详细能力信息(与 antigravity_quota 同 key)
|
||||
AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"`
|
||||
|
||||
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
|
||||
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
|
||||
|
||||
// Antigravity 账号是否被上游禁止 (HTTP 403)
|
||||
IsForbidden bool `json:"is_forbidden,omitempty"`
|
||||
ForbiddenReason string `json:"forbidden_reason,omitempty"`
|
||||
ForbiddenType string `json:"forbidden_type,omitempty"` // "validation" / "violation" / "forbidden"
|
||||
ValidationURL string `json:"validation_url,omitempty"` // 验证/申诉链接
|
||||
|
||||
// 状态标记(从 ForbiddenType / HTTP 错误码推导)
|
||||
NeedsVerify bool `json:"needs_verify,omitempty"` // 需要人工验证(forbidden_type=validation)
|
||||
IsBanned bool `json:"is_banned,omitempty"` // 账号被封(forbidden_type=violation)
|
||||
NeedsReauth bool `json:"needs_reauth,omitempty"` // token 失效需重新授权(401)
|
||||
|
||||
// 错误码(机器可读):forbidden / unauthenticated / rate_limited / network_error
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
|
||||
// 获取 usage 时的错误信息(降级返回,而非 500)
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeUsageResponse Anthropic API返回的usage结构
|
||||
@@ -648,34 +690,157 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
|
||||
return &UsageInfo{UpdatedAt: &now}, nil
|
||||
}
|
||||
|
||||
// 1. 检查缓存(10 分钟)
|
||||
// 1. 检查缓存
|
||||
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||
// 重新计算 RemainingSeconds
|
||||
usage := cache.usageInfo
|
||||
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
||||
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok {
|
||||
ttl := antigravityCacheTTL(cache.usageInfo)
|
||||
if time.Since(cache.timestamp) < ttl {
|
||||
usage := cache.usageInfo
|
||||
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
||||
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取代理 URL
|
||||
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account)
|
||||
// 2. singleflight 防止并发击穿
|
||||
flightKey := fmt.Sprintf("ag-usage:%d", account.ID)
|
||||
result, flightErr, _ := s.cache.antigravityFlight.Do(flightKey, func() (any, error) {
|
||||
// 再次检查缓存(等待期间可能已被填充)
|
||||
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok {
|
||||
ttl := antigravityCacheTTL(cache.usageInfo)
|
||||
if time.Since(cache.timestamp) < ttl {
|
||||
usage := cache.usageInfo
|
||||
// 重新计算 RemainingSeconds,避免返回过时的剩余秒数
|
||||
recalcAntigravityRemainingSeconds(usage)
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 调用 API 获取额度
|
||||
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
|
||||
}
|
||||
// 使用独立 context,避免调用方 cancel 导致所有共享 flight 的请求失败
|
||||
fetchCtx, fetchCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer fetchCancel()
|
||||
|
||||
// 4. 缓存结果
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: result.UsageInfo,
|
||||
timestamp: time.Now(),
|
||||
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(fetchCtx, account)
|
||||
fetchResult, err := s.antigravityQuotaFetcher.FetchQuota(fetchCtx, account, proxyURL)
|
||||
if err != nil {
|
||||
degraded := buildAntigravityDegradedUsage(err)
|
||||
enrichUsageWithAccountError(degraded, account)
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: degraded,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
return degraded, nil
|
||||
}
|
||||
|
||||
enrichUsageWithAccountError(fetchResult.UsageInfo, account)
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: fetchResult.UsageInfo,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
return fetchResult.UsageInfo, nil
|
||||
})
|
||||
|
||||
return result.UsageInfo, nil
|
||||
if flightErr != nil {
|
||||
return nil, flightErr
|
||||
}
|
||||
usage, ok := result.(*UsageInfo)
|
||||
if !ok || usage == nil {
|
||||
now := time.Now()
|
||||
return &UsageInfo{UpdatedAt: &now}, nil
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
|
||||
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
|
||||
func recalcAntigravityRemainingSeconds(info *UsageInfo) {
|
||||
if info == nil {
|
||||
return
|
||||
}
|
||||
if info.FiveHour != nil && info.FiveHour.ResetsAt != nil {
|
||||
remaining := int(time.Until(*info.FiveHour.ResetsAt).Seconds())
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
info.FiveHour.RemainingSeconds = remaining
|
||||
}
|
||||
}
|
||||
|
||||
// antigravityCacheTTL 根据 UsageInfo 内容决定缓存 TTL
|
||||
// 403 forbidden 状态稳定,缓存与成功相同(3 分钟);
|
||||
// 其他错误(401/网络)可能快速恢复,缓存 1 分钟。
|
||||
func antigravityCacheTTL(info *UsageInfo) time.Duration {
|
||||
if info == nil {
|
||||
return antigravityErrorTTL
|
||||
}
|
||||
if info.IsForbidden {
|
||||
return apiCacheTTL // 封号/验证状态不会很快变
|
||||
}
|
||||
if info.ErrorCode != "" || info.Error != "" {
|
||||
return antigravityErrorTTL
|
||||
}
|
||||
return apiCacheTTL
|
||||
}
|
||||
|
||||
// buildAntigravityDegradedUsage 从 FetchQuota 错误构建降级 UsageInfo
|
||||
func buildAntigravityDegradedUsage(err error) *UsageInfo {
|
||||
now := time.Now()
|
||||
errMsg := fmt.Sprintf("usage API error: %v", err)
|
||||
slog.Warn("antigravity usage fetch failed, returning degraded response", "error", err)
|
||||
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
Error: errMsg,
|
||||
}
|
||||
|
||||
// 从错误信息推断 error_code 和状态标记
|
||||
// 错误格式来自 antigravity/client.go: "fetchAvailableModels 失败 (HTTP %d): ..."
|
||||
errStr := err.Error()
|
||||
switch {
|
||||
case strings.Contains(errStr, "HTTP 401") ||
|
||||
strings.Contains(errStr, "UNAUTHENTICATED") ||
|
||||
strings.Contains(errStr, "invalid_grant"):
|
||||
info.ErrorCode = errorCodeUnauthenticated
|
||||
info.NeedsReauth = true
|
||||
case strings.Contains(errStr, "HTTP 429"):
|
||||
info.ErrorCode = errorCodeRateLimited
|
||||
default:
|
||||
info.ErrorCode = errorCodeNetworkError
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// enrichUsageWithAccountError 结合账号错误状态修正 UsageInfo
|
||||
// 场景 1(成功路径):FetchAvailableModels 正常返回,但账号已因 403 被标记为 error,
|
||||
//
|
||||
// 需要在正常 usage 数据上附加 forbidden/validation 信息。
|
||||
//
|
||||
// 场景 2(降级路径):被封号的账号 OAuth token 失效,FetchAvailableModels 返回 401,
|
||||
//
|
||||
// 降级逻辑设置了 needs_reauth,但账号实际是 403 封号/需验证,需覆盖为正确状态。
|
||||
func enrichUsageWithAccountError(info *UsageInfo, account *Account) {
|
||||
if info == nil || account == nil || account.Status != StatusError {
|
||||
return
|
||||
}
|
||||
msg := strings.ToLower(account.ErrorMessage)
|
||||
if !strings.Contains(msg, "403") && !strings.Contains(msg, "forbidden") &&
|
||||
!strings.Contains(msg, "violation") && !strings.Contains(msg, "validation") {
|
||||
return
|
||||
}
|
||||
fbType := classifyForbiddenType(account.ErrorMessage)
|
||||
info.IsForbidden = true
|
||||
info.ForbiddenType = fbType
|
||||
info.ForbiddenReason = account.ErrorMessage
|
||||
info.NeedsVerify = fbType == forbiddenTypeValidation
|
||||
info.IsBanned = fbType == forbiddenTypeViolation
|
||||
info.ValidationURL = extractValidationURL(account.ErrorMessage)
|
||||
info.ErrorCode = errorCodeForbidden
|
||||
info.NeedsReauth = false
|
||||
}
|
||||
|
||||
// addWindowStats 为 usage 数据添加窗口期统计
|
||||
|
||||
@@ -43,12 +43,13 @@ func TestMatchWildcard(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchWildcardMapping(t *testing.T) {
|
||||
func TestMatchWildcardMappingResult(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mapping map[string]string
|
||||
requestedModel string
|
||||
expected string
|
||||
matched bool
|
||||
}{
|
||||
// 精确匹配优先于通配符
|
||||
{
|
||||
@@ -59,6 +60,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5-exact",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 最长通配符优先
|
||||
@@ -71,6 +73,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-series",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 单个通配符
|
||||
@@ -81,6 +84,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "claude-opus-4-5",
|
||||
expected: "claude-mapped",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 无匹配返回原始模型
|
||||
@@ -91,6 +95,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "gemini-3-flash",
|
||||
expected: "gemini-3-flash",
|
||||
matched: false,
|
||||
},
|
||||
|
||||
// 空映射返回原始模型
|
||||
@@ -99,6 +104,7 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
mapping: map[string]string{},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5",
|
||||
matched: false,
|
||||
},
|
||||
|
||||
// Gemini 模型映射
|
||||
@@ -110,14 +116,15 @@ func TestMatchWildcardMapping(t *testing.T) {
|
||||
},
|
||||
requestedModel: "gemini-3-flash-preview",
|
||||
expected: "gemini-3-pro-high",
|
||||
matched: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := matchWildcardMapping(tt.mapping, tt.requestedModel)
|
||||
if result != tt.expected {
|
||||
t.Errorf("matchWildcardMapping(%v, %q) = %q, want %q", tt.mapping, tt.requestedModel, result, tt.expected)
|
||||
result, matched := matchWildcardMappingResult(tt.mapping, tt.requestedModel)
|
||||
if result != tt.expected || matched != tt.matched {
|
||||
t.Errorf("matchWildcardMappingResult(%v, %q) = (%q, %v), want (%q, %v)", tt.mapping, tt.requestedModel, result, matched, tt.expected, tt.matched)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -268,6 +275,69 @@ func TestAccountGetMappedModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountResolveMappedModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
credentials map[string]any
|
||||
requestedModel string
|
||||
expectedModel string
|
||||
expectedMatch bool
|
||||
}{
|
||||
{
|
||||
name: "no mapping reports unmatched",
|
||||
credentials: nil,
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: false,
|
||||
},
|
||||
{
|
||||
name: "exact passthrough mapping still counts as matched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard passthrough mapping still counts as matched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-*": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "missing mapping reports unmatched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.2": "gpt-5.2",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: tt.credentials,
|
||||
}
|
||||
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
|
||||
if mappedModel != tt.expectedModel || matched != tt.expectedMatch {
|
||||
t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
|
||||
@@ -1462,6 +1462,13 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
// 预计算固定时间重置的下次重置时间
|
||||
if account.Extra != nil {
|
||||
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ComputeQuotaResetAt(account.Extra)
|
||||
}
|
||||
if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
|
||||
expiresAt := time.Unix(*input.ExpiresAt, 0)
|
||||
account.ExpiresAt = &expiresAt
|
||||
@@ -1535,6 +1542,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
}
|
||||
}
|
||||
account.Extra = input.Extra
|
||||
// 校验并预计算固定时间重置的下次重置时间
|
||||
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ComputeQuotaResetAt(account.Extra)
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
|
||||
|
||||
@@ -2,12 +2,29 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
const (
|
||||
forbiddenTypeValidation = "validation"
|
||||
forbiddenTypeViolation = "violation"
|
||||
forbiddenTypeForbidden = "forbidden"
|
||||
|
||||
// 机器可读的错误码
|
||||
errorCodeForbidden = "forbidden"
|
||||
errorCodeUnauthenticated = "unauthenticated"
|
||||
errorCodeRateLimited = "rate_limited"
|
||||
errorCodeNetworkError = "network_error"
|
||||
)
|
||||
|
||||
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
|
||||
type AntigravityQuotaFetcher struct {
|
||||
proxyRepo ProxyRepository
|
||||
@@ -40,11 +57,32 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
// 调用 API 获取配额
|
||||
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
||||
if err != nil {
|
||||
// 403 Forbidden: 不报错,返回 is_forbidden 标记
|
||||
var forbiddenErr *antigravity.ForbiddenError
|
||||
if errors.As(err, &forbiddenErr) {
|
||||
now := time.Now()
|
||||
fbType := classifyForbiddenType(forbiddenErr.Body)
|
||||
return &QuotaResult{
|
||||
UsageInfo: &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
IsForbidden: true,
|
||||
ForbiddenReason: forbiddenErr.Body,
|
||||
ForbiddenType: fbType,
|
||||
ValidationURL: extractValidationURL(forbiddenErr.Body),
|
||||
NeedsVerify: fbType == forbiddenTypeValidation,
|
||||
IsBanned: fbType == forbiddenTypeViolation,
|
||||
ErrorCode: errorCodeForbidden,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调用 LoadCodeAssist 获取订阅等级(非关键路径,失败不影响主流程)
|
||||
tierRaw, tierNormalized := f.fetchSubscriptionTier(ctx, client, accessToken)
|
||||
|
||||
// 转换为 UsageInfo
|
||||
usageInfo := f.buildUsageInfo(modelsResp)
|
||||
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized)
|
||||
|
||||
return &QuotaResult{
|
||||
UsageInfo: usageInfo,
|
||||
@@ -52,15 +90,52 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
||||
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串
|
||||
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string) {
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
if err != nil {
|
||||
slog.Warn("failed to fetch subscription tier", "error", err)
|
||||
return "", ""
|
||||
}
|
||||
if loadResp == nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// 遍历所有模型,填充 AntigravityQuota
|
||||
raw = loadResp.GetTier() // 已有方法:paidTier > currentTier
|
||||
normalized = normalizeTier(raw)
|
||||
return raw, normalized
|
||||
}
|
||||
|
||||
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
|
||||
func normalizeTier(raw string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
lower := strings.ToLower(raw)
|
||||
switch {
|
||||
case strings.Contains(lower, "ultra"):
|
||||
return "ULTRA"
|
||||
case strings.Contains(lower, "pro"):
|
||||
return "PRO"
|
||||
case strings.Contains(lower, "free"):
|
||||
return "FREE"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
||||
AntigravityQuotaDetails: make(map[string]*AntigravityModelDetail),
|
||||
SubscriptionTier: tierNormalized,
|
||||
SubscriptionTierRaw: tierRaw,
|
||||
}
|
||||
|
||||
// 遍历所有模型,填充 AntigravityQuota 和 AntigravityQuotaDetails
|
||||
for modelName, modelInfo := range modelsResp.Models {
|
||||
if modelInfo.QuotaInfo == nil {
|
||||
continue
|
||||
@@ -73,6 +148,27 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv
|
||||
Utilization: utilization,
|
||||
ResetTime: modelInfo.QuotaInfo.ResetTime,
|
||||
}
|
||||
|
||||
// 填充模型详细能力信息
|
||||
detail := &AntigravityModelDetail{
|
||||
DisplayName: modelInfo.DisplayName,
|
||||
SupportsImages: modelInfo.SupportsImages,
|
||||
SupportsThinking: modelInfo.SupportsThinking,
|
||||
ThinkingBudget: modelInfo.ThinkingBudget,
|
||||
Recommended: modelInfo.Recommended,
|
||||
MaxTokens: modelInfo.MaxTokens,
|
||||
MaxOutputTokens: modelInfo.MaxOutputTokens,
|
||||
SupportedMimeTypes: modelInfo.SupportedMimeTypes,
|
||||
}
|
||||
info.AntigravityQuotaDetails[modelName] = detail
|
||||
}
|
||||
|
||||
// 废弃模型转发规则
|
||||
if len(modelsResp.DeprecatedModelIDs) > 0 {
|
||||
info.ModelForwardingRules = make(map[string]string, len(modelsResp.DeprecatedModelIDs))
|
||||
for oldID, deprecated := range modelsResp.DeprecatedModelIDs {
|
||||
info.ModelForwardingRules[oldID] = deprecated.NewModelID
|
||||
}
|
||||
}
|
||||
|
||||
// 同时设置 FiveHour 用于兼容展示(取主要模型)
|
||||
@@ -108,3 +204,58 @@ func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Acco
|
||||
}
|
||||
return proxy.URL()
|
||||
}
|
||||
|
||||
// classifyForbiddenType 根据 403 响应体判断禁止类型
|
||||
func classifyForbiddenType(body string) string {
|
||||
lower := strings.ToLower(body)
|
||||
switch {
|
||||
case strings.Contains(lower, "validation_required") ||
|
||||
strings.Contains(lower, "verify your account") ||
|
||||
strings.Contains(lower, "validation_url"):
|
||||
return forbiddenTypeValidation
|
||||
case strings.Contains(lower, "terms of service") ||
|
||||
strings.Contains(lower, "violation"):
|
||||
return forbiddenTypeViolation
|
||||
default:
|
||||
return forbiddenTypeForbidden
|
||||
}
|
||||
}
|
||||
|
||||
// urlPattern 用于从 403 响应体中提取 URL(降级方案)
|
||||
var urlPattern = regexp.MustCompile(`https://[^\s"'\\]+`)
|
||||
|
||||
// extractValidationURL 从 403 响应 JSON 中提取验证/申诉链接
|
||||
func extractValidationURL(body string) string {
|
||||
// 1. 尝试结构化 JSON 提取: /error/details[*]/metadata/validation_url 或 appeal_url
|
||||
var parsed struct {
|
||||
Error struct {
|
||||
Details []struct {
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
} `json:"details"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if json.Unmarshal([]byte(body), &parsed) == nil {
|
||||
for _, detail := range parsed.Error.Details {
|
||||
if u := detail.Metadata["validation_url"]; u != "" {
|
||||
return u
|
||||
}
|
||||
if u := detail.Metadata["appeal_url"]; u != "" {
|
||||
return u
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 降级:正则匹配 URL
|
||||
lower := strings.ToLower(body)
|
||||
if !strings.Contains(lower, "validation") &&
|
||||
!strings.Contains(lower, "verify") &&
|
||||
!strings.Contains(lower, "appeal") {
|
||||
return ""
|
||||
}
|
||||
// 先解码常见转义再匹配
|
||||
normalized := strings.ReplaceAll(body, `\u0026`, "&")
|
||||
if m := urlPattern.FindString(normalized); m != "" {
|
||||
return m
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
497
backend/internal/service/antigravity_quota_fetcher_test.go
Normal file
497
backend/internal/service/antigravity_quota_fetcher_test.go
Normal file
@@ -0,0 +1,497 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// normalizeTier
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNormalizeTier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
expected string
|
||||
}{
|
||||
{name: "empty string", raw: "", expected: ""},
|
||||
{name: "free-tier", raw: "free-tier", expected: "FREE"},
|
||||
{name: "g1-pro-tier", raw: "g1-pro-tier", expected: "PRO"},
|
||||
{name: "g1-ultra-tier", raw: "g1-ultra-tier", expected: "ULTRA"},
|
||||
{name: "unknown-something", raw: "unknown-something", expected: "UNKNOWN"},
|
||||
{name: "Google AI Pro contains pro keyword", raw: "Google AI Pro", expected: "PRO"},
|
||||
{name: "case insensitive FREE", raw: "FREE-TIER", expected: "FREE"},
|
||||
{name: "case insensitive Ultra", raw: "Ultra Plan", expected: "ULTRA"},
|
||||
{name: "arbitrary unrecognized string", raw: "enterprise-custom", expected: "UNKNOWN"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeTier(tt.raw)
|
||||
require.Equal(t, tt.expected, got, "normalizeTier(%q)", tt.raw)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// buildUsageInfo
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func aqfBoolPtr(v bool) *bool { return &v }
|
||||
func aqfIntPtr(v int) *int { return &v }
|
||||
|
||||
func TestBuildUsageInfo_BasicModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.75,
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
DisplayName: "Claude Sonnet 4",
|
||||
SupportsImages: aqfBoolPtr(true),
|
||||
SupportsThinking: aqfBoolPtr(false),
|
||||
ThinkingBudget: aqfIntPtr(0),
|
||||
Recommended: aqfBoolPtr(true),
|
||||
MaxTokens: aqfIntPtr(200000),
|
||||
MaxOutputTokens: aqfIntPtr(16384),
|
||||
SupportedMimeTypes: map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
},
|
||||
},
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
ResetTime: "2026-03-08T15:00:00Z",
|
||||
},
|
||||
DisplayName: "Gemini 2.5 Pro",
|
||||
MaxTokens: aqfIntPtr(1000000),
|
||||
MaxOutputTokens: aqfIntPtr(65536),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO")
|
||||
|
||||
// 基本字段
|
||||
require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set")
|
||||
require.Equal(t, "PRO", info.SubscriptionTier)
|
||||
require.Equal(t, "g1-pro-tier", info.SubscriptionTierRaw)
|
||||
|
||||
// AntigravityQuota
|
||||
require.Len(t, info.AntigravityQuota, 2)
|
||||
|
||||
sonnetQuota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, sonnetQuota)
|
||||
require.Equal(t, 25, sonnetQuota.Utilization) // (1 - 0.75) * 100 = 25
|
||||
require.Equal(t, "2026-03-08T12:00:00Z", sonnetQuota.ResetTime)
|
||||
|
||||
geminiQuota := info.AntigravityQuota["gemini-2.5-pro"]
|
||||
require.NotNil(t, geminiQuota)
|
||||
require.Equal(t, 50, geminiQuota.Utilization) // (1 - 0.50) * 100 = 50
|
||||
require.Equal(t, "2026-03-08T15:00:00Z", geminiQuota.ResetTime)
|
||||
|
||||
// AntigravityQuotaDetails
|
||||
require.Len(t, info.AntigravityQuotaDetails, 2)
|
||||
|
||||
sonnetDetail := info.AntigravityQuotaDetails["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, sonnetDetail)
|
||||
require.Equal(t, "Claude Sonnet 4", sonnetDetail.DisplayName)
|
||||
require.Equal(t, aqfBoolPtr(true), sonnetDetail.SupportsImages)
|
||||
require.Equal(t, aqfBoolPtr(false), sonnetDetail.SupportsThinking)
|
||||
require.Equal(t, aqfIntPtr(0), sonnetDetail.ThinkingBudget)
|
||||
require.Equal(t, aqfBoolPtr(true), sonnetDetail.Recommended)
|
||||
require.Equal(t, aqfIntPtr(200000), sonnetDetail.MaxTokens)
|
||||
require.Equal(t, aqfIntPtr(16384), sonnetDetail.MaxOutputTokens)
|
||||
require.Equal(t, map[string]bool{"image/png": true, "image/jpeg": true}, sonnetDetail.SupportedMimeTypes)
|
||||
|
||||
geminiDetail := info.AntigravityQuotaDetails["gemini-2.5-pro"]
|
||||
require.NotNil(t, geminiDetail)
|
||||
require.Equal(t, "Gemini 2.5 Pro", geminiDetail.DisplayName)
|
||||
require.Nil(t, geminiDetail.SupportsImages)
|
||||
require.Nil(t, geminiDetail.SupportsThinking)
|
||||
require.Equal(t, aqfIntPtr(1000000), geminiDetail.MaxTokens)
|
||||
require.Equal(t, aqfIntPtr(65536), geminiDetail.MaxOutputTokens)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_DeprecatedModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 1.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
DeprecatedModelIDs: map[string]antigravity.DeprecatedModelInfo{
|
||||
"claude-3-sonnet-20240229": {NewModelID: "claude-sonnet-4-20250514"},
|
||||
"claude-3-haiku-20240307": {NewModelID: "claude-haiku-3.5-latest"},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.Len(t, info.ModelForwardingRules, 2)
|
||||
require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"])
|
||||
require.Equal(t, "claude-haiku-3.5-latest", info.ModelForwardingRules["claude-3-haiku-20240307"])
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"some-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{RemainingFraction: 0.9},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_EmptyModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.NotNil(t, info.AntigravityQuota)
|
||||
require.Empty(t, info.AntigravityQuota)
|
||||
require.NotNil(t, info.AntigravityQuotaDetails)
|
||||
require.Empty(t, info.AntigravityQuotaDetails)
|
||||
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"model-without-quota": {
|
||||
DisplayName: "No Quota Model",
|
||||
// QuotaInfo is nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped")
|
||||
require.Empty(t, info.AntigravityQuotaDetails, "models with nil QuotaInfo should be skipped from details too")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// priorityModels = ["claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"]
|
||||
// When the first priority model exists, it should be used for FiveHour
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.40,
|
||||
ResetTime: "2026-03-08T18:00:00Z",
|
||||
},
|
||||
},
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.80,
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists")
|
||||
// claude-sonnet-4-20250514 is first in priority list, so it should be used
|
||||
expectedUtilization := (1.0 - 0.80) * 100 // 20
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
require.NotNil(t, info.FiveHour.ResetsAt, "ResetsAt should be parsed from ResetTime")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// Only claude-sonnet-4 exists (second in priority list), not claude-sonnet-4-20250514
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.60,
|
||||
ResetTime: "2026-03-08T14:00:00Z",
|
||||
},
|
||||
},
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.30,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
expectedUtilization := (1.0 - 0.60) * 100 // 40
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// Only gemini-2.5-pro exists (third in priority list)
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.30,
|
||||
},
|
||||
},
|
||||
"other-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.90,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
expectedUtilization := (1.0 - 0.30) * 100 // 70
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// None of the priority models exist
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"some-other-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
ResetTime: "", // empty reset time
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty")
|
||||
require.Equal(t, 0, info.FiveHour.RemainingSeconds)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FullUtilization(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.0, // fully used
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, quota)
|
||||
require.Equal(t, 100, quota.Utilization)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_ZeroUtilization(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 1.0, // fully available
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, quota)
|
||||
require.Equal(t, 0, quota.Utilization)
|
||||
}
|
||||
|
||||
func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) {
|
||||
// 模拟 FetchQuota 遇到 403 时的行为:
|
||||
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true
|
||||
forbiddenErr := &antigravity.ForbiddenError{
|
||||
StatusCode: 403,
|
||||
Body: "Access denied",
|
||||
}
|
||||
|
||||
// 验证 ForbiddenError 满足 errors.As
|
||||
var target *antigravity.ForbiddenError
|
||||
require.True(t, errors.As(forbiddenErr, &target))
|
||||
require.Equal(t, 403, target.StatusCode)
|
||||
require.Equal(t, "Access denied", target.Body)
|
||||
require.Contains(t, forbiddenErr.Error(), "403")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// classifyForbiddenType
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestClassifyForbiddenType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "VALIDATION_REQUIRED keyword",
|
||||
body: `{"error":{"message":"VALIDATION_REQUIRED"}}`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "verify your account",
|
||||
body: `Please verify your account to continue`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "contains validation_url field",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://..."}}]}}`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "terms of service violation",
|
||||
body: `Your account has been suspended for Terms of Service violation`,
|
||||
expected: "violation",
|
||||
},
|
||||
{
|
||||
name: "violation keyword",
|
||||
body: `Account suspended due to policy violation`,
|
||||
expected: "violation",
|
||||
},
|
||||
{
|
||||
name: "generic 403",
|
||||
body: `Access denied`,
|
||||
expected: "forbidden",
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
expected: "forbidden",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := classifyForbiddenType(tt.body)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// extractValidationURL
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractValidationURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "structured validation_url",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://accounts.google.com/verify?token=abc"}}]}}`,
|
||||
expected: "https://accounts.google.com/verify?token=abc",
|
||||
},
|
||||
{
|
||||
name: "structured appeal_url",
|
||||
body: `{"error":{"details":[{"metadata":{"appeal_url":"https://support.google.com/appeal/123"}}]}}`,
|
||||
expected: "https://support.google.com/appeal/123",
|
||||
},
|
||||
{
|
||||
name: "validation_url takes priority over appeal_url",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://v.com","appeal_url":"https://a.com"}}]}}`,
|
||||
expected: "https://v.com",
|
||||
},
|
||||
{
|
||||
name: "fallback regex with verify keyword",
|
||||
body: `Please verify your account at https://accounts.google.com/verify`,
|
||||
expected: "https://accounts.google.com/verify",
|
||||
},
|
||||
{
|
||||
name: "no URL in generic forbidden",
|
||||
body: `Access denied`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "URL present but no validation keywords",
|
||||
body: `Error at https://example.com/something`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "unicode escaped ampersand",
|
||||
body: `validation required: https://accounts.google.com/verify?a=1\u0026b=2`,
|
||||
expected: "https://accounts.google.com/verify?a=1&b=2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractValidationURL(tt.body)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1087,6 +1087,12 @@ type TokenPair struct {
|
||||
ExpiresIn int `json:"expires_in"` // Access Token有效期(秒)
|
||||
}
|
||||
|
||||
// TokenPairWithUser extends TokenPair with user role for backend mode checks
|
||||
type TokenPairWithUser struct {
|
||||
TokenPair
|
||||
UserRole string
|
||||
}
|
||||
|
||||
// GenerateTokenPair 生成Access Token和Refresh Token对
|
||||
// familyID: 可选的Token家族ID,用于Token轮转时保持家族关系
|
||||
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
|
||||
@@ -1168,7 +1174,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
|
||||
|
||||
// RefreshTokenPair 使用Refresh Token刷新Token对
|
||||
// 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效
|
||||
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
||||
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPairWithUser, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, ErrRefreshTokenInvalid
|
||||
@@ -1233,7 +1239,14 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
|
||||
}
|
||||
|
||||
// 生成新的Token对,保持同一个家族ID
|
||||
return s.GenerateTokenPair(ctx, user, data.FamilyID)
|
||||
pair, err := s.GenerateTokenPair(ctx, user, data.FamilyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenPairWithUser{
|
||||
TokenPair: *pair,
|
||||
UserRole: user.Role,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RevokeRefreshToken 撤销单个Refresh Token
|
||||
|
||||
770
backend/internal/service/backup_service.go
Normal file
770
backend/internal/service/backup_service.go
Normal file
@@ -0,0 +1,770 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/robfig/cron/v3"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
settingKeyBackupS3Config = "backup_s3_config"
|
||||
settingKeyBackupSchedule = "backup_schedule"
|
||||
settingKeyBackupRecords = "backup_records"
|
||||
|
||||
maxBackupRecords = 100
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBackupS3NotConfigured = infraerrors.BadRequest("BACKUP_S3_NOT_CONFIGURED", "backup S3 storage is not configured")
|
||||
ErrBackupNotFound = infraerrors.NotFound("BACKUP_NOT_FOUND", "backup record not found")
|
||||
ErrBackupInProgress = infraerrors.Conflict("BACKUP_IN_PROGRESS", "a backup is already in progress")
|
||||
ErrRestoreInProgress = infraerrors.Conflict("RESTORE_IN_PROGRESS", "a restore is already in progress")
|
||||
ErrBackupRecordsCorrupt = infraerrors.InternalServer("BACKUP_RECORDS_CORRUPT", "backup records data is corrupted")
|
||||
ErrBackupS3ConfigCorrupt = infraerrors.InternalServer("BACKUP_S3_CONFIG_CORRUPT", "backup S3 config data is corrupted")
|
||||
)
|
||||
|
||||
// ─── 接口定义 ───
|
||||
|
||||
// DBDumper abstracts database dump/restore operations
|
||||
type DBDumper interface {
|
||||
Dump(ctx context.Context) (io.ReadCloser, error)
|
||||
Restore(ctx context.Context, data io.Reader) error
|
||||
}
|
||||
|
||||
// BackupObjectStore abstracts object storage for backup files
|
||||
type BackupObjectStore interface {
|
||||
Upload(ctx context.Context, key string, body io.Reader, contentType string) (sizeBytes int64, err error)
|
||||
Download(ctx context.Context, key string) (io.ReadCloser, error)
|
||||
Delete(ctx context.Context, key string) error
|
||||
PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error)
|
||||
HeadBucket(ctx context.Context) error
|
||||
}
|
||||
|
||||
// BackupObjectStoreFactory creates an object store from S3 config
|
||||
type BackupObjectStoreFactory func(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error)
|
||||
|
||||
// ─── 数据模型 ───
|
||||
|
||||
// BackupS3Config S3 兼容存储配置(支持 Cloudflare R2)
|
||||
type BackupS3Config struct {
|
||||
Endpoint string `json:"endpoint"` // e.g. https://<account_id>.r2.cloudflarestorage.com
|
||||
Region string `json:"region"` // R2 用 "auto"
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key,omitempty"` //nolint:revive // field name follows AWS convention
|
||||
Prefix string `json:"prefix"` // S3 key 前缀,如 "backups/"
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
}
|
||||
|
||||
// IsConfigured 检查必要字段是否已配置
|
||||
func (c *BackupS3Config) IsConfigured() bool {
|
||||
return c.Bucket != "" && c.AccessKeyID != "" && c.SecretAccessKey != ""
|
||||
}
|
||||
|
||||
// BackupScheduleConfig 定时备份配置
|
||||
type BackupScheduleConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
CronExpr string `json:"cron_expr"` // cron 表达式,如 "0 2 * * *" 每天凌晨2点
|
||||
RetainDays int `json:"retain_days"` // 备份文件过期天数,默认14,0=不自动清理
|
||||
RetainCount int `json:"retain_count"` // 最多保留份数,0=不限制
|
||||
}
|
||||
|
||||
// BackupRecord 备份记录
|
||||
type BackupRecord struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"` // pending, running, completed, failed
|
||||
BackupType string `json:"backup_type"` // postgres
|
||||
FileName string `json:"file_name"`
|
||||
S3Key string `json:"s3_key"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
TriggeredBy string `json:"triggered_by"` // manual, scheduled
|
||||
ErrorMsg string `json:"error_message,omitempty"`
|
||||
StartedAt string `json:"started_at"`
|
||||
FinishedAt string `json:"finished_at,omitempty"`
|
||||
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
|
||||
}
|
||||
|
||||
// BackupService 数据库备份恢复服务
|
||||
type BackupService struct {
|
||||
settingRepo SettingRepository
|
||||
dbCfg *config.DatabaseConfig
|
||||
encryptor SecretEncryptor
|
||||
storeFactory BackupObjectStoreFactory
|
||||
dumper DBDumper
|
||||
|
||||
mu sync.Mutex
|
||||
store BackupObjectStore
|
||||
s3Cfg *BackupS3Config
|
||||
backingUp bool
|
||||
restoring bool
|
||||
|
||||
recordsMu sync.Mutex // 保护 records 的 load/save 操作
|
||||
|
||||
cronMu sync.Mutex
|
||||
cronSched *cron.Cron
|
||||
cronEntryID cron.EntryID
|
||||
}
|
||||
|
||||
func NewBackupService(
|
||||
settingRepo SettingRepository,
|
||||
cfg *config.Config,
|
||||
encryptor SecretEncryptor,
|
||||
storeFactory BackupObjectStoreFactory,
|
||||
dumper DBDumper,
|
||||
) *BackupService {
|
||||
return &BackupService{
|
||||
settingRepo: settingRepo,
|
||||
dbCfg: &cfg.Database,
|
||||
encryptor: encryptor,
|
||||
storeFactory: storeFactory,
|
||||
dumper: dumper,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动定时备份调度器
|
||||
func (s *BackupService) Start() {
|
||||
s.cronSched = cron.New()
|
||||
s.cronSched.Start()
|
||||
|
||||
// 加载已有的定时配置
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
schedule, err := s.GetSchedule(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 加载定时备份配置失败: %v", err)
|
||||
return
|
||||
}
|
||||
if schedule.Enabled && schedule.CronExpr != "" {
|
||||
if err := s.applyCronSchedule(schedule); err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 应用定时备份配置失败: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止定时备份
|
||||
func (s *BackupService) Stop() {
|
||||
s.cronMu.Lock()
|
||||
defer s.cronMu.Unlock()
|
||||
if s.cronSched != nil {
|
||||
s.cronSched.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// ─── S3 配置管理 ───
|
||||
|
||||
func (s *BackupService) GetS3Config(ctx context.Context) (*BackupS3Config, error) {
|
||||
cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg == nil {
|
||||
return &BackupS3Config{}, nil
|
||||
}
|
||||
// 脱敏返回
|
||||
cfg.SecretAccessKey = ""
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) (*BackupS3Config, error) {
|
||||
// 如果没提供 secret,保留原有值
|
||||
if cfg.SecretAccessKey == "" {
|
||||
old, _ := s.loadS3Config(ctx)
|
||||
if old != nil {
|
||||
cfg.SecretAccessKey = old.SecretAccessKey
|
||||
}
|
||||
} else {
|
||||
// 加密 SecretAccessKey
|
||||
encrypted, err := s.encryptor.Encrypt(cfg.SecretAccessKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encrypt secret: %w", err)
|
||||
}
|
||||
cfg.SecretAccessKey = encrypted
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal s3 config: %w", err)
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, settingKeyBackupS3Config, string(data)); err != nil {
|
||||
return nil, fmt.Errorf("save s3 config: %w", err)
|
||||
}
|
||||
|
||||
// 清除缓存的 S3 客户端
|
||||
s.mu.Lock()
|
||||
s.store = nil
|
||||
s.s3Cfg = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
cfg.SecretAccessKey = ""
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) TestS3Connection(ctx context.Context, cfg BackupS3Config) error {
|
||||
// 如果没提供 secret,用已保存的
|
||||
if cfg.SecretAccessKey == "" {
|
||||
old, _ := s.loadS3Config(ctx)
|
||||
if old != nil {
|
||||
cfg.SecretAccessKey = old.SecretAccessKey
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
|
||||
return fmt.Errorf("incomplete S3 config: bucket, access_key_id, secret_access_key are required")
|
||||
}
|
||||
|
||||
store, err := s.storeFactory(ctx, &cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return store.HeadBucket(ctx)
|
||||
}
|
||||
|
||||
// ─── 定时备份管理 ───
|
||||
|
||||
func (s *BackupService) GetSchedule(ctx context.Context) (*BackupScheduleConfig, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupSchedule)
|
||||
if err != nil || raw == "" {
|
||||
return &BackupScheduleConfig{}, nil
|
||||
}
|
||||
var cfg BackupScheduleConfig
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||
return &BackupScheduleConfig{}, nil
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) UpdateSchedule(ctx context.Context, cfg BackupScheduleConfig) (*BackupScheduleConfig, error) {
|
||||
if cfg.Enabled && cfg.CronExpr == "" {
|
||||
return nil, infraerrors.BadRequest("INVALID_CRON", "cron expression is required when schedule is enabled")
|
||||
}
|
||||
// 验证 cron 表达式
|
||||
if cfg.CronExpr != "" {
|
||||
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||
if _, err := parser.Parse(cfg.CronExpr); err != nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("invalid cron expression: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal schedule config: %w", err)
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, settingKeyBackupSchedule, string(data)); err != nil {
|
||||
return nil, fmt.Errorf("save schedule config: %w", err)
|
||||
}
|
||||
|
||||
// 应用或停止定时任务
|
||||
if cfg.Enabled {
|
||||
if err := s.applyCronSchedule(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
s.removeCronSchedule()
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) applyCronSchedule(cfg *BackupScheduleConfig) error {
|
||||
s.cronMu.Lock()
|
||||
defer s.cronMu.Unlock()
|
||||
|
||||
if s.cronSched == nil {
|
||||
return fmt.Errorf("cron scheduler not initialized")
|
||||
}
|
||||
|
||||
// 移除旧任务
|
||||
if s.cronEntryID != 0 {
|
||||
s.cronSched.Remove(s.cronEntryID)
|
||||
s.cronEntryID = 0
|
||||
}
|
||||
|
||||
entryID, err := s.cronSched.AddFunc(cfg.CronExpr, func() {
|
||||
s.runScheduledBackup()
|
||||
})
|
||||
if err != nil {
|
||||
return infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("failed to schedule: %v", err))
|
||||
}
|
||||
s.cronEntryID = entryID
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已启用: %s", cfg.CronExpr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupService) removeCronSchedule() {
|
||||
s.cronMu.Lock()
|
||||
defer s.cronMu.Unlock()
|
||||
if s.cronSched != nil && s.cronEntryID != 0 {
|
||||
s.cronSched.Remove(s.cronEntryID)
|
||||
s.cronEntryID = 0
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已停用")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupService) runScheduledBackup() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// 读取定时备份配置中的过期天数
|
||||
schedule, _ := s.GetSchedule(ctx)
|
||||
expireDays := 14 // 默认14天过期
|
||||
if schedule != nil && schedule.RetainDays > 0 {
|
||||
expireDays = schedule.RetainDays
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
|
||||
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
|
||||
|
||||
// 清理过期备份(复用已加载的 schedule)
|
||||
if schedule == nil {
|
||||
return
|
||||
}
|
||||
if err := s.cleanupOldBackups(ctx, schedule); err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 清理过期备份失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── 备份/恢复核心 ───
|
||||
|
||||
// CreateBackup 创建全量数据库备份并上传到 S3(流式处理)
|
||||
// expireDays: 备份过期天数,0=永不过期,默认14天
|
||||
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
|
||||
s.mu.Lock()
|
||||
if s.backingUp {
|
||||
s.mu.Unlock()
|
||||
return nil, ErrBackupInProgress
|
||||
}
|
||||
s.backingUp = true
|
||||
s.mu.Unlock()
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.backingUp = false
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s3Cfg == nil || !s3Cfg.IsConfigured() {
|
||||
return nil, ErrBackupS3NotConfigured
|
||||
}
|
||||
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init object store: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
backupID := uuid.New().String()[:8]
|
||||
fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405"))
|
||||
s3Key := s.buildS3Key(s3Cfg, fileName)
|
||||
|
||||
var expiresAt string
|
||||
if expireDays > 0 {
|
||||
expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339)
|
||||
}
|
||||
|
||||
record := &BackupRecord{
|
||||
ID: backupID,
|
||||
Status: "running",
|
||||
BackupType: "postgres",
|
||||
FileName: fileName,
|
||||
S3Key: s3Key,
|
||||
TriggeredBy: triggeredBy,
|
||||
StartedAt: now.Format(time.RFC3339),
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
|
||||
// 流式执行: pg_dump -> gzip -> S3 upload
|
||||
dumpReader, err := s.dumper.Dump(ctx)
|
||||
if err != nil {
|
||||
record.Status = "failed"
|
||||
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
_ = s.saveRecord(ctx, record)
|
||||
return record, fmt.Errorf("pg_dump: %w", err)
|
||||
}
|
||||
|
||||
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
|
||||
pr, pw := io.Pipe()
|
||||
var gzipErr error
|
||||
go func() {
|
||||
gzWriter := gzip.NewWriter(pw)
|
||||
_, gzipErr = io.Copy(gzWriter, dumpReader)
|
||||
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
|
||||
gzipErr = closeErr
|
||||
}
|
||||
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
|
||||
gzipErr = closeErr
|
||||
}
|
||||
if gzipErr != nil {
|
||||
_ = pw.CloseWithError(gzipErr)
|
||||
} else {
|
||||
_ = pw.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
contentType := "application/gzip"
|
||||
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
|
||||
if err != nil {
|
||||
record.Status = "failed"
|
||||
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
|
||||
if gzipErr != nil {
|
||||
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
|
||||
}
|
||||
record.ErrorMsg = errMsg
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
_ = s.saveRecord(ctx, record)
|
||||
return record, fmt.Errorf("backup upload: %w", err)
|
||||
}
|
||||
|
||||
record.SizeBytes = sizeBytes
|
||||
record.Status = "completed"
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
if err := s.saveRecord(ctx, record); err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
|
||||
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
|
||||
s.mu.Lock()
|
||||
if s.restoring {
|
||||
s.mu.Unlock()
|
||||
return ErrRestoreInProgress
|
||||
}
|
||||
s.restoring = true
|
||||
s.mu.Unlock()
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.restoring = false
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
record, err := s.GetBackupRecord(ctx, backupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if record.Status != "completed" {
|
||||
return infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup")
|
||||
}
|
||||
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init object store: %w", err)
|
||||
}
|
||||
|
||||
// 从 S3 流式下载
|
||||
body, err := objectStore.Download(ctx, record.S3Key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("S3 download failed: %w", err)
|
||||
}
|
||||
defer func() { _ = body.Close() }()
|
||||
|
||||
// 流式解压 gzip -> psql(不将全部数据加载到内存)
|
||||
gzReader, err := gzip.NewReader(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gzip reader: %w", err)
|
||||
}
|
||||
defer func() { _ = gzReader.Close() }()
|
||||
|
||||
// 流式恢复
|
||||
if err := s.dumper.Restore(ctx, gzReader); err != nil {
|
||||
return fmt.Errorf("pg restore: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ─── 备份记录管理 ───
|
||||
|
||||
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
|
||||
records, err := s.loadRecords(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 倒序返回(最新在前)
|
||||
sort.Slice(records, func(i, j int) bool {
|
||||
return records[i].StartedAt > records[j].StartedAt
|
||||
})
|
||||
return records, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackupRecord(ctx context.Context, backupID string) (*BackupRecord, error) {
|
||||
records, err := s.loadRecords(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range records {
|
||||
if records[i].ID == backupID {
|
||||
return &records[i], nil
|
||||
}
|
||||
}
|
||||
return nil, ErrBackupNotFound
|
||||
}
|
||||
|
||||
func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, err := s.loadRecordsLocked(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var found *BackupRecord
|
||||
var remaining []BackupRecord
|
||||
for i := range records {
|
||||
if records[i].ID == backupID {
|
||||
found = &records[i]
|
||||
} else {
|
||||
remaining = append(remaining, records[i])
|
||||
}
|
||||
}
|
||||
if found == nil {
|
||||
return ErrBackupNotFound
|
||||
}
|
||||
|
||||
// 从 S3 删除
|
||||
if found.S3Key != "" && found.Status == "completed" {
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err == nil && s3Cfg != nil && s3Cfg.IsConfigured() {
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err == nil {
|
||||
_ = objectStore.Delete(ctx, found.S3Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.saveRecordsLocked(ctx, remaining)
|
||||
}
|
||||
|
||||
// GetBackupDownloadURL 获取备份文件预签名下载 URL
|
||||
func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID string) (string, error) {
|
||||
record, err := s.GetBackupRecord(ctx, backupID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if record.Status != "completed" {
|
||||
return "", infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "backup is not completed")
|
||||
}
|
||||
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
url, err := objectStore.PresignURL(ctx, record.S3Key, 1*time.Hour)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// ─── 内部方法 ───
|
||||
|
||||
func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config)
|
||||
if err != nil || raw == "" {
|
||||
return nil, nil //nolint:nilnil // no config is a valid state
|
||||
}
|
||||
var cfg BackupS3Config
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||
return nil, ErrBackupS3ConfigCorrupt
|
||||
}
|
||||
// 解密 SecretAccessKey
|
||||
if cfg.SecretAccessKey != "" {
|
||||
decrypted, err := s.encryptor.Decrypt(cfg.SecretAccessKey)
|
||||
if err != nil {
|
||||
// 兼容未加密的旧数据:如果解密失败,保持原值
|
||||
logger.LegacyPrintf("service.backup", "[Backup] S3 SecretAccessKey 解密失败(可能是旧的未加密数据): %v", err)
|
||||
} else {
|
||||
cfg.SecretAccessKey = decrypted
|
||||
}
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.store != nil && s.s3Cfg != nil {
|
||||
return s.store, nil
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
return nil, ErrBackupS3NotConfigured
|
||||
}
|
||||
|
||||
store, err := s.storeFactory(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.store = store
|
||||
s.s3Cfg = cfg
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string {
|
||||
prefix := strings.TrimRight(cfg.Prefix, "/")
|
||||
if prefix == "" {
|
||||
prefix = "backups"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s", prefix, time.Now().Format("2006/01/02"), fileName)
|
||||
}
|
||||
|
||||
// loadRecords 加载备份记录,区分"无数据"和"数据损坏"
|
||||
func (s *BackupService) loadRecords(ctx context.Context) ([]BackupRecord, error) {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
return s.loadRecordsLocked(ctx)
|
||||
}
|
||||
|
||||
// loadRecordsLocked 在已持有 recordsMu 锁的情况下加载记录
|
||||
func (s *BackupService) loadRecordsLocked(ctx context.Context) ([]BackupRecord, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupRecords)
|
||||
if err != nil || raw == "" {
|
||||
return nil, nil //nolint:nilnil // no records is a valid state
|
||||
}
|
||||
var records []BackupRecord
|
||||
if err := json.Unmarshal([]byte(raw), &records); err != nil {
|
||||
return nil, ErrBackupRecordsCorrupt
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// saveRecordsLocked 在已持有 recordsMu 锁的情况下保存记录
|
||||
func (s *BackupService) saveRecordsLocked(ctx context.Context, records []BackupRecord) error {
|
||||
data, err := json.Marshal(records)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.settingRepo.Set(ctx, settingKeyBackupRecords, string(data))
|
||||
}
|
||||
|
||||
// saveRecord 保存单条记录(带互斥锁保护)
|
||||
func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) error {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, _ := s.loadRecordsLocked(ctx)
|
||||
|
||||
// 更新已有记录或追加
|
||||
found := false
|
||||
for i := range records {
|
||||
if records[i].ID == record.ID {
|
||||
records[i] = *record
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
records = append(records, *record)
|
||||
}
|
||||
|
||||
// 限制记录数量
|
||||
if len(records) > maxBackupRecords {
|
||||
records = records[len(records)-maxBackupRecords:]
|
||||
}
|
||||
|
||||
return s.saveRecordsLocked(ctx, records)
|
||||
}
|
||||
|
||||
func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupScheduleConfig) error {
|
||||
if schedule == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, err := s.loadRecordsLocked(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 按时间倒序
|
||||
sort.Slice(records, func(i, j int) bool {
|
||||
return records[i].StartedAt > records[j].StartedAt
|
||||
})
|
||||
|
||||
var toDelete []BackupRecord
|
||||
var toKeep []BackupRecord
|
||||
|
||||
for i, r := range records {
|
||||
shouldDelete := false
|
||||
|
||||
// 按保留份数清理
|
||||
if schedule.RetainCount > 0 && i >= schedule.RetainCount {
|
||||
shouldDelete = true
|
||||
}
|
||||
|
||||
// 按保留天数清理
|
||||
if schedule.RetainDays > 0 && r.StartedAt != "" {
|
||||
startedAt, err := time.Parse(time.RFC3339, r.StartedAt)
|
||||
if err == nil && time.Since(startedAt) > time.Duration(schedule.RetainDays)*24*time.Hour {
|
||||
shouldDelete = true
|
||||
}
|
||||
}
|
||||
|
||||
if shouldDelete && r.Status == "completed" {
|
||||
toDelete = append(toDelete, r)
|
||||
} else {
|
||||
toKeep = append(toKeep, r)
|
||||
}
|
||||
}
|
||||
|
||||
// 删除 S3 上的文件
|
||||
for _, r := range toDelete {
|
||||
if r.S3Key != "" {
|
||||
_ = s.deleteS3Object(ctx, r.S3Key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toDelete) > 0 {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 自动清理了 %d 个过期备份", len(toDelete))
|
||||
return s.saveRecordsLocked(ctx, toKeep)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteS3Object(ctx context.Context, key string) error {
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil || s3Cfg == nil {
|
||||
return nil
|
||||
}
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return objectStore.Delete(ctx, key)
|
||||
}
|
||||
528
backend/internal/service/backup_service_test.go
Normal file
528
backend/internal/service/backup_service_test.go
Normal file
@@ -0,0 +1,528 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// ─── Mocks ───
|
||||
|
||||
type mockSettingRepo struct {
|
||||
mu sync.Mutex
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
func newMockSettingRepo() *mockSettingRepo {
|
||||
return &mockSettingRepo{data: make(map[string]string)}
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) Get(_ context.Context, key string) (*Setting, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
v, ok := m.data[key]
|
||||
if !ok {
|
||||
return nil, ErrSettingNotFound
|
||||
}
|
||||
return &Setting{Key: key, Value: v}, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
v, ok := m.data[key]
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) Set(_ context.Context, key, value string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.data[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
result := make(map[string]string)
|
||||
for _, k := range keys {
|
||||
if v, ok := m.data[k]; ok {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for k, v := range settings {
|
||||
m.data[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
result := make(map[string]string, len(m.data))
|
||||
for k, v := range m.data {
|
||||
result[k] = v
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) Delete(_ context.Context, key string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.data, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// plainEncryptor 仅做 base64-like 包装,用于测试
|
||||
type plainEncryptor struct{}
|
||||
|
||||
func (e *plainEncryptor) Encrypt(plaintext string) (string, error) {
|
||||
return "ENC:" + plaintext, nil
|
||||
}
|
||||
|
||||
func (e *plainEncryptor) Decrypt(ciphertext string) (string, error) {
|
||||
if strings.HasPrefix(ciphertext, "ENC:") {
|
||||
return strings.TrimPrefix(ciphertext, "ENC:"), nil
|
||||
}
|
||||
return ciphertext, fmt.Errorf("not encrypted")
|
||||
}
|
||||
|
||||
type mockDumper struct {
|
||||
dumpData []byte
|
||||
dumpErr error
|
||||
restored []byte
|
||||
restErr error
|
||||
}
|
||||
|
||||
func (m *mockDumper) Dump(_ context.Context) (io.ReadCloser, error) {
|
||||
if m.dumpErr != nil {
|
||||
return nil, m.dumpErr
|
||||
}
|
||||
return io.NopCloser(bytes.NewReader(m.dumpData)), nil
|
||||
}
|
||||
|
||||
func (m *mockDumper) Restore(_ context.Context, data io.Reader) error {
|
||||
if m.restErr != nil {
|
||||
return m.restErr
|
||||
}
|
||||
d, err := io.ReadAll(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.restored = d
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockObjectStore struct {
|
||||
objects map[string][]byte
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMockObjectStore() *mockObjectStore {
|
||||
return &mockObjectStore{objects: make(map[string][]byte)}
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) Upload(_ context.Context, key string, body io.Reader, _ string) (int64, error) {
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.objects[key] = data
|
||||
m.mu.Unlock()
|
||||
return int64(len(data)), nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) Download(_ context.Context, key string) (io.ReadCloser, error) {
|
||||
m.mu.Lock()
|
||||
data, ok := m.objects[key]
|
||||
m.mu.Unlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("not found: %s", key)
|
||||
}
|
||||
return io.NopCloser(bytes.NewReader(data)), nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) Delete(_ context.Context, key string) error {
|
||||
m.mu.Lock()
|
||||
delete(m.objects, key)
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) PresignURL(_ context.Context, key string, _ time.Duration) (string, error) {
|
||||
return "https://presigned.example.com/" + key, nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) HeadBucket(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "test",
|
||||
DBName: "testdb",
|
||||
},
|
||||
}
|
||||
factory := func(_ context.Context, _ *BackupS3Config) (BackupObjectStore, error) {
|
||||
return store, nil
|
||||
}
|
||||
return NewBackupService(repo, cfg, &plainEncryptor{}, factory, dumper)
|
||||
}
|
||||
|
||||
func seedS3Config(t *testing.T, repo *mockSettingRepo) {
|
||||
t.Helper()
|
||||
cfg := BackupS3Config{
|
||||
Bucket: "test-bucket",
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "ENC:secret123",
|
||||
Prefix: "backups",
|
||||
}
|
||||
data, _ := json.Marshal(cfg)
|
||||
require.NoError(t, repo.Set(context.Background(), settingKeyBackupS3Config, string(data)))
|
||||
}
|
||||
|
||||
// ─── Tests ───
|
||||
|
||||
func TestBackupService_S3ConfigEncryption(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
// 保存配置 -> SecretAccessKey 应被加密
|
||||
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||
Bucket: "my-bucket",
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "my-secret",
|
||||
Prefix: "backups",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 直接读取数据库中存储的值,应该是加密后的
|
||||
raw, _ := repo.GetValue(context.Background(), settingKeyBackupS3Config)
|
||||
var stored BackupS3Config
|
||||
require.NoError(t, json.Unmarshal([]byte(raw), &stored))
|
||||
require.Equal(t, "ENC:my-secret", stored.SecretAccessKey)
|
||||
|
||||
// 通过 GetS3Config 获取应该脱敏
|
||||
cfg, err := svc.GetS3Config(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, cfg.SecretAccessKey)
|
||||
require.Equal(t, "my-bucket", cfg.Bucket)
|
||||
|
||||
// loadS3Config 内部应解密
|
||||
internal, err := svc.loadS3Config(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "my-secret", internal.SecretAccessKey)
|
||||
}
|
||||
|
||||
func TestBackupService_S3ConfigKeepExistingSecret(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
// 先保存一个有 secret 的配置
|
||||
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||
Bucket: "my-bucket",
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "original-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 再更新时不提供 secret,应保留原值
|
||||
_, err = svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||
Bucket: "my-bucket",
|
||||
AccessKeyID: "AKID-NEW",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
internal, err := svc.loadS3Config(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "original-secret", internal.SecretAccessKey)
|
||||
require.Equal(t, "AKID-NEW", internal.AccessKeyID)
|
||||
}
|
||||
|
||||
func TestBackupService_SaveRecordConcurrency(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
n := 20
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
record := &BackupRecord{
|
||||
ID: fmt.Sprintf("rec-%d", idx),
|
||||
Status: "completed",
|
||||
StartedAt: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
_ = svc.saveRecord(context.Background(), record)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
records, err := svc.loadRecords(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, records, n)
|
||||
}
|
||||
|
||||
func TestBackupService_LoadRecords_Empty(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
records, err := svc.loadRecords(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, records) // 无数据时返回 nil
|
||||
}
|
||||
|
||||
func TestBackupService_LoadRecords_Corrupted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
_ = repo.Set(context.Background(), settingKeyBackupRecords, "not valid json{{{")
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
records, err := svc.loadRecords(context.Background())
|
||||
require.Error(t, err) // 损坏数据应返回错误
|
||||
require.Nil(t, records)
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_Streaming(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
|
||||
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "completed", record.Status)
|
||||
require.Greater(t, record.SizeBytes, int64(0))
|
||||
require.NotEmpty(t, record.S3Key)
|
||||
|
||||
// 验证 S3 上确实有文件
|
||||
store.mu.Lock()
|
||||
require.Len(t, store.objects, 1)
|
||||
store.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_DumpFailure(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumper := &mockDumper{dumpErr: fmt.Errorf("pg_dump failed")}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "failed", record.Status)
|
||||
require.Contains(t, record.ErrorMsg, "pg_dump")
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_NoS3Config(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.ErrorIs(t, err, ErrBackupS3NotConfigured)
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
// 使用一个慢速 dumper 来模拟正在进行的备份
|
||||
dumper := &mockDumper{dumpData: []byte("data")}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
// 手动设置 backingUp 标志
|
||||
svc.mu.Lock()
|
||||
svc.backingUp = true
|
||||
svc.mu.Unlock()
|
||||
|
||||
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.ErrorIs(t, err, ErrBackupInProgress)
|
||||
}
|
||||
|
||||
func TestBackupService_RestoreBackup_Streaming(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
|
||||
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
// 先创建一个备份
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 恢复
|
||||
err = svc.RestoreBackup(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证 psql 收到的数据是否与原始 dump 内容一致
|
||||
require.Equal(t, dumpContent, string(dumper.restored))
|
||||
}
|
||||
|
||||
func TestBackupService_RestoreBackup_NotCompleted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
// 手动插入一条 failed 记录
|
||||
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||
ID: "fail-1",
|
||||
Status: "failed",
|
||||
})
|
||||
|
||||
err := svc.RestoreBackup(context.Background(), "fail-1")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_DeleteBackup(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumpContent := "data"
|
||||
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
|
||||
// S3 中应有文件
|
||||
store.mu.Lock()
|
||||
require.Len(t, store.objects, 1)
|
||||
store.mu.Unlock()
|
||||
|
||||
// 删除
|
||||
err = svc.DeleteBackup(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// S3 中文件应被删除
|
||||
store.mu.Lock()
|
||||
require.Len(t, store.objects, 0)
|
||||
store.mu.Unlock()
|
||||
|
||||
// 记录应不存在
|
||||
_, err = svc.GetBackupRecord(context.Background(), record.ID)
|
||||
require.ErrorIs(t, err, ErrBackupNotFound)
|
||||
}
|
||||
|
||||
func TestBackupService_GetDownloadURL(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumper := &mockDumper{dumpData: []byte("data")}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
|
||||
url, err := svc.GetBackupDownloadURL(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, url, "https://presigned.example.com/")
|
||||
}
|
||||
|
||||
func TestBackupService_ListBackups_Sorted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
now := time.Now()
|
||||
for i := 0; i < 3; i++ {
|
||||
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||
ID: fmt.Sprintf("rec-%d", i),
|
||||
Status: "completed",
|
||||
StartedAt: now.Add(time.Duration(i) * time.Hour).Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
records, err := svc.ListBackups(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, records, 3)
|
||||
// 最新在前
|
||||
require.Equal(t, "rec-2", records[0].ID)
|
||||
require.Equal(t, "rec-0", records[2].ID)
|
||||
}
|
||||
|
||||
func TestBackupService_TestS3Connection(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, store)
|
||||
|
||||
err := svc.TestS3Connection(context.Background(), BackupS3Config{
|
||||
Bucket: "test",
|
||||
AccessKeyID: "ak",
|
||||
SecretAccessKey: "sk",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_TestS3Connection_Incomplete(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
err := svc.TestS3Connection(context.Background(), BackupS3Config{
|
||||
Bucket: "test",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "incomplete")
|
||||
}
|
||||
|
||||
func TestBackupService_Schedule_CronValidation(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
svc.cronSched = nil // 未初始化 cron
|
||||
|
||||
// 启用但 cron 为空
|
||||
_, err := svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
|
||||
Enabled: true,
|
||||
CronExpr: "",
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
// 无效的 cron 表达式
|
||||
_, err = svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
|
||||
Enabled: true,
|
||||
CronExpr: "invalid",
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_LoadS3Config_Corrupted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
_ = repo.Set(context.Background(), settingKeyBackupS3Config, "not json!!!!")
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
cfg, err := svc.loadS3Config(context.Background())
|
||||
require.Error(t, err)
|
||||
require.Nil(t, cfg)
|
||||
}
|
||||
@@ -29,12 +29,11 @@ const (
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock)
|
||||
AccountTypeBedrockAPIKey = domain.AccountTypeBedrockAPIKey // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock)
|
||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@@ -221,6 +220,9 @@ const (
|
||||
|
||||
// SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false:未分组 Key 返回 403)
|
||||
SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling"
|
||||
|
||||
// SettingKeyBackendModeEnabled Backend 模式:禁用用户注册和自助服务,仅管理员可登录
|
||||
SettingKeyBackendModeEnabled = "backend_mode_enabled"
|
||||
)
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
|
||||
@@ -110,7 +110,9 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_401_second_hit_upgrades_to_none",
|
||||
// Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制),
|
||||
// second hit 仍然返回 TempUnscheduled。
|
||||
name: "temp_unschedulable_401_second_hit_antigravity_stays_temp",
|
||||
account: &Account{
|
||||
ID: 15,
|
||||
Type: AccountTypeOAuth,
|
||||
@@ -129,7 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyNone,
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_body_miss_returns_none",
|
||||
|
||||
@@ -2173,10 +2173,10 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []
|
||||
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
|
||||
}
|
||||
|
||||
// isAccountSchedulableForQuota 检查 API Key 账号是否在配额限制内
|
||||
// 仅适用于配置了 quota_limit 的 apikey 类型账号
|
||||
// isAccountSchedulableForQuota 检查账号是否在配额限制内
|
||||
// 适用于配置了 quota_limit 的 apikey 和 bedrock 类型账号
|
||||
func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool {
|
||||
if account.Type != AccountTypeAPIKey {
|
||||
if !account.IsAPIKeyOrBedrock() {
|
||||
return true
|
||||
}
|
||||
return !account.IsQuotaExceeded()
|
||||
@@ -3532,9 +3532,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
||||
}
|
||||
return apiKey, "apikey", nil
|
||||
case AccountTypeBedrock:
|
||||
return "", "bedrock", nil // Bedrock 使用 SigV4 签名,不需要 token
|
||||
case AccountTypeBedrockAPIKey:
|
||||
return "", "bedrock-apikey", nil // Bedrock API Key 使用 Bearer Token,由 forwardBedrock 处理
|
||||
return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理
|
||||
default:
|
||||
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
|
||||
}
|
||||
@@ -5186,7 +5184,7 @@ func (s *GatewayService) forwardBedrock(
|
||||
if account.IsBedrockAPIKey() {
|
||||
bedrockAPIKey = account.GetCredential("api_key")
|
||||
if bedrockAPIKey == "" {
|
||||
return nil, fmt.Errorf("api_key not found in bedrock-apikey credentials")
|
||||
return nil, fmt.Errorf("api_key not found in bedrock credentials")
|
||||
}
|
||||
} else {
|
||||
signer, err = NewBedrockSignerFromAccount(account)
|
||||
@@ -5375,8 +5373,9 @@ func (s *GatewayService) handleBedrockUpstreamErrors(
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
@@ -5398,8 +5397,9 @@ func (s *GatewayService) handleBedrockUpstreamErrors(
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5808,9 +5808,10 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri
|
||||
return betaPolicyResult{}
|
||||
}
|
||||
isOAuth := account.IsOAuth()
|
||||
isBedrock := account.IsBedrock()
|
||||
var result betaPolicyResult
|
||||
for _, rule := range settings.Rules {
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth) {
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
||||
continue
|
||||
}
|
||||
switch rule.Action {
|
||||
@@ -5870,14 +5871,16 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont
|
||||
}
|
||||
|
||||
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
|
||||
func betaPolicyScopeMatches(scope string, isOAuth bool) bool {
|
||||
func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool {
|
||||
switch scope {
|
||||
case BetaPolicyScopeAll:
|
||||
return true
|
||||
case BetaPolicyScopeOAuth:
|
||||
return isOAuth
|
||||
case BetaPolicyScopeAPIKey:
|
||||
return !isOAuth
|
||||
return !isOAuth && !isBedrock
|
||||
case BetaPolicyScopeBedrock:
|
||||
return isBedrock
|
||||
default:
|
||||
return true // unknown scope → match all (fail-open)
|
||||
}
|
||||
@@ -5959,12 +5962,13 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke
|
||||
return nil
|
||||
}
|
||||
isOAuth := account.IsOAuth()
|
||||
isBedrock := account.IsBedrock()
|
||||
tokenSet := buildBetaTokenSet(tokens)
|
||||
for _, rule := range settings.Rules {
|
||||
if rule.Action != BetaPolicyActionBlock {
|
||||
continue
|
||||
}
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth) {
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
||||
continue
|
||||
}
|
||||
if _, present := tokenSet[rule.BetaToken]; present {
|
||||
@@ -6125,6 +6129,29 @@ func extractUpstreamErrorMessage(body []byte) string {
|
||||
return gjson.GetBytes(body, "message").String()
|
||||
}
|
||||
|
||||
func extractUpstreamErrorCode(body []byte) string {
|
||||
if code := strings.TrimSpace(gjson.GetBytes(body, "error.code").String()); code != "" {
|
||||
return code
|
||||
}
|
||||
|
||||
inner := strings.TrimSpace(gjson.GetBytes(body, "error.message").String())
|
||||
if !strings.HasPrefix(inner, "{") {
|
||||
return ""
|
||||
}
|
||||
|
||||
if code := strings.TrimSpace(gjson.Get(inner, "error.code").String()); code != "" {
|
||||
return code
|
||||
}
|
||||
|
||||
if lastBrace := strings.LastIndex(inner, "}"); lastBrace >= 0 {
|
||||
if code := strings.TrimSpace(gjson.Get(inner[:lastBrace+1], "error.code").String()); code != "" {
|
||||
return code
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func isCountTokensUnsupported404(statusCode int, body []byte) bool {
|
||||
if statusCode != http.StatusNotFound {
|
||||
return false
|
||||
@@ -7176,7 +7203,7 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
||||
}
|
||||
|
||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||
if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
|
||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
|
||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||
@@ -7264,7 +7291,7 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
|
||||
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
|
||||
}
|
||||
if p.Cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||
if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
|
||||
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
|
||||
}
|
||||
|
||||
|
||||
@@ -339,8 +339,9 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
}
|
||||
typ, _ := m["type"].(string)
|
||||
|
||||
// 修复 OpenAI 上游的最新校验:"Expected an ID that begins with 'fc'"
|
||||
fixIDPrefix := func(id string) string {
|
||||
// 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id;
|
||||
// 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。
|
||||
fixCallIDPrefix := func(id string) string {
|
||||
if id == "" || strings.HasPrefix(id, "fc") {
|
||||
return id
|
||||
}
|
||||
@@ -358,8 +359,8 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
for key, value := range m {
|
||||
newItem[key] = value
|
||||
}
|
||||
if id, ok := newItem["id"].(string); ok && id != "" {
|
||||
newItem["id"] = fixIDPrefix(id)
|
||||
if id, ok := newItem["id"].(string); ok && strings.HasPrefix(id, "call_") {
|
||||
newItem["id"] = fixCallIDPrefix(id)
|
||||
}
|
||||
filtered = append(filtered, newItem)
|
||||
continue
|
||||
@@ -390,7 +391,7 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
}
|
||||
|
||||
if callID != "" {
|
||||
fixedCallID := fixIDPrefix(callID)
|
||||
fixedCallID := fixCallIDPrefix(callID)
|
||||
if fixedCallID != callID {
|
||||
ensureCopy()
|
||||
newItem["call_id"] = fixedCallID
|
||||
@@ -404,14 +405,6 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
if !isCodexToolCallItemType(typ) {
|
||||
delete(newItem, "call_id")
|
||||
}
|
||||
} else {
|
||||
if id, ok := newItem["id"].(string); ok && id != "" {
|
||||
fixedID := fixIDPrefix(id)
|
||||
if fixedID != id {
|
||||
ensureCopy()
|
||||
newItem["id"] = fixedID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
filtered = append(filtered, newItem)
|
||||
|
||||
@@ -33,12 +33,63 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||
first, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "item_reference", first["type"])
|
||||
require.Equal(t, "fc_ref1", first["id"])
|
||||
require.Equal(t, "ref1", first["id"])
|
||||
|
||||
// 校验 input[1] 为 map,确保后续字段断言安全。
|
||||
second, ok := input[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "fc_o1", second["id"])
|
||||
require.Equal(t, "o1", second["id"])
|
||||
require.Equal(t, "fc1", second["call_id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ToolContinuationPreservesNativeMessageAndReasoningIDs(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.2",
|
||||
"input": []any{
|
||||
map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"},
|
||||
map[string]any{"type": "item_reference", "id": "rs_123"},
|
||||
},
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 2)
|
||||
|
||||
first, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "msg_0", first["id"])
|
||||
|
||||
second, ok := input[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "rs_123", second["id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.2",
|
||||
"input": []any{
|
||||
map[string]any{"type": "item_reference", "id": "call_1"},
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok"},
|
||||
},
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 2)
|
||||
|
||||
first, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "fc1", first["id"])
|
||||
|
||||
second, ok := input[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "fc1", second["call_id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
||||
|
||||
@@ -51,10 +51,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
}
|
||||
|
||||
// 3. Model mapping
|
||||
mappedModel := account.GetMappedModel(originalModel)
|
||||
if mappedModel == originalModel && defaultMappedModel != "" {
|
||||
mappedModel = defaultMappedModel
|
||||
}
|
||||
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||
responsesReq.Model = mappedModel
|
||||
|
||||
logger.L().Debug("openai chat_completions: model mapping applied",
|
||||
|
||||
@@ -59,11 +59,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
}
|
||||
|
||||
// 3. Model mapping
|
||||
mappedModel := account.GetMappedModel(originalModel)
|
||||
// 分组级降级:账号未映射时使用分组默认映射模型
|
||||
if mappedModel == originalModel && defaultMappedModel != "" {
|
||||
mappedModel = defaultMappedModel
|
||||
}
|
||||
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||
responsesReq.Model = mappedModel
|
||||
|
||||
logger.L().Debug("openai messages: model mapping applied",
|
||||
|
||||
@@ -480,6 +480,7 @@ func classifyOpenAIWSReconnectReason(err error) (string, bool) {
|
||||
"upgrade_required",
|
||||
"ws_unsupported",
|
||||
"auth_failed",
|
||||
"invalid_encrypted_content",
|
||||
"previous_response_not_found":
|
||||
return reason, false
|
||||
}
|
||||
@@ -530,6 +531,14 @@ func resolveOpenAIWSFallbackErrorResponse(err error) (statusCode int, errType st
|
||||
}
|
||||
|
||||
switch reason {
|
||||
case "invalid_encrypted_content":
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusBadRequest
|
||||
}
|
||||
errType = "invalid_request_error"
|
||||
if upstreamMessage == "" {
|
||||
upstreamMessage = "encrypted content could not be verified"
|
||||
}
|
||||
case "previous_response_not_found":
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusBadRequest
|
||||
@@ -1924,6 +1933,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
var wsErr error
|
||||
wsLastFailureReason := ""
|
||||
wsPrevResponseRecoveryTried := false
|
||||
wsInvalidEncryptedContentRecoveryTried := false
|
||||
recoverPrevResponseNotFound := func(attempt int) bool {
|
||||
if wsPrevResponseRecoveryTried {
|
||||
return false
|
||||
@@ -1956,6 +1966,37 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
)
|
||||
return true
|
||||
}
|
||||
recoverInvalidEncryptedContent := func(attempt int) bool {
|
||||
if wsInvalidEncryptedContentRecoveryTried {
|
||||
return false
|
||||
}
|
||||
removedReasoningItems := trimOpenAIEncryptedReasoningItems(wsReqBody)
|
||||
if !removedReasoningItems {
|
||||
logOpenAIWSModeInfo(
|
||||
"reconnect_invalid_encrypted_content_recovery_skip account_id=%d attempt=%d reason=missing_encrypted_reasoning_items",
|
||||
account.ID,
|
||||
attempt,
|
||||
)
|
||||
return false
|
||||
}
|
||||
previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id")
|
||||
hasFunctionCallOutput := HasFunctionCallOutput(wsReqBody)
|
||||
if previousResponseID != "" && !hasFunctionCallOutput {
|
||||
delete(wsReqBody, "previous_response_id")
|
||||
}
|
||||
wsInvalidEncryptedContentRecoveryTried = true
|
||||
logOpenAIWSModeInfo(
|
||||
"reconnect_invalid_encrypted_content_recovery account_id=%d attempt=%d action=drop_encrypted_reasoning_items retry=1 previous_response_id_present=%v previous_response_id=%s previous_response_id_kind=%s has_function_call_output=%v dropped_previous_response_id=%v",
|
||||
account.ID,
|
||||
attempt,
|
||||
previousResponseID != "",
|
||||
truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
|
||||
normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)),
|
||||
hasFunctionCallOutput,
|
||||
previousResponseID != "" && !hasFunctionCallOutput,
|
||||
)
|
||||
return true
|
||||
}
|
||||
retryBudget := s.openAIWSRetryTotalBudget()
|
||||
retryStartedAt := time.Now()
|
||||
wsRetryLoop:
|
||||
@@ -1992,6 +2033,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
if reason == "previous_response_not_found" && recoverPrevResponseNotFound(attempt) {
|
||||
continue
|
||||
}
|
||||
if reason == "invalid_encrypted_content" && recoverInvalidEncryptedContent(attempt) {
|
||||
continue
|
||||
}
|
||||
if retryable && attempt < maxAttempts {
|
||||
backoff := s.openAIWSRetryBackoff(attempt)
|
||||
if retryBudget > 0 && time.Since(retryStartedAt)+backoff > retryBudget {
|
||||
@@ -2075,126 +2119,143 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
return nil, wsErr
|
||||
}
|
||||
|
||||
// Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpInvalidEncryptedContentRetryTried := false
|
||||
for {
|
||||
// Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// Send request
|
||||
upstreamStart := time.Now()
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
||||
if err != nil {
|
||||
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
// Send request
|
||||
upstreamStart := time.Now()
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
||||
if err != nil {
|
||||
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||
// Handle error response
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamCode := extractUpstreamErrorCode(respBody)
|
||||
if !httpInvalidEncryptedContentRetryTried && resp.StatusCode == http.StatusBadRequest && upstreamCode == "invalid_encrypted_content" {
|
||||
if trimOpenAIEncryptedReasoningItems(reqBody) {
|
||||
body, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialize invalid_encrypted_content retry body: %w", err)
|
||||
}
|
||||
setOpsUpstreamRequestBody(c, body)
|
||||
httpInvalidEncryptedContentRetryTried = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Retrying non-WSv2 request once after invalid_encrypted_content (account: %s)", account.Name)
|
||||
continue
|
||||
}
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Skip non-WSv2 invalid_encrypted_content retry because encrypted reasoning items are missing (account: %s)", account.Name)
|
||||
}
|
||||
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, body)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// Handle normal response
|
||||
var usage *OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, body)
|
||||
}
|
||||
|
||||
// Handle normal response
|
||||
var usage *OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||
if account.Type == AccountTypeOAuth {
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||
}
|
||||
}
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
if usage == nil {
|
||||
usage = &OpenAIUsage{}
|
||||
}
|
||||
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
|
||||
serviceTier := extractOpenAIServiceTier(reqBody)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
ServiceTier: serviceTier,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: reqStream,
|
||||
OpenAIWSMode: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||
if account.Type == AccountTypeOAuth {
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||
}
|
||||
}
|
||||
|
||||
if usage == nil {
|
||||
usage = &OpenAIUsage{}
|
||||
}
|
||||
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
|
||||
serviceTier := extractOpenAIServiceTier(reqBody)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
ServiceTier: serviceTier,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: reqStream,
|
||||
OpenAIWSMode: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
@@ -3756,6 +3817,109 @@ func buildOpenAIResponsesURL(base string) string {
|
||||
return normalized + "/v1/responses"
|
||||
}
|
||||
|
||||
func trimOpenAIEncryptedReasoningItems(reqBody map[string]any) bool {
|
||||
if len(reqBody) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
inputValue, has := reqBody["input"]
|
||||
if !has {
|
||||
return false
|
||||
}
|
||||
|
||||
switch input := inputValue.(type) {
|
||||
case []any:
|
||||
filtered := input[:0]
|
||||
changed := false
|
||||
for _, item := range input {
|
||||
nextItem, itemChanged, keep := sanitizeEncryptedReasoningInputItem(item)
|
||||
if itemChanged {
|
||||
changed = true
|
||||
}
|
||||
if !keep {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, nextItem)
|
||||
}
|
||||
if !changed {
|
||||
return false
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
delete(reqBody, "input")
|
||||
return true
|
||||
}
|
||||
reqBody["input"] = filtered
|
||||
return true
|
||||
case []map[string]any:
|
||||
filtered := input[:0]
|
||||
changed := false
|
||||
for _, item := range input {
|
||||
nextItem, itemChanged, keep := sanitizeEncryptedReasoningInputItem(item)
|
||||
if itemChanged {
|
||||
changed = true
|
||||
}
|
||||
if !keep {
|
||||
continue
|
||||
}
|
||||
nextMap, ok := nextItem.(map[string]any)
|
||||
if !ok {
|
||||
filtered = append(filtered, item)
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, nextMap)
|
||||
}
|
||||
if !changed {
|
||||
return false
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
delete(reqBody, "input")
|
||||
return true
|
||||
}
|
||||
reqBody["input"] = filtered
|
||||
return true
|
||||
case map[string]any:
|
||||
nextItem, changed, keep := sanitizeEncryptedReasoningInputItem(input)
|
||||
if !changed {
|
||||
return false
|
||||
}
|
||||
if !keep {
|
||||
delete(reqBody, "input")
|
||||
return true
|
||||
}
|
||||
nextMap, ok := nextItem.(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
reqBody["input"] = nextMap
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeEncryptedReasoningInputItem(item any) (next any, changed bool, keep bool) {
|
||||
inputItem, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
return item, false, true
|
||||
}
|
||||
|
||||
itemType, _ := inputItem["type"].(string)
|
||||
if strings.TrimSpace(itemType) != "reasoning" {
|
||||
return item, false, true
|
||||
}
|
||||
|
||||
_, hasEncryptedContent := inputItem["encrypted_content"]
|
||||
if !hasEncryptedContent {
|
||||
return item, false, true
|
||||
}
|
||||
|
||||
delete(inputItem, "encrypted_content")
|
||||
if len(inputItem) == 1 {
|
||||
return nil, true, false
|
||||
}
|
||||
return inputItem, true, true
|
||||
}
|
||||
|
||||
func IsOpenAIResponsesCompactPathForTest(c *gin.Context) bool {
|
||||
return isOpenAIResponsesCompactPath(c)
|
||||
}
|
||||
|
||||
19
backend/internal/service/openai_model_mapping.go
Normal file
19
backend/internal/service/openai_model_mapping.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package service
|
||||
|
||||
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
|
||||
// forwarding. Group-level default mapping only applies when the account itself
|
||||
// did not match any explicit model_mapping rule.
|
||||
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
|
||||
if account == nil {
|
||||
if defaultMappedModel != "" {
|
||||
return defaultMappedModel
|
||||
}
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
mappedModel, matched := account.ResolveMappedModel(requestedModel)
|
||||
if !matched && defaultMappedModel != "" {
|
||||
return defaultMappedModel
|
||||
}
|
||||
return mappedModel
|
||||
}
|
||||
70
backend/internal/service/openai_model_mapping_test.go
Normal file
70
backend/internal/service/openai_model_mapping_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestResolveOpenAIForwardModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
requestedModel string
|
||||
defaultMappedModel string
|
||||
expectedModel string
|
||||
}{
|
||||
{
|
||||
name: "falls back to group default when account has no mapping",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-4o-mini",
|
||||
},
|
||||
{
|
||||
name: "preserves exact passthrough mapping instead of group default",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-5.4",
|
||||
},
|
||||
{
|
||||
name: "preserves wildcard passthrough mapping instead of group default",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-*": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-5.4",
|
||||
},
|
||||
{
|
||||
name: "uses account remap when explicit target differs",
|
||||
account: &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5",
|
||||
defaultMappedModel: "gpt-4o-mini",
|
||||
expectedModel: "gpt-5.4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveOpenAIForwardModel(tt.account, tt.requestedModel, tt.defaultMappedModel); got != tt.expectedModel {
|
||||
t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", got, tt.expectedModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3922,6 +3922,8 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri
|
||||
return "ws_unsupported", true
|
||||
case "websocket_connection_limit_reached":
|
||||
return "ws_connection_limit_reached", true
|
||||
case "invalid_encrypted_content":
|
||||
return "invalid_encrypted_content", true
|
||||
case "previous_response_not_found":
|
||||
return "previous_response_not_found", true
|
||||
}
|
||||
@@ -3940,6 +3942,10 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri
|
||||
if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") {
|
||||
return "ws_connection_limit_reached", true
|
||||
}
|
||||
if strings.Contains(msg, "invalid_encrypted_content") ||
|
||||
(strings.Contains(msg, "encrypted content") && strings.Contains(msg, "could not be verified")) {
|
||||
return "invalid_encrypted_content", true
|
||||
}
|
||||
if strings.Contains(msg, "previous_response_not_found") ||
|
||||
(strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) {
|
||||
return "previous_response_not_found", true
|
||||
@@ -3964,6 +3970,7 @@ func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
|
||||
case strings.Contains(errType, "invalid_request"),
|
||||
strings.Contains(code, "invalid_request"),
|
||||
strings.Contains(code, "bad_request"),
|
||||
code == "invalid_encrypted_content",
|
||||
code == "previous_response_not_found":
|
||||
return http.StatusBadRequest
|
||||
case strings.Contains(errType, "authentication"),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
@@ -19,6 +20,47 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type httpUpstreamSequenceRecorder struct {
|
||||
mu sync.Mutex
|
||||
bodies [][]byte
|
||||
reqs []*http.Request
|
||||
|
||||
responses []*http.Response
|
||||
errs []error
|
||||
callCount int
|
||||
}
|
||||
|
||||
func (u *httpUpstreamSequenceRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
idx := u.callCount
|
||||
u.callCount++
|
||||
u.reqs = append(u.reqs, req)
|
||||
if req != nil && req.Body != nil {
|
||||
b, _ := io.ReadAll(req.Body)
|
||||
u.bodies = append(u.bodies, b)
|
||||
_ = req.Body.Close()
|
||||
req.Body = io.NopCloser(bytes.NewReader(b))
|
||||
} else {
|
||||
u.bodies = append(u.bodies, nil)
|
||||
}
|
||||
if idx < len(u.errs) && u.errs[idx] != nil {
|
||||
return nil, u.errs[idx]
|
||||
}
|
||||
if idx < len(u.responses) {
|
||||
return u.responses[idx], nil
|
||||
}
|
||||
if len(u.responses) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return u.responses[len(u.responses)-1], nil
|
||||
}
|
||||
|
||||
func (u *httpUpstreamSequenceRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -143,6 +185,176 @@ func TestOpenAIGatewayService_Forward_HTTPIngressStaysHTTPWhenWSEnabled(t *testi
|
||||
require.Equal(t, "client_protocol_http", reason)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_HTTPIngressRetriesInvalidEncryptedContentOnce(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer wsFallbackServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "custom-client/1.0")
|
||||
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
|
||||
|
||||
upstream := &httpUpstreamSequenceRecorder{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`{"error":{"code":"invalid_encrypted_content","type":"invalid_request_error","message":"The encrypted content could not be verified."}}`,
|
||||
)),
|
||||
},
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`{"id":"resp_http_retry_ok","usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
|
||||
)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: upstream,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsFallbackServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_retry","input":[{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me"}]},{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发")
|
||||
require.Equal(t, 2, upstream.callCount, "命中 invalid_encrypted_content 后应只在 HTTP 路径重试一次")
|
||||
require.Len(t, upstream.bodies, 2)
|
||||
|
||||
firstBody := upstream.bodies[0]
|
||||
secondBody := upstream.bodies[1]
|
||||
require.False(t, gjson.GetBytes(firstBody, "previous_response_id").Exists(), "HTTP 首次请求仍应沿用原逻辑移除 previous_response_id")
|
||||
require.True(t, gjson.GetBytes(firstBody, "input.0.encrypted_content").Exists(), "首次请求不应做发送前预清理")
|
||||
require.Equal(t, "keep me", gjson.GetBytes(firstBody, "input.0.summary.0.text").String())
|
||||
|
||||
require.False(t, gjson.GetBytes(secondBody, "previous_response_id").Exists(), "HTTP 精确重试不应重新带回 previous_response_id")
|
||||
require.False(t, gjson.GetBytes(secondBody, "input.0.encrypted_content").Exists(), "精确重试应移除 reasoning.encrypted_content")
|
||||
require.Equal(t, "keep me", gjson.GetBytes(secondBody, "input.0.summary.0.text").String(), "精确重试应保留有效 reasoning summary")
|
||||
require.Equal(t, "input_text", gjson.GetBytes(secondBody, "input.1.type").String(), "非 reasoning input 应保持原样")
|
||||
|
||||
decision, _ := c.Get("openai_ws_transport_decision")
|
||||
reason, _ := c.Get("openai_ws_transport_reason")
|
||||
require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision)
|
||||
require.Equal(t, "client_protocol_http", reason)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_HTTPIngressRetriesWrappedInvalidEncryptedContentOnce(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer wsFallbackServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "custom-client/1.0")
|
||||
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
|
||||
|
||||
upstream := &httpUpstreamSequenceRecorder{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`{"error":{"code":null,"message":"{\"error\":{\"message\":\"The encrypted content could not be verified.\",\"type\":\"invalid_request_error\",\"param\":null,\"code\":\"invalid_encrypted_content\"}}(traceid: fb7ad1dbc7699c18f8a02f258f1af5ab)","param":null,"type":"invalid_request_error"}}`,
|
||||
)),
|
||||
},
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"x-request-id": []string{"req_http_retry_wrapped_ok"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`{"id":"resp_http_retry_wrapped_ok","usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
|
||||
)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: upstream,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Name: "openai-apikey-wrapped",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsFallbackServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_retry_wrapped","input":[{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me too"}]},{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发")
|
||||
require.Equal(t, 2, upstream.callCount, "wrapped invalid_encrypted_content 也应只在 HTTP 路径重试一次")
|
||||
require.Len(t, upstream.bodies, 2)
|
||||
|
||||
firstBody := upstream.bodies[0]
|
||||
secondBody := upstream.bodies[1]
|
||||
require.True(t, gjson.GetBytes(firstBody, "input.0.encrypted_content").Exists(), "首次请求不应做发送前预清理")
|
||||
require.False(t, gjson.GetBytes(secondBody, "input.0.encrypted_content").Exists(), "wrapped exact retry 应移除 reasoning.encrypted_content")
|
||||
require.Equal(t, "keep me too", gjson.GetBytes(secondBody, "input.0.summary.0.text").String(), "wrapped exact retry 应保留有效 reasoning summary")
|
||||
|
||||
decision, _ := c.Get("openai_ws_transport_decision")
|
||||
reason, _ := c.Get("openai_ws_transport_reason")
|
||||
require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision)
|
||||
require.Equal(t, "client_protocol_http", reason)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -1218,3 +1430,460 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOn
|
||||
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应包含 previous_response_id")
|
||||
require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentRecoversOnce(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var wsAttempts atomic.Int32
|
||||
var wsRequestPayloads [][]byte
|
||||
var wsRequestMu sync.Mutex
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempt := wsAttempts.Add(1)
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
var req map[string]any
|
||||
if err := conn.ReadJSON(&req); err != nil {
|
||||
t.Errorf("read ws request failed: %v", err)
|
||||
return
|
||||
}
|
||||
reqRaw, _ := json.Marshal(req)
|
||||
wsRequestMu.Lock()
|
||||
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
|
||||
wsRequestMu.Unlock()
|
||||
if attempt == 1 {
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"code": "invalid_encrypted_content",
|
||||
"type": "invalid_request_error",
|
||||
"message": "The encrypted content could not be verified.",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "response.completed",
|
||||
"response": map[string]any{
|
||||
"id": "resp_ws_invalid_encrypted_content_recover_ok",
|
||||
"model": "gpt-5.3-codex",
|
||||
"usage": map[string]any{
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
"input_tokens_details": map[string]any{
|
||||
"cached_tokens": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "custom-client/1.0")
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: upstream,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 95,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_encrypted","input":[{"type":"reasoning","encrypted_content":"gAAA"},{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "resp_ws_invalid_encrypted_content_recover_ok", result.RequestID)
|
||||
require.Nil(t, upstream.lastReq, "invalid_encrypted_content 不应回退 HTTP")
|
||||
require.Equal(t, int32(2), wsAttempts.Load(), "invalid_encrypted_content 应触发一次清洗后重试")
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "resp_ws_invalid_encrypted_content_recover_ok", gjson.Get(rec.Body.String(), "id").String())
|
||||
|
||||
wsRequestMu.Lock()
|
||||
requests := append([][]byte(nil), wsRequestPayloads...)
|
||||
wsRequestMu.Unlock()
|
||||
require.Len(t, requests, 2)
|
||||
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id")
|
||||
require.True(t, gjson.GetBytes(requests[0], `input.0.encrypted_content`).Exists(), "首轮请求应保留 encrypted reasoning")
|
||||
require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
|
||||
require.False(t, gjson.GetBytes(requests[1], `input.0.encrypted_content`).Exists(), "恢复重试应移除 encrypted reasoning item")
|
||||
require.Equal(t, "input_text", gjson.GetBytes(requests[1], `input.0.type`).String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentSkipsRecoveryWithoutReasoningItem(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var wsAttempts atomic.Int32
|
||||
var wsRequestPayloads [][]byte
|
||||
var wsRequestMu sync.Mutex
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
wsAttempts.Add(1)
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
var req map[string]any
|
||||
if err := conn.ReadJSON(&req); err != nil {
|
||||
t.Errorf("read ws request failed: %v", err)
|
||||
return
|
||||
}
|
||||
reqRaw, _ := json.Marshal(req)
|
||||
wsRequestMu.Lock()
|
||||
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
|
||||
wsRequestMu.Unlock()
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"code": "invalid_encrypted_content",
|
||||
"type": "invalid_request_error",
|
||||
"message": "The encrypted content could not be verified.",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "custom-client/1.0")
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: upstream,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 96,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_encrypted","input":[{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Nil(t, upstream.lastReq, "invalid_encrypted_content 不应回退 HTTP")
|
||||
require.Equal(t, int32(1), wsAttempts.Load(), "缺少 reasoning encrypted item 时应跳过自动恢复重试")
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
require.Contains(t, strings.ToLower(rec.Body.String()), "encrypted content")
|
||||
|
||||
wsRequestMu.Lock()
|
||||
requests := append([][]byte(nil), wsRequestPayloads...)
|
||||
wsRequestMu.Unlock()
|
||||
require.Len(t, requests, 1)
|
||||
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists())
|
||||
require.False(t, gjson.GetBytes(requests[0], `input.0.encrypted_content`).Exists())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentRecoversSingleObjectInputAndKeepsSummary(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var wsAttempts atomic.Int32
|
||||
var wsRequestPayloads [][]byte
|
||||
var wsRequestMu sync.Mutex
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempt := wsAttempts.Add(1)
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
var req map[string]any
|
||||
if err := conn.ReadJSON(&req); err != nil {
|
||||
t.Errorf("read ws request failed: %v", err)
|
||||
return
|
||||
}
|
||||
reqRaw, _ := json.Marshal(req)
|
||||
wsRequestMu.Lock()
|
||||
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
|
||||
wsRequestMu.Unlock()
|
||||
if attempt == 1 {
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"code": "invalid_encrypted_content",
|
||||
"type": "invalid_request_error",
|
||||
"message": "The encrypted content could not be verified.",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "response.completed",
|
||||
"response": map[string]any{
|
||||
"id": "resp_ws_invalid_encrypted_content_object_ok",
|
||||
"model": "gpt-5.3-codex",
|
||||
"usage": map[string]any{
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
"input_tokens_details": map[string]any{
|
||||
"cached_tokens": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "custom-client/1.0")
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: upstream,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 97,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_encrypted","input":{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me"}]}}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "resp_ws_invalid_encrypted_content_object_ok", result.RequestID)
|
||||
require.Nil(t, upstream.lastReq, "invalid_encrypted_content 单对象 input 不应回退 HTTP")
|
||||
require.Equal(t, int32(2), wsAttempts.Load(), "单对象 reasoning input 也应触发一次清洗后重试")
|
||||
|
||||
wsRequestMu.Lock()
|
||||
requests := append([][]byte(nil), wsRequestPayloads...)
|
||||
wsRequestMu.Unlock()
|
||||
require.Len(t, requests, 2)
|
||||
require.True(t, gjson.GetBytes(requests[0], `input.encrypted_content`).Exists(), "首轮单对象应保留 encrypted_content")
|
||||
require.True(t, gjson.GetBytes(requests[1], `input.summary.0.text`).Exists(), "恢复重试应保留 reasoning summary")
|
||||
require.False(t, gjson.GetBytes(requests[1], `input.encrypted_content`).Exists(), "恢复重试只应移除 encrypted_content")
|
||||
require.Equal(t, "reasoning", gjson.GetBytes(requests[1], `input.type`).String())
|
||||
require.False(t, gjson.GetBytes(requests[1], `previous_response_id`).Exists(), "恢复重试应移除 previous_response_id")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentKeepsPreviousResponseIDForFunctionCallOutput(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var wsAttempts atomic.Int32
|
||||
var wsRequestPayloads [][]byte
|
||||
var wsRequestMu sync.Mutex
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempt := wsAttempts.Add(1)
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
var req map[string]any
|
||||
if err := conn.ReadJSON(&req); err != nil {
|
||||
t.Errorf("read ws request failed: %v", err)
|
||||
return
|
||||
}
|
||||
reqRaw, _ := json.Marshal(req)
|
||||
wsRequestMu.Lock()
|
||||
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
|
||||
wsRequestMu.Unlock()
|
||||
if attempt == 1 {
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"code": "invalid_encrypted_content",
|
||||
"type": "invalid_request_error",
|
||||
"message": "The encrypted content could not be verified.",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "response.completed",
|
||||
"response": map[string]any{
|
||||
"id": "resp_ws_invalid_encrypted_content_function_call_output_ok",
|
||||
"model": "gpt-5.3-codex",
|
||||
"usage": map[string]any{
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
"input_tokens_details": map[string]any{
|
||||
"cached_tokens": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "custom-client/1.0")
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: upstream,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 98,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_function_call","input":[{"type":"reasoning","encrypted_content":"gAAA"},{"type":"function_call_output","call_id":"call_123","output":"ok"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "resp_ws_invalid_encrypted_content_function_call_output_ok", result.RequestID)
|
||||
require.Nil(t, upstream.lastReq, "function_call_output + invalid_encrypted_content 不应回退 HTTP")
|
||||
require.Equal(t, int32(2), wsAttempts.Load(), "应只做一次保锚点的清洗后重试")
|
||||
|
||||
wsRequestMu.Lock()
|
||||
requests := append([][]byte(nil), wsRequestPayloads...)
|
||||
wsRequestMu.Unlock()
|
||||
require.Len(t, requests, 2)
|
||||
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id")
|
||||
require.True(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "function_call_output 恢复重试不应移除 previous_response_id")
|
||||
require.False(t, gjson.GetBytes(requests[1], `input.0.encrypted_content`).Exists(), "恢复重试应移除 reasoning encrypted_content")
|
||||
require.Equal(t, "function_call_output", gjson.GetBytes(requests[1], `input.0.type`).String(), "清洗后应保留 function_call_output 作为首个输入项")
|
||||
require.Equal(t, "call_123", gjson.GetBytes(requests[1], `input.0.call_id`).String())
|
||||
require.Equal(t, "ok", gjson.GetBytes(requests[1], `input.0.output`).String())
|
||||
require.Equal(t, "resp_prev_function_call", gjson.GetBytes(requests[1], "previous_response_id").String())
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ const (
|
||||
opsAggDailyInterval = 1 * time.Hour
|
||||
|
||||
// Keep in sync with ops retention target (vNext default 30d).
|
||||
opsAggBackfillWindow = 30 * 24 * time.Hour
|
||||
opsAggBackfillWindow = 1 * time.Hour
|
||||
|
||||
// Recompute overlap to absorb late-arriving rows near boundaries.
|
||||
opsAggHourlyOverlap = 2 * time.Hour
|
||||
@@ -36,7 +36,7 @@ const (
|
||||
// that may still receive late inserts.
|
||||
opsAggSafeDelay = 5 * time.Minute
|
||||
|
||||
opsAggMaxQueryTimeout = 3 * time.Second
|
||||
opsAggMaxQueryTimeout = 5 * time.Second
|
||||
opsAggHourlyTimeout = 5 * time.Minute
|
||||
opsAggDailyTimeout = 2 * time.Minute
|
||||
|
||||
|
||||
@@ -371,6 +371,8 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
|
||||
IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略
|
||||
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
|
||||
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
|
||||
DisplayOpenAITokenStats: false,
|
||||
DisplayAlertEvents: true,
|
||||
AutoRefreshEnabled: false,
|
||||
AutoRefreshIntervalSec: 30,
|
||||
}
|
||||
@@ -438,7 +440,7 @@ func (s *OpsService) GetOpsAdvancedSettings(ctx context.Context) (*OpsAdvancedSe
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &OpsAdvancedSettings{}
|
||||
cfg := defaultOpsAdvancedSettings()
|
||||
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
|
||||
return defaultCfg, nil
|
||||
}
|
||||
|
||||
97
backend/internal/service/ops_settings_advanced_test.go
Normal file
97
backend/internal/service/ops_settings_advanced_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetOpsAdvancedSettings_DefaultHidesOpenAITokenStats(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
svc := &OpsService{settingRepo: repo}
|
||||
|
||||
cfg, err := svc.GetOpsAdvancedSettings(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetOpsAdvancedSettings() error = %v", err)
|
||||
}
|
||||
if cfg.DisplayOpenAITokenStats {
|
||||
t.Fatalf("DisplayOpenAITokenStats = true, want false by default")
|
||||
}
|
||||
if !cfg.DisplayAlertEvents {
|
||||
t.Fatalf("DisplayAlertEvents = false, want true by default")
|
||||
}
|
||||
if repo.setCalls != 1 {
|
||||
t.Fatalf("expected defaults to be persisted once, got %d", repo.setCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateOpsAdvancedSettings_PersistsOpenAITokenStatsVisibility(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
svc := &OpsService{settingRepo: repo}
|
||||
|
||||
cfg := defaultOpsAdvancedSettings()
|
||||
cfg.DisplayOpenAITokenStats = true
|
||||
cfg.DisplayAlertEvents = false
|
||||
|
||||
updated, err := svc.UpdateOpsAdvancedSettings(context.Background(), cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateOpsAdvancedSettings() error = %v", err)
|
||||
}
|
||||
if !updated.DisplayOpenAITokenStats {
|
||||
t.Fatalf("DisplayOpenAITokenStats = false, want true")
|
||||
}
|
||||
if updated.DisplayAlertEvents {
|
||||
t.Fatalf("DisplayAlertEvents = true, want false")
|
||||
}
|
||||
|
||||
reloaded, err := svc.GetOpsAdvancedSettings(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetOpsAdvancedSettings() after update error = %v", err)
|
||||
}
|
||||
if !reloaded.DisplayOpenAITokenStats {
|
||||
t.Fatalf("reloaded DisplayOpenAITokenStats = false, want true")
|
||||
}
|
||||
if reloaded.DisplayAlertEvents {
|
||||
t.Fatalf("reloaded DisplayAlertEvents = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOpsAdvancedSettings_BackfillsNewDisplayFlagsFromDefaults(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
svc := &OpsService{settingRepo: repo}
|
||||
|
||||
legacyCfg := map[string]any{
|
||||
"data_retention": map[string]any{
|
||||
"cleanup_enabled": false,
|
||||
"cleanup_schedule": "0 2 * * *",
|
||||
"error_log_retention_days": 30,
|
||||
"minute_metrics_retention_days": 30,
|
||||
"hourly_metrics_retention_days": 30,
|
||||
},
|
||||
"aggregation": map[string]any{
|
||||
"aggregation_enabled": false,
|
||||
},
|
||||
"ignore_count_tokens_errors": true,
|
||||
"ignore_context_canceled": true,
|
||||
"ignore_no_available_accounts": false,
|
||||
"ignore_invalid_api_key_errors": false,
|
||||
"auto_refresh_enabled": false,
|
||||
"auto_refresh_interval_seconds": 30,
|
||||
}
|
||||
raw, err := json.Marshal(legacyCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal legacy config: %v", err)
|
||||
}
|
||||
repo.values[SettingKeyOpsAdvancedSettings] = string(raw)
|
||||
|
||||
cfg, err := svc.GetOpsAdvancedSettings(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetOpsAdvancedSettings() error = %v", err)
|
||||
}
|
||||
if cfg.DisplayOpenAITokenStats {
|
||||
t.Fatalf("DisplayOpenAITokenStats = true, want false default backfill")
|
||||
}
|
||||
if !cfg.DisplayAlertEvents {
|
||||
t.Fatalf("DisplayAlertEvents = false, want true default backfill")
|
||||
}
|
||||
}
|
||||
@@ -98,6 +98,8 @@ type OpsAdvancedSettings struct {
|
||||
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
|
||||
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
|
||||
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
|
||||
DisplayOpenAITokenStats bool `json:"display_openai_token_stats"`
|
||||
DisplayAlertEvents bool `json:"display_alert_events"`
|
||||
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
|
||||
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
|
||||
}
|
||||
|
||||
@@ -149,8 +149,9 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
}
|
||||
// 其他 400 错误(如参数问题)不处理,不禁用账号
|
||||
case 401:
|
||||
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
|
||||
if account.Type == AccountTypeOAuth {
|
||||
// OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为。
|
||||
// Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制。
|
||||
if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity {
|
||||
// 1. 失效缓存
|
||||
if s.tokenCacheInvalidator != nil {
|
||||
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
@@ -182,7 +183,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
}
|
||||
shouldDisable = true
|
||||
} else {
|
||||
// 非 OAuth 账号(APIKey):保持原有 SetError 行为
|
||||
// 非 OAuth / Antigravity OAuth:保持 SetError 行为
|
||||
msg := "Authentication failed (401): invalid or expired credentials"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Authentication failed (401): " + upstreamMsg
|
||||
@@ -199,11 +200,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
shouldDisable = true
|
||||
case 403:
|
||||
// 禁止访问:停止调度,记录错误
|
||||
msg := "Access forbidden (403): account may be suspended or lack permissions"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Access forbidden (403): " + upstreamMsg
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"service.ratelimit",
|
||||
"[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s",
|
||||
@@ -215,8 +211,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
upstreamMsg,
|
||||
truncateForLog(responseBody, 1024),
|
||||
)
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
shouldDisable = true
|
||||
shouldDisable = s.handle403(ctx, account, upstreamMsg, responseBody)
|
||||
case 429:
|
||||
s.handle429(ctx, account, headers, responseBody)
|
||||
shouldDisable = false
|
||||
@@ -621,6 +616,62 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account
|
||||
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
|
||||
}
|
||||
|
||||
// handle403 处理 403 Forbidden 错误
|
||||
// Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用;
|
||||
// 其他平台保持原有 SetError 行为。
|
||||
func (s *RateLimitService) handle403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody)
|
||||
}
|
||||
// 非 Antigravity 平台:保持原有行为
|
||||
msg := "Access forbidden (403): account may be suspended or lack permissions"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Access forbidden (403): " + upstreamMsg
|
||||
}
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
// handleAntigravity403 处理 Antigravity 平台的 403 错误
|
||||
// validation(需要验证)→ 永久 SetError(需人工去 Google 验证后恢复)
|
||||
// violation(违规封号)→ 永久 SetError(需人工处理)
|
||||
// generic(通用禁止)→ 永久 SetError
|
||||
func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
|
||||
fbType := classifyForbiddenType(string(responseBody))
|
||||
|
||||
switch fbType {
|
||||
case forbiddenTypeValidation:
|
||||
// VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复
|
||||
msg := "Validation required (403): account needs Google verification"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Validation required (403): " + upstreamMsg
|
||||
}
|
||||
if validationURL := extractValidationURL(string(responseBody)); validationURL != "" {
|
||||
msg += " | validation_url: " + validationURL
|
||||
}
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
return true
|
||||
|
||||
case forbiddenTypeViolation:
|
||||
// 违规封号: 永久禁用,需人工处理
|
||||
msg := "Account violation (403): terms of service violation"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Account violation (403): " + upstreamMsg
|
||||
}
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
return true
|
||||
|
||||
default:
|
||||
// 通用 403: 保持原有行为
|
||||
msg := "Access forbidden (403): account may be suspended or lack permissions"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Access forbidden (403): " + upstreamMsg
|
||||
}
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// handleCustomErrorCode 处理自定义错误码,停止账号调度
|
||||
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
|
||||
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
|
||||
@@ -1213,7 +1264,8 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac
|
||||
}
|
||||
// 401 首次命中可临时不可调度(给 token 刷新窗口);
|
||||
// 若历史上已因 401 进入过临时不可调度,则本次应升级为 error(返回 false 交由默认错误逻辑处理)。
|
||||
if statusCode == http.StatusUnauthorized {
|
||||
// Antigravity 跳过:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制,无需升级逻辑。
|
||||
if statusCode == http.StatusUnauthorized && account.Platform != PlatformAntigravity {
|
||||
reason := account.TempUnschedulableReason
|
||||
// 缓存可能没有 reason,从 DB 回退读取
|
||||
if reason == "" {
|
||||
|
||||
@@ -27,34 +27,68 @@ func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, e
|
||||
|
||||
func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
|
||||
// Scenario: cache account has empty TempUnschedulableReason (cache miss),
|
||||
// but DB account has a previous 401 record → should escalate to ErrorPolicyNone.
|
||||
repo := &dbFallbackRepoStub{
|
||||
dbAccount: &Account{
|
||||
ID: 20,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
},
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
// but DB account has a previous 401 record.
|
||||
// Non-Antigravity: should escalate to ErrorPolicyNone (second 401 = permanent error).
|
||||
// Antigravity: skips escalation logic (401 handled by applyErrorPolicy rules).
|
||||
t.Run("gemini_escalates", func(t *testing.T) {
|
||||
repo := &dbFallbackRepoStub{
|
||||
dbAccount: &Account{
|
||||
ID: 20,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
},
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
TempUnschedulableReason: "", // cache miss — reason is empty
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformGemini,
|
||||
TempUnschedulableReason: "",
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||
require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone")
|
||||
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||
require.Equal(t, ErrorPolicyNone, result, "gemini 401 with DB fallback showing previous 401 should escalate")
|
||||
})
|
||||
|
||||
t.Run("antigravity_stays_temp", func(t *testing.T) {
|
||||
repo := &dbFallbackRepoStub{
|
||||
dbAccount: &Account{
|
||||
ID: 20,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
},
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
TempUnschedulableReason: "",
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||
require.Equal(t, ErrorPolicyTempUnscheduled, result, "antigravity 401 skips escalation, stays temp-unscheduled")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) {
|
||||
|
||||
@@ -42,45 +42,56 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
}{
|
||||
{name: "gemini", platform: PlatformGemini},
|
||||
{name: "antigravity", platform: PlatformAntigravity},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: tt.platform,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": 401,
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": 30,
|
||||
"description": "custom rule",
|
||||
},
|
||||
t.Run("gemini", func(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": 401,
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": 30,
|
||||
"description": "custom rule",
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 0, repo.setErrorCalls)
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
})
|
||||
}
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 0, repo.setErrorCalls)
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
})
|
||||
|
||||
t.Run("antigravity_401_uses_SetError", func(t *testing.T) {
|
||||
// Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制,
|
||||
// HandleUpstreamError 中走 SetError 路径。
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Equal(t, 0, repo.tempCalls)
|
||||
require.Empty(t, invalidator.accounts)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
|
||||
|
||||
@@ -65,6 +65,19 @@ const minVersionErrorTTL = 5 * time.Second
|
||||
// minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context
|
||||
const minVersionDBTimeout = 5 * time.Second
|
||||
|
||||
// cachedBackendMode Backend Mode cache (in-process, 60s TTL)
|
||||
type cachedBackendMode struct {
|
||||
value bool
|
||||
expiresAt int64 // unix nano
|
||||
}
|
||||
|
||||
var backendModeCache atomic.Value // *cachedBackendMode
|
||||
var backendModeSF singleflight.Group
|
||||
|
||||
const backendModeCacheTTL = 60 * time.Second
|
||||
const backendModeErrorTTL = 5 * time.Second
|
||||
const backendModeDBTimeout = 5 * time.Second
|
||||
|
||||
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
|
||||
type DefaultSubscriptionGroupReader interface {
|
||||
GetByID(ctx context.Context, id int64) (*Group, error)
|
||||
@@ -128,6 +141,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeySoraClientEnabled,
|
||||
SettingKeyCustomMenuItems,
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
SettingKeyBackendModeEnabled,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -172,6 +186,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
||||
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -223,6 +238,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
Version string `json:"version,omitempty"`
|
||||
}{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
@@ -247,6 +263,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
Version: s.version,
|
||||
}, nil
|
||||
}
|
||||
@@ -461,6 +478,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
// 分组隔离
|
||||
updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling)
|
||||
|
||||
// Backend Mode
|
||||
updates[SettingKeyBackendModeEnabled] = strconv.FormatBool(settings.BackendModeEnabled)
|
||||
|
||||
err = s.settingRepo.SetMultiple(ctx, updates)
|
||||
if err == nil {
|
||||
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
|
||||
@@ -469,6 +489,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
value: settings.MinClaudeCodeVersion,
|
||||
expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(),
|
||||
})
|
||||
backendModeSF.Forget("backend_mode")
|
||||
backendModeCache.Store(&cachedBackendMode{
|
||||
value: settings.BackendModeEnabled,
|
||||
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
|
||||
})
|
||||
if s.onUpdate != nil {
|
||||
s.onUpdate() // Invalidate cache after settings update
|
||||
}
|
||||
@@ -525,6 +550,52 @@ func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// IsBackendModeEnabled checks if backend mode is enabled
|
||||
// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path
|
||||
func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool {
|
||||
if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil {
|
||||
if time.Now().UnixNano() < cached.expiresAt {
|
||||
return cached.value
|
||||
}
|
||||
}
|
||||
result, _, _ := backendModeSF.Do("backend_mode", func() (any, error) {
|
||||
if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil {
|
||||
if time.Now().UnixNano() < cached.expiresAt {
|
||||
return cached.value, nil
|
||||
}
|
||||
}
|
||||
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), backendModeDBTimeout)
|
||||
defer cancel()
|
||||
value, err := s.settingRepo.GetValue(dbCtx, SettingKeyBackendModeEnabled)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
// Setting not yet created (fresh install) - default to disabled with full TTL
|
||||
backendModeCache.Store(&cachedBackendMode{
|
||||
value: false,
|
||||
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
|
||||
})
|
||||
return false, nil
|
||||
}
|
||||
slog.Warn("failed to get backend_mode_enabled setting", "error", err)
|
||||
backendModeCache.Store(&cachedBackendMode{
|
||||
value: false,
|
||||
expiresAt: time.Now().Add(backendModeErrorTTL).UnixNano(),
|
||||
})
|
||||
return false, nil
|
||||
}
|
||||
enabled := value == "true"
|
||||
backendModeCache.Store(&cachedBackendMode{
|
||||
value: enabled,
|
||||
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
|
||||
})
|
||||
return enabled, nil
|
||||
})
|
||||
if val, ok := result.(bool); ok {
|
||||
return val
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsEmailVerifyEnabled 检查是否开启邮件验证
|
||||
func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled)
|
||||
@@ -719,6 +790,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
||||
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
||||
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
|
||||
}
|
||||
|
||||
// 解析整数类型
|
||||
@@ -1278,7 +1350,7 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be
|
||||
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
|
||||
}
|
||||
validScopes := map[string]bool{
|
||||
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true,
|
||||
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true,
|
||||
}
|
||||
|
||||
for i, rule := range settings.Rules {
|
||||
|
||||
199
backend/internal/service/setting_service_backend_mode_test.go
Normal file
199
backend/internal/service/setting_service_backend_mode_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type bmRepoStub struct {
|
||||
getValueFn func(ctx context.Context, key string) (string, error)
|
||||
calls int
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
s.calls++
|
||||
if s.getValueFn == nil {
|
||||
panic("unexpected GetValue call")
|
||||
}
|
||||
return s.getValueFn(ctx, key)
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
panic("unexpected SetMultiple call")
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
type bmUpdateRepoStub struct {
|
||||
updates map[string]string
|
||||
getValueFn func(ctx context.Context, key string) (string, error)
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
if s.getValueFn == nil {
|
||||
panic("unexpected GetValue call")
|
||||
}
|
||||
return s.getValueFn(ctx, key)
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
s.updates = make(map[string]string, len(settings))
|
||||
for k, v := range settings {
|
||||
s.updates[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func resetBackendModeTestCache(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendModeCache.Store((*cachedBackendMode)(nil))
|
||||
t.Cleanup(func() {
|
||||
backendModeCache.Store((*cachedBackendMode)(nil))
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsBackendModeEnabled_ReturnsTrue(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
repo := &bmRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "true", nil
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
require.True(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
|
||||
func TestIsBackendModeEnabled_ReturnsFalse(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
repo := &bmRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "false", nil
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
require.False(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
|
||||
func TestIsBackendModeEnabled_ReturnsFalseOnNotFound(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
repo := &bmRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "", ErrSettingNotFound
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
require.False(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
|
||||
func TestIsBackendModeEnabled_ReturnsFalseOnDBError(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
repo := &bmRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "", errors.New("db down")
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
require.False(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
|
||||
func TestIsBackendModeEnabled_CachesResult(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
repo := &bmRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "true", nil
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
require.True(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.True(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
|
||||
func TestUpdateSettings_InvalidatesBackendModeCache(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
backendModeCache.Store(&cachedBackendMode{
|
||||
value: true,
|
||||
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
|
||||
})
|
||||
|
||||
repo := &bmUpdateRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "true", nil
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
err := svc.UpdateSettings(context.Background(), &SystemSettings{
|
||||
BackendModeEnabled: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "false", repo.updates[SettingKeyBackendModeEnabled])
|
||||
require.False(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
}
|
||||
@@ -69,6 +69,9 @@ type SystemSettings struct {
|
||||
|
||||
// 分组隔离:允许未分组 Key 调度(默认 false → 403)
|
||||
AllowUngroupedKeyScheduling bool
|
||||
|
||||
// Backend 模式:禁用用户注册和自助服务,仅管理员可登录
|
||||
BackendModeEnabled bool
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -101,6 +104,7 @@ type PublicSettings struct {
|
||||
CustomMenuItems string // JSON array of custom menu items
|
||||
|
||||
LinuxDoOAuthEnabled bool
|
||||
BackendModeEnabled bool
|
||||
Version string
|
||||
}
|
||||
|
||||
@@ -198,16 +202,17 @@ const (
|
||||
BetaPolicyActionFilter = "filter" // 过滤,从 beta header 中移除该 token
|
||||
BetaPolicyActionBlock = "block" // 拦截,直接返回错误
|
||||
|
||||
BetaPolicyScopeAll = "all" // 所有账号类型
|
||||
BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号
|
||||
BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号
|
||||
BetaPolicyScopeAll = "all" // 所有账号类型
|
||||
BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号
|
||||
BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号
|
||||
BetaPolicyScopeBedrock = "bedrock" // 仅 AWS Bedrock 账号
|
||||
)
|
||||
|
||||
// BetaPolicyRule 单条 Beta 策略规则
|
||||
type BetaPolicyRule struct {
|
||||
BetaToken string `json:"beta_token"` // beta token 值
|
||||
Action string `json:"action"` // "pass" | "filter" | "block"
|
||||
Scope string `json:"scope"` // "all" | "oauth" | "apikey"
|
||||
Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock"
|
||||
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
|
||||
}
|
||||
|
||||
|
||||
@@ -322,6 +322,19 @@ func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthC
|
||||
return apiKeyService
|
||||
}
|
||||
|
||||
// ProvideBackupService creates and starts BackupService
|
||||
func ProvideBackupService(
|
||||
settingRepo SettingRepository,
|
||||
cfg *config.Config,
|
||||
encryptor SecretEncryptor,
|
||||
storeFactory BackupObjectStoreFactory,
|
||||
dumper DBDumper,
|
||||
) *BackupService {
|
||||
svc := NewBackupService(settingRepo, cfg, encryptor, storeFactory, dumper)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideSettingService wires SettingService with group reader for default subscription validation.
|
||||
func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService {
|
||||
svc := NewSettingService(settingRepo, cfg)
|
||||
@@ -373,6 +386,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAccountTestService,
|
||||
ProvideSettingService,
|
||||
NewDataManagementService,
|
||||
ProvideBackupService,
|
||||
ProvideOpsSystemLogSink,
|
||||
NewOpsService,
|
||||
ProvideOpsMetricsCollector,
|
||||
|
||||
105
deploy/docker-compose.dev.yml
Normal file
105
deploy/docker-compose.dev.yml
Normal file
@@ -0,0 +1,105 @@
|
||||
# =============================================================================
|
||||
# Sub2API Docker Compose - Local Development Build
|
||||
# =============================================================================
|
||||
# Build from local source code for testing changes.
|
||||
#
|
||||
# Usage:
|
||||
# cd deploy
|
||||
# docker compose -f docker-compose.dev.yml up --build
|
||||
# =============================================================================
|
||||
|
||||
services:
|
||||
sub2api:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: Dockerfile
|
||||
container_name: sub2api-dev
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "${BIND_HOST:-127.0.0.1}:${SERVER_PORT:-8080}:8080"
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
environment:
|
||||
- AUTO_SETUP=true
|
||||
- SERVER_HOST=0.0.0.0
|
||||
- SERVER_PORT=8080
|
||||
- SERVER_MODE=debug
|
||||
- RUN_MODE=${RUN_MODE:-standard}
|
||||
- DATABASE_HOST=postgres
|
||||
- DATABASE_PORT=5432
|
||||
- DATABASE_USER=${POSTGRES_USER:-sub2api}
|
||||
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
|
||||
- DATABASE_SSLMODE=disable
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=${REDIS_PASSWORD:-}
|
||||
- REDIS_DB=${REDIS_DB:-0}
|
||||
- ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local}
|
||||
- ADMIN_PASSWORD=${ADMIN_PASSWORD:-}
|
||||
- JWT_SECRET=${JWT_SECRET:-}
|
||||
- TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-}
|
||||
- TZ=${TZ:-Asia/Shanghai}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- sub2api-network
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 30s
|
||||
|
||||
postgres:
|
||||
image: postgres:18-alpine
|
||||
container_name: sub2api-postgres-dev
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./postgres_data:/var/lib/postgresql/data
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
|
||||
- PGDATA=/var/lib/postgresql/data
|
||||
- TZ=${TZ:-Asia/Shanghai}
|
||||
networks:
|
||||
- sub2api-network
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 10s
|
||||
|
||||
redis:
|
||||
image: redis:8-alpine
|
||||
container_name: sub2api-redis-dev
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./redis_data:/data
|
||||
command: >
|
||||
sh -c '
|
||||
redis-server
|
||||
--save 60 1
|
||||
--appendonly yes
|
||||
--appendfsync everysec
|
||||
${REDIS_PASSWORD:+--requirepass "$REDIS_PASSWORD"}'
|
||||
environment:
|
||||
- TZ=${TZ:-Asia/Shanghai}
|
||||
- REDISCLI_AUTH=${REDIS_PASSWORD:-}
|
||||
networks:
|
||||
- sub2api-network
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 5s
|
||||
|
||||
networks:
|
||||
sub2api-network:
|
||||
driver: bridge
|
||||
114
frontend/src/api/admin/backup.ts
Normal file
114
frontend/src/api/admin/backup.ts
Normal file
@@ -0,0 +1,114 @@
|
||||
import { apiClient } from '../client'
|
||||
|
||||
export interface BackupS3Config {
|
||||
endpoint: string
|
||||
region: string
|
||||
bucket: string
|
||||
access_key_id: string
|
||||
secret_access_key?: string
|
||||
prefix: string
|
||||
force_path_style: boolean
|
||||
}
|
||||
|
||||
export interface BackupScheduleConfig {
|
||||
enabled: boolean
|
||||
cron_expr: string
|
||||
retain_days: number
|
||||
retain_count: number
|
||||
}
|
||||
|
||||
export interface BackupRecord {
|
||||
id: string
|
||||
status: 'pending' | 'running' | 'completed' | 'failed'
|
||||
backup_type: string
|
||||
file_name: string
|
||||
s3_key: string
|
||||
size_bytes: number
|
||||
triggered_by: string
|
||||
error_message?: string
|
||||
started_at: string
|
||||
finished_at?: string
|
||||
expires_at?: string
|
||||
}
|
||||
|
||||
export interface CreateBackupRequest {
|
||||
expire_days?: number
|
||||
}
|
||||
|
||||
export interface TestS3Response {
|
||||
ok: boolean
|
||||
message: string
|
||||
}
|
||||
|
||||
// S3 Config
|
||||
export async function getS3Config(): Promise<BackupS3Config> {
|
||||
const { data } = await apiClient.get<BackupS3Config>('/admin/backups/s3-config')
|
||||
return data
|
||||
}
|
||||
|
||||
export async function updateS3Config(config: BackupS3Config): Promise<BackupS3Config> {
|
||||
const { data } = await apiClient.put<BackupS3Config>('/admin/backups/s3-config', config)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function testS3Connection(config: BackupS3Config): Promise<TestS3Response> {
|
||||
const { data } = await apiClient.post<TestS3Response>('/admin/backups/s3-config/test', config)
|
||||
return data
|
||||
}
|
||||
|
||||
// Schedule
|
||||
export async function getSchedule(): Promise<BackupScheduleConfig> {
|
||||
const { data } = await apiClient.get<BackupScheduleConfig>('/admin/backups/schedule')
|
||||
return data
|
||||
}
|
||||
|
||||
export async function updateSchedule(config: BackupScheduleConfig): Promise<BackupScheduleConfig> {
|
||||
const { data } = await apiClient.put<BackupScheduleConfig>('/admin/backups/schedule', config)
|
||||
return data
|
||||
}
|
||||
|
||||
// Backup operations
|
||||
export async function createBackup(req?: CreateBackupRequest): Promise<BackupRecord> {
|
||||
const { data } = await apiClient.post<BackupRecord>('/admin/backups', req || {}, { timeout: 600000 })
|
||||
return data
|
||||
}
|
||||
|
||||
export async function listBackups(): Promise<{ items: BackupRecord[] }> {
|
||||
const { data } = await apiClient.get<{ items: BackupRecord[] }>('/admin/backups')
|
||||
return data
|
||||
}
|
||||
|
||||
export async function getBackup(id: string): Promise<BackupRecord> {
|
||||
const { data } = await apiClient.get<BackupRecord>(`/admin/backups/${id}`)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function deleteBackup(id: string): Promise<void> {
|
||||
await apiClient.delete(`/admin/backups/${id}`)
|
||||
}
|
||||
|
||||
export async function getDownloadURL(id: string): Promise<{ url: string }> {
|
||||
const { data } = await apiClient.get<{ url: string }>(`/admin/backups/${id}/download-url`)
|
||||
return data
|
||||
}
|
||||
|
||||
// Restore
|
||||
export async function restoreBackup(id: string, password: string): Promise<void> {
|
||||
await apiClient.post(`/admin/backups/${id}/restore`, { password }, { timeout: 600000 })
|
||||
}
|
||||
|
||||
export const backupAPI = {
|
||||
getS3Config,
|
||||
updateS3Config,
|
||||
testS3Connection,
|
||||
getSchedule,
|
||||
updateSchedule,
|
||||
createBackup,
|
||||
listBackups,
|
||||
getBackup,
|
||||
deleteBackup,
|
||||
getDownloadURL,
|
||||
restoreBackup,
|
||||
}
|
||||
|
||||
export default backupAPI
|
||||
@@ -23,6 +23,7 @@ import errorPassthroughAPI from './errorPassthrough'
|
||||
import dataManagementAPI from './dataManagement'
|
||||
import apiKeysAPI from './apiKeys'
|
||||
import scheduledTestsAPI from './scheduledTests'
|
||||
import backupAPI from './backup'
|
||||
|
||||
/**
|
||||
* Unified admin API object for convenient access
|
||||
@@ -47,7 +48,8 @@ export const adminAPI = {
|
||||
errorPassthrough: errorPassthroughAPI,
|
||||
dataManagement: dataManagementAPI,
|
||||
apiKeys: apiKeysAPI,
|
||||
scheduledTests: scheduledTestsAPI
|
||||
scheduledTests: scheduledTestsAPI,
|
||||
backup: backupAPI
|
||||
}
|
||||
|
||||
export {
|
||||
@@ -70,7 +72,8 @@ export {
|
||||
errorPassthroughAPI,
|
||||
dataManagementAPI,
|
||||
apiKeysAPI,
|
||||
scheduledTestsAPI
|
||||
scheduledTestsAPI,
|
||||
backupAPI
|
||||
}
|
||||
|
||||
export default adminAPI
|
||||
|
||||
@@ -841,6 +841,8 @@ export interface OpsAdvancedSettings {
|
||||
ignore_context_canceled: boolean
|
||||
ignore_no_available_accounts: boolean
|
||||
ignore_invalid_api_key_errors: boolean
|
||||
display_openai_token_stats: boolean
|
||||
display_alert_events: boolean
|
||||
auto_refresh_enabled: boolean
|
||||
auto_refresh_interval_seconds: number
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ export interface SystemSettings {
|
||||
purchase_subscription_enabled: boolean
|
||||
purchase_subscription_url: string
|
||||
sora_client_enabled: boolean
|
||||
backend_mode_enabled: boolean
|
||||
custom_menu_items: CustomMenuItem[]
|
||||
// SMTP settings
|
||||
smtp_host: string
|
||||
@@ -106,6 +107,7 @@ export interface UpdateSettingsRequest {
|
||||
purchase_subscription_enabled?: boolean
|
||||
purchase_subscription_url?: string
|
||||
sora_client_enabled?: boolean
|
||||
backend_mode_enabled?: boolean
|
||||
custom_menu_items?: CustomMenuItem[]
|
||||
smtp_host?: string
|
||||
smtp_port?: number
|
||||
@@ -316,7 +318,7 @@ export async function updateRectifierSettings(
|
||||
export interface BetaPolicyRule {
|
||||
beta_token: string
|
||||
action: 'pass' | 'filter' | 'block'
|
||||
scope: 'all' | 'oauth' | 'apikey'
|
||||
scope: 'all' | 'oauth' | 'apikey' | 'bedrock'
|
||||
error_message?: string
|
||||
}
|
||||
|
||||
|
||||
@@ -292,17 +292,19 @@ const rpmTooltip = computed(() => {
|
||||
}
|
||||
})
|
||||
|
||||
// 是否显示各维度配额(仅 apikey 类型)
|
||||
// 是否显示各维度配额(apikey / bedrock 类型)
|
||||
const isQuotaEligible = computed(() => props.account.type === 'apikey' || props.account.type === 'bedrock')
|
||||
|
||||
const showDailyQuota = computed(() => {
|
||||
return props.account.type === 'apikey' && (props.account.quota_daily_limit ?? 0) > 0
|
||||
return isQuotaEligible.value && (props.account.quota_daily_limit ?? 0) > 0
|
||||
})
|
||||
|
||||
const showWeeklyQuota = computed(() => {
|
||||
return props.account.type === 'apikey' && (props.account.quota_weekly_limit ?? 0) > 0
|
||||
return isQuotaEligible.value && (props.account.quota_weekly_limit ?? 0) > 0
|
||||
})
|
||||
|
||||
const showTotalQuota = computed(() => {
|
||||
return props.account.type === 'apikey' && (props.account.quota_limit ?? 0) > 0
|
||||
return isQuotaEligible.value && (props.account.quota_limit ?? 0) > 0
|
||||
})
|
||||
|
||||
// 格式化费用显示
|
||||
|
||||
@@ -36,6 +36,10 @@
|
||||
|
||||
<!-- Usage data -->
|
||||
<div v-else-if="usageInfo" class="space-y-1">
|
||||
<!-- API error (degraded response) -->
|
||||
<div v-if="usageInfo.error" class="text-xs text-amber-600 dark:text-amber-400 truncate max-w-[200px]" :title="usageInfo.error">
|
||||
{{ usageInfo.error }}
|
||||
</div>
|
||||
<!-- 5h Window -->
|
||||
<UsageProgressBar
|
||||
v-if="usageInfo.five_hour"
|
||||
@@ -189,8 +193,53 @@
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Forbidden state (403) -->
|
||||
<div v-if="isForbidden" class="space-y-1">
|
||||
<span
|
||||
:class="[
|
||||
'inline-block rounded px-1.5 py-0.5 text-[10px] font-medium',
|
||||
forbiddenBadgeClass
|
||||
]"
|
||||
>
|
||||
{{ forbiddenLabel }}
|
||||
</span>
|
||||
<div v-if="validationURL" class="flex items-center gap-1">
|
||||
<a
|
||||
:href="validationURL"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
class="text-[10px] text-blue-600 hover:text-blue-800 hover:underline dark:text-blue-400 dark:hover:text-blue-300"
|
||||
:title="t('admin.accounts.openVerification')"
|
||||
>
|
||||
{{ t('admin.accounts.openVerification') }}
|
||||
</a>
|
||||
<button
|
||||
type="button"
|
||||
class="text-[10px] text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200"
|
||||
:title="t('admin.accounts.copyLink')"
|
||||
@click="copyValidationURL"
|
||||
>
|
||||
{{ linkCopied ? t('admin.accounts.linkCopied') : t('admin.accounts.copyLink') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Needs reauth (401) -->
|
||||
<div v-else-if="needsReauth" class="space-y-1">
|
||||
<span class="inline-block rounded px-1.5 py-0.5 text-[10px] font-medium bg-orange-100 text-orange-700 dark:bg-orange-900/40 dark:text-orange-300">
|
||||
{{ t('admin.accounts.needsReauth') }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Degraded error (non-403, non-401) -->
|
||||
<div v-else-if="usageInfo?.error" class="space-y-1">
|
||||
<span class="inline-block rounded px-1.5 py-0.5 text-[10px] font-medium bg-amber-100 text-amber-700 dark:bg-amber-900/40 dark:text-amber-300">
|
||||
{{ usageErrorLabel }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Loading state -->
|
||||
<div v-if="loading" class="space-y-1.5">
|
||||
<div v-else-if="loading" class="space-y-1.5">
|
||||
<div class="flex items-center gap-1">
|
||||
<div class="h-3 w-[32px] animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div class="h-1.5 w-8 animate-pulse rounded-full bg-gray-200 dark:bg-gray-700"></div>
|
||||
@@ -816,6 +865,51 @@ const hasIneligibleTiers = computed(() => {
|
||||
return Array.isArray(ineligibleTiers) && ineligibleTiers.length > 0
|
||||
})
|
||||
|
||||
// Antigravity 403 forbidden 状态
|
||||
const isForbidden = computed(() => !!usageInfo.value?.is_forbidden)
|
||||
const forbiddenType = computed(() => usageInfo.value?.forbidden_type || 'forbidden')
|
||||
const validationURL = computed(() => usageInfo.value?.validation_url || '')
|
||||
|
||||
// 需要重新授权(401)
|
||||
const needsReauth = computed(() => !!usageInfo.value?.needs_reauth)
|
||||
|
||||
// 降级错误标签(rate_limited / network_error)
|
||||
const usageErrorLabel = computed(() => {
|
||||
const code = usageInfo.value?.error_code
|
||||
if (code === 'rate_limited') return t('admin.accounts.rateLimited')
|
||||
return t('admin.accounts.usageError')
|
||||
})
|
||||
|
||||
const forbiddenLabel = computed(() => {
|
||||
switch (forbiddenType.value) {
|
||||
case 'validation':
|
||||
return t('admin.accounts.forbiddenValidation')
|
||||
case 'violation':
|
||||
return t('admin.accounts.forbiddenViolation')
|
||||
default:
|
||||
return t('admin.accounts.forbidden')
|
||||
}
|
||||
})
|
||||
|
||||
const forbiddenBadgeClass = computed(() => {
|
||||
if (forbiddenType.value === 'validation') {
|
||||
return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/40 dark:text-yellow-300'
|
||||
}
|
||||
return 'bg-red-100 text-red-700 dark:bg-red-900/40 dark:text-red-300'
|
||||
})
|
||||
|
||||
const linkCopied = ref(false)
|
||||
const copyValidationURL = async () => {
|
||||
if (!validationURL.value) return
|
||||
try {
|
||||
await navigator.clipboard.writeText(validationURL.value)
|
||||
linkCopied.value = true
|
||||
setTimeout(() => { linkCopied.value = false }, 2000)
|
||||
} catch {
|
||||
// fallback: ignore
|
||||
}
|
||||
}
|
||||
|
||||
const loadUsage = async () => {
|
||||
if (!shouldFetchUsage.value) return
|
||||
|
||||
@@ -848,18 +942,30 @@ const makeQuotaBar = (
|
||||
let resetsAt: string | null = null
|
||||
if (startKey) {
|
||||
const extra = props.account.extra as Record<string, unknown> | undefined
|
||||
const startStr = extra?.[startKey] as string | undefined
|
||||
if (startStr) {
|
||||
const startDate = new Date(startStr)
|
||||
const periodMs = startKey.includes('daily') ? 24 * 60 * 60 * 1000 : 7 * 24 * 60 * 60 * 1000
|
||||
resetsAt = new Date(startDate.getTime() + periodMs).toISOString()
|
||||
const isDaily = startKey.includes('daily')
|
||||
const mode = isDaily
|
||||
? (extra?.quota_daily_reset_mode as string) || 'rolling'
|
||||
: (extra?.quota_weekly_reset_mode as string) || 'rolling'
|
||||
|
||||
if (mode === 'fixed') {
|
||||
// Use pre-computed next reset time for fixed mode
|
||||
const resetAtKey = isDaily ? 'quota_daily_reset_at' : 'quota_weekly_reset_at'
|
||||
resetsAt = (extra?.[resetAtKey] as string) || null
|
||||
} else {
|
||||
// Rolling mode: compute from start + period
|
||||
const startStr = extra?.[startKey] as string | undefined
|
||||
if (startStr) {
|
||||
const startDate = new Date(startStr)
|
||||
const periodMs = isDaily ? 24 * 60 * 60 * 1000 : 7 * 24 * 60 * 60 * 1000
|
||||
resetsAt = new Date(startDate.getTime() + periodMs).toISOString()
|
||||
}
|
||||
}
|
||||
}
|
||||
return { utilization, resetsAt }
|
||||
}
|
||||
|
||||
const hasApiKeyQuota = computed(() => {
|
||||
if (props.account.type !== 'apikey') return false
|
||||
if (props.account.type !== 'apikey' && props.account.type !== 'bedrock') return false
|
||||
return (
|
||||
(props.account.quota_daily_limit ?? 0) > 0 ||
|
||||
(props.account.quota_weekly_limit ?? 0) > 0 ||
|
||||
|
||||
@@ -323,35 +323,6 @@
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
@click="accountCategory = 'bedrock-apikey'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||
accountCategory === 'bedrock-apikey'
|
||||
? 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
|
||||
: 'border-gray-200 hover:border-amber-300 dark:border-dark-600 dark:hover:border-amber-700'
|
||||
]"
|
||||
>
|
||||
<div
|
||||
:class="[
|
||||
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||
accountCategory === 'bedrock-apikey'
|
||||
? 'bg-amber-500 text-white'
|
||||
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||
]"
|
||||
>
|
||||
<Icon name="key" size="sm" />
|
||||
</div>
|
||||
<div>
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">{{
|
||||
t('admin.accounts.bedrockApiKeyLabel')
|
||||
}}</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{
|
||||
t('admin.accounts.bedrockApiKeyDesc')
|
||||
}}</span>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -956,7 +927,7 @@
|
||||
</div>
|
||||
|
||||
<!-- API Key input (only for apikey type, excluding Antigravity which has its own fields) -->
|
||||
<div v-if="form.type === 'apikey' && form.platform !== 'antigravity' && accountCategory !== 'bedrock-apikey'" class="space-y-4">
|
||||
<div v-if="form.type === 'apikey' && form.platform !== 'antigravity'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.baseUrl') }}</label>
|
||||
<input
|
||||
@@ -1341,34 +1312,75 @@
|
||||
|
||||
<!-- Bedrock credentials (only for Anthropic Bedrock type) -->
|
||||
<div v-if="form.platform === 'anthropic' && accountCategory === 'bedrock'" class="space-y-4">
|
||||
<!-- Auth Mode Radio -->
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label>
|
||||
<input
|
||||
v-model="bedrockAccessKeyId"
|
||||
type="text"
|
||||
required
|
||||
class="input font-mono"
|
||||
placeholder="AKIA..."
|
||||
/>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockAuthMode') }}</label>
|
||||
<div class="mt-2 flex gap-4">
|
||||
<label class="flex cursor-pointer items-center">
|
||||
<input
|
||||
v-model="bedrockAuthMode"
|
||||
type="radio"
|
||||
value="sigv4"
|
||||
class="mr-2 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockAuthModeSigv4') }}</span>
|
||||
</label>
|
||||
<label class="flex cursor-pointer items-center">
|
||||
<input
|
||||
v-model="bedrockAuthMode"
|
||||
type="radio"
|
||||
value="apikey"
|
||||
class="mr-2 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockAuthModeApikey') }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label>
|
||||
|
||||
<!-- SigV4 fields -->
|
||||
<template v-if="bedrockAuthMode === 'sigv4'">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label>
|
||||
<input
|
||||
v-model="bedrockAccessKeyId"
|
||||
type="text"
|
||||
required
|
||||
class="input font-mono"
|
||||
placeholder="AKIA..."
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label>
|
||||
<input
|
||||
v-model="bedrockSecretAccessKey"
|
||||
type="password"
|
||||
required
|
||||
class="input font-mono"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label>
|
||||
<input
|
||||
v-model="bedrockSessionToken"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- API Key field -->
|
||||
<div v-if="bedrockAuthMode === 'apikey'">
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
|
||||
<input
|
||||
v-model="bedrockSecretAccessKey"
|
||||
v-model="bedrockApiKeyValue"
|
||||
type="password"
|
||||
required
|
||||
class="input font-mono"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label>
|
||||
<input
|
||||
v-model="bedrockSessionToken"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Shared: Region -->
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
|
||||
<select v-model="bedrockRegion" class="input">
|
||||
@@ -1408,6 +1420,8 @@
|
||||
</select>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Shared: Force Global -->
|
||||
<div>
|
||||
<label class="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
@@ -1488,142 +1502,62 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Bedrock API Key credentials (only for Anthropic Bedrock API Key type) -->
|
||||
<div v-if="form.platform === 'anthropic' && accountCategory === 'bedrock-apikey'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
|
||||
<input
|
||||
v-model="bedrockApiKeyValue"
|
||||
type="password"
|
||||
required
|
||||
class="input font-mono"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
|
||||
<select v-model="bedrockApiKeyRegion" class="input">
|
||||
<optgroup label="US">
|
||||
<option value="us-east-1">us-east-1 (N. Virginia)</option>
|
||||
<option value="us-east-2">us-east-2 (Ohio)</option>
|
||||
<option value="us-west-1">us-west-1 (N. California)</option>
|
||||
<option value="us-west-2">us-west-2 (Oregon)</option>
|
||||
<option value="us-gov-east-1">us-gov-east-1 (GovCloud US-East)</option>
|
||||
<option value="us-gov-west-1">us-gov-west-1 (GovCloud US-West)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Europe">
|
||||
<option value="eu-west-1">eu-west-1 (Ireland)</option>
|
||||
<option value="eu-west-2">eu-west-2 (London)</option>
|
||||
<option value="eu-west-3">eu-west-3 (Paris)</option>
|
||||
<option value="eu-central-1">eu-central-1 (Frankfurt)</option>
|
||||
<option value="eu-central-2">eu-central-2 (Zurich)</option>
|
||||
<option value="eu-south-1">eu-south-1 (Milan)</option>
|
||||
<option value="eu-south-2">eu-south-2 (Spain)</option>
|
||||
<option value="eu-north-1">eu-north-1 (Stockholm)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Asia Pacific">
|
||||
<option value="ap-northeast-1">ap-northeast-1 (Tokyo)</option>
|
||||
<option value="ap-northeast-2">ap-northeast-2 (Seoul)</option>
|
||||
<option value="ap-northeast-3">ap-northeast-3 (Osaka)</option>
|
||||
<option value="ap-south-1">ap-south-1 (Mumbai)</option>
|
||||
<option value="ap-south-2">ap-south-2 (Hyderabad)</option>
|
||||
<option value="ap-southeast-1">ap-southeast-1 (Singapore)</option>
|
||||
<option value="ap-southeast-2">ap-southeast-2 (Sydney)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Canada">
|
||||
<option value="ca-central-1">ca-central-1 (Canada)</option>
|
||||
</optgroup>
|
||||
<optgroup label="South America">
|
||||
<option value="sa-east-1">sa-east-1 (São Paulo)</option>
|
||||
</optgroup>
|
||||
</select>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
v-model="bedrockApiKeyForceGlobal"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockForceGlobal') }}</span>
|
||||
</label>
|
||||
<p class="input-hint mt-1">{{ t('admin.accounts.bedrockForceGlobalHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Restriction Section for Bedrock API Key -->
|
||||
<!-- Pool Mode Section for Bedrock -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
|
||||
<!-- Mode Toggle -->
|
||||
<div class="mb-4 flex gap-2">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.poolMode') }}</label>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.poolModeHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'whitelist'"
|
||||
@click="poolModeEnabled = !poolModeEnabled"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'whitelist'
|
||||
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||
poolModeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelWhitelist') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'mapping'"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'mapping'
|
||||
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelMapping') }}
|
||||
<span
|
||||
:class="[
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
poolModeEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Whitelist Mode -->
|
||||
<div v-if="modelRestrictionMode === 'whitelist'">
|
||||
<ModelWhitelistSelector v-model="allowedModels" platform="anthropic" />
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||
<span v-if="allowedModels.length === 0">{{ t('admin.accounts.supportsAllModels') }}</span>
|
||||
<div v-if="poolModeEnabled" class="rounded-lg bg-blue-50 p-3 dark:bg-blue-900/20">
|
||||
<p class="text-xs text-blue-700 dark:text-blue-400">
|
||||
<Icon name="exclamationCircle" size="sm" class="mr-1 inline" :stroke-width="2" />
|
||||
{{ t('admin.accounts.poolModeInfo') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Mapping Mode -->
|
||||
<div v-else class="space-y-3">
|
||||
<div v-for="(mapping, index) in modelMappings" :key="index" class="flex items-center gap-2">
|
||||
<input v-model="mapping.from" type="text" class="input flex-1" :placeholder="t('admin.accounts.fromModel')" />
|
||||
<span class="text-gray-400">→</span>
|
||||
<input v-model="mapping.to" type="text" class="input flex-1" :placeholder="t('admin.accounts.toModel')" />
|
||||
<button type="button" @click="modelMappings.splice(index, 1)" class="text-red-500 hover:text-red-700">
|
||||
<Icon name="trash" size="sm" />
|
||||
</button>
|
||||
</div>
|
||||
<button type="button" @click="modelMappings.push({ from: '', to: '' })" class="btn btn-secondary text-sm">
|
||||
+ {{ t('admin.accounts.addMapping') }}
|
||||
</button>
|
||||
<!-- Bedrock Preset Mappings -->
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="preset in bedrockPresets"
|
||||
:key="preset.from"
|
||||
type="button"
|
||||
@click="addPresetMapping(preset.from, preset.to)"
|
||||
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
|
||||
>
|
||||
+ {{ preset.label }}
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="poolModeEnabled" class="mt-3">
|
||||
<label class="input-label">{{ t('admin.accounts.poolModeRetryCount') }}</label>
|
||||
<input
|
||||
v-model.number="poolModeRetryCount"
|
||||
type="number"
|
||||
min="0"
|
||||
:max="MAX_POOL_MODE_RETRY_COUNT"
|
||||
step="1"
|
||||
class="input"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{
|
||||
t('admin.accounts.poolModeRetryCountHint', {
|
||||
default: DEFAULT_POOL_MODE_RETRY_COUNT,
|
||||
max: MAX_POOL_MODE_RETRY_COUNT
|
||||
})
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- API Key 账号配额限制 -->
|
||||
<div v-if="form.type === 'apikey'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
|
||||
<!-- API Key / Bedrock 账号配额限制 -->
|
||||
<div v-if="form.type === 'apikey' || form.type === 'bedrock'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
|
||||
<div class="mb-3">
|
||||
<h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaLimit') }}</h3>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
@@ -1634,9 +1568,21 @@
|
||||
:totalLimit="editQuotaLimit"
|
||||
:dailyLimit="editQuotaDailyLimit"
|
||||
:weeklyLimit="editQuotaWeeklyLimit"
|
||||
:dailyResetMode="editDailyResetMode"
|
||||
:dailyResetHour="editDailyResetHour"
|
||||
:weeklyResetMode="editWeeklyResetMode"
|
||||
:weeklyResetDay="editWeeklyResetDay"
|
||||
:weeklyResetHour="editWeeklyResetHour"
|
||||
:resetTimezone="editResetTimezone"
|
||||
@update:totalLimit="editQuotaLimit = $event"
|
||||
@update:dailyLimit="editQuotaDailyLimit = $event"
|
||||
@update:weeklyLimit="editQuotaWeeklyLimit = $event"
|
||||
@update:dailyResetMode="editDailyResetMode = $event"
|
||||
@update:dailyResetHour="editDailyResetHour = $event"
|
||||
@update:weeklyResetMode="editWeeklyResetMode = $event"
|
||||
@update:weeklyResetDay="editWeeklyResetDay = $event"
|
||||
@update:weeklyResetHour="editWeeklyResetHour = $event"
|
||||
@update:resetTimezone="editResetTimezone = $event"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -3014,13 +2960,19 @@ interface TempUnschedRuleForm {
|
||||
// State
|
||||
const step = ref(1)
|
||||
const submitting = ref(false)
|
||||
const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'bedrock-apikey'>('oauth-based') // UI selection for account category
|
||||
const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock'>('oauth-based') // UI selection for account category
|
||||
const addMethod = ref<AddMethod>('oauth') // For oauth-based: 'oauth' or 'setup-token'
|
||||
const apiKeyBaseUrl = ref('https://api.anthropic.com')
|
||||
const apiKeyValue = ref('')
|
||||
const editQuotaLimit = ref<number | null>(null)
|
||||
const editQuotaDailyLimit = ref<number | null>(null)
|
||||
const editQuotaWeeklyLimit = ref<number | null>(null)
|
||||
const editDailyResetMode = ref<'rolling' | 'fixed' | null>(null)
|
||||
const editDailyResetHour = ref<number | null>(null)
|
||||
const editWeeklyResetMode = ref<'rolling' | 'fixed' | null>(null)
|
||||
const editWeeklyResetDay = ref<number | null>(null)
|
||||
const editWeeklyResetHour = ref<number | null>(null)
|
||||
const editResetTimezone = ref<string | null>(null)
|
||||
const modelMappings = ref<ModelMapping[]>([])
|
||||
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||
const allowedModels = ref<string[]>([])
|
||||
@@ -3050,16 +3002,13 @@ const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('an
|
||||
const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock'))
|
||||
|
||||
// Bedrock credentials
|
||||
const bedrockAuthMode = ref<'sigv4' | 'apikey'>('sigv4')
|
||||
const bedrockAccessKeyId = ref('')
|
||||
const bedrockSecretAccessKey = ref('')
|
||||
const bedrockSessionToken = ref('')
|
||||
const bedrockRegion = ref('us-east-1')
|
||||
const bedrockForceGlobal = ref(false)
|
||||
|
||||
// Bedrock API Key credentials
|
||||
const bedrockApiKeyValue = ref('')
|
||||
const bedrockApiKeyRegion = ref('us-east-1')
|
||||
const bedrockApiKeyForceGlobal = ref(false)
|
||||
const tempUnschedEnabled = ref(false)
|
||||
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
|
||||
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping')
|
||||
@@ -3343,7 +3292,8 @@ watch(
|
||||
bedrockSessionToken.value = ''
|
||||
bedrockRegion.value = 'us-east-1'
|
||||
bedrockForceGlobal.value = false
|
||||
bedrockApiKeyForceGlobal.value = false
|
||||
bedrockAuthMode.value = 'sigv4'
|
||||
bedrockApiKeyValue.value = ''
|
||||
// Reset Anthropic/Antigravity-specific settings when switching to other platforms
|
||||
if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') {
|
||||
interceptWarmupRequests.value = false
|
||||
@@ -3719,6 +3669,12 @@ const resetForm = () => {
|
||||
editQuotaLimit.value = null
|
||||
editQuotaDailyLimit.value = null
|
||||
editQuotaWeeklyLimit.value = null
|
||||
editDailyResetMode.value = null
|
||||
editDailyResetHour.value = null
|
||||
editWeeklyResetMode.value = null
|
||||
editWeeklyResetDay.value = null
|
||||
editWeeklyResetHour.value = null
|
||||
editResetTimezone.value = null
|
||||
modelMappings.value = []
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
allowedModels.value = [...claudeModels] // Default fill related models
|
||||
@@ -3919,27 +3875,34 @@ const handleSubmit = async () => {
|
||||
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
|
||||
return
|
||||
}
|
||||
if (!bedrockAccessKeyId.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockAccessKeyIdRequired'))
|
||||
return
|
||||
}
|
||||
if (!bedrockSecretAccessKey.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockSecretAccessKeyRequired'))
|
||||
return
|
||||
}
|
||||
if (!bedrockRegion.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockRegionRequired'))
|
||||
return
|
||||
}
|
||||
|
||||
const credentials: Record<string, unknown> = {
|
||||
aws_access_key_id: bedrockAccessKeyId.value.trim(),
|
||||
aws_secret_access_key: bedrockSecretAccessKey.value.trim(),
|
||||
aws_region: bedrockRegion.value.trim(),
|
||||
auth_mode: bedrockAuthMode.value,
|
||||
aws_region: bedrockRegion.value.trim() || 'us-east-1',
|
||||
}
|
||||
if (bedrockSessionToken.value.trim()) {
|
||||
credentials.aws_session_token = bedrockSessionToken.value.trim()
|
||||
|
||||
if (bedrockAuthMode.value === 'sigv4') {
|
||||
if (!bedrockAccessKeyId.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockAccessKeyIdRequired'))
|
||||
return
|
||||
}
|
||||
if (!bedrockSecretAccessKey.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockSecretAccessKeyRequired'))
|
||||
return
|
||||
}
|
||||
credentials.aws_access_key_id = bedrockAccessKeyId.value.trim()
|
||||
credentials.aws_secret_access_key = bedrockSecretAccessKey.value.trim()
|
||||
if (bedrockSessionToken.value.trim()) {
|
||||
credentials.aws_session_token = bedrockSessionToken.value.trim()
|
||||
}
|
||||
} else {
|
||||
if (!bedrockApiKeyValue.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockApiKeyRequired'))
|
||||
return
|
||||
}
|
||||
credentials.api_key = bedrockApiKeyValue.value.trim()
|
||||
}
|
||||
|
||||
if (bedrockForceGlobal.value) {
|
||||
credentials.aws_force_global = 'true'
|
||||
}
|
||||
@@ -3952,45 +3915,18 @@ const handleSubmit = async () => {
|
||||
credentials.model_mapping = modelMapping
|
||||
}
|
||||
|
||||
// Pool mode
|
||||
if (poolModeEnabled.value) {
|
||||
credentials.pool_mode = true
|
||||
credentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value)
|
||||
}
|
||||
|
||||
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||
|
||||
await createAccountAndFinish('anthropic', 'bedrock' as AccountType, credentials)
|
||||
return
|
||||
}
|
||||
|
||||
// For Bedrock API Key type, create directly
|
||||
if (form.platform === 'anthropic' && accountCategory.value === 'bedrock-apikey') {
|
||||
if (!form.name.trim()) {
|
||||
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
|
||||
return
|
||||
}
|
||||
if (!bedrockApiKeyValue.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockApiKeyRequired'))
|
||||
return
|
||||
}
|
||||
|
||||
const credentials: Record<string, unknown> = {
|
||||
api_key: bedrockApiKeyValue.value.trim(),
|
||||
aws_region: bedrockApiKeyRegion.value.trim() || 'us-east-1',
|
||||
}
|
||||
if (bedrockApiKeyForceGlobal.value) {
|
||||
credentials.aws_force_global = 'true'
|
||||
}
|
||||
|
||||
// Model mapping
|
||||
const modelMapping = buildModelMappingObject(
|
||||
modelRestrictionMode.value, allowedModels.value, modelMappings.value
|
||||
)
|
||||
if (modelMapping) {
|
||||
credentials.model_mapping = modelMapping
|
||||
}
|
||||
|
||||
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||
|
||||
await createAccountAndFinish('anthropic', 'bedrock-apikey' as AccountType, credentials)
|
||||
return
|
||||
}
|
||||
|
||||
// For Antigravity upstream type, create directly
|
||||
if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') {
|
||||
if (!form.name.trim()) {
|
||||
@@ -4233,9 +4169,9 @@ const createAccountAndFinish = async (
|
||||
if (!applyTempUnschedConfig(credentials)) {
|
||||
return
|
||||
}
|
||||
// Inject quota limits for apikey accounts
|
||||
// Inject quota limits for apikey/bedrock accounts
|
||||
let finalExtra = extra
|
||||
if (type === 'apikey') {
|
||||
if (type === 'apikey' || type === 'bedrock') {
|
||||
const quotaExtra: Record<string, unknown> = { ...(extra || {}) }
|
||||
if (editQuotaLimit.value != null && editQuotaLimit.value > 0) {
|
||||
quotaExtra.quota_limit = editQuotaLimit.value
|
||||
@@ -4246,6 +4182,19 @@ const createAccountAndFinish = async (
|
||||
if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) {
|
||||
quotaExtra.quota_weekly_limit = editQuotaWeeklyLimit.value
|
||||
}
|
||||
// Quota reset mode config
|
||||
if (editDailyResetMode.value === 'fixed') {
|
||||
quotaExtra.quota_daily_reset_mode = 'fixed'
|
||||
quotaExtra.quota_daily_reset_hour = editDailyResetHour.value ?? 0
|
||||
}
|
||||
if (editWeeklyResetMode.value === 'fixed') {
|
||||
quotaExtra.quota_weekly_reset_mode = 'fixed'
|
||||
quotaExtra.quota_weekly_reset_day = editWeeklyResetDay.value ?? 1
|
||||
quotaExtra.quota_weekly_reset_hour = editWeeklyResetHour.value ?? 0
|
||||
}
|
||||
if (editDailyResetMode.value === 'fixed' || editWeeklyResetMode.value === 'fixed') {
|
||||
quotaExtra.quota_reset_timezone = editResetTimezone.value || 'UTC'
|
||||
}
|
||||
if (Object.keys(quotaExtra).length > 0) {
|
||||
finalExtra = quotaExtra
|
||||
}
|
||||
|
||||
@@ -563,37 +563,54 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Bedrock fields (only for bedrock type) -->
|
||||
<!-- Bedrock fields (for bedrock type, both SigV4 and API Key modes) -->
|
||||
<div v-if="account.type === 'bedrock'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label>
|
||||
<!-- SigV4 fields -->
|
||||
<template v-if="!isBedrockAPIKeyMode">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label>
|
||||
<input
|
||||
v-model="editBedrockAccessKeyId"
|
||||
type="text"
|
||||
class="input font-mono"
|
||||
placeholder="AKIA..."
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label>
|
||||
<input
|
||||
v-model="editBedrockSecretAccessKey"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockSecretKeyLeaveEmpty') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label>
|
||||
<input
|
||||
v-model="editBedrockSessionToken"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- API Key field -->
|
||||
<div v-if="isBedrockAPIKeyMode">
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
|
||||
<input
|
||||
v-model="editBedrockAccessKeyId"
|
||||
type="text"
|
||||
class="input font-mono"
|
||||
placeholder="AKIA..."
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label>
|
||||
<input
|
||||
v-model="editBedrockSecretAccessKey"
|
||||
v-model="editBedrockApiKeyValue"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')"
|
||||
:placeholder="t('admin.accounts.bedrockApiKeyLeaveEmpty')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockSecretKeyLeaveEmpty') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label>
|
||||
<input
|
||||
v-model="editBedrockSessionToken"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockApiKeyLeaveEmpty') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Shared: Region -->
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
|
||||
<input
|
||||
@@ -604,6 +621,8 @@
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Shared: Force Global -->
|
||||
<div>
|
||||
<label class="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
@@ -684,108 +703,56 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Bedrock API Key fields (only for bedrock-apikey type) -->
|
||||
<div v-if="account.type === 'bedrock-apikey'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
|
||||
<input
|
||||
v-model="editBedrockApiKeyValue"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
:placeholder="t('admin.accounts.bedrockApiKeyLeaveEmpty')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockApiKeyLeaveEmpty') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
|
||||
<input
|
||||
v-model="editBedrockApiKeyRegion"
|
||||
type="text"
|
||||
class="input"
|
||||
placeholder="us-east-1"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
v-model="editBedrockApiKeyForceGlobal"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockForceGlobal') }}</span>
|
||||
</label>
|
||||
<p class="input-hint mt-1">{{ t('admin.accounts.bedrockForceGlobalHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Restriction for Bedrock API Key -->
|
||||
<!-- Pool Mode Section for Bedrock -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
|
||||
<!-- Mode Toggle -->
|
||||
<div class="mb-4 flex gap-2">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.poolMode') }}</label>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.poolModeHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'whitelist'"
|
||||
@click="poolModeEnabled = !poolModeEnabled"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'whitelist'
|
||||
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||
poolModeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelWhitelist') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'mapping'"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'mapping'
|
||||
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelMapping') }}
|
||||
<span
|
||||
:class="[
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
poolModeEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Whitelist Mode -->
|
||||
<div v-if="modelRestrictionMode === 'whitelist'">
|
||||
<ModelWhitelistSelector v-model="allowedModels" platform="anthropic" />
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||
<span v-if="allowedModels.length === 0">{{ t('admin.accounts.supportsAllModels') }}</span>
|
||||
<div v-if="poolModeEnabled" class="rounded-lg bg-blue-50 p-3 dark:bg-blue-900/20">
|
||||
<p class="text-xs text-blue-700 dark:text-blue-400">
|
||||
<Icon name="exclamationCircle" size="sm" class="mr-1 inline" :stroke-width="2" />
|
||||
{{ t('admin.accounts.poolModeInfo') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Mapping Mode -->
|
||||
<div v-else class="space-y-3">
|
||||
<div v-for="(mapping, index) in modelMappings" :key="getModelMappingKey(mapping)" class="flex items-center gap-2">
|
||||
<input v-model="mapping.from" type="text" class="input flex-1" :placeholder="t('admin.accounts.fromModel')" />
|
||||
<span class="text-gray-400">→</span>
|
||||
<input v-model="mapping.to" type="text" class="input flex-1" :placeholder="t('admin.accounts.toModel')" />
|
||||
<button type="button" @click="modelMappings.splice(index, 1)" class="text-red-500 hover:text-red-700">
|
||||
<Icon name="trash" size="sm" />
|
||||
</button>
|
||||
</div>
|
||||
<button type="button" @click="modelMappings.push({ from: '', to: '' })" class="btn btn-secondary text-sm">
|
||||
+ {{ t('admin.accounts.addMapping') }}
|
||||
</button>
|
||||
<!-- Bedrock Preset Mappings -->
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="preset in bedrockPresets"
|
||||
:key="preset.from"
|
||||
type="button"
|
||||
@click="modelMappings.push({ from: preset.from, to: preset.to })"
|
||||
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
|
||||
>
|
||||
+ {{ preset.label }}
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="poolModeEnabled" class="mt-3">
|
||||
<label class="input-label">{{ t('admin.accounts.poolModeRetryCount') }}</label>
|
||||
<input
|
||||
v-model.number="poolModeRetryCount"
|
||||
type="number"
|
||||
min="0"
|
||||
:max="MAX_POOL_MODE_RETRY_COUNT"
|
||||
step="1"
|
||||
class="input"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{
|
||||
t('admin.accounts.poolModeRetryCountHint', {
|
||||
default: DEFAULT_POOL_MODE_RETRY_COUNT,
|
||||
max: MAX_POOL_MODE_RETRY_COUNT
|
||||
})
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -1182,8 +1149,8 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- API Key 账号配额限制 -->
|
||||
<div v-if="account?.type === 'apikey'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
|
||||
<!-- API Key / Bedrock 账号配额限制 -->
|
||||
<div v-if="account?.type === 'apikey' || account?.type === 'bedrock'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
|
||||
<div class="mb-3">
|
||||
<h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaLimit') }}</h3>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
@@ -1194,9 +1161,21 @@
|
||||
:totalLimit="editQuotaLimit"
|
||||
:dailyLimit="editQuotaDailyLimit"
|
||||
:weeklyLimit="editQuotaWeeklyLimit"
|
||||
:dailyResetMode="editDailyResetMode"
|
||||
:dailyResetHour="editDailyResetHour"
|
||||
:weeklyResetMode="editWeeklyResetMode"
|
||||
:weeklyResetDay="editWeeklyResetDay"
|
||||
:weeklyResetHour="editWeeklyResetHour"
|
||||
:resetTimezone="editResetTimezone"
|
||||
@update:totalLimit="editQuotaLimit = $event"
|
||||
@update:dailyLimit="editQuotaDailyLimit = $event"
|
||||
@update:weeklyLimit="editQuotaWeeklyLimit = $event"
|
||||
@update:dailyResetMode="editDailyResetMode = $event"
|
||||
@update:dailyResetHour="editDailyResetHour = $event"
|
||||
@update:weeklyResetMode="editWeeklyResetMode = $event"
|
||||
@update:weeklyResetDay="editWeeklyResetDay = $event"
|
||||
@update:weeklyResetHour="editWeeklyResetHour = $event"
|
||||
@update:resetTimezone="editResetTimezone = $event"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -1781,11 +1760,11 @@ const editBedrockSecretAccessKey = ref('')
|
||||
const editBedrockSessionToken = ref('')
|
||||
const editBedrockRegion = ref('')
|
||||
const editBedrockForceGlobal = ref(false)
|
||||
|
||||
// Bedrock API Key credentials
|
||||
const editBedrockApiKeyValue = ref('')
|
||||
const editBedrockApiKeyRegion = ref('')
|
||||
const editBedrockApiKeyForceGlobal = ref(false)
|
||||
const isBedrockAPIKeyMode = computed(() =>
|
||||
props.account?.type === 'bedrock' &&
|
||||
(props.account?.credentials as Record<string, unknown>)?.auth_mode === 'apikey'
|
||||
)
|
||||
const modelMappings = ref<ModelMapping[]>([])
|
||||
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||
const allowedModels = ref<string[]>([])
|
||||
@@ -1847,6 +1826,12 @@ const anthropicPassthroughEnabled = ref(false)
|
||||
const editQuotaLimit = ref<number | null>(null)
|
||||
const editQuotaDailyLimit = ref<number | null>(null)
|
||||
const editQuotaWeeklyLimit = ref<number | null>(null)
|
||||
const editDailyResetMode = ref<'rolling' | 'fixed' | null>(null)
|
||||
const editDailyResetHour = ref<number | null>(null)
|
||||
const editWeeklyResetMode = ref<'rolling' | 'fixed' | null>(null)
|
||||
const editWeeklyResetDay = ref<number | null>(null)
|
||||
const editWeeklyResetHour = ref<number | null>(null)
|
||||
const editResetTimezone = ref<string | null>(null)
|
||||
const openAIWSModeOptions = computed(() => [
|
||||
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
|
||||
// TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复
|
||||
@@ -2026,18 +2011,31 @@ watch(
|
||||
anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true
|
||||
}
|
||||
|
||||
// Load quota limit for apikey accounts
|
||||
if (newAccount.type === 'apikey') {
|
||||
// Load quota limit for apikey/bedrock accounts (bedrock quota is also loaded in its own branch above)
|
||||
if (newAccount.type === 'apikey' || newAccount.type === 'bedrock') {
|
||||
const quotaVal = extra?.quota_limit as number | undefined
|
||||
editQuotaLimit.value = (quotaVal && quotaVal > 0) ? quotaVal : null
|
||||
const dailyVal = extra?.quota_daily_limit as number | undefined
|
||||
editQuotaDailyLimit.value = (dailyVal && dailyVal > 0) ? dailyVal : null
|
||||
const weeklyVal = extra?.quota_weekly_limit as number | undefined
|
||||
editQuotaWeeklyLimit.value = (weeklyVal && weeklyVal > 0) ? weeklyVal : null
|
||||
// Load quota reset mode config
|
||||
editDailyResetMode.value = (extra?.quota_daily_reset_mode as 'rolling' | 'fixed') || null
|
||||
editDailyResetHour.value = (extra?.quota_daily_reset_hour as number) ?? null
|
||||
editWeeklyResetMode.value = (extra?.quota_weekly_reset_mode as 'rolling' | 'fixed') || null
|
||||
editWeeklyResetDay.value = (extra?.quota_weekly_reset_day as number) ?? null
|
||||
editWeeklyResetHour.value = (extra?.quota_weekly_reset_hour as number) ?? null
|
||||
editResetTimezone.value = (extra?.quota_reset_timezone as string) || null
|
||||
} else {
|
||||
editQuotaLimit.value = null
|
||||
editQuotaDailyLimit.value = null
|
||||
editQuotaWeeklyLimit.value = null
|
||||
editDailyResetMode.value = null
|
||||
editDailyResetHour.value = null
|
||||
editWeeklyResetMode.value = null
|
||||
editWeeklyResetDay.value = null
|
||||
editWeeklyResetHour.value = null
|
||||
editResetTimezone.value = null
|
||||
}
|
||||
|
||||
// Load antigravity model mapping (Antigravity 只支持映射模式)
|
||||
@@ -2130,11 +2128,28 @@ watch(
|
||||
}
|
||||
} else if (newAccount.type === 'bedrock' && newAccount.credentials) {
|
||||
const bedrockCreds = newAccount.credentials as Record<string, unknown>
|
||||
editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || ''
|
||||
const authMode = (bedrockCreds.auth_mode as string) || 'sigv4'
|
||||
editBedrockRegion.value = (bedrockCreds.aws_region as string) || ''
|
||||
editBedrockForceGlobal.value = (bedrockCreds.aws_force_global as string) === 'true'
|
||||
editBedrockSecretAccessKey.value = ''
|
||||
editBedrockSessionToken.value = ''
|
||||
|
||||
if (authMode === 'apikey') {
|
||||
editBedrockApiKeyValue.value = ''
|
||||
} else {
|
||||
editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || ''
|
||||
editBedrockSecretAccessKey.value = ''
|
||||
editBedrockSessionToken.value = ''
|
||||
}
|
||||
|
||||
// Load pool mode for bedrock
|
||||
poolModeEnabled.value = bedrockCreds.pool_mode === true
|
||||
const retryCount = bedrockCreds.pool_mode_retry_count
|
||||
poolModeRetryCount.value = (typeof retryCount === 'number' && retryCount >= 0) ? retryCount : DEFAULT_POOL_MODE_RETRY_COUNT
|
||||
|
||||
// Load quota limits for bedrock
|
||||
const bedrockExtra = (newAccount.extra as Record<string, unknown>) || {}
|
||||
editQuotaLimit.value = typeof bedrockExtra.quota_limit === 'number' ? bedrockExtra.quota_limit : null
|
||||
editQuotaDailyLimit.value = typeof bedrockExtra.quota_daily_limit === 'number' ? bedrockExtra.quota_daily_limit : null
|
||||
editQuotaWeeklyLimit.value = typeof bedrockExtra.quota_weekly_limit === 'number' ? bedrockExtra.quota_weekly_limit : null
|
||||
|
||||
// Load model mappings for bedrock
|
||||
const existingMappings = bedrockCreds.model_mapping as Record<string, string> | undefined
|
||||
@@ -2155,31 +2170,6 @@ watch(
|
||||
modelMappings.value = []
|
||||
allowedModels.value = []
|
||||
}
|
||||
} else if (newAccount.type === 'bedrock-apikey' && newAccount.credentials) {
|
||||
const bedrockApiKeyCreds = newAccount.credentials as Record<string, unknown>
|
||||
editBedrockApiKeyRegion.value = (bedrockApiKeyCreds.aws_region as string) || 'us-east-1'
|
||||
editBedrockApiKeyForceGlobal.value = (bedrockApiKeyCreds.aws_force_global as string) === 'true'
|
||||
editBedrockApiKeyValue.value = ''
|
||||
|
||||
// Load model mappings for bedrock-apikey
|
||||
const existingMappings = bedrockApiKeyCreds.model_mapping as Record<string, string> | undefined
|
||||
if (existingMappings && typeof existingMappings === 'object') {
|
||||
const entries = Object.entries(existingMappings)
|
||||
const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
|
||||
if (isWhitelistMode) {
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
allowedModels.value = entries.map(([from]) => from)
|
||||
modelMappings.value = []
|
||||
} else {
|
||||
modelRestrictionMode.value = 'mapping'
|
||||
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
|
||||
allowedModels.value = []
|
||||
}
|
||||
} else {
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
modelMappings.value = []
|
||||
allowedModels.value = []
|
||||
}
|
||||
} else if (newAccount.type === 'upstream' && newAccount.credentials) {
|
||||
const credentials = newAccount.credentials as Record<string, unknown>
|
||||
editBaseUrl.value = (credentials.base_url as string) || ''
|
||||
@@ -2727,7 +2717,6 @@ const handleSubmit = async () => {
|
||||
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
|
||||
const newCredentials: Record<string, unknown> = { ...currentCredentials }
|
||||
|
||||
newCredentials.aws_access_key_id = editBedrockAccessKeyId.value.trim()
|
||||
newCredentials.aws_region = editBedrockRegion.value.trim()
|
||||
if (editBedrockForceGlobal.value) {
|
||||
newCredentials.aws_force_global = 'true'
|
||||
@@ -2735,42 +2724,29 @@ const handleSubmit = async () => {
|
||||
delete newCredentials.aws_force_global
|
||||
}
|
||||
|
||||
// Only update secrets if user provided new values
|
||||
if (editBedrockSecretAccessKey.value.trim()) {
|
||||
newCredentials.aws_secret_access_key = editBedrockSecretAccessKey.value.trim()
|
||||
}
|
||||
if (editBedrockSessionToken.value.trim()) {
|
||||
newCredentials.aws_session_token = editBedrockSessionToken.value.trim()
|
||||
}
|
||||
|
||||
// Model mapping
|
||||
const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
|
||||
if (modelMapping) {
|
||||
newCredentials.model_mapping = modelMapping
|
||||
if (isBedrockAPIKeyMode.value) {
|
||||
// API Key mode: only update api_key if user provided new value
|
||||
if (editBedrockApiKeyValue.value.trim()) {
|
||||
newCredentials.api_key = editBedrockApiKeyValue.value.trim()
|
||||
}
|
||||
} else {
|
||||
delete newCredentials.model_mapping
|
||||
// SigV4 mode
|
||||
newCredentials.aws_access_key_id = editBedrockAccessKeyId.value.trim()
|
||||
if (editBedrockSecretAccessKey.value.trim()) {
|
||||
newCredentials.aws_secret_access_key = editBedrockSecretAccessKey.value.trim()
|
||||
}
|
||||
if (editBedrockSessionToken.value.trim()) {
|
||||
newCredentials.aws_session_token = editBedrockSessionToken.value.trim()
|
||||
}
|
||||
}
|
||||
|
||||
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
|
||||
if (!applyTempUnschedConfig(newCredentials)) {
|
||||
return
|
||||
}
|
||||
|
||||
updatePayload.credentials = newCredentials
|
||||
} else if (props.account.type === 'bedrock-apikey') {
|
||||
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
|
||||
const newCredentials: Record<string, unknown> = { ...currentCredentials }
|
||||
|
||||
newCredentials.aws_region = editBedrockApiKeyRegion.value.trim() || 'us-east-1'
|
||||
if (editBedrockApiKeyForceGlobal.value) {
|
||||
newCredentials.aws_force_global = 'true'
|
||||
// Pool mode
|
||||
if (poolModeEnabled.value) {
|
||||
newCredentials.pool_mode = true
|
||||
newCredentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value)
|
||||
} else {
|
||||
delete newCredentials.aws_force_global
|
||||
}
|
||||
|
||||
// Only update API key if user provided new value
|
||||
if (editBedrockApiKeyValue.value.trim()) {
|
||||
newCredentials.api_key = editBedrockApiKeyValue.value.trim()
|
||||
delete newCredentials.pool_mode
|
||||
delete newCredentials.pool_mode_retry_count
|
||||
}
|
||||
|
||||
// Model mapping
|
||||
@@ -2980,8 +2956,8 @@ const handleSubmit = async () => {
|
||||
updatePayload.extra = newExtra
|
||||
}
|
||||
|
||||
// For apikey accounts, handle quota_limit in extra
|
||||
if (props.account.type === 'apikey') {
|
||||
// For apikey/bedrock accounts, handle quota_limit in extra
|
||||
if (props.account.type === 'apikey' || props.account.type === 'bedrock') {
|
||||
const currentExtra = (updatePayload.extra as Record<string, unknown>) ||
|
||||
(props.account.extra as Record<string, unknown>) || {}
|
||||
const newExtra: Record<string, unknown> = { ...currentExtra }
|
||||
@@ -3000,6 +2976,28 @@ const handleSubmit = async () => {
|
||||
} else {
|
||||
delete newExtra.quota_weekly_limit
|
||||
}
|
||||
// Quota reset mode config
|
||||
if (editDailyResetMode.value === 'fixed') {
|
||||
newExtra.quota_daily_reset_mode = 'fixed'
|
||||
newExtra.quota_daily_reset_hour = editDailyResetHour.value ?? 0
|
||||
} else {
|
||||
delete newExtra.quota_daily_reset_mode
|
||||
delete newExtra.quota_daily_reset_hour
|
||||
}
|
||||
if (editWeeklyResetMode.value === 'fixed') {
|
||||
newExtra.quota_weekly_reset_mode = 'fixed'
|
||||
newExtra.quota_weekly_reset_day = editWeeklyResetDay.value ?? 1
|
||||
newExtra.quota_weekly_reset_hour = editWeeklyResetHour.value ?? 0
|
||||
} else {
|
||||
delete newExtra.quota_weekly_reset_mode
|
||||
delete newExtra.quota_weekly_reset_day
|
||||
delete newExtra.quota_weekly_reset_hour
|
||||
}
|
||||
if (editDailyResetMode.value === 'fixed' || editWeeklyResetMode.value === 'fixed') {
|
||||
newExtra.quota_reset_timezone = editResetTimezone.value || 'UTC'
|
||||
} else {
|
||||
delete newExtra.quota_reset_timezone
|
||||
}
|
||||
updatePayload.extra = newExtra
|
||||
}
|
||||
|
||||
|
||||
@@ -8,12 +8,24 @@ const props = defineProps<{
|
||||
totalLimit: number | null
|
||||
dailyLimit: number | null
|
||||
weeklyLimit: number | null
|
||||
dailyResetMode: 'rolling' | 'fixed' | null
|
||||
dailyResetHour: number | null
|
||||
weeklyResetMode: 'rolling' | 'fixed' | null
|
||||
weeklyResetDay: number | null
|
||||
weeklyResetHour: number | null
|
||||
resetTimezone: string | null
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:totalLimit': [value: number | null]
|
||||
'update:dailyLimit': [value: number | null]
|
||||
'update:weeklyLimit': [value: number | null]
|
||||
'update:dailyResetMode': [value: 'rolling' | 'fixed' | null]
|
||||
'update:dailyResetHour': [value: number | null]
|
||||
'update:weeklyResetMode': [value: 'rolling' | 'fixed' | null]
|
||||
'update:weeklyResetDay': [value: number | null]
|
||||
'update:weeklyResetHour': [value: number | null]
|
||||
'update:resetTimezone': [value: string | null]
|
||||
}>()
|
||||
|
||||
const enabled = computed(() =>
|
||||
@@ -35,9 +47,56 @@ watch(localEnabled, (val) => {
|
||||
emit('update:totalLimit', null)
|
||||
emit('update:dailyLimit', null)
|
||||
emit('update:weeklyLimit', null)
|
||||
emit('update:dailyResetMode', null)
|
||||
emit('update:dailyResetHour', null)
|
||||
emit('update:weeklyResetMode', null)
|
||||
emit('update:weeklyResetDay', null)
|
||||
emit('update:weeklyResetHour', null)
|
||||
emit('update:resetTimezone', null)
|
||||
}
|
||||
})
|
||||
|
||||
// Whether any fixed mode is active (to show timezone selector)
|
||||
const hasFixedMode = computed(() =>
|
||||
props.dailyResetMode === 'fixed' || props.weeklyResetMode === 'fixed'
|
||||
)
|
||||
|
||||
// Common timezone options
|
||||
const timezoneOptions = [
|
||||
'UTC',
|
||||
'Asia/Shanghai',
|
||||
'Asia/Tokyo',
|
||||
'Asia/Seoul',
|
||||
'Asia/Singapore',
|
||||
'Asia/Kolkata',
|
||||
'Asia/Dubai',
|
||||
'Europe/London',
|
||||
'Europe/Paris',
|
||||
'Europe/Berlin',
|
||||
'Europe/Moscow',
|
||||
'America/New_York',
|
||||
'America/Chicago',
|
||||
'America/Denver',
|
||||
'America/Los_Angeles',
|
||||
'America/Sao_Paulo',
|
||||
'Australia/Sydney',
|
||||
'Pacific/Auckland',
|
||||
]
|
||||
|
||||
// Hours for dropdown (0-23)
|
||||
const hourOptions = Array.from({ length: 24 }, (_, i) => i)
|
||||
|
||||
// Day of week options
|
||||
const dayOptions = [
|
||||
{ value: 1, key: 'monday' },
|
||||
{ value: 2, key: 'tuesday' },
|
||||
{ value: 3, key: 'wednesday' },
|
||||
{ value: 4, key: 'thursday' },
|
||||
{ value: 5, key: 'friday' },
|
||||
{ value: 6, key: 'saturday' },
|
||||
{ value: 0, key: 'sunday' },
|
||||
]
|
||||
|
||||
const onTotalInput = (e: Event) => {
|
||||
const raw = (e.target as HTMLInputElement).valueAsNumber
|
||||
emit('update:totalLimit', Number.isNaN(raw) ? null : raw)
|
||||
@@ -50,6 +109,25 @@ const onWeeklyInput = (e: Event) => {
|
||||
const raw = (e.target as HTMLInputElement).valueAsNumber
|
||||
emit('update:weeklyLimit', Number.isNaN(raw) ? null : raw)
|
||||
}
|
||||
|
||||
const onDailyModeChange = (e: Event) => {
|
||||
const val = (e.target as HTMLSelectElement).value as 'rolling' | 'fixed'
|
||||
emit('update:dailyResetMode', val)
|
||||
if (val === 'fixed') {
|
||||
if (props.dailyResetHour == null) emit('update:dailyResetHour', 0)
|
||||
if (!props.resetTimezone) emit('update:resetTimezone', 'UTC')
|
||||
}
|
||||
}
|
||||
|
||||
const onWeeklyModeChange = (e: Event) => {
|
||||
const val = (e.target as HTMLSelectElement).value as 'rolling' | 'fixed'
|
||||
emit('update:weeklyResetMode', val)
|
||||
if (val === 'fixed') {
|
||||
if (props.weeklyResetDay == null) emit('update:weeklyResetDay', 1)
|
||||
if (props.weeklyResetHour == null) emit('update:weeklyResetHour', 0)
|
||||
if (!props.resetTimezone) emit('update:resetTimezone', 'UTC')
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
@@ -94,7 +172,37 @@ const onWeeklyInput = (e: Event) => {
|
||||
:placeholder="t('admin.accounts.quotaLimitPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
<p class="input-hint">{{ t('admin.accounts.quotaDailyLimitHint') }}</p>
|
||||
<!-- 日配额重置模式 -->
|
||||
<div class="mt-2 flex items-center gap-2">
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400 whitespace-nowrap">{{ t('admin.accounts.quotaResetMode') }}</label>
|
||||
<select
|
||||
:value="dailyResetMode || 'rolling'"
|
||||
@change="onDailyModeChange"
|
||||
class="input py-1 text-xs"
|
||||
>
|
||||
<option value="rolling">{{ t('admin.accounts.quotaResetModeRolling') }}</option>
|
||||
<option value="fixed">{{ t('admin.accounts.quotaResetModeFixed') }}</option>
|
||||
</select>
|
||||
</div>
|
||||
<!-- 固定模式:小时选择 -->
|
||||
<div v-if="dailyResetMode === 'fixed'" class="mt-2 flex items-center gap-2">
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400 whitespace-nowrap">{{ t('admin.accounts.quotaResetHour') }}</label>
|
||||
<select
|
||||
:value="dailyResetHour ?? 0"
|
||||
@change="emit('update:dailyResetHour', Number(($event.target as HTMLSelectElement).value))"
|
||||
class="input py-1 text-xs w-24"
|
||||
>
|
||||
<option v-for="h in hourOptions" :key="h" :value="h">{{ String(h).padStart(2, '0') }}:00</option>
|
||||
</select>
|
||||
</div>
|
||||
<p class="input-hint">
|
||||
<template v-if="dailyResetMode === 'fixed'">
|
||||
{{ t('admin.accounts.quotaDailyLimitHintFixed', { hour: String(dailyResetHour ?? 0).padStart(2, '0'), timezone: resetTimezone || 'UTC' }) }}
|
||||
</template>
|
||||
<template v-else>
|
||||
{{ t('admin.accounts.quotaDailyLimitHint') }}
|
||||
</template>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 周配额 -->
|
||||
@@ -112,7 +220,57 @@ const onWeeklyInput = (e: Event) => {
|
||||
:placeholder="t('admin.accounts.quotaLimitPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
<p class="input-hint">{{ t('admin.accounts.quotaWeeklyLimitHint') }}</p>
|
||||
<!-- 周配额重置模式 -->
|
||||
<div class="mt-2 flex items-center gap-2">
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400 whitespace-nowrap">{{ t('admin.accounts.quotaResetMode') }}</label>
|
||||
<select
|
||||
:value="weeklyResetMode || 'rolling'"
|
||||
@change="onWeeklyModeChange"
|
||||
class="input py-1 text-xs"
|
||||
>
|
||||
<option value="rolling">{{ t('admin.accounts.quotaResetModeRolling') }}</option>
|
||||
<option value="fixed">{{ t('admin.accounts.quotaResetModeFixed') }}</option>
|
||||
</select>
|
||||
</div>
|
||||
<!-- 固定模式:星期几 + 小时 -->
|
||||
<div v-if="weeklyResetMode === 'fixed'" class="mt-2 flex items-center gap-2 flex-wrap">
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400 whitespace-nowrap">{{ t('admin.accounts.quotaWeeklyResetDay') }}</label>
|
||||
<select
|
||||
:value="weeklyResetDay ?? 1"
|
||||
@change="emit('update:weeklyResetDay', Number(($event.target as HTMLSelectElement).value))"
|
||||
class="input py-1 text-xs w-28"
|
||||
>
|
||||
<option v-for="d in dayOptions" :key="d.value" :value="d.value">{{ t('admin.accounts.dayOfWeek.' + d.key) }}</option>
|
||||
</select>
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400 whitespace-nowrap">{{ t('admin.accounts.quotaResetHour') }}</label>
|
||||
<select
|
||||
:value="weeklyResetHour ?? 0"
|
||||
@change="emit('update:weeklyResetHour', Number(($event.target as HTMLSelectElement).value))"
|
||||
class="input py-1 text-xs w-24"
|
||||
>
|
||||
<option v-for="h in hourOptions" :key="h" :value="h">{{ String(h).padStart(2, '0') }}:00</option>
|
||||
</select>
|
||||
</div>
|
||||
<p class="input-hint">
|
||||
<template v-if="weeklyResetMode === 'fixed'">
|
||||
{{ t('admin.accounts.quotaWeeklyLimitHintFixed', { day: t('admin.accounts.dayOfWeek.' + (dayOptions.find(d => d.value === (weeklyResetDay ?? 1))?.key || 'monday')), hour: String(weeklyResetHour ?? 0).padStart(2, '0'), timezone: resetTimezone || 'UTC' }) }}
|
||||
</template>
|
||||
<template v-else>
|
||||
{{ t('admin.accounts.quotaWeeklyLimitHint') }}
|
||||
</template>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 时区选择(当任一维度使用固定模式时显示) -->
|
||||
<div v-if="hasFixedMode">
|
||||
<label class="input-label">{{ t('admin.accounts.quotaResetTimezone') }}</label>
|
||||
<select
|
||||
:value="resetTimezone || 'UTC'"
|
||||
@change="emit('update:resetTimezone', ($event.target as HTMLSelectElement).value)"
|
||||
class="input text-sm"
|
||||
>
|
||||
<option v-for="tz in timezoneOptions" :key="tz" :value="tz">{{ tz }}</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<!-- 总配额 -->
|
||||
|
||||
@@ -76,7 +76,7 @@ const hasRecoverableState = computed(() => {
|
||||
return props.account?.status === 'error' || Boolean(isRateLimited.value) || Boolean(isOverloaded.value) || Boolean(isTempUnschedulable.value)
|
||||
})
|
||||
const hasQuotaLimit = computed(() => {
|
||||
return props.account?.type === 'apikey' && (
|
||||
return (props.account?.type === 'apikey' || props.account?.type === 'bedrock') && (
|
||||
(props.account?.quota_limit ?? 0) > 0 ||
|
||||
(props.account?.quota_daily_limit ?? 0) > 0 ||
|
||||
(props.account?.quota_weekly_limit ?? 0) > 0
|
||||
|
||||
@@ -83,7 +83,7 @@ const typeLabel = computed(() => {
|
||||
case 'apikey':
|
||||
return 'Key'
|
||||
case 'bedrock':
|
||||
return 'Bedrock'
|
||||
return 'AWS'
|
||||
default:
|
||||
return props.type
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@
|
||||
</template>
|
||||
|
||||
<!-- Regular User View -->
|
||||
<template v-else>
|
||||
<template v-else-if="!appStore.backendModeEnabled">
|
||||
<div class="sidebar-section">
|
||||
<router-link
|
||||
v-for="item in userNavItems"
|
||||
@@ -357,36 +357,6 @@ const ServerIcon = {
|
||||
)
|
||||
}
|
||||
|
||||
const DatabaseIcon = {
|
||||
render: () =>
|
||||
h(
|
||||
'svg',
|
||||
{ fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' },
|
||||
[
|
||||
h('path', {
|
||||
'stroke-linecap': 'round',
|
||||
'stroke-linejoin': 'round',
|
||||
d: 'M3.75 5.25C3.75 4.007 7.443 3 12 3s8.25 1.007 8.25 2.25S16.557 7.5 12 7.5 3.75 6.493 3.75 5.25z'
|
||||
}),
|
||||
h('path', {
|
||||
'stroke-linecap': 'round',
|
||||
'stroke-linejoin': 'round',
|
||||
d: 'M3.75 5.25v4.5C3.75 10.993 7.443 12 12 12s8.25-1.007 8.25-2.25v-4.5'
|
||||
}),
|
||||
h('path', {
|
||||
'stroke-linecap': 'round',
|
||||
'stroke-linejoin': 'round',
|
||||
d: 'M3.75 9.75v4.5c0 1.243 3.693 2.25 8.25 2.25s8.25-1.007 8.25-2.25v-4.5'
|
||||
}),
|
||||
h('path', {
|
||||
'stroke-linecap': 'round',
|
||||
'stroke-linejoin': 'round',
|
||||
d: 'M3.75 14.25v4.5C3.75 19.993 7.443 21 12 21s8.25-1.007 8.25-2.25v-4.5'
|
||||
})
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
const BellIcon = {
|
||||
render: () =>
|
||||
h(
|
||||
@@ -611,7 +581,6 @@ const adminNavItems = computed((): NavItem[] => {
|
||||
if (authStore.isSimpleMode) {
|
||||
const filtered = baseItems.filter(item => !item.hideInSimpleMode)
|
||||
filtered.push({ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon })
|
||||
filtered.push({ path: '/admin/data-management', label: t('nav.dataManagement'), icon: DatabaseIcon })
|
||||
filtered.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon })
|
||||
// Add admin custom menu items after settings
|
||||
for (const cm of customMenuItemsForAdmin.value) {
|
||||
@@ -620,7 +589,6 @@ const adminNavItems = computed((): NavItem[] => {
|
||||
return filtered
|
||||
}
|
||||
|
||||
baseItems.push({ path: '/admin/data-management', label: t('nav.dataManagement'), icon: DatabaseIcon })
|
||||
baseItems.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon })
|
||||
// Add admin custom menu items after settings
|
||||
for (const cm of customMenuItemsForAdmin.value) {
|
||||
|
||||
@@ -84,9 +84,7 @@ onUnmounted(() => {
|
||||
}
|
||||
|
||||
.table-scroll-container :deep(th) {
|
||||
/* 表头高度和文字加粗优化 */
|
||||
@apply px-5 py-4 text-left text-sm font-bold text-gray-900 dark:text-white border-b border-gray-200 dark:border-dark-700;
|
||||
@apply uppercase tracking-wider; /* 让表头更有设计感 */
|
||||
@apply px-5 py-4 text-left text-sm font-medium text-gray-600 dark:text-dark-300 border-b border-gray-200 dark:border-dark-700;
|
||||
}
|
||||
|
||||
.table-scroll-container :deep(td) {
|
||||
|
||||
@@ -412,7 +412,7 @@ export function getPresetMappingsByPlatform(platform: string) {
|
||||
if (platform === 'gemini') return geminiPresetMappings
|
||||
if (platform === 'sora') return soraPresetMappings
|
||||
if (platform === 'antigravity') return antigravityPresetMappings
|
||||
if (platform === 'bedrock' || platform === 'bedrock-apikey') return bedrockPresetMappings
|
||||
if (platform === 'bedrock') return bedrockPresetMappings
|
||||
return anthropicPresetMappings
|
||||
}
|
||||
|
||||
|
||||
@@ -340,7 +340,6 @@ export default {
|
||||
redeemCodes: 'Redeem Codes',
|
||||
ops: 'Ops',
|
||||
promoCodes: 'Promo Codes',
|
||||
dataManagement: 'Data Management',
|
||||
settings: 'Settings',
|
||||
myAccount: 'My Account',
|
||||
lightMode: 'Light Mode',
|
||||
@@ -978,6 +977,111 @@ export default {
|
||||
failedToLoad: 'Failed to load dashboard statistics'
|
||||
},
|
||||
|
||||
backup: {
|
||||
title: 'Database Backup',
|
||||
description: 'Full database backup to S3-compatible storage with scheduled backup and restore',
|
||||
s3: {
|
||||
title: 'S3 Storage Configuration',
|
||||
description: 'Configure S3-compatible storage (supports Cloudflare R2)',
|
||||
descriptionPrefix: 'Configure S3-compatible storage (supports',
|
||||
descriptionSuffix: ')',
|
||||
enabled: 'Enable S3 Storage',
|
||||
endpoint: 'Endpoint',
|
||||
region: 'Region',
|
||||
bucket: 'Bucket',
|
||||
prefix: 'Key Prefix',
|
||||
accessKeyId: 'Access Key ID',
|
||||
secretAccessKey: 'Secret Access Key',
|
||||
secretConfigured: 'Already configured, leave empty to keep',
|
||||
forcePathStyle: 'Force Path Style',
|
||||
testConnection: 'Test Connection',
|
||||
testSuccess: 'S3 connection test successful',
|
||||
testFailed: 'S3 connection test failed',
|
||||
saved: 'S3 configuration saved'
|
||||
},
|
||||
schedule: {
|
||||
title: 'Scheduled Backup',
|
||||
description: 'Configure automatic scheduled backups',
|
||||
enabled: 'Enable Scheduled Backup',
|
||||
cronExpr: 'Cron Expression',
|
||||
cronHint: 'e.g. "0 2 * * *" means every day at 2:00 AM',
|
||||
retainDays: 'Backup Expire Days',
|
||||
retainDaysHint: 'Backup files auto-delete after this many days, 0 = never expire',
|
||||
retainCount: 'Max Retain Count',
|
||||
retainCountHint: 'Maximum number of backups to keep, 0 = unlimited',
|
||||
saved: 'Schedule configuration saved'
|
||||
},
|
||||
operations: {
|
||||
title: 'Backup Records',
|
||||
description: 'Create manual backups and manage existing backup records',
|
||||
createBackup: 'Create Backup',
|
||||
backing: 'Backing up...',
|
||||
backupCreated: 'Backup created successfully',
|
||||
expireDays: 'Expire Days'
|
||||
},
|
||||
columns: {
|
||||
status: 'Status',
|
||||
fileName: 'File Name',
|
||||
size: 'Size',
|
||||
expiresAt: 'Expires At',
|
||||
triggeredBy: 'Triggered By',
|
||||
startedAt: 'Started At',
|
||||
actions: 'Actions'
|
||||
},
|
||||
status: {
|
||||
pending: 'Pending',
|
||||
running: 'Running',
|
||||
completed: 'Completed',
|
||||
failed: 'Failed'
|
||||
},
|
||||
trigger: {
|
||||
manual: 'Manual',
|
||||
scheduled: 'Scheduled'
|
||||
},
|
||||
neverExpire: 'Never',
|
||||
empty: 'No backup records',
|
||||
actions: {
|
||||
download: 'Download',
|
||||
restore: 'Restore',
|
||||
restoreConfirm: 'Are you sure you want to restore from this backup? This will overwrite the current database!',
|
||||
restorePasswordPrompt: 'Please enter your admin password to confirm the restore operation',
|
||||
restoreSuccess: 'Database restored successfully',
|
||||
deleteConfirm: 'Are you sure you want to delete this backup?',
|
||||
deleted: 'Backup deleted'
|
||||
},
|
||||
r2Guide: {
|
||||
title: 'Cloudflare R2 Setup Guide',
|
||||
intro: 'Cloudflare R2 provides S3-compatible object storage with a free tier of 10GB storage + 1M Class A requests/month, ideal for database backups.',
|
||||
step1: {
|
||||
title: 'Create an R2 Bucket',
|
||||
line1: 'Log in to the Cloudflare Dashboard (dash.cloudflare.com), select "R2 Object Storage" from the sidebar',
|
||||
line2: 'Click "Create bucket", enter a name (e.g. sub2api-backups), choose a region',
|
||||
line3: 'Click create to finish'
|
||||
},
|
||||
step2: {
|
||||
title: 'Create an API Token',
|
||||
line1: 'On the R2 page, click "Manage R2 API Tokens" in the top right',
|
||||
line2: 'Click "Create API token", set permission to "Object Read & Write"',
|
||||
line3: 'Recommended: restrict to specific bucket for better security',
|
||||
line4: 'After creation, you will see the Access Key ID and Secret Access Key',
|
||||
warning: 'The Secret Access Key is only shown once — copy and save it immediately!'
|
||||
},
|
||||
step3: {
|
||||
title: 'Get the S3 Endpoint',
|
||||
desc: 'Find your Account ID on the R2 overview page (in the URL or the right panel). The endpoint format is:',
|
||||
accountId: 'your_account_id'
|
||||
},
|
||||
step4: {
|
||||
title: 'Fill in the Configuration',
|
||||
checkEnabled: 'Checked',
|
||||
bucketValue: 'Your bucket name',
|
||||
fromStep2: 'Value from Step 2',
|
||||
unchecked: 'Unchecked'
|
||||
},
|
||||
freeTier: 'R2 Free Tier: 10GB storage + 1M Class A requests + 10M Class B requests per month — more than enough for database backups.'
|
||||
}
|
||||
},
|
||||
|
||||
dataManagement: {
|
||||
title: 'Data Management',
|
||||
description: 'Manage data management agent status, object storage settings, and backup jobs in one place',
|
||||
@@ -1866,6 +1970,23 @@ export default {
|
||||
quotaWeeklyLimitHint: 'Automatically resets every 7 days from first usage.',
|
||||
quotaTotalLimit: 'Total Limit',
|
||||
quotaTotalLimitHint: 'Cumulative spending limit. Does not auto-reset — use "Reset Quota" to clear.',
|
||||
quotaResetMode: 'Reset Mode',
|
||||
quotaResetModeRolling: 'Rolling Window',
|
||||
quotaResetModeFixed: 'Fixed Time',
|
||||
quotaResetHour: 'Reset Hour',
|
||||
quotaWeeklyResetDay: 'Reset Day',
|
||||
quotaResetTimezone: 'Reset Timezone',
|
||||
quotaDailyLimitHintFixed: 'Resets daily at {hour}:00 ({timezone}).',
|
||||
quotaWeeklyLimitHintFixed: 'Resets every {day} at {hour}:00 ({timezone}).',
|
||||
dayOfWeek: {
|
||||
monday: 'Monday',
|
||||
tuesday: 'Tuesday',
|
||||
wednesday: 'Wednesday',
|
||||
thursday: 'Thursday',
|
||||
friday: 'Friday',
|
||||
saturday: 'Saturday',
|
||||
sunday: 'Sunday',
|
||||
},
|
||||
quotaLimitAmount: 'Total Limit',
|
||||
quotaLimitAmountHint: 'Cumulative spending limit. Does not auto-reset.',
|
||||
testConnection: 'Test Connection',
|
||||
@@ -1934,7 +2055,7 @@ export default {
|
||||
claudeCode: 'Claude Code',
|
||||
claudeConsole: 'Claude Console',
|
||||
bedrockLabel: 'AWS Bedrock',
|
||||
bedrockDesc: 'SigV4 Signing',
|
||||
bedrockDesc: 'SigV4 / API Key',
|
||||
oauthSetupToken: 'OAuth / Setup Token',
|
||||
addMethod: 'Add Method',
|
||||
setupTokenLongLived: 'Setup Token (Long-lived)',
|
||||
@@ -2136,6 +2257,9 @@ export default {
|
||||
bedrockRegionRequired: 'Please select AWS Region',
|
||||
bedrockSessionTokenHint: 'Optional, for temporary credentials',
|
||||
bedrockSecretKeyLeaveEmpty: 'Leave empty to keep current key',
|
||||
bedrockAuthMode: 'Authentication Mode',
|
||||
bedrockAuthModeSigv4: 'SigV4 Signing',
|
||||
bedrockAuthModeApikey: 'Bedrock API Key',
|
||||
bedrockApiKeyLabel: 'Bedrock API Key',
|
||||
bedrockApiKeyDesc: 'Bearer Token',
|
||||
bedrockApiKeyInput: 'API Key',
|
||||
@@ -2555,7 +2679,16 @@ export default {
|
||||
unlimited: 'Unlimited'
|
||||
},
|
||||
ineligibleWarning:
|
||||
'This account is not eligible for Antigravity, but API forwarding still works. Use at your own risk.'
|
||||
'This account is not eligible for Antigravity, but API forwarding still works. Use at your own risk.',
|
||||
forbidden: 'Forbidden',
|
||||
forbiddenValidation: 'Verification Required',
|
||||
forbiddenViolation: 'Violation Ban',
|
||||
openVerification: 'Open Verification Link',
|
||||
copyLink: 'Copy Link',
|
||||
linkCopied: 'Link Copied',
|
||||
needsReauth: 'Re-auth Required',
|
||||
rateLimited: 'Rate Limited',
|
||||
usageError: 'Fetch Error'
|
||||
},
|
||||
|
||||
// Scheduled Tests
|
||||
@@ -3709,6 +3842,11 @@ export default {
|
||||
refreshInterval15s: '15 seconds',
|
||||
refreshInterval30s: '30 seconds',
|
||||
refreshInterval60s: '60 seconds',
|
||||
dashboardCards: 'Dashboard Cards',
|
||||
displayAlertEvents: 'Display alert events',
|
||||
displayAlertEventsHint: 'Show or hide the recent alert events card on the ops dashboard. Enabled by default.',
|
||||
displayOpenAITokenStats: 'Display OpenAI token request stats',
|
||||
displayOpenAITokenStatsHint: 'Show or hide the OpenAI token request stats card on the ops dashboard. Hidden by default.',
|
||||
autoRefreshCountdown: 'Auto refresh: {seconds}s',
|
||||
validation: {
|
||||
title: 'Please fix the following issues',
|
||||
@@ -3798,6 +3936,8 @@ export default {
|
||||
users: 'Users',
|
||||
gateway: 'Gateway',
|
||||
email: 'Email',
|
||||
backup: 'Backup',
|
||||
data: 'Sora Storage',
|
||||
},
|
||||
emailTabDisabledTitle: 'Email Verification Not Enabled',
|
||||
emailTabDisabledHint: 'Enable email verification in the Security tab to configure SMTP settings.',
|
||||
@@ -3888,6 +4028,9 @@ export default {
|
||||
site: {
|
||||
title: 'Site Settings',
|
||||
description: 'Customize site branding',
|
||||
backendMode: 'Backend Mode',
|
||||
backendModeDescription:
|
||||
'Disables user registration, public site, and self-service features. Only admin can log in and manage the platform.',
|
||||
siteName: 'Site Name',
|
||||
siteNamePlaceholder: 'Sub2API',
|
||||
siteNameHint: 'Displayed in emails and page titles',
|
||||
@@ -4127,6 +4270,7 @@ export default {
|
||||
scopeAll: 'All accounts',
|
||||
scopeOAuth: 'OAuth only',
|
||||
scopeAPIKey: 'API Key only',
|
||||
scopeBedrock: 'Bedrock only',
|
||||
errorMessage: 'Error message',
|
||||
errorMessagePlaceholder: 'Custom error message when blocked',
|
||||
errorMessageHint: 'Leave empty for default message',
|
||||
|
||||
@@ -340,7 +340,6 @@ export default {
|
||||
redeemCodes: '兑换码',
|
||||
ops: '运维监控',
|
||||
promoCodes: '优惠码',
|
||||
dataManagement: '数据管理',
|
||||
settings: '系统设置',
|
||||
myAccount: '我的账户',
|
||||
lightMode: '浅色模式',
|
||||
@@ -1000,6 +999,111 @@ export default {
|
||||
failedToLoad: '加载仪表盘数据失败'
|
||||
},
|
||||
|
||||
backup: {
|
||||
title: '数据库备份',
|
||||
description: '全量数据库备份到 S3 兼容存储,支持定时备份与恢复',
|
||||
s3: {
|
||||
title: 'S3 存储配置',
|
||||
description: '配置 S3 兼容存储(支持 Cloudflare R2)',
|
||||
descriptionPrefix: '配置 S3 兼容存储(支持',
|
||||
descriptionSuffix: ')',
|
||||
enabled: '启用 S3 存储',
|
||||
endpoint: '端点地址',
|
||||
region: '区域',
|
||||
bucket: '存储桶',
|
||||
prefix: 'Key 前缀',
|
||||
accessKeyId: 'Access Key ID',
|
||||
secretAccessKey: 'Secret Access Key',
|
||||
secretConfigured: '已配置,留空保持不变',
|
||||
forcePathStyle: '强制路径风格',
|
||||
testConnection: '测试连接',
|
||||
testSuccess: 'S3 连接测试成功',
|
||||
testFailed: 'S3 连接测试失败',
|
||||
saved: 'S3 配置已保存'
|
||||
},
|
||||
schedule: {
|
||||
title: '定时备份',
|
||||
description: '配置自动定时备份',
|
||||
enabled: '启用定时备份',
|
||||
cronExpr: 'Cron 表达式',
|
||||
cronHint: '例如 "0 2 * * *" 表示每天凌晨 2 点',
|
||||
retainDays: '备份过期天数',
|
||||
retainDaysHint: '备份文件超过此天数后自动删除,0 = 永不过期',
|
||||
retainCount: '最大保留份数',
|
||||
retainCountHint: '最多保留的备份数量,0 = 不限制',
|
||||
saved: '定时备份配置已保存'
|
||||
},
|
||||
operations: {
|
||||
title: '备份记录',
|
||||
description: '创建手动备份和管理已有备份记录',
|
||||
createBackup: '创建备份',
|
||||
backing: '备份中...',
|
||||
backupCreated: '备份创建成功',
|
||||
expireDays: '过期天数'
|
||||
},
|
||||
columns: {
|
||||
status: '状态',
|
||||
fileName: '文件名',
|
||||
size: '大小',
|
||||
expiresAt: '过期时间',
|
||||
triggeredBy: '触发方式',
|
||||
startedAt: '开始时间',
|
||||
actions: '操作'
|
||||
},
|
||||
status: {
|
||||
pending: '等待中',
|
||||
running: '执行中',
|
||||
completed: '已完成',
|
||||
failed: '失败'
|
||||
},
|
||||
trigger: {
|
||||
manual: '手动',
|
||||
scheduled: '定时'
|
||||
},
|
||||
neverExpire: '永不过期',
|
||||
empty: '暂无备份记录',
|
||||
actions: {
|
||||
download: '下载',
|
||||
restore: '恢复',
|
||||
restoreConfirm: '确定要从此备份恢复吗?这将覆盖当前数据库!',
|
||||
restorePasswordPrompt: '请输入管理员密码以确认恢复操作',
|
||||
restoreSuccess: '数据库恢复成功',
|
||||
deleteConfirm: '确定要删除此备份吗?',
|
||||
deleted: '备份已删除'
|
||||
},
|
||||
r2Guide: {
|
||||
title: 'Cloudflare R2 配置教程',
|
||||
intro: 'Cloudflare R2 提供 S3 兼容的对象存储,免费额度为 10GB 存储 + 每月 100 万次 A 类请求,非常适合数据库备份。',
|
||||
step1: {
|
||||
title: '创建 R2 存储桶',
|
||||
line1: '登录 Cloudflare Dashboard (dash.cloudflare.com),左侧菜单选择「R2 对象存储」',
|
||||
line2: '点击「创建存储桶」,输入名称(如 sub2api-backups),选择区域',
|
||||
line3: '点击创建完成'
|
||||
},
|
||||
step2: {
|
||||
title: '创建 API 令牌',
|
||||
line1: '在 R2 页面,点击右上角「管理 R2 API 令牌」',
|
||||
line2: '点击「创建 API 令牌」,权限选择「对象读和写」',
|
||||
line3: '建议指定存储桶范围(仅允许访问备份桶,更安全)',
|
||||
line4: '创建后会显示 Access Key ID 和 Secret Access Key',
|
||||
warning: 'Secret Access Key 只会显示一次,请立即复制保存!'
|
||||
},
|
||||
step3: {
|
||||
title: '获取 S3 端点地址',
|
||||
desc: '在 R2 概览页面找到你的账户 ID(在 URL 或右侧面板中),端点格式为:',
|
||||
accountId: '你的账户 ID'
|
||||
},
|
||||
step4: {
|
||||
title: '填写以下配置',
|
||||
checkEnabled: '勾选',
|
||||
bucketValue: '你创建的存储桶名称',
|
||||
fromStep2: '第 2 步获取的值',
|
||||
unchecked: '不勾选'
|
||||
},
|
||||
freeTier: 'R2 免费额度:10GB 存储 + 每月 100 万次 A 类请求 + 1000 万次 B 类请求,对数据库备份完全够用。'
|
||||
}
|
||||
},
|
||||
|
||||
dataManagement: {
|
||||
title: '数据管理',
|
||||
description: '统一管理数据管理代理状态、对象存储配置和备份任务',
|
||||
@@ -1872,6 +1976,23 @@ export default {
|
||||
quotaWeeklyLimitHint: '从首次使用起每 7 天自动重置。',
|
||||
quotaTotalLimit: '总限额',
|
||||
quotaTotalLimitHint: '累计消费上限,不会自动重置 — 使用「重置配额」手动清零。',
|
||||
quotaResetMode: '重置方式',
|
||||
quotaResetModeRolling: '滚动窗口',
|
||||
quotaResetModeFixed: '固定时间',
|
||||
quotaResetHour: '重置时间',
|
||||
quotaWeeklyResetDay: '重置日',
|
||||
quotaResetTimezone: '重置时区',
|
||||
quotaDailyLimitHintFixed: '每天 {hour}:00({timezone})重置。',
|
||||
quotaWeeklyLimitHintFixed: '每{day} {hour}:00({timezone})重置。',
|
||||
dayOfWeek: {
|
||||
monday: '周一',
|
||||
tuesday: '周二',
|
||||
wednesday: '周三',
|
||||
thursday: '周四',
|
||||
friday: '周五',
|
||||
saturday: '周六',
|
||||
sunday: '周日',
|
||||
},
|
||||
quotaLimitAmount: '总限额',
|
||||
quotaLimitAmountHint: '累计消费上限,不会自动重置。',
|
||||
testConnection: '测试连接',
|
||||
@@ -1992,6 +2113,15 @@ export default {
|
||||
},
|
||||
ineligibleWarning:
|
||||
'该账号无 Antigravity 使用权限,但仍能进行 API 转发。继续使用请自行承担风险。',
|
||||
forbidden: '已封禁',
|
||||
forbiddenValidation: '需要验证',
|
||||
forbiddenViolation: '违规封禁',
|
||||
openVerification: '打开验证链接',
|
||||
copyLink: '复制链接',
|
||||
linkCopied: '链接已复制',
|
||||
needsReauth: '需要重新授权',
|
||||
rateLimited: '限流中',
|
||||
usageError: '获取失败',
|
||||
form: {
|
||||
nameLabel: '账号名称',
|
||||
namePlaceholder: '请输入账号名称',
|
||||
@@ -2082,7 +2212,7 @@ export default {
|
||||
claudeCode: 'Claude Code',
|
||||
claudeConsole: 'Claude Console',
|
||||
bedrockLabel: 'AWS Bedrock',
|
||||
bedrockDesc: 'SigV4 签名',
|
||||
bedrockDesc: 'SigV4 / API Key',
|
||||
oauthSetupToken: 'OAuth / Setup Token',
|
||||
addMethod: '添加方式',
|
||||
setupTokenLongLived: 'Setup Token(长期有效)',
|
||||
@@ -2277,6 +2407,9 @@ export default {
|
||||
bedrockRegionRequired: '请选择 AWS Region',
|
||||
bedrockSessionTokenHint: '可选,用于临时凭证',
|
||||
bedrockSecretKeyLeaveEmpty: '留空以保持当前密钥',
|
||||
bedrockAuthMode: '认证方式',
|
||||
bedrockAuthModeSigv4: 'SigV4 签名',
|
||||
bedrockAuthModeApikey: 'Bedrock API Key',
|
||||
bedrockApiKeyLabel: 'Bedrock API Key',
|
||||
bedrockApiKeyDesc: 'Bearer Token 认证',
|
||||
bedrockApiKeyInput: 'API Key',
|
||||
@@ -3883,6 +4016,11 @@ export default {
|
||||
refreshInterval15s: '15 秒',
|
||||
refreshInterval30s: '30 秒',
|
||||
refreshInterval60s: '60 秒',
|
||||
dashboardCards: '仪表盘卡片',
|
||||
displayAlertEvents: '展示告警事件',
|
||||
displayAlertEventsHint: '控制运维监控仪表盘中告警事件卡片是否显示,默认开启。',
|
||||
displayOpenAITokenStats: '展示 OpenAI Token 请求统计',
|
||||
displayOpenAITokenStatsHint: '控制运维监控仪表盘中 OpenAI Token 请求统计卡片是否显示,默认关闭。',
|
||||
autoRefreshCountdown: '自动刷新:{seconds}s',
|
||||
validation: {
|
||||
title: '请先修正以下问题',
|
||||
@@ -3972,6 +4110,8 @@ export default {
|
||||
users: '用户默认值',
|
||||
gateway: '网关服务',
|
||||
email: '邮件设置',
|
||||
backup: '数据备份',
|
||||
data: 'Sora 存储',
|
||||
},
|
||||
emailTabDisabledTitle: '邮箱验证未启用',
|
||||
emailTabDisabledHint: '请在「安全与认证」选项卡中启用邮箱验证后,再配置 SMTP 设置。',
|
||||
@@ -4060,6 +4200,9 @@ export default {
|
||||
site: {
|
||||
title: '站点设置',
|
||||
description: '自定义站点品牌',
|
||||
backendMode: 'Backend 模式',
|
||||
backendModeDescription:
|
||||
'禁用用户注册、公开页面和自助服务功能。仅管理员可以登录和管理平台。',
|
||||
siteName: '站点名称',
|
||||
siteNameHint: '显示在邮件和页面标题中',
|
||||
siteNamePlaceholder: 'Sub2API',
|
||||
@@ -4300,6 +4443,7 @@ export default {
|
||||
scopeAll: '全部账号',
|
||||
scopeOAuth: '仅 OAuth 账号',
|
||||
scopeAPIKey: '仅 API Key 账号',
|
||||
scopeBedrock: '仅 Bedrock 账号',
|
||||
errorMessage: '错误消息',
|
||||
errorMessagePlaceholder: '拦截时返回的自定义错误消息',
|
||||
errorMessageHint: '留空则使用默认错误消息',
|
||||
|
||||
@@ -51,6 +51,7 @@ interface MockAuthState {
|
||||
isAuthenticated: boolean
|
||||
isAdmin: boolean
|
||||
isSimpleMode: boolean
|
||||
backendModeEnabled: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -70,8 +71,17 @@ function simulateGuard(
|
||||
authState.isAuthenticated &&
|
||||
(toPath === '/login' || toPath === '/register')
|
||||
) {
|
||||
if (authState.backendModeEnabled && !authState.isAdmin) {
|
||||
return null
|
||||
}
|
||||
return authState.isAdmin ? '/admin/dashboard' : '/dashboard'
|
||||
}
|
||||
if (authState.backendModeEnabled && !authState.isAuthenticated) {
|
||||
const allowed = ['/login', '/key-usage', '/setup']
|
||||
if (!allowed.some((path) => toPath === path || toPath.startsWith(path))) {
|
||||
return '/login'
|
||||
}
|
||||
}
|
||||
return null // 允许通过
|
||||
}
|
||||
|
||||
@@ -99,6 +109,17 @@ function simulateGuard(
|
||||
}
|
||||
}
|
||||
|
||||
// Backend mode: admin gets full access, non-admin blocked
|
||||
if (authState.backendModeEnabled) {
|
||||
if (authState.isAuthenticated && authState.isAdmin) {
|
||||
return null
|
||||
}
|
||||
const allowed = ['/login', '/key-usage', '/setup']
|
||||
if (!allowed.some((path) => toPath === path || toPath.startsWith(path))) {
|
||||
return '/login'
|
||||
}
|
||||
}
|
||||
|
||||
return null // 允许通过
|
||||
}
|
||||
|
||||
@@ -114,6 +135,7 @@ describe('路由守卫逻辑', () => {
|
||||
isAuthenticated: false,
|
||||
isAdmin: false,
|
||||
isSimpleMode: false,
|
||||
backendModeEnabled: false,
|
||||
}
|
||||
|
||||
it('访问需要认证的页面重定向到 /login', () => {
|
||||
@@ -144,6 +166,7 @@ describe('路由守卫逻辑', () => {
|
||||
isAuthenticated: true,
|
||||
isAdmin: false,
|
||||
isSimpleMode: false,
|
||||
backendModeEnabled: false,
|
||||
}
|
||||
|
||||
it('访问 /login 重定向到 /dashboard', () => {
|
||||
@@ -179,6 +202,7 @@ describe('路由守卫逻辑', () => {
|
||||
isAuthenticated: true,
|
||||
isAdmin: true,
|
||||
isSimpleMode: false,
|
||||
backendModeEnabled: false,
|
||||
}
|
||||
|
||||
it('访问 /login 重定向到 /admin/dashboard', () => {
|
||||
@@ -205,6 +229,7 @@ describe('路由守卫逻辑', () => {
|
||||
isAuthenticated: true,
|
||||
isAdmin: false,
|
||||
isSimpleMode: true,
|
||||
backendModeEnabled: false,
|
||||
}
|
||||
const redirect = simulateGuard('/subscriptions', {}, authState)
|
||||
expect(redirect).toBe('/dashboard')
|
||||
@@ -215,6 +240,7 @@ describe('路由守卫逻辑', () => {
|
||||
isAuthenticated: true,
|
||||
isAdmin: false,
|
||||
isSimpleMode: true,
|
||||
backendModeEnabled: false,
|
||||
}
|
||||
const redirect = simulateGuard('/redeem', {}, authState)
|
||||
expect(redirect).toBe('/dashboard')
|
||||
@@ -225,6 +251,7 @@ describe('路由守卫逻辑', () => {
|
||||
isAuthenticated: true,
|
||||
isAdmin: true,
|
||||
isSimpleMode: true,
|
||||
backendModeEnabled: false,
|
||||
}
|
||||
const redirect = simulateGuard('/admin/groups', { requiresAdmin: true }, authState)
|
||||
expect(redirect).toBe('/admin/dashboard')
|
||||
@@ -235,6 +262,7 @@ describe('路由守卫逻辑', () => {
|
||||
isAuthenticated: true,
|
||||
isAdmin: true,
|
||||
isSimpleMode: true,
|
||||
backendModeEnabled: false,
|
||||
}
|
||||
const redirect = simulateGuard(
|
||||
'/admin/subscriptions',
|
||||
@@ -249,6 +277,7 @@ describe('路由守卫逻辑', () => {
|
||||
isAuthenticated: true,
|
||||
isAdmin: false,
|
||||
isSimpleMode: true,
|
||||
backendModeEnabled: false,
|
||||
}
|
||||
const redirect = simulateGuard('/dashboard', {}, authState)
|
||||
expect(redirect).toBeNull()
|
||||
@@ -259,9 +288,111 @@ describe('路由守卫逻辑', () => {
|
||||
isAuthenticated: true,
|
||||
isAdmin: false,
|
||||
isSimpleMode: true,
|
||||
backendModeEnabled: false,
|
||||
}
|
||||
const redirect = simulateGuard('/keys', {}, authState)
|
||||
expect(redirect).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Backend Mode', () => {
|
||||
it('unauthenticated: /home redirects to /login', () => {
|
||||
const authState: MockAuthState = {
|
||||
isAuthenticated: false,
|
||||
isAdmin: false,
|
||||
isSimpleMode: false,
|
||||
backendModeEnabled: true,
|
||||
}
|
||||
const redirect = simulateGuard('/home', { requiresAuth: false }, authState)
|
||||
expect(redirect).toBe('/login')
|
||||
})
|
||||
|
||||
it('unauthenticated: /login is allowed', () => {
|
||||
const authState: MockAuthState = {
|
||||
isAuthenticated: false,
|
||||
isAdmin: false,
|
||||
isSimpleMode: false,
|
||||
backendModeEnabled: true,
|
||||
}
|
||||
const redirect = simulateGuard('/login', { requiresAuth: false }, authState)
|
||||
expect(redirect).toBeNull()
|
||||
})
|
||||
|
||||
it('unauthenticated: /key-usage is allowed', () => {
|
||||
const authState: MockAuthState = {
|
||||
isAuthenticated: false,
|
||||
isAdmin: false,
|
||||
isSimpleMode: false,
|
||||
backendModeEnabled: true,
|
||||
}
|
||||
const redirect = simulateGuard('/key-usage', { requiresAuth: false }, authState)
|
||||
expect(redirect).toBeNull()
|
||||
})
|
||||
|
||||
it('unauthenticated: /setup is allowed', () => {
|
||||
const authState: MockAuthState = {
|
||||
isAuthenticated: false,
|
||||
isAdmin: false,
|
||||
isSimpleMode: false,
|
||||
backendModeEnabled: true,
|
||||
}
|
||||
const redirect = simulateGuard('/setup', { requiresAuth: false }, authState)
|
||||
expect(redirect).toBeNull()
|
||||
})
|
||||
|
||||
it('admin: /admin/dashboard is allowed', () => {
|
||||
const authState: MockAuthState = {
|
||||
isAuthenticated: true,
|
||||
isAdmin: true,
|
||||
isSimpleMode: false,
|
||||
backendModeEnabled: true,
|
||||
}
|
||||
const redirect = simulateGuard('/admin/dashboard', { requiresAdmin: true }, authState)
|
||||
expect(redirect).toBeNull()
|
||||
})
|
||||
|
||||
it('admin: /login redirects to /admin/dashboard', () => {
|
||||
const authState: MockAuthState = {
|
||||
isAuthenticated: true,
|
||||
isAdmin: true,
|
||||
isSimpleMode: false,
|
||||
backendModeEnabled: true,
|
||||
}
|
||||
const redirect = simulateGuard('/login', { requiresAuth: false }, authState)
|
||||
expect(redirect).toBe('/admin/dashboard')
|
||||
})
|
||||
|
||||
it('non-admin authenticated: /dashboard redirects to /login', () => {
|
||||
const authState: MockAuthState = {
|
||||
isAuthenticated: true,
|
||||
isAdmin: false,
|
||||
isSimpleMode: false,
|
||||
backendModeEnabled: true,
|
||||
}
|
||||
const redirect = simulateGuard('/dashboard', {}, authState)
|
||||
expect(redirect).toBe('/login')
|
||||
})
|
||||
|
||||
it('non-admin authenticated: /login is allowed (no redirect loop)', () => {
|
||||
const authState: MockAuthState = {
|
||||
isAuthenticated: true,
|
||||
isAdmin: false,
|
||||
isSimpleMode: false,
|
||||
backendModeEnabled: true,
|
||||
}
|
||||
const redirect = simulateGuard('/login', { requiresAuth: false }, authState)
|
||||
expect(redirect).toBeNull()
|
||||
})
|
||||
|
||||
it('non-admin authenticated: /key-usage is allowed', () => {
|
||||
const authState: MockAuthState = {
|
||||
isAuthenticated: true,
|
||||
isAdmin: false,
|
||||
isSimpleMode: false,
|
||||
backendModeEnabled: true,
|
||||
}
|
||||
const redirect = simulateGuard('/key-usage', { requiresAuth: false }, authState)
|
||||
expect(redirect).toBeNull()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -350,18 +350,6 @@ const routes: RouteRecordRaw[] = [
|
||||
descriptionKey: 'admin.promo.description'
|
||||
}
|
||||
},
|
||||
{
|
||||
path: '/admin/data-management',
|
||||
name: 'AdminDataManagement',
|
||||
component: () => import('@/views/admin/DataManagementView.vue'),
|
||||
meta: {
|
||||
requiresAuth: true,
|
||||
requiresAdmin: true,
|
||||
title: 'Data Management',
|
||||
titleKey: 'admin.dataManagement.title',
|
||||
descriptionKey: 'admin.dataManagement.description'
|
||||
}
|
||||
},
|
||||
{
|
||||
path: '/admin/settings',
|
||||
name: 'AdminSettings',
|
||||
@@ -423,6 +411,7 @@ let authInitialized = false
|
||||
const navigationLoading = useNavigationLoadingState()
|
||||
// 延迟初始化预加载,传入 router 实例
|
||||
let routePrefetch: ReturnType<typeof useRoutePrefetch> | null = null
|
||||
const BACKEND_MODE_ALLOWED_PATHS = ['/login', '/key-usage', '/setup']
|
||||
|
||||
router.beforeEach((to, _from, next) => {
|
||||
// 开始导航加载状态
|
||||
@@ -463,10 +452,24 @@ router.beforeEach((to, _from, next) => {
|
||||
if (!requiresAuth) {
|
||||
// If already authenticated and trying to access login/register, redirect to appropriate dashboard
|
||||
if (authStore.isAuthenticated && (to.path === '/login' || to.path === '/register')) {
|
||||
// In backend mode, non-admin users should NOT be redirected away from login
|
||||
// (they are blocked from all protected routes, so redirecting would cause a loop)
|
||||
if (appStore.backendModeEnabled && !authStore.isAdmin) {
|
||||
next()
|
||||
return
|
||||
}
|
||||
// Admin users go to admin dashboard, regular users go to user dashboard
|
||||
next(authStore.isAdmin ? '/admin/dashboard' : '/dashboard')
|
||||
return
|
||||
}
|
||||
// Backend mode: block public pages for unauthenticated users (except login, key-usage, setup)
|
||||
if (appStore.backendModeEnabled && !authStore.isAuthenticated) {
|
||||
const isAllowed = BACKEND_MODE_ALLOWED_PATHS.some((p) => to.path === p || to.path.startsWith(p))
|
||||
if (!isAllowed) {
|
||||
next('/login')
|
||||
return
|
||||
}
|
||||
}
|
||||
next()
|
||||
return
|
||||
}
|
||||
@@ -505,6 +508,19 @@ router.beforeEach((to, _from, next) => {
|
||||
}
|
||||
}
|
||||
|
||||
// Backend mode: admin gets full access, non-admin blocked
|
||||
if (appStore.backendModeEnabled) {
|
||||
if (authStore.isAuthenticated && authStore.isAdmin) {
|
||||
next()
|
||||
return
|
||||
}
|
||||
const isAllowed = BACKEND_MODE_ALLOWED_PATHS.some((p) => to.path === p || to.path.startsWith(p))
|
||||
if (!isAllowed) {
|
||||
next('/login')
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// All checks passed, allow navigation
|
||||
next()
|
||||
})
|
||||
|
||||
@@ -47,6 +47,7 @@ export const useAppStore = defineStore('app', () => {
|
||||
// ==================== Computed ====================
|
||||
|
||||
const hasActiveToasts = computed(() => toasts.value.length > 0)
|
||||
const backendModeEnabled = computed(() => cachedPublicSettings.value?.backend_mode_enabled ?? false)
|
||||
|
||||
const loadingCount = ref<number>(0)
|
||||
|
||||
@@ -331,6 +332,7 @@ export const useAppStore = defineStore('app', () => {
|
||||
custom_menu_items: [],
|
||||
linuxdo_oauth_enabled: false,
|
||||
sora_client_enabled: false,
|
||||
backend_mode_enabled: false,
|
||||
version: siteVersion.value
|
||||
}
|
||||
}
|
||||
@@ -404,6 +406,7 @@ export const useAppStore = defineStore('app', () => {
|
||||
|
||||
// Computed
|
||||
hasActiveToasts,
|
||||
backendModeEnabled,
|
||||
|
||||
// Actions
|
||||
toggleSidebar,
|
||||
|
||||
@@ -106,6 +106,7 @@ export interface PublicSettings {
|
||||
custom_menu_items: CustomMenuItem[]
|
||||
linuxdo_oauth_enabled: boolean
|
||||
sora_client_enabled: boolean
|
||||
backend_mode_enabled: boolean
|
||||
version: string
|
||||
}
|
||||
|
||||
@@ -531,7 +532,7 @@ export interface UpdateGroupRequest {
|
||||
// ==================== Account & Proxy Types ====================
|
||||
|
||||
export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' | 'sora'
|
||||
export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream' | 'bedrock' | 'bedrock-apikey'
|
||||
export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream' | 'bedrock'
|
||||
export type OAuthAddMethod = 'oauth' | 'setup-token'
|
||||
export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h'
|
||||
|
||||
@@ -727,6 +728,16 @@ export interface Account {
|
||||
quota_weekly_limit?: number | null
|
||||
quota_weekly_used?: number | null
|
||||
|
||||
// 配额固定时间重置配置
|
||||
quota_daily_reset_mode?: 'rolling' | 'fixed' | null
|
||||
quota_daily_reset_hour?: number | null
|
||||
quota_weekly_reset_mode?: 'rolling' | 'fixed' | null
|
||||
quota_weekly_reset_day?: number | null
|
||||
quota_weekly_reset_hour?: number | null
|
||||
quota_reset_timezone?: string | null
|
||||
quota_daily_reset_at?: string | null
|
||||
quota_weekly_reset_at?: string | null
|
||||
|
||||
// 运行时状态(仅当启用对应限制时返回)
|
||||
current_window_cost?: number | null // 当前窗口费用
|
||||
active_sessions?: number | null // 当前活跃会话数
|
||||
@@ -769,6 +780,21 @@ export interface AccountUsageInfo {
|
||||
gemini_pro_minute?: UsageProgress | null
|
||||
gemini_flash_minute?: UsageProgress | null
|
||||
antigravity_quota?: Record<string, AntigravityModelQuota> | null
|
||||
// Antigravity 403 forbidden 状态
|
||||
is_forbidden?: boolean
|
||||
forbidden_reason?: string
|
||||
forbidden_type?: string // "validation" | "violation" | "forbidden"
|
||||
validation_url?: string // 验证/申诉链接
|
||||
|
||||
// 状态标记(后端自动推导)
|
||||
needs_verify?: boolean // 需要人工验证(forbidden_type=validation)
|
||||
is_banned?: boolean // 账号被封(forbidden_type=violation)
|
||||
needs_reauth?: boolean // token 失效需重新授权(401)
|
||||
|
||||
// 机器可读错误码:forbidden / unauthenticated / rate_limited / network_error
|
||||
error_code?: string
|
||||
|
||||
error?: string // usage 获取失败时的错误信息
|
||||
}
|
||||
|
||||
// OpenAI Codex usage snapshot (from response headers)
|
||||
|
||||
@@ -171,7 +171,15 @@
|
||||
<span v-else class="text-sm text-gray-400 dark:text-dark-500">-</span>
|
||||
</template>
|
||||
<template #cell-platform_type="{ row }">
|
||||
<PlatformTypeBadge :platform="row.platform" :type="row.type" :plan-type="row.credentials?.plan_type" :privacy-mode="row.extra?.privacy_mode" />
|
||||
<div class="flex flex-wrap items-center gap-1">
|
||||
<PlatformTypeBadge :platform="row.platform" :type="row.type" :plan-type="row.credentials?.plan_type" :privacy-mode="row.extra?.privacy_mode" />
|
||||
<span
|
||||
v-if="getAntigravityTierLabel(row)"
|
||||
:class="['inline-block rounded px-1.5 py-0.5 text-[10px] font-medium', getAntigravityTierClass(row)]"
|
||||
>
|
||||
{{ getAntigravityTierLabel(row) }}
|
||||
</span>
|
||||
</div>
|
||||
</template>
|
||||
<template #cell-capacity="{ row }">
|
||||
<AccountCapacityCell :account="row" />
|
||||
@@ -794,6 +802,40 @@ const { pause: pauseAutoRefresh, resume: resumeAutoRefresh } = useIntervalFn(
|
||||
{ immediate: false }
|
||||
)
|
||||
|
||||
// Antigravity 订阅等级辅助函数
|
||||
function getAntigravityTierFromRow(row: any): string | null {
|
||||
if (row.platform !== 'antigravity') return null
|
||||
const extra = row.extra as Record<string, unknown> | undefined
|
||||
if (!extra) return null
|
||||
const lca = extra.load_code_assist as Record<string, unknown> | undefined
|
||||
if (!lca) return null
|
||||
const paid = lca.paidTier as Record<string, unknown> | undefined
|
||||
if (paid && typeof paid.id === 'string') return paid.id
|
||||
const current = lca.currentTier as Record<string, unknown> | undefined
|
||||
if (current && typeof current.id === 'string') return current.id
|
||||
return null
|
||||
}
|
||||
|
||||
function getAntigravityTierLabel(row: any): string | null {
|
||||
const tier = getAntigravityTierFromRow(row)
|
||||
switch (tier) {
|
||||
case 'free-tier': return t('admin.accounts.tier.free')
|
||||
case 'g1-pro-tier': return t('admin.accounts.tier.pro')
|
||||
case 'g1-ultra-tier': return t('admin.accounts.tier.ultra')
|
||||
default: return null
|
||||
}
|
||||
}
|
||||
|
||||
function getAntigravityTierClass(row: any): string {
|
||||
const tier = getAntigravityTierFromRow(row)
|
||||
switch (tier) {
|
||||
case 'free-tier': return 'bg-gray-100 text-gray-600 dark:bg-gray-700 dark:text-gray-300'
|
||||
case 'g1-pro-tier': return 'bg-blue-100 text-blue-600 dark:bg-blue-900/40 dark:text-blue-300'
|
||||
case 'g1-ultra-tier': return 'bg-purple-100 text-purple-600 dark:bg-purple-900/40 dark:text-purple-300'
|
||||
default: return ''
|
||||
}
|
||||
}
|
||||
|
||||
// All available columns
|
||||
const allColumns = computed(() => {
|
||||
const c = [
|
||||
|
||||
505
frontend/src/views/admin/BackupView.vue
Normal file
505
frontend/src/views/admin/BackupView.vue
Normal file
@@ -0,0 +1,505 @@
|
||||
<template>
|
||||
<div class="space-y-6">
|
||||
<!-- S3 Storage Config -->
|
||||
<div class="card p-6">
|
||||
<div class="mb-4 flex flex-wrap items-center justify-between gap-3">
|
||||
<div>
|
||||
<h3 class="text-base font-semibold text-gray-900 dark:text-white">
|
||||
{{ t('admin.backup.s3.title') }}
|
||||
</h3>
|
||||
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.backup.s3.descriptionPrefix') }}
|
||||
<button type="button" class="text-primary-600 underline hover:text-primary-700 dark:text-primary-400 dark:hover:text-primary-300" @click="showR2Guide = true">Cloudflare R2</button>
|
||||
{{ t('admin.backup.s3.descriptionSuffix') }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="grid grid-cols-1 gap-3 md:grid-cols-2">
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.s3.endpoint') }}</label>
|
||||
<input v-model="s3Form.endpoint" class="input w-full" placeholder="https://<account_id>.r2.cloudflarestorage.com" />
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.s3.region') }}</label>
|
||||
<input v-model="s3Form.region" class="input w-full" placeholder="auto" />
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.s3.bucket') }}</label>
|
||||
<input v-model="s3Form.bucket" class="input w-full" />
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.s3.prefix') }}</label>
|
||||
<input v-model="s3Form.prefix" class="input w-full" placeholder="backups/" />
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.s3.accessKeyId') }}</label>
|
||||
<input v-model="s3Form.access_key_id" class="input w-full" />
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.s3.secretAccessKey') }}</label>
|
||||
<input v-model="s3Form.secret_access_key" type="password" class="input w-full" :placeholder="s3SecretConfigured ? t('admin.backup.s3.secretConfigured') : ''" />
|
||||
</div>
|
||||
<label class="inline-flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300 md:col-span-2">
|
||||
<input v-model="s3Form.force_path_style" type="checkbox" />
|
||||
<span>{{ t('admin.backup.s3.forcePathStyle') }}</span>
|
||||
</label>
|
||||
</div>
|
||||
<div class="mt-4 flex flex-wrap gap-2">
|
||||
<button type="button" class="btn btn-secondary btn-sm" :disabled="testingS3" @click="testS3">
|
||||
{{ testingS3 ? t('common.loading') : t('admin.backup.s3.testConnection') }}
|
||||
</button>
|
||||
<button type="button" class="btn btn-primary btn-sm" :disabled="savingS3" @click="saveS3Config">
|
||||
{{ savingS3 ? t('common.loading') : t('common.save') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Schedule Config -->
|
||||
<div class="card p-6">
|
||||
<div class="mb-4">
|
||||
<h3 class="text-base font-semibold text-gray-900 dark:text-white">
|
||||
{{ t('admin.backup.schedule.title') }}
|
||||
</h3>
|
||||
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.backup.schedule.description') }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="grid grid-cols-1 gap-3 md:grid-cols-2">
|
||||
<label class="inline-flex items-center gap-2 text-sm text-gray-700 dark:text-gray-300 md:col-span-2">
|
||||
<input v-model="scheduleForm.enabled" type="checkbox" />
|
||||
<span>{{ t('admin.backup.schedule.enabled') }}</span>
|
||||
</label>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.schedule.cronExpr') }}</label>
|
||||
<input v-model="scheduleForm.cron_expr" class="input w-full" placeholder="0 2 * * *" />
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">{{ t('admin.backup.schedule.cronHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.schedule.retainDays') }}</label>
|
||||
<input v-model.number="scheduleForm.retain_days" type="number" min="0" class="input w-full" />
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">{{ t('admin.backup.schedule.retainDaysHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400">{{ t('admin.backup.schedule.retainCount') }}</label>
|
||||
<input v-model.number="scheduleForm.retain_count" type="number" min="0" class="input w-full" />
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">{{ t('admin.backup.schedule.retainCountHint') }}</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="mt-4">
|
||||
<button type="button" class="btn btn-primary btn-sm" :disabled="savingSchedule" @click="saveSchedule">
|
||||
{{ savingSchedule ? t('common.loading') : t('common.save') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Backup Operations -->
|
||||
<div class="card p-6">
|
||||
<div class="mb-4 flex flex-wrap items-center justify-between gap-3">
|
||||
<div>
|
||||
<h3 class="text-base font-semibold text-gray-900 dark:text-white">
|
||||
{{ t('admin.backup.operations.title') }}
|
||||
</h3>
|
||||
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.backup.operations.description') }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="flex flex-wrap items-center gap-2">
|
||||
<div class="flex items-center gap-1">
|
||||
<label class="text-xs text-gray-600 dark:text-gray-400">{{ t('admin.backup.operations.expireDays') }}</label>
|
||||
<input v-model.number="manualExpireDays" type="number" min="0" class="input w-20 text-xs" />
|
||||
</div>
|
||||
<button type="button" class="btn btn-primary btn-sm" :disabled="creatingBackup" @click="createBackup">
|
||||
{{ creatingBackup ? t('admin.backup.operations.backing') : t('admin.backup.operations.createBackup') }}
|
||||
</button>
|
||||
<button type="button" class="btn btn-secondary btn-sm" :disabled="loadingBackups" @click="loadBackups">
|
||||
{{ loadingBackups ? t('common.loading') : t('common.refresh') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="overflow-x-auto">
|
||||
<table class="w-full min-w-[800px] text-sm">
|
||||
<thead>
|
||||
<tr class="border-b border-gray-200 text-left text-xs uppercase tracking-wide text-gray-500 dark:border-dark-700 dark:text-gray-400">
|
||||
<th class="py-2 pr-4">ID</th>
|
||||
<th class="py-2 pr-4">{{ t('admin.backup.columns.status') }}</th>
|
||||
<th class="py-2 pr-4">{{ t('admin.backup.columns.fileName') }}</th>
|
||||
<th class="py-2 pr-4">{{ t('admin.backup.columns.size') }}</th>
|
||||
<th class="py-2 pr-4">{{ t('admin.backup.columns.expiresAt') }}</th>
|
||||
<th class="py-2 pr-4">{{ t('admin.backup.columns.triggeredBy') }}</th>
|
||||
<th class="py-2 pr-4">{{ t('admin.backup.columns.startedAt') }}</th>
|
||||
<th class="py-2">{{ t('admin.backup.columns.actions') }}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="record in backups" :key="record.id" class="border-b border-gray-100 align-top dark:border-dark-800">
|
||||
<td class="py-3 pr-4 font-mono text-xs">{{ record.id }}</td>
|
||||
<td class="py-3 pr-4">
|
||||
<span
|
||||
class="rounded px-2 py-0.5 text-xs"
|
||||
:class="statusClass(record.status)"
|
||||
>
|
||||
{{ t(`admin.backup.status.${record.status}`) }}
|
||||
</span>
|
||||
</td>
|
||||
<td class="py-3 pr-4 text-xs">{{ record.file_name }}</td>
|
||||
<td class="py-3 pr-4 text-xs">{{ formatSize(record.size_bytes) }}</td>
|
||||
<td class="py-3 pr-4 text-xs">
|
||||
{{ record.expires_at ? formatDate(record.expires_at) : t('admin.backup.neverExpire') }}
|
||||
</td>
|
||||
<td class="py-3 pr-4 text-xs">
|
||||
{{ record.triggered_by === 'scheduled' ? t('admin.backup.trigger.scheduled') : t('admin.backup.trigger.manual') }}
|
||||
</td>
|
||||
<td class="py-3 pr-4 text-xs">{{ formatDate(record.started_at) }}</td>
|
||||
<td class="py-3 text-xs">
|
||||
<div class="flex flex-wrap gap-1">
|
||||
<button
|
||||
v-if="record.status === 'completed'"
|
||||
type="button"
|
||||
class="btn btn-secondary btn-xs"
|
||||
@click="downloadBackup(record.id)"
|
||||
>
|
||||
{{ t('admin.backup.actions.download') }}
|
||||
</button>
|
||||
<button
|
||||
v-if="record.status === 'completed'"
|
||||
type="button"
|
||||
class="btn btn-secondary btn-xs"
|
||||
:disabled="restoringId === record.id"
|
||||
@click="restoreBackup(record.id)"
|
||||
>
|
||||
{{ restoringId === record.id ? t('common.loading') : t('admin.backup.actions.restore') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="btn btn-danger btn-xs"
|
||||
@click="removeBackup(record.id)"
|
||||
>
|
||||
{{ t('common.delete') }}
|
||||
</button>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
<tr v-if="backups.length === 0">
|
||||
<td colspan="8" class="py-6 text-center text-sm text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.backup.empty') }}
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Cloudflare R2 Setup Guide Modal -->
|
||||
<teleport to="body">
|
||||
<transition name="modal">
|
||||
<div v-if="showR2Guide" class="fixed inset-0 z-50 flex items-center justify-center p-4" @mousedown.self="showR2Guide = false">
|
||||
<div class="fixed inset-0 bg-black/50" @click="showR2Guide = false"></div>
|
||||
<div class="relative max-h-[85vh] w-full max-w-2xl overflow-y-auto rounded-xl bg-white p-6 shadow-2xl dark:bg-dark-800">
|
||||
<button type="button" class="absolute right-4 top-4 text-gray-400 hover:text-gray-600 dark:hover:text-gray-200" @click="showR2Guide = false">
|
||||
<svg class="h-5 w-5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2"><path stroke-linecap="round" stroke-linejoin="round" d="M6 18L18 6M6 6l12 12" /></svg>
|
||||
</button>
|
||||
|
||||
<h2 class="mb-4 text-lg font-bold text-gray-900 dark:text-white">{{ t('admin.backup.r2Guide.title') }}</h2>
|
||||
<p class="mb-4 text-sm text-gray-500 dark:text-gray-400">{{ t('admin.backup.r2Guide.intro') }}</p>
|
||||
|
||||
<!-- Step 1 -->
|
||||
<div class="mb-5">
|
||||
<h3 class="mb-2 flex items-center gap-2 text-sm font-semibold text-gray-900 dark:text-white">
|
||||
<span class="flex h-6 w-6 items-center justify-center rounded-full bg-primary-100 text-xs font-bold text-primary-700 dark:bg-primary-900/40 dark:text-primary-300">1</span>
|
||||
{{ t('admin.backup.r2Guide.step1.title') }}
|
||||
</h3>
|
||||
<ol class="ml-8 list-decimal space-y-1 text-sm text-gray-600 dark:text-gray-300">
|
||||
<li>{{ t('admin.backup.r2Guide.step1.line1') }}</li>
|
||||
<li>{{ t('admin.backup.r2Guide.step1.line2') }}</li>
|
||||
<li>{{ t('admin.backup.r2Guide.step1.line3') }}</li>
|
||||
</ol>
|
||||
</div>
|
||||
|
||||
<!-- Step 2 -->
|
||||
<div class="mb-5">
|
||||
<h3 class="mb-2 flex items-center gap-2 text-sm font-semibold text-gray-900 dark:text-white">
|
||||
<span class="flex h-6 w-6 items-center justify-center rounded-full bg-primary-100 text-xs font-bold text-primary-700 dark:bg-primary-900/40 dark:text-primary-300">2</span>
|
||||
{{ t('admin.backup.r2Guide.step2.title') }}
|
||||
</h3>
|
||||
<ol class="ml-8 list-decimal space-y-1 text-sm text-gray-600 dark:text-gray-300">
|
||||
<li>{{ t('admin.backup.r2Guide.step2.line1') }}</li>
|
||||
<li>{{ t('admin.backup.r2Guide.step2.line2') }}</li>
|
||||
<li>{{ t('admin.backup.r2Guide.step2.line3') }}</li>
|
||||
<li>{{ t('admin.backup.r2Guide.step2.line4') }}</li>
|
||||
</ol>
|
||||
<div class="mt-2 rounded-lg bg-amber-50 p-3 text-xs text-amber-700 dark:bg-amber-900/20 dark:text-amber-300">
|
||||
{{ t('admin.backup.r2Guide.step2.warning') }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Step 3 -->
|
||||
<div class="mb-5">
|
||||
<h3 class="mb-2 flex items-center gap-2 text-sm font-semibold text-gray-900 dark:text-white">
|
||||
<span class="flex h-6 w-6 items-center justify-center rounded-full bg-primary-100 text-xs font-bold text-primary-700 dark:bg-primary-900/40 dark:text-primary-300">3</span>
|
||||
{{ t('admin.backup.r2Guide.step3.title') }}
|
||||
</h3>
|
||||
<p class="ml-8 text-sm text-gray-600 dark:text-gray-300">{{ t('admin.backup.r2Guide.step3.desc') }}</p>
|
||||
<code class="ml-8 mt-1 block rounded bg-gray-100 px-3 py-2 text-xs text-gray-800 dark:bg-dark-700 dark:text-gray-200">https://<{{ t('admin.backup.r2Guide.step3.accountId') }}>.r2.cloudflarestorage.com</code>
|
||||
</div>
|
||||
|
||||
<!-- Step 4: Fill form -->
|
||||
<div class="mb-5">
|
||||
<h3 class="mb-2 flex items-center gap-2 text-sm font-semibold text-gray-900 dark:text-white">
|
||||
<span class="flex h-6 w-6 items-center justify-center rounded-full bg-primary-100 text-xs font-bold text-primary-700 dark:bg-primary-900/40 dark:text-primary-300">4</span>
|
||||
{{ t('admin.backup.r2Guide.step4.title') }}
|
||||
</h3>
|
||||
<div class="ml-8 overflow-hidden rounded-lg border border-gray-200 dark:border-dark-600">
|
||||
<table class="w-full text-sm">
|
||||
<tbody>
|
||||
<tr v-for="(row, i) in r2ConfigRows" :key="i" class="border-b border-gray-100 dark:border-dark-700 last:border-0">
|
||||
<td class="whitespace-nowrap bg-gray-50 px-3 py-2 font-medium text-gray-700 dark:bg-dark-700 dark:text-gray-300">{{ row.field }}</td>
|
||||
<td class="px-3 py-2 text-gray-600 dark:text-gray-400"><code class="text-xs">{{ row.value }}</code></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Free tier note -->
|
||||
<div class="rounded-lg bg-green-50 p-3 text-xs text-green-700 dark:bg-green-900/20 dark:text-green-300">
|
||||
{{ t('admin.backup.r2Guide.freeTier') }}
|
||||
</div>
|
||||
|
||||
<div class="mt-4 text-right">
|
||||
<button type="button" class="btn btn-primary btn-sm" @click="showR2Guide = false">{{ t('common.close') }}</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</transition>
|
||||
</teleport>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, onMounted, ref } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { adminAPI } from '@/api'
|
||||
import { useAppStore } from '@/stores'
|
||||
import type { BackupS3Config, BackupScheduleConfig, BackupRecord } from '@/api/admin/backup'
|
||||
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
|
||||
// S3 config
|
||||
const s3Form = ref<BackupS3Config>({
|
||||
endpoint: '',
|
||||
region: 'auto',
|
||||
bucket: '',
|
||||
access_key_id: '',
|
||||
secret_access_key: '',
|
||||
prefix: 'backups/',
|
||||
force_path_style: false,
|
||||
})
|
||||
const s3SecretConfigured = ref(false)
|
||||
const savingS3 = ref(false)
|
||||
const testingS3 = ref(false)
|
||||
|
||||
// Schedule config
|
||||
const scheduleForm = ref<BackupScheduleConfig>({
|
||||
enabled: false,
|
||||
cron_expr: '0 2 * * *',
|
||||
retain_days: 14,
|
||||
retain_count: 10,
|
||||
})
|
||||
const savingSchedule = ref(false)
|
||||
|
||||
// Backups
|
||||
const backups = ref<BackupRecord[]>([])
|
||||
const loadingBackups = ref(false)
|
||||
const creatingBackup = ref(false)
|
||||
const restoringId = ref('')
|
||||
const manualExpireDays = ref(14)
|
||||
|
||||
// R2 guide
|
||||
const showR2Guide = ref(false)
|
||||
const r2ConfigRows = computed(() => [
|
||||
{ field: t('admin.backup.s3.endpoint'), value: 'https://<account_id>.r2.cloudflarestorage.com' },
|
||||
{ field: t('admin.backup.s3.region'), value: 'auto' },
|
||||
{ field: t('admin.backup.s3.bucket'), value: t('admin.backup.r2Guide.step4.bucketValue') },
|
||||
{ field: t('admin.backup.s3.prefix'), value: 'backups/' },
|
||||
{ field: 'Access Key ID', value: t('admin.backup.r2Guide.step4.fromStep2') },
|
||||
{ field: 'Secret Access Key', value: t('admin.backup.r2Guide.step4.fromStep2') },
|
||||
{ field: t('admin.backup.s3.forcePathStyle'), value: t('admin.backup.r2Guide.step4.unchecked') },
|
||||
])
|
||||
|
||||
async function loadS3Config() {
|
||||
try {
|
||||
const cfg = await adminAPI.backup.getS3Config()
|
||||
s3Form.value = {
|
||||
endpoint: cfg.endpoint || '',
|
||||
region: cfg.region || 'auto',
|
||||
bucket: cfg.bucket || '',
|
||||
access_key_id: cfg.access_key_id || '',
|
||||
secret_access_key: '',
|
||||
prefix: cfg.prefix || 'backups/',
|
||||
force_path_style: cfg.force_path_style,
|
||||
}
|
||||
s3SecretConfigured.value = Boolean(cfg.access_key_id)
|
||||
} catch (error) {
|
||||
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
|
||||
}
|
||||
}
|
||||
|
||||
async function saveS3Config() {
|
||||
savingS3.value = true
|
||||
try {
|
||||
await adminAPI.backup.updateS3Config(s3Form.value)
|
||||
appStore.showSuccess(t('admin.backup.s3.saved'))
|
||||
await loadS3Config()
|
||||
} catch (error) {
|
||||
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
|
||||
} finally {
|
||||
savingS3.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function testS3() {
|
||||
testingS3.value = true
|
||||
try {
|
||||
const result = await adminAPI.backup.testS3Connection(s3Form.value)
|
||||
if (result.ok) {
|
||||
appStore.showSuccess(result.message || t('admin.backup.s3.testSuccess'))
|
||||
} else {
|
||||
appStore.showError(result.message || t('admin.backup.s3.testFailed'))
|
||||
}
|
||||
} catch (error) {
|
||||
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
|
||||
} finally {
|
||||
testingS3.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function loadSchedule() {
|
||||
try {
|
||||
const cfg = await adminAPI.backup.getSchedule()
|
||||
scheduleForm.value = {
|
||||
enabled: cfg.enabled,
|
||||
cron_expr: cfg.cron_expr || '0 2 * * *',
|
||||
retain_days: cfg.retain_days || 14,
|
||||
retain_count: cfg.retain_count || 10,
|
||||
}
|
||||
} catch (error) {
|
||||
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
|
||||
}
|
||||
}
|
||||
|
||||
async function saveSchedule() {
|
||||
savingSchedule.value = true
|
||||
try {
|
||||
await adminAPI.backup.updateSchedule(scheduleForm.value)
|
||||
appStore.showSuccess(t('admin.backup.schedule.saved'))
|
||||
} catch (error) {
|
||||
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
|
||||
} finally {
|
||||
savingSchedule.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function loadBackups() {
|
||||
loadingBackups.value = true
|
||||
try {
|
||||
const result = await adminAPI.backup.listBackups()
|
||||
backups.value = result.items || []
|
||||
} catch (error) {
|
||||
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
|
||||
} finally {
|
||||
loadingBackups.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function createBackup() {
|
||||
creatingBackup.value = true
|
||||
try {
|
||||
await adminAPI.backup.createBackup({ expire_days: manualExpireDays.value })
|
||||
appStore.showSuccess(t('admin.backup.operations.backupCreated'))
|
||||
await loadBackups()
|
||||
} catch (error) {
|
||||
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
|
||||
} finally {
|
||||
creatingBackup.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function downloadBackup(id: string) {
|
||||
try {
|
||||
const result = await adminAPI.backup.getDownloadURL(id)
|
||||
window.open(result.url, '_blank')
|
||||
} catch (error) {
|
||||
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
|
||||
}
|
||||
}
|
||||
|
||||
async function restoreBackup(id: string) {
|
||||
if (!window.confirm(t('admin.backup.actions.restoreConfirm'))) return
|
||||
const password = window.prompt(t('admin.backup.actions.restorePasswordPrompt'))
|
||||
if (!password) return
|
||||
restoringId.value = id
|
||||
try {
|
||||
await adminAPI.backup.restoreBackup(id, password)
|
||||
appStore.showSuccess(t('admin.backup.actions.restoreSuccess'))
|
||||
} catch (error) {
|
||||
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
|
||||
} finally {
|
||||
restoringId.value = ''
|
||||
}
|
||||
}
|
||||
|
||||
async function removeBackup(id: string) {
|
||||
if (!window.confirm(t('admin.backup.actions.deleteConfirm'))) return
|
||||
try {
|
||||
await adminAPI.backup.deleteBackup(id)
|
||||
appStore.showSuccess(t('admin.backup.actions.deleted'))
|
||||
await loadBackups()
|
||||
} catch (error) {
|
||||
appStore.showError((error as { message?: string })?.message || t('errors.networkError'))
|
||||
}
|
||||
}
|
||||
|
||||
function statusClass(status: string): string {
|
||||
switch (status) {
|
||||
case 'completed':
|
||||
return 'bg-green-100 text-green-700 dark:bg-green-900/30 dark:text-green-300'
|
||||
case 'running':
|
||||
return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300'
|
||||
case 'failed':
|
||||
return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-300'
|
||||
default:
|
||||
return 'bg-gray-100 text-gray-700 dark:bg-dark-800 dark:text-gray-300'
|
||||
}
|
||||
}
|
||||
|
||||
function formatSize(bytes: number): string {
|
||||
if (!bytes || bytes <= 0) return '-'
|
||||
if (bytes < 1024) return `${bytes} B`
|
||||
if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(1)} KB`
|
||||
return `${(bytes / (1024 * 1024)).toFixed(1)} MB`
|
||||
}
|
||||
|
||||
function formatDate(value?: string): string {
|
||||
if (!value) return '-'
|
||||
const date = new Date(value)
|
||||
if (Number.isNaN(date.getTime())) return value
|
||||
return date.toLocaleString()
|
||||
}
|
||||
|
||||
onMounted(async () => {
|
||||
await Promise.all([loadS3Config(), loadSchedule(), loadBackups()])
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.modal-enter-active,
|
||||
.modal-leave-active {
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
.modal-enter-from,
|
||||
.modal-leave-to {
|
||||
opacity: 0;
|
||||
}
|
||||
</style>
|
||||
@@ -1,5 +1,4 @@
|
||||
<template>
|
||||
<AppLayout>
|
||||
<div class="space-y-6">
|
||||
<div class="card p-6">
|
||||
<div class="mb-4 flex flex-wrap items-center justify-between gap-3">
|
||||
@@ -183,13 +182,11 @@
|
||||
</div>
|
||||
</Transition>
|
||||
</Teleport>
|
||||
</AppLayout>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { onMounted, ref } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import AppLayout from '@/components/layout/AppLayout.vue'
|
||||
import type { SoraS3Profile } from '@/api/admin/settings'
|
||||
import { adminAPI } from '@/api'
|
||||
import { useAppStore } from '@/stores'
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
<!-- Settings Form -->
|
||||
<form v-else @submit.prevent="saveSettings" class="space-y-6">
|
||||
<!-- Tab Navigation -->
|
||||
<div class="sticky top-0 z-10 overflow-x-auto scrollbar-hide">
|
||||
<div class="sticky top-0 z-10 overflow-x-auto settings-tabs-scroll">
|
||||
<nav class="settings-tabs">
|
||||
<button
|
||||
v-for="tab in settingsTabs"
|
||||
@@ -1070,6 +1070,21 @@
|
||||
</p>
|
||||
</div>
|
||||
<div class="space-y-6 p-6">
|
||||
<!-- Backend Mode -->
|
||||
<div
|
||||
class="flex items-center justify-between rounded-lg border border-amber-200 bg-amber-50 p-4 dark:border-amber-800 dark:bg-amber-900/20"
|
||||
>
|
||||
<div>
|
||||
<h3 class="text-sm font-medium text-gray-900 dark:text-white">
|
||||
{{ t('admin.settings.site.backendMode') }}
|
||||
</h3>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.settings.site.backendModeDescription') }}
|
||||
</p>
|
||||
</div>
|
||||
<Toggle v-model="form.backend_mode_enabled" />
|
||||
</div>
|
||||
|
||||
<div class="grid grid-cols-1 gap-6 md:grid-cols-2">
|
||||
<div>
|
||||
<label class="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
@@ -1634,8 +1649,18 @@
|
||||
</div>
|
||||
</div><!-- /Tab: Email -->
|
||||
|
||||
<!-- Tab: Backup -->
|
||||
<div v-show="activeTab === 'backup'">
|
||||
<BackupSettings />
|
||||
</div>
|
||||
|
||||
<!-- Tab: Data Management -->
|
||||
<div v-show="activeTab === 'data'">
|
||||
<DataManagementSettings />
|
||||
</div>
|
||||
|
||||
<!-- Save Button -->
|
||||
<div class="flex justify-end">
|
||||
<div v-show="activeTab !== 'backup' && activeTab !== 'data'" class="flex justify-end">
|
||||
<button type="submit" :disabled="saving" class="btn btn-primary">
|
||||
<svg v-if="saving" class="h-4 w-4 animate-spin" fill="none" viewBox="0 0 24 24">
|
||||
<circle
|
||||
@@ -1677,6 +1702,8 @@ import GroupBadge from '@/components/common/GroupBadge.vue'
|
||||
import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
|
||||
import Toggle from '@/components/common/Toggle.vue'
|
||||
import ImageUpload from '@/components/common/ImageUpload.vue'
|
||||
import BackupSettings from '@/views/admin/BackupView.vue'
|
||||
import DataManagementSettings from '@/views/admin/DataManagementView.vue'
|
||||
import { useClipboard } from '@/composables/useClipboard'
|
||||
import { useAppStore } from '@/stores'
|
||||
import { useAdminSettingsStore } from '@/stores/adminSettings'
|
||||
@@ -1691,7 +1718,7 @@ const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
const adminSettingsStore = useAdminSettingsStore()
|
||||
|
||||
type SettingsTab = 'general' | 'security' | 'users' | 'gateway' | 'email'
|
||||
type SettingsTab = 'general' | 'security' | 'users' | 'gateway' | 'email' | 'backup' | 'data'
|
||||
const activeTab = ref<SettingsTab>('general')
|
||||
const settingsTabs = [
|
||||
{ key: 'general' as SettingsTab, icon: 'home' as const },
|
||||
@@ -1699,6 +1726,8 @@ const settingsTabs = [
|
||||
{ key: 'users' as SettingsTab, icon: 'user' as const },
|
||||
{ key: 'gateway' as SettingsTab, icon: 'server' as const },
|
||||
{ key: 'email' as SettingsTab, icon: 'mail' as const },
|
||||
{ key: 'backup' as SettingsTab, icon: 'database' as const },
|
||||
{ key: 'data' as SettingsTab, icon: 'cube' as const },
|
||||
]
|
||||
const { copyToClipboard } = useClipboard()
|
||||
|
||||
@@ -1745,7 +1774,7 @@ const betaPolicyForm = reactive({
|
||||
rules: [] as Array<{
|
||||
beta_token: string
|
||||
action: 'pass' | 'filter' | 'block'
|
||||
scope: 'all' | 'oauth' | 'apikey'
|
||||
scope: 'all' | 'oauth' | 'apikey' | 'bedrock'
|
||||
error_message?: string
|
||||
}>
|
||||
})
|
||||
@@ -1785,6 +1814,7 @@ const form = reactive<SettingsForm>({
|
||||
contact_info: '',
|
||||
doc_url: '',
|
||||
home_content: '',
|
||||
backend_mode_enabled: false,
|
||||
hide_ccs_import_button: false,
|
||||
purchase_subscription_enabled: false,
|
||||
purchase_subscription_url: '',
|
||||
@@ -1962,6 +1992,7 @@ async function loadSettings() {
|
||||
try {
|
||||
const settings = await adminAPI.settings.getSettings()
|
||||
Object.assign(form, settings)
|
||||
form.backend_mode_enabled = settings.backend_mode_enabled
|
||||
form.default_subscriptions = Array.isArray(settings.default_subscriptions)
|
||||
? settings.default_subscriptions
|
||||
.filter((item) => item.group_id > 0 && item.validity_days > 0)
|
||||
@@ -2060,6 +2091,7 @@ async function saveSettings() {
|
||||
contact_info: form.contact_info,
|
||||
doc_url: form.doc_url,
|
||||
home_content: form.home_content,
|
||||
backend_mode_enabled: form.backend_mode_enabled,
|
||||
hide_ccs_import_button: form.hide_ccs_import_button,
|
||||
purchase_subscription_enabled: form.purchase_subscription_enabled,
|
||||
purchase_subscription_url: form.purchase_subscription_url,
|
||||
@@ -2297,7 +2329,8 @@ const betaPolicyActionOptions = computed(() => [
|
||||
const betaPolicyScopeOptions = computed(() => [
|
||||
{ value: 'all', label: t('admin.settings.betaPolicy.scopeAll') },
|
||||
{ value: 'oauth', label: t('admin.settings.betaPolicy.scopeOAuth') },
|
||||
{ value: 'apikey', label: t('admin.settings.betaPolicy.scopeAPIKey') }
|
||||
{ value: 'apikey', label: t('admin.settings.betaPolicy.scopeAPIKey') },
|
||||
{ value: 'bedrock', label: t('admin.settings.betaPolicy.scopeBedrock') }
|
||||
])
|
||||
|
||||
// Beta Policy 方法
|
||||
@@ -2359,9 +2392,38 @@ onMounted(() => {
|
||||
}
|
||||
|
||||
/* ============ Settings Tab Navigation ============ */
|
||||
|
||||
/* Scroll container: thin scrollbar on PC, auto-hide on mobile */
|
||||
.settings-tabs-scroll {
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: transparent transparent;
|
||||
}
|
||||
.settings-tabs-scroll:hover {
|
||||
scrollbar-color: rgb(0 0 0 / 0.15) transparent;
|
||||
}
|
||||
:root.dark .settings-tabs-scroll:hover {
|
||||
scrollbar-color: rgb(255 255 255 / 0.2) transparent;
|
||||
}
|
||||
.settings-tabs-scroll::-webkit-scrollbar {
|
||||
height: 3px;
|
||||
}
|
||||
.settings-tabs-scroll::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
.settings-tabs-scroll::-webkit-scrollbar-thumb {
|
||||
background: transparent;
|
||||
border-radius: 3px;
|
||||
}
|
||||
.settings-tabs-scroll:hover::-webkit-scrollbar-thumb {
|
||||
background: rgb(0 0 0 / 0.15);
|
||||
}
|
||||
:root.dark .settings-tabs-scroll:hover::-webkit-scrollbar-thumb {
|
||||
background: rgb(255 255 255 / 0.2);
|
||||
}
|
||||
|
||||
.settings-tabs {
|
||||
@apply inline-flex min-w-full gap-1 rounded-2xl
|
||||
border border-gray-100 bg-white/80 p-1.5 backdrop-blur-sm
|
||||
@apply inline-flex min-w-full gap-0.5 rounded-2xl
|
||||
border border-gray-100 bg-white/80 p-1 backdrop-blur-sm
|
||||
dark:border-dark-700/50 dark:bg-dark-800/80;
|
||||
box-shadow: 0 1px 3px rgb(0 0 0 / 0.04), 0 1px 2px rgb(0 0 0 / 0.02);
|
||||
}
|
||||
@@ -2373,8 +2435,8 @@ onMounted(() => {
|
||||
}
|
||||
|
||||
.settings-tab {
|
||||
@apply relative flex flex-1 items-center justify-center gap-2
|
||||
whitespace-nowrap rounded-xl px-4 py-2.5
|
||||
@apply relative flex flex-1 items-center justify-center gap-1.5
|
||||
whitespace-nowrap rounded-xl px-2.5 py-2
|
||||
text-sm font-medium
|
||||
text-gray-500 dark:text-dark-400
|
||||
transition-all duration-200 ease-out;
|
||||
@@ -2401,7 +2463,7 @@ onMounted(() => {
|
||||
}
|
||||
|
||||
.settings-tab-icon {
|
||||
@apply flex h-7 w-7 items-center justify-center rounded-lg
|
||||
@apply flex h-6 w-6 items-center justify-center rounded-lg
|
||||
transition-all duration-200;
|
||||
}
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@
|
||||
</div>
|
||||
|
||||
<!-- Row: OpenAI Token Stats -->
|
||||
<div v-if="opsEnabled && !(loading && !hasLoadedOnce)" class="grid grid-cols-1 gap-6">
|
||||
<div v-if="opsEnabled && showOpenAITokenStats && !(loading && !hasLoadedOnce)" class="grid grid-cols-1 gap-6">
|
||||
<OpsOpenAITokenStatsCard
|
||||
:platform-filter="platform"
|
||||
:group-id-filter="groupId"
|
||||
@@ -94,7 +94,7 @@
|
||||
</div>
|
||||
|
||||
<!-- Alert Events -->
|
||||
<OpsAlertEventsCard v-if="opsEnabled && !(loading && !hasLoadedOnce)" />
|
||||
<OpsAlertEventsCard v-if="opsEnabled && showAlertEvents && !(loading && !hasLoadedOnce)" />
|
||||
|
||||
<!-- System Logs -->
|
||||
<OpsSystemLogTable
|
||||
@@ -381,6 +381,8 @@ const showSettingsDialog = ref(false)
|
||||
const showAlertRulesCard = ref(false)
|
||||
|
||||
// Auto refresh settings
|
||||
const showAlertEvents = ref(true)
|
||||
const showOpenAITokenStats = ref(false)
|
||||
const autoRefreshEnabled = ref(false)
|
||||
const autoRefreshIntervalMs = ref(30000) // default 30 seconds
|
||||
const autoRefreshCountdown = ref(0)
|
||||
@@ -408,15 +410,22 @@ const { pause: pauseCountdown, resume: resumeCountdown } = useIntervalFn(
|
||||
{ immediate: false }
|
||||
)
|
||||
|
||||
// Load auto refresh settings from backend
|
||||
async function loadAutoRefreshSettings() {
|
||||
// Load ops dashboard presentation settings from backend.
|
||||
async function loadDashboardAdvancedSettings() {
|
||||
try {
|
||||
const settings = await opsAPI.getAdvancedSettings()
|
||||
showAlertEvents.value = settings.display_alert_events
|
||||
showOpenAITokenStats.value = settings.display_openai_token_stats
|
||||
autoRefreshEnabled.value = settings.auto_refresh_enabled
|
||||
autoRefreshIntervalMs.value = settings.auto_refresh_interval_seconds * 1000
|
||||
autoRefreshCountdown.value = settings.auto_refresh_interval_seconds
|
||||
} catch (err) {
|
||||
console.error('[OpsDashboard] Failed to load auto refresh settings', err)
|
||||
console.error('[OpsDashboard] Failed to load dashboard advanced settings', err)
|
||||
showAlertEvents.value = true
|
||||
showOpenAITokenStats.value = false
|
||||
autoRefreshEnabled.value = false
|
||||
autoRefreshIntervalMs.value = 30000
|
||||
autoRefreshCountdown.value = 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -464,7 +473,8 @@ function onCustomTimeRangeChange(startTime: string, endTime: string) {
|
||||
customEndTime.value = endTime
|
||||
}
|
||||
|
||||
function onSettingsSaved() {
|
||||
async function onSettingsSaved() {
|
||||
await loadDashboardAdvancedSettings()
|
||||
loadThresholds()
|
||||
fetchData()
|
||||
}
|
||||
@@ -774,7 +784,7 @@ onMounted(async () => {
|
||||
loadThresholds()
|
||||
|
||||
// Load auto refresh settings
|
||||
await loadAutoRefreshSettings()
|
||||
await loadDashboardAdvancedSettings()
|
||||
|
||||
if (opsEnabled.value) {
|
||||
await fetchData()
|
||||
@@ -816,7 +826,7 @@ watch(autoRefreshEnabled, (enabled) => {
|
||||
// Reload auto refresh settings after settings dialog is closed
|
||||
watch(showSettingsDialog, async (show) => {
|
||||
if (!show) {
|
||||
await loadAutoRefreshSettings()
|
||||
await loadDashboardAdvancedSettings()
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
@@ -208,35 +208,39 @@ function onNextPage() {
|
||||
:description="t('admin.ops.openaiTokenStats.empty')"
|
||||
/>
|
||||
|
||||
<div v-else class="overflow-x-auto">
|
||||
<table class="min-w-full text-left text-xs md:text-sm">
|
||||
<thead>
|
||||
<tr class="border-b border-gray-200 text-gray-500 dark:border-dark-700 dark:text-gray-400">
|
||||
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.model') }}</th>
|
||||
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.requestCount') }}</th>
|
||||
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.avgTokensPerSec') }}</th>
|
||||
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.avgFirstTokenMs') }}</th>
|
||||
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.totalOutputTokens') }}</th>
|
||||
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.avgDurationMs') }}</th>
|
||||
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.requestsWithFirstToken') }}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr
|
||||
v-for="row in items"
|
||||
:key="row.model"
|
||||
class="border-b border-gray-100 text-gray-700 dark:border-dark-800 dark:text-gray-200"
|
||||
>
|
||||
<td class="px-2 py-2 font-medium">{{ row.model }}</td>
|
||||
<td class="px-2 py-2">{{ formatInt(row.request_count) }}</td>
|
||||
<td class="px-2 py-2">{{ formatRate(row.avg_tokens_per_sec) }}</td>
|
||||
<td class="px-2 py-2">{{ formatRate(row.avg_first_token_ms) }}</td>
|
||||
<td class="px-2 py-2">{{ formatInt(row.total_output_tokens) }}</td>
|
||||
<td class="px-2 py-2">{{ formatInt(row.avg_duration_ms) }}</td>
|
||||
<td class="px-2 py-2">{{ formatInt(row.requests_with_first_token) }}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
<div v-else class="space-y-3">
|
||||
<div class="overflow-hidden rounded-xl border border-gray-200 dark:border-dark-700">
|
||||
<div class="max-h-[420px] overflow-auto">
|
||||
<table class="min-w-full text-left text-xs md:text-sm">
|
||||
<thead class="sticky top-0 z-10 bg-white dark:bg-dark-800">
|
||||
<tr class="border-b border-gray-200 text-gray-500 dark:border-dark-700 dark:text-gray-400">
|
||||
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.model') }}</th>
|
||||
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.requestCount') }}</th>
|
||||
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.avgTokensPerSec') }}</th>
|
||||
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.avgFirstTokenMs') }}</th>
|
||||
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.totalOutputTokens') }}</th>
|
||||
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.avgDurationMs') }}</th>
|
||||
<th class="px-2 py-2 font-semibold">{{ t('admin.ops.openaiTokenStats.table.requestsWithFirstToken') }}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr
|
||||
v-for="row in items"
|
||||
:key="row.model"
|
||||
class="border-b border-gray-100 text-gray-700 last:border-b-0 dark:border-dark-800 dark:text-gray-200"
|
||||
>
|
||||
<td class="px-2 py-2 font-medium">{{ row.model }}</td>
|
||||
<td class="px-2 py-2">{{ formatInt(row.request_count) }}</td>
|
||||
<td class="px-2 py-2">{{ formatRate(row.avg_tokens_per_sec) }}</td>
|
||||
<td class="px-2 py-2">{{ formatRate(row.avg_first_token_ms) }}</td>
|
||||
<td class="px-2 py-2">{{ formatInt(row.total_output_tokens) }}</td>
|
||||
<td class="px-2 py-2">{{ formatInt(row.avg_duration_ms) }}</td>
|
||||
<td class="px-2 py-2">{{ formatInt(row.requests_with_first_token) }}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="viewMode === 'topn'" class="mt-3 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.ops.openaiTokenStats.totalModels', { total }) }}
|
||||
</div>
|
||||
|
||||
@@ -131,15 +131,7 @@ const validation = computed(() => {
|
||||
}
|
||||
}
|
||||
|
||||
// 验证邮件配置
|
||||
if (emailConfig.value) {
|
||||
if (emailConfig.value.alert.enabled && emailConfig.value.alert.recipients.length === 0) {
|
||||
errors.push(t('admin.ops.email.validation.alertRecipientsRequired'))
|
||||
}
|
||||
if (emailConfig.value.report.enabled && emailConfig.value.report.recipients.length === 0) {
|
||||
errors.push(t('admin.ops.email.validation.reportRecipientsRequired'))
|
||||
}
|
||||
}
|
||||
// 邮件配置: 启用但无收件人时不阻断保存, 保存时会自动禁用
|
||||
|
||||
// 验证高级设置
|
||||
if (advancedSettings.value) {
|
||||
@@ -181,6 +173,15 @@ async function saveAllSettings() {
|
||||
|
||||
saving.value = true
|
||||
try {
|
||||
// 无收件人时自动禁用邮件通知
|
||||
if (emailConfig.value) {
|
||||
if (emailConfig.value.alert.enabled && emailConfig.value.alert.recipients.length === 0) {
|
||||
emailConfig.value.alert.enabled = false
|
||||
}
|
||||
if (emailConfig.value.report.enabled && emailConfig.value.report.recipients.length === 0) {
|
||||
emailConfig.value.report.enabled = false
|
||||
}
|
||||
}
|
||||
await Promise.all([
|
||||
runtimeSettings.value ? opsAPI.updateAlertRuntimeSettings(runtimeSettings.value) : Promise.resolve(),
|
||||
emailConfig.value ? opsAPI.updateEmailNotificationConfig(emailConfig.value) : Promise.resolve(),
|
||||
@@ -543,6 +544,31 @@ async function saveAllSettings() {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Dashboard Cards -->
|
||||
<div class="space-y-3">
|
||||
<h5 class="text-xs font-semibold text-gray-700 dark:text-gray-300">{{ t('admin.ops.settings.dashboardCards') }}</h5>
|
||||
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">{{ t('admin.ops.settings.displayAlertEvents') }}</label>
|
||||
<p class="mt-1 text-xs text-gray-500">
|
||||
{{ t('admin.ops.settings.displayAlertEventsHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<Toggle v-model="advancedSettings.display_alert_events" />
|
||||
</div>
|
||||
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">{{ t('admin.ops.settings.displayOpenAITokenStats') }}</label>
|
||||
<p class="mt-1 text-xs text-gray-500">
|
||||
{{ t('admin.ops.settings.displayOpenAITokenStatsHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<Toggle v-model="advancedSettings.display_openai_token_stats" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</details>
|
||||
</div>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user