mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-04 15:32:13 +08:00
Compare commits
57 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d3a9f5bb88 | ||
|
|
7eb0415a8a | ||
|
|
bdbc8fa08f | ||
|
|
63f3af0f94 | ||
|
|
686f890fbf | ||
|
|
220fbe6544 | ||
|
|
ae44a94325 | ||
|
|
3718d6dcd4 | ||
|
|
90b3838173 | ||
|
|
19d3ecc76f | ||
|
|
6fba4ebb13 | ||
|
|
c31974c913 | ||
|
|
6177fa5dd8 | ||
|
|
cfe72159d0 | ||
|
|
8321e4a647 | ||
|
|
3084330d0c | ||
|
|
b566649e79 | ||
|
|
10a6180e4a | ||
|
|
cbe9e78977 | ||
|
|
74145b1f39 | ||
|
|
359e56751b | ||
|
|
5899784aa4 | ||
|
|
9e8959c56d | ||
|
|
1bff2292a6 | ||
|
|
cf9247754e | ||
|
|
eefab15958 | ||
|
|
0e23732631 | ||
|
|
37c044fb4b | ||
|
|
6da5fa01b9 | ||
|
|
616930f9d3 | ||
|
|
b9c31fa7c4 | ||
|
|
17b339972c | ||
|
|
39f8bd91b9 | ||
|
|
aa4e37d085 | ||
|
|
f59b66b7d4 | ||
|
|
8f0ea7a02d | ||
|
|
a1dc00890e | ||
|
|
dfbcc363d1 | ||
|
|
1047f973d5 | ||
|
|
e32977dd73 | ||
|
|
b5f78ec1e8 | ||
|
|
e0f290fdc8 | ||
|
|
fc00a4e3b2 | ||
|
|
db1f6ded88 | ||
|
|
4644af2ccc | ||
|
|
e4a4dfd038 | ||
|
|
2666422b99 | ||
|
|
45456fa24c | ||
|
|
6344fa2a86 | ||
|
|
29b0e4a8a5 | ||
|
|
f7177be3b6 | ||
|
|
875b417fde | ||
|
|
2573107b32 | ||
|
|
5b85005945 | ||
|
|
53ad1645cf | ||
|
|
af9c4a7dd0 | ||
|
|
6826149a8f |
20
Dockerfile
20
Dockerfile
@@ -9,6 +9,7 @@
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
|
||||
@@ -73,7 +74,12 @@ RUN VERSION_VALUE="${VERSION}" && \
|
||||
./cmd/server
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 3: Final Runtime Image
|
||||
# Stage 3: PostgreSQL Client (version-matched with docker-compose)
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 4: Final Runtime Image
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${ALPINE_IMAGE}
|
||||
|
||||
@@ -86,8 +92,20 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||
RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Copy pg_dump and psql from the same postgres image used in docker-compose
|
||||
# This ensures version consistency between backup tools and the database server
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
|
||||
@@ -5,7 +5,12 @@
|
||||
# It only packages the pre-built binary, no compilation needed.
|
||||
# =============================================================================
|
||||
|
||||
FROM alpine:3.19
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
FROM ${ALPINE_IMAGE}
|
||||
|
||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||
LABEL description="Sub2API - AI API Gateway Platform"
|
||||
@@ -16,8 +21,20 @@ RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
curl \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
|
||||
# Copy pg_dump and psql from a version-matched PostgreSQL image so backup and
|
||||
# restore work in the runtime container without requiring Docker socket access.
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
|
||||
@@ -94,6 +94,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -230,6 +231,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -146,6 +146,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
|
||||
dataManagementService := service.NewDataManagementService()
|
||||
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
|
||||
backupObjectStoreFactory := repository.NewS3BackupStoreFactory()
|
||||
dbDumper := repository.NewPgDumper(configConfig)
|
||||
backupService := service.ProvideBackupService(settingRepository, configConfig, secretEncryptor, backupObjectStoreFactory, dbDumper)
|
||||
backupHandler := admin.NewBackupHandler(backupService, userService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
|
||||
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
|
||||
@@ -201,7 +205,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler, scheduledTestHandler)
|
||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
|
||||
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
|
||||
@@ -232,7 +236,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -285,6 +289,7 @@ func provideCleanup(
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
openAIGateway *service.OpenAIGatewayService,
|
||||
scheduledTestRunner *service.ScheduledTestRunnerService,
|
||||
backupSvc *service.BackupService,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -420,6 +425,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"BackupService", func() error {
|
||||
if backupSvc != nil {
|
||||
backupSvc.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
}
|
||||
|
||||
infraSteps := []cleanupStep{
|
||||
|
||||
@@ -75,6 +75,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
antigravityOAuthSvc,
|
||||
nil, // openAIGateway
|
||||
nil, // scheduledTestRunner
|
||||
nil, // backupSvc
|
||||
)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
|
||||
@@ -27,12 +27,11 @@ const (
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock)
|
||||
AccountTypeBedrockAPIKey = "bedrock-apikey" // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock)
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
|
||||
@@ -97,7 +97,7 @@ type CreateAccountRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
|
||||
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials" binding:"required"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
@@ -116,7 +116,7 @@ type CreateAccountRequest struct {
|
||||
type UpdateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
|
||||
204
backend/internal/handler/admin/backup_handler.go
Normal file
204
backend/internal/handler/admin/backup_handler.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type BackupHandler struct {
|
||||
backupService *service.BackupService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
|
||||
return &BackupHandler{
|
||||
backupService: backupService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// ─── S3 配置 ───
|
||||
|
||||
func (h *BackupHandler) GetS3Config(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetS3Config(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
|
||||
var req service.BackupS3Config
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
err := h.backupService.TestS3Connection(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.Success(c, gin.H{"ok": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
|
||||
}
|
||||
|
||||
// ─── 定时备份 ───
|
||||
|
||||
func (h *BackupHandler) GetSchedule(c *gin.Context) {
|
||||
cfg, err := h.backupService.GetSchedule(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
|
||||
var req service.BackupScheduleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
// ─── 备份操作 ───
|
||||
|
||||
type CreateBackupRequest struct {
|
||||
ExpireDays *int `json:"expire_days"` // nil=使用默认值14,0=永不过期
|
||||
}
|
||||
|
||||
func (h *BackupHandler) CreateBackup(c *gin.Context) {
|
||||
var req CreateBackupRequest
|
||||
_ = c.ShouldBindJSON(&req) // 允许空 body
|
||||
|
||||
expireDays := 14 // 默认14天过期
|
||||
if req.ExpireDays != nil {
|
||||
expireDays = *req.ExpireDays
|
||||
}
|
||||
|
||||
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) ListBackups(c *gin.Context) {
|
||||
records, err := h.backupService.ListBackups(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if records == nil {
|
||||
records = []service.BackupRecord{}
|
||||
}
|
||||
response.Success(c, gin.H{"items": records})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, record)
|
||||
}
|
||||
|
||||
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"url": url})
|
||||
}
|
||||
|
||||
// ─── 恢复操作(需要重新输入管理员密码) ───
|
||||
|
||||
type RestoreBackupRequest struct {
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
|
||||
backupID := c.Param("id")
|
||||
if backupID == "" {
|
||||
response.BadRequest(c, "backup ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req RestoreBackupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "password is required for restore operation")
|
||||
return
|
||||
}
|
||||
|
||||
// 从上下文获取当前管理员用户 ID
|
||||
sub, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取管理员用户并验证密码
|
||||
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if !user.CheckPassword(req.Password) {
|
||||
response.BadRequest(c, "incorrect admin password")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"restored": true})
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -16,6 +19,55 @@ type GroupHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
type optionalLimitField struct {
|
||||
set bool
|
||||
value *float64
|
||||
}
|
||||
|
||||
func (f *optionalLimitField) UnmarshalJSON(data []byte) error {
|
||||
f.set = true
|
||||
|
||||
trimmed := bytes.TrimSpace(data)
|
||||
if bytes.Equal(trimmed, []byte("null")) {
|
||||
f.value = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
var number float64
|
||||
if err := json.Unmarshal(trimmed, &number); err == nil {
|
||||
f.value = &number
|
||||
return nil
|
||||
}
|
||||
|
||||
var text string
|
||||
if err := json.Unmarshal(trimmed, &text); err == nil {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
f.value = nil
|
||||
return nil
|
||||
}
|
||||
number, err = strconv.ParseFloat(text, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid numeric limit value %q: %w", text, err)
|
||||
}
|
||||
f.value = &number
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid limit value: %s", string(trimmed))
|
||||
}
|
||||
|
||||
func (f optionalLimitField) ToServiceInput() *float64 {
|
||||
if !f.set {
|
||||
return nil
|
||||
}
|
||||
if f.value != nil {
|
||||
return f.value
|
||||
}
|
||||
zero := 0.0
|
||||
return &zero
|
||||
}
|
||||
|
||||
// NewGroupHandler creates a new admin group handler
|
||||
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
return &GroupHandler{
|
||||
@@ -25,15 +77,15 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
|
||||
// CreateGroupRequest represents create group request
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
@@ -62,16 +114,16 @@ type CreateGroupRequest struct {
|
||||
|
||||
// UpdateGroupRequest represents update group request
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
@@ -191,9 +243,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
@@ -244,9 +296,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
|
||||
@@ -80,6 +80,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
FrontendURL: settings.FrontendURL,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
@@ -125,6 +126,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -136,6 +138,7 @@ type UpdateSettingsRequest struct {
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
|
||||
@@ -199,6 +202,9 @@ type UpdateSettingsRequest struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -322,6 +328,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Frontend URL 验证
|
||||
req.FrontendURL = strings.TrimSpace(req.FrontendURL)
|
||||
if req.FrontendURL != "" {
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.FrontendURL); err != nil {
|
||||
response.BadRequest(c, "Frontend URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 自定义菜单项验证
|
||||
const (
|
||||
maxCustomMenuItems = 20
|
||||
@@ -433,6 +448,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
FrontendURL: req.FrontendURL,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
@@ -473,6 +489,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: req.BackendModeEnabled,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -526,6 +543,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||
FrontendURL: updatedSettings.FrontendURL,
|
||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||
TotpEnabled: updatedSettings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
@@ -571,6 +589,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: updatedSettings.BackendModeEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -608,6 +627,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||
changed = append(changed, "password_reset_enabled")
|
||||
}
|
||||
if before.FrontendURL != after.FrontendURL {
|
||||
changed = append(changed, "frontend_url")
|
||||
}
|
||||
if before.TotpEnabled != after.TotpEnabled {
|
||||
changed = append(changed, "totp_enabled")
|
||||
}
|
||||
@@ -725,6 +747,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
||||
changed = append(changed, "allow_ungrouped_key_scheduling")
|
||||
}
|
||||
if before.BackendModeEnabled != after.BackendModeEnabled {
|
||||
changed = append(changed, "backend_mode_enabled")
|
||||
}
|
||||
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
||||
changed = append(changed, "purchase_subscription_enabled")
|
||||
}
|
||||
|
||||
@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
// Get the user
|
||||
// Get the user (before session deletion so we can check backend mode)
|
||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: only admin can login (check BEFORE deleting session)
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session (only after all checks pass)
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
h.respondWithTokenPair(c, user)
|
||||
}
|
||||
|
||||
@@ -447,9 +459,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
|
||||
frontendBaseURL := strings.TrimSpace(h.settingSvc.GetFrontendURL(c.Request.Context()))
|
||||
if frontendBaseURL == "" {
|
||||
slog.Error("server.frontend_url not configured; cannot build password reset link")
|
||||
slog.Error("frontend_url not configured in settings or config; cannot build password reset link")
|
||||
response.InternalError(c, "Password reset is not configured")
|
||||
return
|
||||
}
|
||||
@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Backend mode: block non-admin token refresh
|
||||
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
|
||||
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, RefreshTokenResponse{
|
||||
AccessToken: tokenPair.AccessToken,
|
||||
RefreshToken: tokenPair.RefreshToken,
|
||||
ExpiresIn: tokenPair.ExpiresIn,
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresIn: result.ExpiresIn,
|
||||
TokenType: "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -264,8 +264,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
}
|
||||
}
|
||||
|
||||
// 提取 API Key 账号配额限制(仅 apikey 类型有效)
|
||||
if a.Type == service.AccountTypeAPIKey {
|
||||
// 提取账号配额限制(apikey / bedrock 类型有效)
|
||||
if a.IsAPIKeyOrBedrock() {
|
||||
if limit := a.GetQuotaLimit(); limit > 0 {
|
||||
out.QuotaLimit = &limit
|
||||
used := a.GetQuotaUsed()
|
||||
@@ -281,6 +281,31 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
used := a.GetQuotaWeeklyUsed()
|
||||
out.QuotaWeeklyUsed = &used
|
||||
}
|
||||
// 固定时间重置配置
|
||||
if mode := a.GetQuotaDailyResetMode(); mode == "fixed" {
|
||||
out.QuotaDailyResetMode = &mode
|
||||
hour := a.GetQuotaDailyResetHour()
|
||||
out.QuotaDailyResetHour = &hour
|
||||
}
|
||||
if mode := a.GetQuotaWeeklyResetMode(); mode == "fixed" {
|
||||
out.QuotaWeeklyResetMode = &mode
|
||||
day := a.GetQuotaWeeklyResetDay()
|
||||
out.QuotaWeeklyResetDay = &day
|
||||
hour := a.GetQuotaWeeklyResetHour()
|
||||
out.QuotaWeeklyResetHour = &hour
|
||||
}
|
||||
if a.GetQuotaDailyResetMode() == "fixed" || a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||
tz := a.GetQuotaResetTimezone()
|
||||
out.QuotaResetTimezone = &tz
|
||||
}
|
||||
if a.Extra != nil {
|
||||
if v, ok := a.Extra["quota_daily_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaDailyResetAt = &v
|
||||
}
|
||||
if v, ok := a.Extra["quota_weekly_reset_at"].(string); ok && v != "" {
|
||||
out.QuotaWeeklyResetAt = &v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -498,6 +523,8 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
Model: l.Model,
|
||||
ServiceTier: l.ServiceTier,
|
||||
ReasoningEffort: l.ReasoningEffort,
|
||||
InboundEndpoint: l.InboundEndpoint,
|
||||
UpstreamEndpoint: l.UpstreamEndpoint,
|
||||
GroupID: l.GroupID,
|
||||
SubscriptionID: l.SubscriptionID,
|
||||
InputTokens: l.InputTokens,
|
||||
|
||||
@@ -76,10 +76,14 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceTier := "priority"
|
||||
inboundEndpoint := "/v1/chat/completions"
|
||||
upstreamEndpoint := "/v1/responses"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_3",
|
||||
Model: "gpt-5.4",
|
||||
ServiceTier: &serviceTier,
|
||||
InboundEndpoint: &inboundEndpoint,
|
||||
UpstreamEndpoint: &upstreamEndpoint,
|
||||
AccountRateMultiplier: f64Ptr(1.5),
|
||||
}
|
||||
|
||||
@@ -88,8 +92,16 @@ func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) {
|
||||
|
||||
require.NotNil(t, userDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *userDTO.ServiceTier)
|
||||
require.NotNil(t, userDTO.InboundEndpoint)
|
||||
require.Equal(t, inboundEndpoint, *userDTO.InboundEndpoint)
|
||||
require.NotNil(t, userDTO.UpstreamEndpoint)
|
||||
require.Equal(t, upstreamEndpoint, *userDTO.UpstreamEndpoint)
|
||||
require.NotNil(t, adminDTO.ServiceTier)
|
||||
require.Equal(t, serviceTier, *adminDTO.ServiceTier)
|
||||
require.NotNil(t, adminDTO.InboundEndpoint)
|
||||
require.Equal(t, inboundEndpoint, *adminDTO.InboundEndpoint)
|
||||
require.NotNil(t, adminDTO.UpstreamEndpoint)
|
||||
require.Equal(t, upstreamEndpoint, *adminDTO.UpstreamEndpoint)
|
||||
require.NotNil(t, adminDTO.AccountRateMultiplier)
|
||||
require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ type SystemSettings struct {
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
FrontendURL string `json:"frontend_url"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
@@ -81,6 +82,9 @@ type SystemSettings struct {
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
|
||||
// Backend Mode
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -111,6 +115,7 @@ type PublicSettings struct {
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
|
||||
@@ -203,6 +203,16 @@ type Account struct {
|
||||
QuotaWeeklyLimit *float64 `json:"quota_weekly_limit,omitempty"`
|
||||
QuotaWeeklyUsed *float64 `json:"quota_weekly_used,omitempty"`
|
||||
|
||||
// 配额固定时间重置配置
|
||||
QuotaDailyResetMode *string `json:"quota_daily_reset_mode,omitempty"`
|
||||
QuotaDailyResetHour *int `json:"quota_daily_reset_hour,omitempty"`
|
||||
QuotaWeeklyResetMode *string `json:"quota_weekly_reset_mode,omitempty"`
|
||||
QuotaWeeklyResetDay *int `json:"quota_weekly_reset_day,omitempty"`
|
||||
QuotaWeeklyResetHour *int `json:"quota_weekly_reset_hour,omitempty"`
|
||||
QuotaResetTimezone *string `json:"quota_reset_timezone,omitempty"`
|
||||
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
|
||||
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
@@ -324,9 +334,13 @@ type UsageLog struct {
|
||||
Model string `json:"model"`
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
|
||||
// nil means not provided / not applicable.
|
||||
// ReasoningEffort is the request's reasoning effort level.
|
||||
// OpenAI: "low"/"medium"/"high"/"xhigh"; Claude: "low"/"medium"/"high"/"max".
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
|
||||
InboundEndpoint *string `json:"inbound_endpoint,omitempty"`
|
||||
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
|
||||
UpstreamEndpoint *string `json:"upstream_endpoint,omitempty"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
SubscriptionID *int64 `json:"subscription_id"`
|
||||
|
||||
@@ -391,6 +391,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||
} else {
|
||||
@@ -402,6 +404,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
@@ -436,6 +443,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
if result.ReasoningEffort == nil {
|
||||
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -637,6 +648,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
// 记录 Forward 前已写入字节数,Forward 后若增加则说明 SSE 内容已发,禁止 failover
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||
} else {
|
||||
@@ -706,6 +719,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
// 流式内容已写入客户端,无法撤销,禁止 failover 以防止流拼接腐化
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, true)
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
@@ -740,6 +758,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
if result.ReasoningEffort == nil {
|
||||
result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort)
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
|
||||
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
122
backend/internal/handler/gateway_handler_stream_failover_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// partialMessageStartSSE 模拟 handleStreamingResponse 已写入的首批 SSE 事件。
|
||||
const partialMessageStartSSE = "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_01\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-5\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":1}}}\n\n" +
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"
|
||||
|
||||
// TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten 验证:
|
||||
// 当 Forward 在返回 UpstreamFailoverError 前已向客户端写入 SSE 内容时,
|
||||
// 故障转移保护逻辑必须终止循环并发送 SSE 错误事件,而不是进行下一次 Forward。
|
||||
// 具体验证:
|
||||
// 1. c.Writer.Size() 检测条件正确触发(字节数已增加)
|
||||
// 2. handleFailoverExhausted 以 streamStarted=true 调用后,响应体以 SSE 错误事件结尾
|
||||
// 3. 响应体中只出现一个 message_start,不存在第二个(防止流拼接腐化)
|
||||
func TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 步骤 1:记录 Forward 前的 writer size(模拟 writerSizeBeforeForward := c.Writer.Size())
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
require.Equal(t, -1, sizeBeforeForward, "gin writer 初始 Size 应为 -1(未写入任何字节)")
|
||||
|
||||
// 步骤 2:模拟 Forward 已向客户端写入部分 SSE 内容(message_start + content_block_start)
|
||||
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||
require.NoError(t, err)
|
||||
|
||||
// 步骤 3:验证守卫条件成立(c.Writer.Size() != sizeBeforeForward)
|
||||
require.NotEqual(t, sizeBeforeForward, c.Writer.Size(),
|
||||
"写入 SSE 内容后 writer size 必须增加,守卫条件应为 true")
|
||||
|
||||
// 步骤 4:模拟 UpstreamFailoverError(上游在流中途返回 403)
|
||||
failoverErr := &service.UpstreamFailoverError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
ResponseBody: []byte(`{"error":{"type":"permission_error","message":"forbidden"}}`),
|
||||
}
|
||||
|
||||
// 步骤 5:守卫触发 → 调用 handleFailoverExhausted,streamStarted=true
|
||||
h := &GatewayHandler{}
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformAnthropic, true)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
// 断言 A:响应体中包含最初写入的 message_start SSE 事件行
|
||||
require.Contains(t, body, "event: message_start", "响应体应包含已写入的 message_start SSE 事件")
|
||||
|
||||
// 断言 B:响应体以 SSE 错误事件结尾(data: {"type":"error",...}\n\n)
|
||||
require.True(t, strings.HasSuffix(strings.TrimRight(body, "\n"), "}"),
|
||||
"响应体应以 JSON 对象结尾(SSE error event 的 data 字段)")
|
||||
require.Contains(t, body, `"type":"error"`, "响应体末尾必须包含 SSE 错误事件")
|
||||
|
||||
// 断言 C:SSE event 行 "event: message_start" 只出现一次(防止双 message_start 拼接腐化)
|
||||
firstIdx := strings.Index(body, "event: message_start")
|
||||
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||
assert.Equal(t, firstIdx, lastIdx,
|
||||
"响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次")
|
||||
}
|
||||
|
||||
// TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten 与上述测试相同,
|
||||
// 验证 Gemini 路径使用 service.PlatformGemini(而非 account.Platform)时行为一致。
|
||||
func TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.0-flash:streamGenerateContent", nil)
|
||||
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
|
||||
_, err := c.Writer.Write([]byte(partialMessageStartSSE))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEqual(t, sizeBeforeForward, c.Writer.Size())
|
||||
|
||||
failoverErr := &service.UpstreamFailoverError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
}
|
||||
|
||||
h := &GatewayHandler{}
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
require.Contains(t, body, "event: message_start")
|
||||
require.Contains(t, body, `"type":"error"`)
|
||||
|
||||
firstIdx := strings.Index(body, "event: message_start")
|
||||
lastIdx := strings.LastIndex(body, "event: message_start")
|
||||
assert.Equal(t, firstIdx, lastIdx, "Gemini 路径不得出现双 message_start")
|
||||
}
|
||||
|
||||
// TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered 验证反向场景:
|
||||
// 当 Forward 返回 UpstreamFailoverError 时若未向客户端写入任何 SSE 内容,
|
||||
// 守卫条件(c.Writer.Size() != sizeBeforeForward)为 false,不应中止 failover。
|
||||
func TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 模拟 writerSizeBeforeForward:初始为 -1
|
||||
sizeBeforeForward := c.Writer.Size()
|
||||
|
||||
// Forward 未写入任何字节直接返回错误(例如 401 发生在连接建立前)
|
||||
// c.Writer.Size() 仍为 -1
|
||||
|
||||
// 守卫条件:sizeBeforeForward == c.Writer.Size() → 不触发
|
||||
guardTriggered := c.Writer.Size() != sizeBeforeForward
|
||||
require.False(t, guardTriggered,
|
||||
"未写入任何字节时,守卫条件必须为 false,应允许正常 failover 继续")
|
||||
}
|
||||
@@ -12,6 +12,7 @@ type AdminHandlers struct {
|
||||
Account *admin.AccountHandler
|
||||
Announcement *admin.AnnouncementHandler
|
||||
DataManagement *admin.DataManagementHandler
|
||||
Backup *admin.BackupHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
|
||||
@@ -256,14 +256,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointChatCompletions),
|
||||
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNormalizedOpenAIUpstreamEndpoint(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
fallback string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "responses root maps to responses upstream",
|
||||
path: "/v1/responses",
|
||||
fallback: openAIUpstreamEndpointResponses,
|
||||
want: "/v1/responses",
|
||||
},
|
||||
{
|
||||
name: "responses compact keeps compact suffix",
|
||||
path: "/openai/v1/responses/compact",
|
||||
fallback: openAIUpstreamEndpointResponses,
|
||||
want: "/v1/responses/compact",
|
||||
},
|
||||
{
|
||||
name: "responses nested suffix preserved",
|
||||
path: "/openai/v1/responses/compact/detail",
|
||||
fallback: openAIUpstreamEndpointResponses,
|
||||
want: "/v1/responses/compact/detail",
|
||||
},
|
||||
{
|
||||
name: "non responses path uses fallback",
|
||||
path: "/v1/messages",
|
||||
fallback: openAIUpstreamEndpointResponses,
|
||||
want: "/v1/responses",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, tt.path, nil)
|
||||
|
||||
got := normalizedOpenAIUpstreamEndpoint(c, tt.fallback)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,13 @@ type OpenAIGatewayHandler struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
const (
|
||||
openAIInboundEndpointResponses = "/v1/responses"
|
||||
openAIInboundEndpointMessages = "/v1/messages"
|
||||
openAIInboundEndpointChatCompletions = "/v1/chat/completions"
|
||||
openAIUpstreamEndpointResponses = "/v1/responses"
|
||||
)
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
func NewOpenAIGatewayHandler(
|
||||
gatewayService *service.OpenAIGatewayService,
|
||||
@@ -362,6 +369,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointResponses),
|
||||
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@@ -738,6 +747,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointMessages),
|
||||
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
@@ -1235,6 +1246,8 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: normalizedOpenAIInboundEndpoint(c, openAIInboundEndpointResponses),
|
||||
UpstreamEndpoint: normalizedOpenAIUpstreamEndpoint(c, openAIUpstreamEndpointResponses),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
@@ -1530,6 +1543,62 @@ func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64)
|
||||
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
|
||||
}
|
||||
|
||||
func normalizedOpenAIInboundEndpoint(c *gin.Context, fallback string) string {
|
||||
path := strings.TrimSpace(fallback)
|
||||
if c != nil {
|
||||
if fullPath := strings.TrimSpace(c.FullPath()); fullPath != "" {
|
||||
path = fullPath
|
||||
} else if c.Request != nil && c.Request.URL != nil {
|
||||
if requestPath := strings.TrimSpace(c.Request.URL.Path); requestPath != "" {
|
||||
path = requestPath
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.Contains(path, openAIInboundEndpointChatCompletions):
|
||||
return openAIInboundEndpointChatCompletions
|
||||
case strings.Contains(path, openAIInboundEndpointMessages):
|
||||
return openAIInboundEndpointMessages
|
||||
case strings.Contains(path, openAIInboundEndpointResponses):
|
||||
return openAIInboundEndpointResponses
|
||||
default:
|
||||
return path
|
||||
}
|
||||
}
|
||||
|
||||
func normalizedOpenAIUpstreamEndpoint(c *gin.Context, fallback string) string {
|
||||
base := strings.TrimSpace(fallback)
|
||||
if base == "" {
|
||||
base = openAIUpstreamEndpointResponses
|
||||
}
|
||||
base = strings.TrimRight(base, "/")
|
||||
|
||||
if c == nil || c.Request == nil || c.Request.URL == nil {
|
||||
return base
|
||||
}
|
||||
|
||||
path := strings.TrimRight(strings.TrimSpace(c.Request.URL.Path), "/")
|
||||
if path == "" {
|
||||
return base
|
||||
}
|
||||
|
||||
idx := strings.LastIndex(path, "/responses")
|
||||
if idx < 0 {
|
||||
return base
|
||||
}
|
||||
|
||||
suffix := strings.TrimSpace(path[idx+len("/responses"):])
|
||||
if suffix == "" || suffix == "/" {
|
||||
return base
|
||||
}
|
||||
if !strings.HasPrefix(suffix, "/") {
|
||||
return base
|
||||
}
|
||||
|
||||
return base + suffix
|
||||
}
|
||||
|
||||
func isOpenAIWSUpgradeRequest(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
|
||||
@@ -26,6 +26,22 @@ const (
|
||||
opsStreamKey = "ops_stream"
|
||||
opsRequestBodyKey = "ops_request_body"
|
||||
opsAccountIDKey = "ops_account_id"
|
||||
|
||||
// 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用
|
||||
opsErrContextCanceled = "context canceled"
|
||||
opsErrNoAvailableAccounts = "no available accounts"
|
||||
opsErrInvalidAPIKey = "invalid_api_key"
|
||||
opsErrAPIKeyRequired = "api_key_required"
|
||||
opsErrInsufficientBalance = "insufficient balance"
|
||||
opsErrInsufficientAccountBalance = "insufficient account balance"
|
||||
opsErrInsufficientQuota = "insufficient_quota"
|
||||
|
||||
// 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited)
|
||||
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
|
||||
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
|
||||
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
|
||||
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
|
||||
opsCodeUserInactive = "USER_INACTIVE"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -1024,9 +1040,9 @@ func normalizeOpsErrorType(errType string, code string) string {
|
||||
return errType
|
||||
}
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE":
|
||||
case opsCodeInsufficientBalance:
|
||||
return "billing_error"
|
||||
case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||
case opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||
return "subscription_error"
|
||||
default:
|
||||
return "api_error"
|
||||
@@ -1038,7 +1054,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
// Standardized phases: request|auth|routing|upstream|network|internal
|
||||
// Map billing/concurrency/response => request; scheduling => routing.
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||
return "request"
|
||||
}
|
||||
|
||||
@@ -1057,7 +1073,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
case "upstream_error", "overloaded_error":
|
||||
return "upstream"
|
||||
case "api_error":
|
||||
if strings.Contains(msg, "no available accounts") {
|
||||
if strings.Contains(msg, opsErrNoAvailableAccounts) {
|
||||
return "routing"
|
||||
}
|
||||
return "internal"
|
||||
@@ -1103,7 +1119,7 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool {
|
||||
|
||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
||||
switch strings.TrimSpace(code) {
|
||||
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID", "USER_INACTIVE":
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
|
||||
return true
|
||||
}
|
||||
if phase == "billing" || phase == "concurrency" {
|
||||
@@ -1197,21 +1213,30 @@ func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message
|
||||
|
||||
// Check if context canceled errors should be ignored (client disconnects)
|
||||
if settings.IgnoreContextCanceled {
|
||||
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
|
||||
if strings.Contains(msgLower, opsErrContextCanceled) || strings.Contains(bodyLower, opsErrContextCanceled) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if "no available accounts" errors should be ignored
|
||||
if settings.IgnoreNoAvailableAccounts {
|
||||
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
|
||||
if strings.Contains(msgLower, opsErrNoAvailableAccounts) || strings.Contains(bodyLower, opsErrNoAvailableAccounts) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if invalid/missing API key errors should be ignored (user misconfiguration)
|
||||
if settings.IgnoreInvalidApiKeyErrors {
|
||||
if strings.Contains(bodyLower, "invalid_api_key") || strings.Contains(bodyLower, "api_key_required") {
|
||||
if strings.Contains(bodyLower, opsErrInvalidAPIKey) || strings.Contains(bodyLower, opsErrAPIKeyRequired) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if insufficient balance errors should be ignored
|
||||
if settings.IgnoreInsufficientBalanceErrors {
|
||||
if strings.Contains(bodyLower, opsErrInsufficientBalance) || strings.Contains(bodyLower, opsErrInsufficientAccountBalance) ||
|
||||
strings.Contains(bodyLower, opsErrInsufficientQuota) ||
|
||||
strings.Contains(msgLower, opsErrInsufficientBalance) || strings.Contains(msgLower, opsErrInsufficientAccountBalance) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -334,6 +334,14 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
|
||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return []usagestats.EndpointStat{}, nil
|
||||
}
|
||||
|
||||
func (s *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return []usagestats.EndpointStat{}, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ func ProvideAdminHandlers(
|
||||
accountHandler *admin.AccountHandler,
|
||||
announcementHandler *admin.AnnouncementHandler,
|
||||
dataManagementHandler *admin.DataManagementHandler,
|
||||
backupHandler *admin.BackupHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
|
||||
Account: accountHandler,
|
||||
Announcement: announcementHandler,
|
||||
DataManagement: dataManagementHandler,
|
||||
Backup: backupHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
@@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewAccountHandler,
|
||||
admin.NewAnnouncementHandler,
|
||||
admin.NewDataManagementHandler,
|
||||
admin.NewBackupHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
|
||||
@@ -19,6 +19,16 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
)
|
||||
|
||||
// ForbiddenError 表示上游返回 403 Forbidden
|
||||
type ForbiddenError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *ForbiddenError) Error() string {
|
||||
return fmt.Sprintf("fetchAvailableModels 失败 (HTTP %d): %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||
@@ -514,7 +524,20 @@ type ModelQuotaInfo struct {
|
||||
|
||||
// ModelInfo 模型信息
|
||||
type ModelInfo struct {
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
SupportsImages *bool `json:"supportsImages,omitempty"`
|
||||
SupportsThinking *bool `json:"supportsThinking,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
Recommended *bool `json:"recommended,omitempty"`
|
||||
MaxTokens *int `json:"maxTokens,omitempty"`
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
SupportedMimeTypes map[string]bool `json:"supportedMimeTypes,omitempty"`
|
||||
}
|
||||
|
||||
// DeprecatedModelInfo 废弃模型转发信息
|
||||
type DeprecatedModelInfo struct {
|
||||
NewModelID string `json:"newModelId"`
|
||||
}
|
||||
|
||||
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
||||
@@ -524,7 +547,8 @@ type FetchAvailableModelsRequest struct {
|
||||
|
||||
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
||||
type FetchAvailableModelsResponse struct {
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
DeprecatedModelIDs map[string]DeprecatedModelInfo `json:"deprecatedModelIds,omitempty"`
|
||||
}
|
||||
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
@@ -573,6 +597,13 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
return nil, nil, &ForbiddenError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBodyBytes),
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
@@ -81,6 +81,15 @@ type ModelStat struct {
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// EndpointStat represents usage statistics for a single request endpoint.
|
||||
type EndpointStat struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
Requests int64 `json:"requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"` // 标准计费
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// GroupStat represents usage statistics for a single group
|
||||
type GroupStat struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
@@ -179,15 +188,18 @@ type UsageLogFilters struct {
|
||||
|
||||
// UsageStats represents usage statistics
|
||||
type UsageStats struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
Endpoints []EndpointStat `json:"endpoints,omitempty"`
|
||||
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints,omitempty"`
|
||||
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
|
||||
}
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
@@ -254,7 +266,9 @@ type AccountUsageSummary struct {
|
||||
|
||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||
type AccountUsageStatsResponse struct {
|
||||
History []AccountUsageHistory `json:"history"`
|
||||
Summary AccountUsageSummary `json:"summary"`
|
||||
Models []ModelStat `json:"models"`
|
||||
History []AccountUsageHistory `json:"history"`
|
||||
Summary AccountUsageSummary `json:"summary"`
|
||||
Models []ModelStat `json:"models"`
|
||||
Endpoints []EndpointStat `json:"endpoints"`
|
||||
UpstreamEndpoints []EndpointStat `json:"upstream_endpoints"`
|
||||
}
|
||||
|
||||
@@ -1727,8 +1727,96 @@ func (r *accountRepository) FindByExtraField(ctx context.Context, key string, va
|
||||
// nowUTC is a SQL expression to generate a UTC RFC3339 timestamp string.
|
||||
const nowUTC = `to_char(NOW() AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"')`
|
||||
|
||||
// dailyExpiredExpr is a SQL expression that evaluates to TRUE when daily quota period has expired.
|
||||
// Supports both rolling (24h from start) and fixed (pre-computed reset_at) modes.
|
||||
const dailyExpiredExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||
THEN NOW() >= COALESCE((extra->>'quota_daily_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||
ELSE COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
END
|
||||
)`
|
||||
|
||||
// weeklyExpiredExpr is a SQL expression that evaluates to TRUE when weekly quota period has expired.
|
||||
const weeklyExpiredExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||
THEN NOW() >= COALESCE((extra->>'quota_weekly_reset_at')::timestamptz, '1970-01-01'::timestamptz)
|
||||
ELSE COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
END
|
||||
)`
|
||||
|
||||
// nextDailyResetAtExpr is a SQL expression to compute the next daily reset_at when a reset occurs.
|
||||
// For fixed mode: computes the next future reset time based on NOW(), timezone, and configured hour.
|
||||
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||
const nextDailyResetAtExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_daily_reset_mode', 'rolling') = 'fixed'
|
||||
THEN to_char((
|
||||
-- Compute today's reset point in the configured timezone, then pick next future one
|
||||
CASE WHEN NOW() >= (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- NOW() is at or past today's reset point → next reset is tomorrow
|
||||
THEN (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ '1 day'::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- NOW() is before today's reset point → next reset is today
|
||||
ELSE (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_daily_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
END
|
||||
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||
ELSE NULL END
|
||||
)`
|
||||
|
||||
// nextWeeklyResetAtExpr is a SQL expression to compute the next weekly reset_at when a reset occurs.
|
||||
// For fixed mode: computes the next future reset time based on NOW(), timezone, configured day and hour.
|
||||
// This correctly handles long-inactive accounts by jumping directly to the next valid reset point.
|
||||
const nextWeeklyResetAtExpr = `(
|
||||
CASE WHEN COALESCE(extra->>'quota_weekly_reset_mode', 'rolling') = 'fixed'
|
||||
THEN to_char((
|
||||
-- Compute this week's reset point in the configured timezone
|
||||
-- Step 1: get today's date at reset hour in configured tz
|
||||
-- Step 2: compute days forward to target weekday
|
||||
-- Step 3: if same day but past reset hour, advance 7 days
|
||||
CASE
|
||||
WHEN (
|
||||
-- days_forward = (target_day - current_day + 7) % 7
|
||||
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||
+ 7) % 7
|
||||
) = 0 AND NOW() >= (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
-- Same weekday and past reset hour → next week
|
||||
THEN (
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ '7 days'::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
ELSE (
|
||||
-- Advance to target weekday this week (or next if days_forward > 0)
|
||||
date_trunc('day', NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))
|
||||
+ (COALESCE((extra->>'quota_weekly_reset_hour')::int, 0) || ' hours')::interval
|
||||
+ ((
|
||||
(COALESCE((extra->>'quota_weekly_reset_day')::int, 1)
|
||||
- EXTRACT(DOW FROM NOW() AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC'))::int
|
||||
+ 7) % 7
|
||||
) || ' days')::interval
|
||||
) AT TIME ZONE COALESCE(extra->>'quota_reset_timezone', 'UTC')
|
||||
END
|
||||
) AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
|
||||
ELSE NULL END
|
||||
)`
|
||||
|
||||
// IncrementQuotaUsed 原子递增账号的配额用量(总/日/周三个维度)
|
||||
// 日/周额度在周期过期时自动重置为 0 再递增。
|
||||
// 支持滚动窗口(rolling)和固定时间(fixed)两种重置模式。
|
||||
func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
rows, err := r.sql.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
@@ -1739,31 +1827,35 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
CASE WHEN `+dailyExpiredExpr+`
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
CASE WHEN `+dailyExpiredExpr+`
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
-- 固定模式重置时更新下次重置时间
|
||||
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
|
||||
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
|
||||
ELSE '{}'::jsonb END
|
||||
ELSE '{}'::jsonb END
|
||||
-- 周额度:仅在 quota_weekly_limit > 0 时处理
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
CASE WHEN `+weeklyExpiredExpr+`
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
CASE WHEN `+weeklyExpiredExpr+`
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
-- 固定模式重置时更新下次重置时间
|
||||
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
|
||||
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
|
||||
ELSE '{}'::jsonb END
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
@@ -1796,12 +1888,13 @@ func (r *accountRepository) IncrementQuotaUsed(ctx context.Context, id int64, am
|
||||
}
|
||||
|
||||
// ResetQuotaUsed 重置账号所有维度的配额用量为 0
|
||||
// 保留固定重置模式的配置字段(quota_daily_reset_mode 等),仅清零用量和窗口起始时间
|
||||
func (r *accountRepository) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| '{"quota_used": 0, "quota_daily_used": 0, "quota_weekly_used": 0}'::jsonb
|
||||
) - 'quota_daily_start' - 'quota_weekly_start', updated_at = NOW()
|
||||
) - 'quota_daily_start' - 'quota_weekly_start' - 'quota_daily_reset_at' - 'quota_weekly_reset_at', updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
|
||||
98
backend/internal/repository/backup_pg_dumper.go
Normal file
98
backend/internal/repository/backup_pg_dumper.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// PgDumper implements service.DBDumper using pg_dump/psql
|
||||
type PgDumper struct {
|
||||
cfg *config.DatabaseConfig
|
||||
}
|
||||
|
||||
// NewPgDumper creates a new PgDumper
|
||||
func NewPgDumper(cfg *config.Config) service.DBDumper {
|
||||
return &PgDumper{cfg: &cfg.Database}
|
||||
}
|
||||
|
||||
// Dump executes pg_dump and returns a streaming reader of the output
|
||||
func (d *PgDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--no-owner",
|
||||
"--no-acl",
|
||||
"--clean",
|
||||
"--if-exists",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_dump", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("start pg_dump: %w", err)
|
||||
}
|
||||
|
||||
// 返回一个 ReadCloser:读 stdout,关闭时等待进程退出
|
||||
return &cmdReadCloser{ReadCloser: stdout, cmd: cmd}, nil
|
||||
}
|
||||
|
||||
// Restore executes psql to restore from a streaming reader
|
||||
func (d *PgDumper) Restore(ctx context.Context, data io.Reader) error {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--single-transaction",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
cmd.Stdin = data
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%v: %s", err, string(output))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cmdReadCloser wraps a command stdout pipe and waits for the process on Close
|
||||
type cmdReadCloser struct {
|
||||
io.ReadCloser
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func (c *cmdReadCloser) Close() error {
|
||||
// Close the pipe first
|
||||
_ = c.ReadCloser.Close()
|
||||
// Wait for the process to exit
|
||||
if err := c.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("pg_dump exited with error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
116
backend/internal/repository/backup_s3_store.go
Normal file
116
backend/internal/repository/backup_s3_store.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage
|
||||
type S3BackupStore struct {
|
||||
client *s3.Client
|
||||
bucket string
|
||||
}
|
||||
|
||||
// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores
|
||||
func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
|
||||
return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) {
|
||||
region := cfg.Region
|
||||
if region == "" {
|
||||
region = "auto" // Cloudflare R2 默认 region
|
||||
}
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||||
awsconfig.WithRegion(region),
|
||||
awsconfig.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load aws config: %w", err)
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||
if cfg.Endpoint != "" {
|
||||
o.BaseEndpoint = &cfg.Endpoint
|
||||
}
|
||||
if cfg.ForcePathStyle {
|
||||
o.UsePathStyle = true
|
||||
}
|
||||
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
|
||||
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||
})
|
||||
|
||||
return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
|
||||
// 读取全部内容以获取大小(S3 PutObject 需要知道内容长度)
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
Body: bytes.NewReader(data),
|
||||
ContentType: &contentType,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S3 PutObject: %w", err)
|
||||
}
|
||||
return int64(len(data)), nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("S3 GetObject: %w", err)
|
||||
}
|
||||
return result.Body, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Delete(ctx context.Context, key string) error {
|
||||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) {
|
||||
presignClient := s3.NewPresignClient(s.client)
|
||||
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
}, s3.WithPresignExpires(expiry))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) HeadBucket(ctx context.Context) error {
|
||||
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: &s.bucket,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("S3 HeadBucket failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -132,7 +132,7 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
|
||||
if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) {
|
||||
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ import (
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
|
||||
|
||||
var usageLogInsertArgTypes = [...]string{
|
||||
"bigint",
|
||||
@@ -65,6 +65,8 @@ var usageLogInsertArgTypes = [...]string{
|
||||
"text",
|
||||
"text",
|
||||
"text",
|
||||
"text",
|
||||
"text",
|
||||
"boolean",
|
||||
"timestamptz",
|
||||
}
|
||||
@@ -304,6 +306,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES (
|
||||
@@ -312,7 +316,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -732,11 +736,13 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(keys)*37)
|
||||
args := make([]any, 0, len(keys)*38)
|
||||
argPos := 1
|
||||
for idx, key := range keys {
|
||||
if idx > 0 {
|
||||
@@ -799,6 +805,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
)
|
||||
@@ -837,6 +845,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
FROM input
|
||||
@@ -915,11 +925,13 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(preparedList)*36)
|
||||
args := make([]any, 0, len(preparedList)*38)
|
||||
argPos := 1
|
||||
for idx, prepared := range preparedList {
|
||||
if idx > 0 {
|
||||
@@ -979,6 +991,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
)
|
||||
@@ -1017,6 +1031,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
FROM input
|
||||
@@ -1063,6 +1079,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
inbound_endpoint,
|
||||
upstream_endpoint,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES (
|
||||
@@ -1071,7 +1089,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
`, prepared.args...)
|
||||
@@ -1101,6 +1119,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
mediaType := nullString(log.MediaType)
|
||||
serviceTier := nullString(log.ServiceTier)
|
||||
reasoningEffort := nullString(log.ReasoningEffort)
|
||||
inboundEndpoint := nullString(log.InboundEndpoint)
|
||||
upstreamEndpoint := nullString(log.UpstreamEndpoint)
|
||||
|
||||
var requestIDArg any
|
||||
if requestID != "" {
|
||||
@@ -1147,6 +1167,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
mediaType,
|
||||
serviceTier,
|
||||
reasoningEffort,
|
||||
inboundEndpoint,
|
||||
upstreamEndpoint,
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
},
|
||||
@@ -2505,7 +2527,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
args = append(args, *filters.StartTime)
|
||||
}
|
||||
if filters.EndTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
|
||||
conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1))
|
||||
args = append(args, *filters.EndTime)
|
||||
}
|
||||
|
||||
@@ -3040,7 +3062,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
||||
args = append(args, *filters.StartTime)
|
||||
}
|
||||
if filters.EndTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
|
||||
conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1))
|
||||
args = append(args, *filters.EndTime)
|
||||
}
|
||||
|
||||
@@ -3080,6 +3102,35 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
||||
stats.TotalAccountCost = &totalAccountCost
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
|
||||
start := time.Unix(0, 0).UTC()
|
||||
if filters.StartTime != nil {
|
||||
start = *filters.StartTime
|
||||
}
|
||||
end := time.Now().UTC()
|
||||
if filters.EndTime != nil {
|
||||
end = *filters.EndTime
|
||||
}
|
||||
|
||||
endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
|
||||
if endpointErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetStatsWithFilters: %v", endpointErr)
|
||||
endpoints = []EndpointStat{}
|
||||
}
|
||||
upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
|
||||
if upstreamEndpointErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetStatsWithFilters: %v", upstreamEndpointErr)
|
||||
upstreamEndpoints = []EndpointStat{}
|
||||
}
|
||||
endpointPaths, endpointPathErr := r.getEndpointPathStatsWithFilters(ctx, start, end, filters.UserID, filters.APIKeyID, filters.AccountID, filters.GroupID, filters.Model, filters.RequestType, filters.Stream, filters.BillingType)
|
||||
if endpointPathErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "getEndpointPathStatsWithFilters failed in GetStatsWithFilters: %v", endpointPathErr)
|
||||
endpointPaths = []EndpointStat{}
|
||||
}
|
||||
stats.Endpoints = endpoints
|
||||
stats.UpstreamEndpoints = upstreamEndpoints
|
||||
stats.EndpointPaths = endpointPaths
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
@@ -3092,6 +3143,163 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
|
||||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||||
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
|
||||
|
||||
// EndpointStat represents endpoint usage statistics row.
|
||||
type EndpointStat = usagestats.EndpointStat
|
||||
|
||||
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
COALESCE(NULLIF(TRIM(%s), ''), 'unknown') AS endpoint,
|
||||
COUNT(*) AS requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
%s
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`, endpointColumn, actualCostExpr)
|
||||
|
||||
args := []any{startTime, endTime}
|
||||
if userID > 0 {
|
||||
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
|
||||
args = append(args, userID)
|
||||
}
|
||||
if apiKeyID > 0 {
|
||||
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
||||
args = append(args, apiKeyID)
|
||||
}
|
||||
if accountID > 0 {
|
||||
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||||
args = append(args, accountID)
|
||||
}
|
||||
if groupID > 0 {
|
||||
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||
args = append(args, groupID)
|
||||
}
|
||||
if model != "" {
|
||||
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||
args = append(args, model)
|
||||
}
|
||||
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
args = append(args, int16(*billingType))
|
||||
}
|
||||
query += " GROUP BY endpoint ORDER BY requests DESC"
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
results = nil
|
||||
}
|
||||
}()
|
||||
|
||||
results = make([]EndpointStat, 0)
|
||||
for rows.Next() {
|
||||
var row EndpointStat
|
||||
if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
CONCAT(
|
||||
COALESCE(NULLIF(TRIM(inbound_endpoint), ''), 'unknown'),
|
||||
' -> ',
|
||||
COALESCE(NULLIF(TRIM(upstream_endpoint), ''), 'unknown')
|
||||
) AS endpoint,
|
||||
COUNT(*) AS requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
%s
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`, actualCostExpr)
|
||||
|
||||
args := []any{startTime, endTime}
|
||||
if userID > 0 {
|
||||
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
|
||||
args = append(args, userID)
|
||||
}
|
||||
if apiKeyID > 0 {
|
||||
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
||||
args = append(args, apiKeyID)
|
||||
}
|
||||
if accountID > 0 {
|
||||
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||||
args = append(args, accountID)
|
||||
}
|
||||
if groupID > 0 {
|
||||
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||
args = append(args, groupID)
|
||||
}
|
||||
if model != "" {
|
||||
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||
args = append(args, model)
|
||||
}
|
||||
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
args = append(args, int16(*billingType))
|
||||
}
|
||||
query += " GROUP BY endpoint ORDER BY requests DESC"
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
results = nil
|
||||
}
|
||||
}()
|
||||
|
||||
results = make([]EndpointStat, 0)
|
||||
for rows.Next() {
|
||||
var row EndpointStat
|
||||
if err := rows.Scan(&row.Endpoint, &row.Requests, &row.TotalTokens, &row.Cost, &row.ActualCost); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetEndpointStatsWithFilters returns inbound endpoint statistics with optional filters.
|
||||
func (r *usageLogRepository) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) {
|
||||
return r.getEndpointStatsByColumnWithFilters(ctx, "inbound_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
}
|
||||
|
||||
// GetUpstreamEndpointStatsWithFilters returns upstream endpoint statistics with optional filters.
|
||||
func (r *usageLogRepository) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]EndpointStat, error) {
|
||||
return r.getEndpointStatsByColumnWithFilters(ctx, "upstream_endpoint", startTime, endTime, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
}
|
||||
|
||||
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
|
||||
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) {
|
||||
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
|
||||
@@ -3254,11 +3462,23 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
if err != nil {
|
||||
models = []ModelStat{}
|
||||
}
|
||||
endpoints, endpointErr := r.GetEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil)
|
||||
if endpointErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "GetEndpointStatsWithFilters failed in GetAccountUsageStats: %v", endpointErr)
|
||||
endpoints = []EndpointStat{}
|
||||
}
|
||||
upstreamEndpoints, upstreamEndpointErr := r.GetUpstreamEndpointStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, "", nil, nil, nil)
|
||||
if upstreamEndpointErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "GetUpstreamEndpointStatsWithFilters failed in GetAccountUsageStats: %v", upstreamEndpointErr)
|
||||
upstreamEndpoints = []EndpointStat{}
|
||||
}
|
||||
|
||||
resp = &AccountUsageStatsResponse{
|
||||
History: history,
|
||||
Summary: summary,
|
||||
Models: models,
|
||||
History: history,
|
||||
Summary: summary,
|
||||
Models: models,
|
||||
Endpoints: endpoints,
|
||||
UpstreamEndpoints: upstreamEndpoints,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -3541,6 +3761,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
mediaType sql.NullString
|
||||
serviceTier sql.NullString
|
||||
reasoningEffort sql.NullString
|
||||
inboundEndpoint sql.NullString
|
||||
upstreamEndpoint sql.NullString
|
||||
cacheTTLOverridden bool
|
||||
createdAt time.Time
|
||||
)
|
||||
@@ -3581,6 +3803,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&mediaType,
|
||||
&serviceTier,
|
||||
&reasoningEffort,
|
||||
&inboundEndpoint,
|
||||
&upstreamEndpoint,
|
||||
&cacheTTLOverridden,
|
||||
&createdAt,
|
||||
); err != nil {
|
||||
@@ -3656,6 +3880,12 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if reasoningEffort.Valid {
|
||||
log.ReasoningEffort = &reasoningEffort.String
|
||||
}
|
||||
if inboundEndpoint.Valid {
|
||||
log.InboundEndpoint = &inboundEndpoint.String
|
||||
}
|
||||
if upstreamEndpoint.Valid {
|
||||
log.UpstreamEndpoint = &upstreamEndpoint.String
|
||||
}
|
||||
|
||||
return log, nil
|
||||
}
|
||||
|
||||
@@ -73,6 +73,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
sqlmock.AnyArg(), // media_type
|
||||
sqlmock.AnyArg(), // service_tier
|
||||
sqlmock.AnyArg(), // reasoning_effort
|
||||
sqlmock.AnyArg(), // inbound_endpoint
|
||||
sqlmock.AnyArg(), // upstream_endpoint
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
).
|
||||
@@ -141,6 +143,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
||||
sqlmock.AnyArg(),
|
||||
serviceTier,
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
sqlmock.AnyArg(),
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
).
|
||||
@@ -376,6 +380,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
@@ -415,6 +421,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "flex"},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
@@ -454,6 +462,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
sql.NullString{},
|
||||
sql.NullString{Valid: true, String: "priority"},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
|
||||
@@ -100,6 +100,10 @@ var ProviderSet = wire.NewSet(
|
||||
// Encryptors
|
||||
NewAESEncryptor,
|
||||
|
||||
// Backup infrastructure
|
||||
NewPgDumper,
|
||||
NewS3BackupStoreFactory,
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
ProvidePricingRemoteClient,
|
||||
|
||||
@@ -493,6 +493,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"registration_email_suffix_whitelist": [],
|
||||
"promo_code_enabled": true,
|
||||
"password_reset_enabled": false,
|
||||
"frontend_url": "",
|
||||
"totp_enabled": false,
|
||||
"totp_encryption_key_configured": false,
|
||||
"smtp_host": "smtp.example.com",
|
||||
@@ -537,6 +538,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"purchase_subscription_url": "",
|
||||
"min_claude_code_version": "",
|
||||
"allow_ungrouped_key_scheduling": false,
|
||||
"backend_mode_enabled": false,
|
||||
"custom_menu_items": []
|
||||
}
|
||||
}`,
|
||||
@@ -1623,6 +1625,14 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
51
backend/internal/server/middleware/backend_mode_guard.go
Normal file
51
backend/internal/server/middleware/backend_mode_guard.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// BackendModeUserGuard blocks non-admin users from accessing user routes when backend mode is enabled.
|
||||
// Must be placed AFTER JWT auth middleware so that the user role is available in context.
|
||||
func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
role, _ := GetUserRoleFromContext(c)
|
||||
if role == "admin" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
response.Forbidden(c, "Backend mode is active. User self-service is disabled.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
|
||||
// Allows: login, login/2fa, logout, refresh (admin needs these).
|
||||
// Blocks: register, forgot-password, reset-password, OAuth, etc.
|
||||
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
path := c.Request.URL.Path
|
||||
// Allow login, 2FA, logout, refresh, public settings
|
||||
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
|
||||
for _, suffix := range allowedSuffixes {
|
||||
if strings.HasSuffix(path, suffix) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
239
backend/internal/server/middleware/backend_mode_guard_test.go
Normal file
239
backend/internal/server/middleware/backend_mode_guard_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type bmSettingRepo struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Get(_ context.Context, _ string) (*service.Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||
v, ok := r.values[key]
|
||||
if !ok {
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Set(_ context.Context, _, _ string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetMultiple(_ context.Context, _ []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
if r.values == nil {
|
||||
r.values = make(map[string]string, len(settings))
|
||||
}
|
||||
for key, value := range settings {
|
||||
r.values[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Delete(_ context.Context, _ string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func newBackendModeSettingService(t *testing.T, enabled string) *service.SettingService {
|
||||
t.Helper()
|
||||
|
||||
repo := &bmSettingRepo{
|
||||
values: map[string]string{
|
||||
service.SettingKeyBackendModeEnabled: enabled,
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{})
|
||||
require.NoError(t, svc.UpdateSettings(context.Background(), &service.SystemSettings{
|
||||
BackendModeEnabled: enabled == "true",
|
||||
}))
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
func stringPtr(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestBackendModeUserGuard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nilService bool
|
||||
enabled string
|
||||
role *string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "disabled_allows_all",
|
||||
enabled: "false",
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "nil_service_allows_all",
|
||||
nilService: true,
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_admin_allowed",
|
||||
enabled: "true",
|
||||
role: stringPtr("admin"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_user_blocked",
|
||||
enabled: "true",
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_no_role_blocked",
|
||||
enabled: "true",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_empty_role_blocked",
|
||||
enabled: "true",
|
||||
role: stringPtr(""),
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
if tc.role != nil {
|
||||
role := *tc.role
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(string(ContextKeyUserRole), role)
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
|
||||
var svc *service.SettingService
|
||||
if !tc.nilService {
|
||||
svc = newBackendModeSettingService(t, tc.enabled)
|
||||
}
|
||||
|
||||
r.Use(BackendModeUserGuard(svc))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tc.wantStatus, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendModeAuthGuard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nilService bool
|
||||
enabled string
|
||||
path string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "disabled_allows_all",
|
||||
enabled: "false",
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "nil_service_allows_all",
|
||||
nilService: true,
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_login",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/login",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_login_2fa",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/login/2fa",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_logout",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/logout",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_refresh",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/refresh",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_register",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_forgot_password",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/forgot-password",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
|
||||
var svc *service.SettingService
|
||||
if !tc.nilService {
|
||||
svc = newBackendModeSettingService(t, tc.enabled)
|
||||
}
|
||||
|
||||
r.Use(BackendModeAuthGuard(svc))
|
||||
r.Any("/*path", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tc.wantStatus, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -107,9 +107,9 @@ func registerRoutes(
|
||||
v1 := r.Group("/api/v1")
|
||||
|
||||
// 注册各模块路由
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
||||
}
|
||||
|
||||
@@ -58,6 +58,9 @@ func RegisterAdminRoutes(
|
||||
// 数据管理
|
||||
registerDataManagementRoutes(admin, h)
|
||||
|
||||
// 数据库备份恢复
|
||||
registerBackupRoutes(admin, h)
|
||||
|
||||
// 运维监控(Ops)
|
||||
registerOpsRoutes(admin, h)
|
||||
|
||||
@@ -440,6 +443,30 @@ func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerBackupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
backup := admin.Group("/backups")
|
||||
{
|
||||
// S3 存储配置
|
||||
backup.GET("/s3-config", h.Admin.Backup.GetS3Config)
|
||||
backup.PUT("/s3-config", h.Admin.Backup.UpdateS3Config)
|
||||
backup.POST("/s3-config/test", h.Admin.Backup.TestS3Connection)
|
||||
|
||||
// 定时备份配置
|
||||
backup.GET("/schedule", h.Admin.Backup.GetSchedule)
|
||||
backup.PUT("/schedule", h.Admin.Backup.UpdateSchedule)
|
||||
|
||||
// 备份操作
|
||||
backup.POST("", h.Admin.Backup.CreateBackup)
|
||||
backup.GET("", h.Admin.Backup.ListBackups)
|
||||
backup.GET("/:id", h.Admin.Backup.GetBackup)
|
||||
backup.DELETE("/:id", h.Admin.Backup.DeleteBackup)
|
||||
backup.GET("/:id/download-url", h.Admin.Backup.GetDownloadURL)
|
||||
|
||||
// 恢复操作
|
||||
backup.POST("/:id/restore", h.Admin.Backup.RestoreBackup)
|
||||
}
|
||||
}
|
||||
|
||||
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
system := admin.Group("/system")
|
||||
{
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -17,12 +18,14 @@ func RegisterAuthRoutes(
|
||||
h *handler.Handlers,
|
||||
jwtAuth servermiddleware.JWTAuthMiddleware,
|
||||
redisClient *redis.Client,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
// 创建速率限制器
|
||||
rateLimiter := middleware.NewRateLimiter(redisClient)
|
||||
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
auth.Use(servermiddleware.BackendModeAuthGuard(settingService))
|
||||
{
|
||||
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
|
||||
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
|
||||
@@ -78,6 +81,7 @@ func RegisterAuthRoutes(
|
||||
// 需要认证的当前用户信息
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(servermiddleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
||||
// 撤销所有会话(需要认证)
|
||||
|
||||
@@ -29,6 +29,7 @@ func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
|
||||
c.Next()
|
||||
}),
|
||||
redisClient,
|
||||
nil,
|
||||
)
|
||||
|
||||
return router
|
||||
|
||||
@@ -3,6 +3,7 @@ package routes
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -12,6 +13,7 @@ func RegisterSoraClientRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
if h.SoraClient == nil {
|
||||
return
|
||||
@@ -19,6 +21,7 @@ func RegisterSoraClientRoutes(
|
||||
|
||||
authenticated := v1.Group("/sora")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
authenticated.POST("/generate", h.SoraClient.Generate)
|
||||
authenticated.GET("/generations", h.SoraClient.ListGenerations)
|
||||
|
||||
@@ -3,6 +3,7 @@ package routes
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -12,9 +13,11 @@ func RegisterUserRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
// 用户接口
|
||||
user := authenticated.Group("/user")
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"hash/fnv"
|
||||
"reflect"
|
||||
"sort"
|
||||
@@ -656,7 +657,7 @@ func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
// IsPoolMode 检查 API Key 账号是否启用池模式。
|
||||
// 池模式下,上游错误不标记本地账号状态,而是在同一账号上重试。
|
||||
func (a *Account) IsPoolMode() bool {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
if !a.IsAPIKeyOrBedrock() || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["pool_mode"]; ok {
|
||||
@@ -771,11 +772,16 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||
}
|
||||
|
||||
func (a *Account) IsBedrock() bool {
|
||||
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeBedrock || a.Type == AccountTypeBedrockAPIKey)
|
||||
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
|
||||
}
|
||||
|
||||
func (a *Account) IsBedrockAPIKey() bool {
|
||||
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrockAPIKey
|
||||
return a.IsBedrock() && a.GetCredential("auth_mode") == "apikey"
|
||||
}
|
||||
|
||||
// IsAPIKeyOrBedrock 返回账号类型是否支持配额和池模式等特性
|
||||
func (a *Account) IsAPIKeyOrBedrock() bool {
|
||||
return a.Type == AccountTypeAPIKey || a.Type == AccountTypeBedrock
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAI() bool {
|
||||
@@ -1274,6 +1280,240 @@ func (a *Account) getExtraTime(key string) time.Time {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// getExtraString 从 Extra 中读取指定 key 的字符串值
|
||||
func (a *Account) getExtraString(key string) string {
|
||||
if a.Extra == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getExtraInt 从 Extra 中读取指定 key 的 int 值
|
||||
func (a *Account) getExtraInt(key string) int {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
return int(parseExtraFloat64(v))
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetQuotaDailyResetMode 获取日额度重置模式:"rolling"(默认)或 "fixed"
|
||||
func (a *Account) GetQuotaDailyResetMode() string {
|
||||
if m := a.getExtraString("quota_daily_reset_mode"); m == "fixed" {
|
||||
return "fixed"
|
||||
}
|
||||
return "rolling"
|
||||
}
|
||||
|
||||
// GetQuotaDailyResetHour 获取固定重置的小时(0-23),默认 0
|
||||
func (a *Account) GetQuotaDailyResetHour() int {
|
||||
return a.getExtraInt("quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyResetMode 获取周额度重置模式:"rolling"(默认)或 "fixed"
|
||||
func (a *Account) GetQuotaWeeklyResetMode() string {
|
||||
if m := a.getExtraString("quota_weekly_reset_mode"); m == "fixed" {
|
||||
return "fixed"
|
||||
}
|
||||
return "rolling"
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyResetDay 获取固定重置的星期几(0=周日, 1=周一, ..., 6=周六),默认 1(周一)
|
||||
func (a *Account) GetQuotaWeeklyResetDay() int {
|
||||
if a.Extra == nil {
|
||||
return 1
|
||||
}
|
||||
if _, ok := a.Extra["quota_weekly_reset_day"]; !ok {
|
||||
return 1
|
||||
}
|
||||
return a.getExtraInt("quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
// GetQuotaWeeklyResetHour 获取周配额固定重置的小时(0-23),默认 0
|
||||
func (a *Account) GetQuotaWeeklyResetHour() int {
|
||||
return a.getExtraInt("quota_weekly_reset_hour")
|
||||
}
|
||||
|
||||
// GetQuotaResetTimezone 获取固定重置的时区名(IANA),默认 "UTC"
|
||||
func (a *Account) GetQuotaResetTimezone() string {
|
||||
if tz := a.getExtraString("quota_reset_timezone"); tz != "" {
|
||||
return tz
|
||||
}
|
||||
return "UTC"
|
||||
}
|
||||
|
||||
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
|
||||
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
|
||||
t := after.In(tz)
|
||||
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
if !after.Before(today) {
|
||||
return today.AddDate(0, 0, 1)
|
||||
}
|
||||
return today
|
||||
}
|
||||
|
||||
// lastFixedDailyReset 计算 now 之前最近一次的每日固定重置时间点
|
||||
func lastFixedDailyReset(hour int, tz *time.Location, now time.Time) time.Time {
|
||||
t := now.In(tz)
|
||||
today := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
if now.Before(today) {
|
||||
return today.AddDate(0, 0, -1)
|
||||
}
|
||||
return today
|
||||
}
|
||||
|
||||
// nextFixedWeeklyReset 计算在 after 之后的下一个每周固定重置时间点
|
||||
// day: 0=Sunday, 1=Monday, ..., 6=Saturday
|
||||
func nextFixedWeeklyReset(day, hour int, tz *time.Location, after time.Time) time.Time {
|
||||
t := after.In(tz)
|
||||
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
currentDay := int(todayReset.Weekday())
|
||||
|
||||
daysForward := (day - currentDay + 7) % 7
|
||||
if daysForward == 0 && !after.Before(todayReset) {
|
||||
daysForward = 7
|
||||
}
|
||||
return todayReset.AddDate(0, 0, daysForward)
|
||||
}
|
||||
|
||||
// lastFixedWeeklyReset 计算 now 之前最近一次的每周固定重置时间点
|
||||
func lastFixedWeeklyReset(day, hour int, tz *time.Location, now time.Time) time.Time {
|
||||
t := now.In(tz)
|
||||
todayReset := time.Date(t.Year(), t.Month(), t.Day(), hour, 0, 0, 0, tz)
|
||||
currentDay := int(todayReset.Weekday())
|
||||
|
||||
daysBack := (currentDay - day + 7) % 7
|
||||
if daysBack == 0 && now.Before(todayReset) {
|
||||
daysBack = 7
|
||||
}
|
||||
return todayReset.AddDate(0, 0, -daysBack)
|
||||
}
|
||||
|
||||
// isFixedDailyPeriodExpired 检查日配额是否在固定时间模式下已过期
|
||||
func (a *Account) isFixedDailyPeriodExpired(periodStart time.Time) bool {
|
||||
if periodStart.IsZero() {
|
||||
return true
|
||||
}
|
||||
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
|
||||
if err != nil {
|
||||
tz = time.UTC
|
||||
}
|
||||
lastReset := lastFixedDailyReset(a.GetQuotaDailyResetHour(), tz, time.Now())
|
||||
return periodStart.Before(lastReset)
|
||||
}
|
||||
|
||||
// isFixedWeeklyPeriodExpired 检查周配额是否在固定时间模式下已过期
|
||||
func (a *Account) isFixedWeeklyPeriodExpired(periodStart time.Time) bool {
|
||||
if periodStart.IsZero() {
|
||||
return true
|
||||
}
|
||||
tz, err := time.LoadLocation(a.GetQuotaResetTimezone())
|
||||
if err != nil {
|
||||
tz = time.UTC
|
||||
}
|
||||
lastReset := lastFixedWeeklyReset(a.GetQuotaWeeklyResetDay(), a.GetQuotaWeeklyResetHour(), tz, time.Now())
|
||||
return periodStart.Before(lastReset)
|
||||
}
|
||||
|
||||
// ComputeQuotaResetAt 根据当前配置计算并填充 extra 中的 quota_daily_reset_at / quota_weekly_reset_at
|
||||
// 在保存账号配置时调用
|
||||
func ComputeQuotaResetAt(extra map[string]any) {
|
||||
now := time.Now()
|
||||
tzName, _ := extra["quota_reset_timezone"].(string)
|
||||
if tzName == "" {
|
||||
tzName = "UTC"
|
||||
}
|
||||
tz, err := time.LoadLocation(tzName)
|
||||
if err != nil {
|
||||
tz = time.UTC
|
||||
}
|
||||
|
||||
// 日配额固定重置时间
|
||||
if mode, _ := extra["quota_daily_reset_mode"].(string); mode == "fixed" {
|
||||
hour := int(parseExtraFloat64(extra["quota_daily_reset_hour"]))
|
||||
if hour < 0 || hour > 23 {
|
||||
hour = 0
|
||||
}
|
||||
resetAt := nextFixedDailyReset(hour, tz, now)
|
||||
extra["quota_daily_reset_at"] = resetAt.UTC().Format(time.RFC3339)
|
||||
} else {
|
||||
delete(extra, "quota_daily_reset_at")
|
||||
}
|
||||
|
||||
// 周配额固定重置时间
|
||||
if mode, _ := extra["quota_weekly_reset_mode"].(string); mode == "fixed" {
|
||||
day := 1 // 默认周一
|
||||
if d, ok := extra["quota_weekly_reset_day"]; ok {
|
||||
day = int(parseExtraFloat64(d))
|
||||
}
|
||||
if day < 0 || day > 6 {
|
||||
day = 1
|
||||
}
|
||||
hour := int(parseExtraFloat64(extra["quota_weekly_reset_hour"]))
|
||||
if hour < 0 || hour > 23 {
|
||||
hour = 0
|
||||
}
|
||||
resetAt := nextFixedWeeklyReset(day, hour, tz, now)
|
||||
extra["quota_weekly_reset_at"] = resetAt.UTC().Format(time.RFC3339)
|
||||
} else {
|
||||
delete(extra, "quota_weekly_reset_at")
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateQuotaResetConfig 校验配额固定重置时间配置的合法性
|
||||
func ValidateQuotaResetConfig(extra map[string]any) error {
|
||||
if extra == nil {
|
||||
return nil
|
||||
}
|
||||
// 校验时区
|
||||
if tz, ok := extra["quota_reset_timezone"].(string); ok && tz != "" {
|
||||
if _, err := time.LoadLocation(tz); err != nil {
|
||||
return errors.New("invalid quota_reset_timezone: must be a valid IANA timezone name")
|
||||
}
|
||||
}
|
||||
// 日配额重置模式
|
||||
if mode, ok := extra["quota_daily_reset_mode"].(string); ok {
|
||||
if mode != "rolling" && mode != "fixed" {
|
||||
return errors.New("quota_daily_reset_mode must be 'rolling' or 'fixed'")
|
||||
}
|
||||
}
|
||||
// 日配额重置小时
|
||||
if v, ok := extra["quota_daily_reset_hour"]; ok {
|
||||
hour := int(parseExtraFloat64(v))
|
||||
if hour < 0 || hour > 23 {
|
||||
return errors.New("quota_daily_reset_hour must be between 0 and 23")
|
||||
}
|
||||
}
|
||||
// 周配额重置模式
|
||||
if mode, ok := extra["quota_weekly_reset_mode"].(string); ok {
|
||||
if mode != "rolling" && mode != "fixed" {
|
||||
return errors.New("quota_weekly_reset_mode must be 'rolling' or 'fixed'")
|
||||
}
|
||||
}
|
||||
// 周配额重置星期几
|
||||
if v, ok := extra["quota_weekly_reset_day"]; ok {
|
||||
day := int(parseExtraFloat64(v))
|
||||
if day < 0 || day > 6 {
|
||||
return errors.New("quota_weekly_reset_day must be between 0 (Sunday) and 6 (Saturday)")
|
||||
}
|
||||
}
|
||||
// 周配额重置小时
|
||||
if v, ok := extra["quota_weekly_reset_hour"]; ok {
|
||||
hour := int(parseExtraFloat64(v))
|
||||
if hour < 0 || hour > 23 {
|
||||
return errors.New("quota_weekly_reset_hour must be between 0 and 23")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasAnyQuotaLimit 检查是否配置了任一维度的配额限制
|
||||
func (a *Account) HasAnyQuotaLimit() bool {
|
||||
return a.GetQuotaLimit() > 0 || a.GetQuotaDailyLimit() > 0 || a.GetQuotaWeeklyLimit() > 0
|
||||
@@ -1296,14 +1536,26 @@ func (a *Account) IsQuotaExceeded() bool {
|
||||
// 日额度(周期过期视为未超限,下次 increment 会重置)
|
||||
if limit := a.GetQuotaDailyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_daily_start")
|
||||
if !isPeriodExpired(start, 24*time.Hour) && a.GetQuotaDailyUsed() >= limit {
|
||||
var expired bool
|
||||
if a.GetQuotaDailyResetMode() == "fixed" {
|
||||
expired = a.isFixedDailyPeriodExpired(start)
|
||||
} else {
|
||||
expired = isPeriodExpired(start, 24*time.Hour)
|
||||
}
|
||||
if !expired && a.GetQuotaDailyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// 周额度
|
||||
if limit := a.GetQuotaWeeklyLimit(); limit > 0 {
|
||||
start := a.getExtraTime("quota_weekly_start")
|
||||
if !isPeriodExpired(start, 7*24*time.Hour) && a.GetQuotaWeeklyUsed() >= limit {
|
||||
var expired bool
|
||||
if a.GetQuotaWeeklyResetMode() == "fixed" {
|
||||
expired = a.isFixedWeeklyPeriodExpired(start)
|
||||
} else {
|
||||
expired = isPeriodExpired(start, 7*24*time.Hour)
|
||||
}
|
||||
if !expired && a.GetQuotaWeeklyUsed() >= limit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
516
backend/internal/service/account_quota_reset_test.go
Normal file
516
backend/internal/service/account_quota_reset_test.go
Normal file
@@ -0,0 +1,516 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// nextFixedDailyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNextFixedDailyReset_BeforeResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 06:00 UTC, reset hour = 9
|
||||
after := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_AtResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// Exactly at reset hour → should return tomorrow
|
||||
after := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_AfterResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// After reset hour → should return tomorrow
|
||||
after := time.Date(2026, 3, 14, 15, 30, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_MidnightReset(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// Reset at hour 0 (midnight), currently 23:59
|
||||
after := time.Date(2026, 3, 14, 23, 59, 0, 0, tz)
|
||||
got := nextFixedDailyReset(0, tz, after)
|
||||
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_NonUTCTimezone(t *testing.T) {
|
||||
tz, err := time.LoadLocation("Asia/Shanghai")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2026-03-14 07:00 UTC = 2026-03-14 15:00 CST, reset hour = 9 (CST)
|
||||
after := time.Date(2026, 3, 14, 7, 0, 0, 0, time.UTC)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
// Already past 9:00 CST today → tomorrow 9:00 CST = 2026-03-15 01:00 UTC
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// lastFixedDailyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLastFixedDailyReset_BeforeResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// Before today's 9:00 → yesterday 9:00
|
||||
want := time.Date(2026, 3, 13, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedDailyReset_AtResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// At exactly 9:00 → today 9:00
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedDailyReset_AfterResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 15, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// After 9:00 → today 9:00
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// nextFixedWeeklyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayAhead(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 is Saturday (day=6), target = Monday (day=1), hour = 9
|
||||
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday = 2026-03-16
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_BeforeHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, before 9:00
|
||||
after := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Today at 9:00
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_AtHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, exactly at 9:00
|
||||
after := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday at 9:00
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_AfterHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, after 9:00
|
||||
after := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday at 9:00
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayPast(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
|
||||
after := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday = 2026-03-23
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_Sunday(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 is Saturday (day=6), target = Sunday (day=0)
|
||||
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(0, 0, tz, after)
|
||||
// Next Sunday = 2026-03-15
|
||||
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// lastFixedWeeklyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLastFixedWeeklyReset_SameDay_AfterHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, now = 15:00
|
||||
now := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Today at 9:00
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedWeeklyReset_SameDay_BeforeHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, now = 06:00
|
||||
now := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Last Monday at 9:00 = 2026-03-09
|
||||
want := time.Date(2026, 3, 9, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedWeeklyReset_DifferentDay(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
|
||||
now := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Last Monday = 2026-03-16
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isFixedDailyPeriodExpired
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_ZeroPeriodStart(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(time.Time{}))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_NotExpired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started after the most recent reset → not expired
|
||||
// (This test uses a time very close to "now", which is after the last reset)
|
||||
periodStart := time.Now().Add(-1 * time.Minute)
|
||||
assert.False(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_Expired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 3 days ago → definitely expired
|
||||
periodStart := time.Now().Add(-72 * time.Hour)
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_InvalidTimezone(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "Invalid/Timezone",
|
||||
}}
|
||||
// Invalid timezone falls back to UTC
|
||||
periodStart := time.Now().Add(-72 * time.Hour)
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isFixedWeeklyPeriodExpired
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_ZeroPeriodStart(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
assert.True(t, a.isFixedWeeklyPeriodExpired(time.Time{}))
|
||||
}
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_NotExpired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 1 minute ago → not expired
|
||||
periodStart := time.Now().Add(-1 * time.Minute)
|
||||
assert.False(t, a.isFixedWeeklyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_Expired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 10 days ago → definitely expired
|
||||
periodStart := time.Now().Add(-240 * time.Hour)
|
||||
assert.True(t, a.isFixedWeeklyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ValidateQuotaResetConfig
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestValidateQuotaResetConfig_NilExtra(t *testing.T) {
|
||||
assert.NoError(t, ValidateQuotaResetConfig(nil))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_EmptyExtra(t *testing.T) {
|
||||
assert.NoError(t, ValidateQuotaResetConfig(map[string]any{}))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_ValidFixed(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "Asia/Shanghai",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_ValidRolling(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidTimezone(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_reset_timezone": "Not/A/Timezone",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_reset_timezone")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyMode(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "invalid",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_mode")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyHour_TooHigh(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(24),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyHour_Negative(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(-1),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyMode(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_mode": "unknown",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_mode")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyDay_TooHigh(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_day": float64(7),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyDay_Negative(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_day": float64(-1),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyHour(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_hour": float64(25),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_BoundaryValues(t *testing.T) {
|
||||
// All boundary values should be valid
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(23),
|
||||
"quota_weekly_reset_day": float64(0), // Sunday
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
|
||||
extra2 := map[string]any{
|
||||
"quota_daily_reset_hour": float64(0),
|
||||
"quota_weekly_reset_day": float64(6), // Saturday
|
||||
"quota_weekly_reset_hour": float64(23),
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra2))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ComputeQuotaResetAt
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestComputeQuotaResetAt_RollingMode_NoResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
_, hasDailyResetAt := extra["quota_daily_reset_at"]
|
||||
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
|
||||
assert.False(t, hasDailyResetAt, "rolling mode should not set quota_daily_reset_at")
|
||||
assert.False(t, hasWeeklyResetAt, "rolling mode should not set quota_weekly_reset_at")
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_RollingMode_ClearsExistingResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
"quota_daily_reset_at": "2026-03-14T09:00:00Z",
|
||||
"quota_weekly_reset_at": "2026-03-16T09:00:00Z",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
_, hasDailyResetAt := extra["quota_daily_reset_at"]
|
||||
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
|
||||
assert.False(t, hasDailyResetAt, "rolling mode should remove quota_daily_reset_at")
|
||||
assert.False(t, hasWeeklyResetAt, "rolling mode should remove quota_weekly_reset_at")
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedDaily_SetsResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok, "quota_daily_reset_at should be set")
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Reset time should be in the future
|
||||
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
|
||||
// Reset hour should be 9 UTC
|
||||
assert.Equal(t, 9, resetAt.UTC().Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedWeekly_SetsResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1), // Monday
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_weekly_reset_at"].(string)
|
||||
require.True(t, ok, "quota_weekly_reset_at should be set")
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Reset time should be in the future
|
||||
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
|
||||
// Reset day should be Monday
|
||||
assert.Equal(t, time.Monday, resetAt.UTC().Weekday())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedDaily_WithTimezone(t *testing.T) {
|
||||
tz, err := time.LoadLocation("Asia/Shanghai")
|
||||
require.NoError(t, err)
|
||||
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "Asia/Shanghai",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// In Shanghai timezone, the hour should be 9
|
||||
assert.Equal(t, 9, resetAt.In(tz).Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_DefaultTimezone(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(12),
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Default timezone is UTC
|
||||
assert.Equal(t, 12, resetAt.UTC().Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_InvalidHour_ClampedToZero(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(99),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Invalid hour → clamped to 0
|
||||
assert.Equal(t, 0, resetAt.UTC().Hour())
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -44,6 +45,8 @@ type UsageLogRepository interface {
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
||||
GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
|
||||
GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error)
|
||||
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
@@ -100,6 +103,7 @@ type antigravityUsageCache struct {
|
||||
const (
|
||||
apiCacheTTL = 3 * time.Minute
|
||||
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟
|
||||
antigravityErrorTTL = 1 * time.Minute // Antigravity 错误缓存 TTL(可恢复错误)
|
||||
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
openAIProbeCacheTTL = 10 * time.Minute
|
||||
@@ -108,11 +112,12 @@ const (
|
||||
|
||||
// UsageCache 封装账户使用量相关的缓存
|
||||
type UsageCache struct {
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存
|
||||
openAIProbeCache sync.Map // accountID -> time.Time
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存(Anthropic)
|
||||
antigravityFlight singleflight.Group // 防止同一 Antigravity 账号的并发请求击穿缓存
|
||||
openAIProbeCache sync.Map // accountID -> time.Time
|
||||
}
|
||||
|
||||
// NewUsageCache 创建 UsageCache 实例
|
||||
@@ -149,6 +154,18 @@ type AntigravityModelQuota struct {
|
||||
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
|
||||
}
|
||||
|
||||
// AntigravityModelDetail Antigravity 单个模型的详细能力信息
|
||||
type AntigravityModelDetail struct {
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
SupportsImages *bool `json:"supports_images,omitempty"`
|
||||
SupportsThinking *bool `json:"supports_thinking,omitempty"`
|
||||
ThinkingBudget *int `json:"thinking_budget,omitempty"`
|
||||
Recommended *bool `json:"recommended,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||
SupportedMimeTypes map[string]bool `json:"supported_mime_types,omitempty"`
|
||||
}
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
type UsageInfo struct {
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
@@ -164,6 +181,33 @@ type UsageInfo struct {
|
||||
|
||||
// Antigravity 多模型配额
|
||||
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
|
||||
|
||||
// Antigravity 账号级信息
|
||||
SubscriptionTier string `json:"subscription_tier,omitempty"` // 归一化订阅等级: FREE/PRO/ULTRA/UNKNOWN
|
||||
SubscriptionTierRaw string `json:"subscription_tier_raw,omitempty"` // 上游原始订阅等级名称
|
||||
|
||||
// Antigravity 模型详细能力信息(与 antigravity_quota 同 key)
|
||||
AntigravityQuotaDetails map[string]*AntigravityModelDetail `json:"antigravity_quota_details,omitempty"`
|
||||
|
||||
// Antigravity 废弃模型转发规则 (old_model_id -> new_model_id)
|
||||
ModelForwardingRules map[string]string `json:"model_forwarding_rules,omitempty"`
|
||||
|
||||
// Antigravity 账号是否被上游禁止 (HTTP 403)
|
||||
IsForbidden bool `json:"is_forbidden,omitempty"`
|
||||
ForbiddenReason string `json:"forbidden_reason,omitempty"`
|
||||
ForbiddenType string `json:"forbidden_type,omitempty"` // "validation" / "violation" / "forbidden"
|
||||
ValidationURL string `json:"validation_url,omitempty"` // 验证/申诉链接
|
||||
|
||||
// 状态标记(从 ForbiddenType / HTTP 错误码推导)
|
||||
NeedsVerify bool `json:"needs_verify,omitempty"` // 需要人工验证(forbidden_type=validation)
|
||||
IsBanned bool `json:"is_banned,omitempty"` // 账号被封(forbidden_type=violation)
|
||||
NeedsReauth bool `json:"needs_reauth,omitempty"` // token 失效需重新授权(401)
|
||||
|
||||
// 错误码(机器可读):forbidden / unauthenticated / rate_limited / network_error
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
|
||||
// 获取 usage 时的错误信息(降级返回,而非 500)
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeUsageResponse Anthropic API返回的usage结构
|
||||
@@ -648,34 +692,157 @@ func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *
|
||||
return &UsageInfo{UpdatedAt: &now}, nil
|
||||
}
|
||||
|
||||
// 1. 检查缓存(10 分钟)
|
||||
// 1. 检查缓存
|
||||
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||
// 重新计算 RemainingSeconds
|
||||
usage := cache.usageInfo
|
||||
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
||||
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok {
|
||||
ttl := antigravityCacheTTL(cache.usageInfo)
|
||||
if time.Since(cache.timestamp) < ttl {
|
||||
usage := cache.usageInfo
|
||||
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
||||
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取代理 URL
|
||||
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account)
|
||||
// 2. singleflight 防止并发击穿
|
||||
flightKey := fmt.Sprintf("ag-usage:%d", account.ID)
|
||||
result, flightErr, _ := s.cache.antigravityFlight.Do(flightKey, func() (any, error) {
|
||||
// 再次检查缓存(等待期间可能已被填充)
|
||||
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok {
|
||||
ttl := antigravityCacheTTL(cache.usageInfo)
|
||||
if time.Since(cache.timestamp) < ttl {
|
||||
usage := cache.usageInfo
|
||||
// 重新计算 RemainingSeconds,避免返回过时的剩余秒数
|
||||
recalcAntigravityRemainingSeconds(usage)
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 调用 API 获取额度
|
||||
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
|
||||
}
|
||||
// 使用独立 context,避免调用方 cancel 导致所有共享 flight 的请求失败
|
||||
fetchCtx, fetchCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer fetchCancel()
|
||||
|
||||
// 4. 缓存结果
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: result.UsageInfo,
|
||||
timestamp: time.Now(),
|
||||
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(fetchCtx, account)
|
||||
fetchResult, err := s.antigravityQuotaFetcher.FetchQuota(fetchCtx, account, proxyURL)
|
||||
if err != nil {
|
||||
degraded := buildAntigravityDegradedUsage(err)
|
||||
enrichUsageWithAccountError(degraded, account)
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: degraded,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
return degraded, nil
|
||||
}
|
||||
|
||||
enrichUsageWithAccountError(fetchResult.UsageInfo, account)
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: fetchResult.UsageInfo,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
return fetchResult.UsageInfo, nil
|
||||
})
|
||||
|
||||
return result.UsageInfo, nil
|
||||
if flightErr != nil {
|
||||
return nil, flightErr
|
||||
}
|
||||
usage, ok := result.(*UsageInfo)
|
||||
if !ok || usage == nil {
|
||||
now := time.Now()
|
||||
return &UsageInfo{UpdatedAt: &now}, nil
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
|
||||
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
|
||||
func recalcAntigravityRemainingSeconds(info *UsageInfo) {
|
||||
if info == nil {
|
||||
return
|
||||
}
|
||||
if info.FiveHour != nil && info.FiveHour.ResetsAt != nil {
|
||||
remaining := int(time.Until(*info.FiveHour.ResetsAt).Seconds())
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
info.FiveHour.RemainingSeconds = remaining
|
||||
}
|
||||
}
|
||||
|
||||
// antigravityCacheTTL 根据 UsageInfo 内容决定缓存 TTL
|
||||
// 403 forbidden 状态稳定,缓存与成功相同(3 分钟);
|
||||
// 其他错误(401/网络)可能快速恢复,缓存 1 分钟。
|
||||
func antigravityCacheTTL(info *UsageInfo) time.Duration {
|
||||
if info == nil {
|
||||
return antigravityErrorTTL
|
||||
}
|
||||
if info.IsForbidden {
|
||||
return apiCacheTTL // 封号/验证状态不会很快变
|
||||
}
|
||||
if info.ErrorCode != "" || info.Error != "" {
|
||||
return antigravityErrorTTL
|
||||
}
|
||||
return apiCacheTTL
|
||||
}
|
||||
|
||||
// buildAntigravityDegradedUsage 从 FetchQuota 错误构建降级 UsageInfo
|
||||
func buildAntigravityDegradedUsage(err error) *UsageInfo {
|
||||
now := time.Now()
|
||||
errMsg := fmt.Sprintf("usage API error: %v", err)
|
||||
slog.Warn("antigravity usage fetch failed, returning degraded response", "error", err)
|
||||
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
Error: errMsg,
|
||||
}
|
||||
|
||||
// 从错误信息推断 error_code 和状态标记
|
||||
// 错误格式来自 antigravity/client.go: "fetchAvailableModels 失败 (HTTP %d): ..."
|
||||
errStr := err.Error()
|
||||
switch {
|
||||
case strings.Contains(errStr, "HTTP 401") ||
|
||||
strings.Contains(errStr, "UNAUTHENTICATED") ||
|
||||
strings.Contains(errStr, "invalid_grant"):
|
||||
info.ErrorCode = errorCodeUnauthenticated
|
||||
info.NeedsReauth = true
|
||||
case strings.Contains(errStr, "HTTP 429"):
|
||||
info.ErrorCode = errorCodeRateLimited
|
||||
default:
|
||||
info.ErrorCode = errorCodeNetworkError
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// enrichUsageWithAccountError 结合账号错误状态修正 UsageInfo
|
||||
// 场景 1(成功路径):FetchAvailableModels 正常返回,但账号已因 403 被标记为 error,
|
||||
//
|
||||
// 需要在正常 usage 数据上附加 forbidden/validation 信息。
|
||||
//
|
||||
// 场景 2(降级路径):被封号的账号 OAuth token 失效,FetchAvailableModels 返回 401,
|
||||
//
|
||||
// 降级逻辑设置了 needs_reauth,但账号实际是 403 封号/需验证,需覆盖为正确状态。
|
||||
func enrichUsageWithAccountError(info *UsageInfo, account *Account) {
|
||||
if info == nil || account == nil || account.Status != StatusError {
|
||||
return
|
||||
}
|
||||
msg := strings.ToLower(account.ErrorMessage)
|
||||
if !strings.Contains(msg, "403") && !strings.Contains(msg, "forbidden") &&
|
||||
!strings.Contains(msg, "violation") && !strings.Contains(msg, "validation") {
|
||||
return
|
||||
}
|
||||
fbType := classifyForbiddenType(account.ErrorMessage)
|
||||
info.IsForbidden = true
|
||||
info.ForbiddenType = fbType
|
||||
info.ForbiddenReason = account.ErrorMessage
|
||||
info.NeedsVerify = fbType == forbiddenTypeValidation
|
||||
info.IsBanned = fbType == forbiddenTypeViolation
|
||||
info.ValidationURL = extractValidationURL(account.ErrorMessage)
|
||||
info.ErrorCode = errorCodeForbidden
|
||||
info.NeedsReauth = false
|
||||
}
|
||||
|
||||
// addWindowStats 为 usage 数据添加窗口期统计
|
||||
|
||||
@@ -832,7 +832,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
subscriptionType = SubscriptionTypeStandard
|
||||
}
|
||||
|
||||
// 限额字段:0 和 nil 都表示"无限制"
|
||||
// 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额
|
||||
dailyLimit := normalizeLimit(input.DailyLimitUSD)
|
||||
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
|
||||
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
|
||||
@@ -944,9 +944,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// normalizeLimit 将 0 或负数转换为 nil(表示无限制)
|
||||
// normalizeLimit 将负数转换为 nil(表示无限制),0 保留(表示限额为零)
|
||||
func normalizeLimit(limit *float64) *float64 {
|
||||
if limit == nil || *limit <= 0 {
|
||||
if limit == nil || *limit < 0 {
|
||||
return nil
|
||||
}
|
||||
return limit
|
||||
@@ -1058,16 +1058,11 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.SubscriptionType != "" {
|
||||
group.SubscriptionType = input.SubscriptionType
|
||||
}
|
||||
// 限额字段:0 和 nil 都表示"无限制",正数表示具体限额
|
||||
if input.DailyLimitUSD != nil {
|
||||
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
|
||||
}
|
||||
if input.WeeklyLimitUSD != nil {
|
||||
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
|
||||
}
|
||||
if input.MonthlyLimitUSD != nil {
|
||||
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
||||
}
|
||||
// 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额
|
||||
// 前端始终发送这三个字段,无需 nil 守卫
|
||||
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
|
||||
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
|
||||
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
||||
// 图片生成计费配置:负数表示清除(使用默认价格)
|
||||
if input.ImagePrice1K != nil {
|
||||
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
|
||||
@@ -1462,6 +1457,13 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
// 预计算固定时间重置的下次重置时间
|
||||
if account.Extra != nil {
|
||||
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ComputeQuotaResetAt(account.Extra)
|
||||
}
|
||||
if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
|
||||
expiresAt := time.Unix(*input.ExpiresAt, 0)
|
||||
account.ExpiresAt = &expiresAt
|
||||
@@ -1535,6 +1537,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
}
|
||||
}
|
||||
account.Extra = input.Extra
|
||||
// 校验并预计算固定时间重置的下次重置时间
|
||||
if err := ValidateQuotaResetConfig(account.Extra); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ComputeQuotaResetAt(account.Extra)
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
|
||||
|
||||
@@ -2,12 +2,29 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
const (
|
||||
forbiddenTypeValidation = "validation"
|
||||
forbiddenTypeViolation = "violation"
|
||||
forbiddenTypeForbidden = "forbidden"
|
||||
|
||||
// 机器可读的错误码
|
||||
errorCodeForbidden = "forbidden"
|
||||
errorCodeUnauthenticated = "unauthenticated"
|
||||
errorCodeRateLimited = "rate_limited"
|
||||
errorCodeNetworkError = "network_error"
|
||||
)
|
||||
|
||||
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
|
||||
type AntigravityQuotaFetcher struct {
|
||||
proxyRepo ProxyRepository
|
||||
@@ -40,11 +57,32 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
// 调用 API 获取配额
|
||||
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
||||
if err != nil {
|
||||
// 403 Forbidden: 不报错,返回 is_forbidden 标记
|
||||
var forbiddenErr *antigravity.ForbiddenError
|
||||
if errors.As(err, &forbiddenErr) {
|
||||
now := time.Now()
|
||||
fbType := classifyForbiddenType(forbiddenErr.Body)
|
||||
return &QuotaResult{
|
||||
UsageInfo: &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
IsForbidden: true,
|
||||
ForbiddenReason: forbiddenErr.Body,
|
||||
ForbiddenType: fbType,
|
||||
ValidationURL: extractValidationURL(forbiddenErr.Body),
|
||||
NeedsVerify: fbType == forbiddenTypeValidation,
|
||||
IsBanned: fbType == forbiddenTypeViolation,
|
||||
ErrorCode: errorCodeForbidden,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调用 LoadCodeAssist 获取订阅等级(非关键路径,失败不影响主流程)
|
||||
tierRaw, tierNormalized := f.fetchSubscriptionTier(ctx, client, accessToken)
|
||||
|
||||
// 转换为 UsageInfo
|
||||
usageInfo := f.buildUsageInfo(modelsResp)
|
||||
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized)
|
||||
|
||||
return &QuotaResult{
|
||||
UsageInfo: usageInfo,
|
||||
@@ -52,15 +90,52 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
||||
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串
|
||||
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string) {
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
if err != nil {
|
||||
slog.Warn("failed to fetch subscription tier", "error", err)
|
||||
return "", ""
|
||||
}
|
||||
if loadResp == nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// 遍历所有模型,填充 AntigravityQuota
|
||||
raw = loadResp.GetTier() // 已有方法:paidTier > currentTier
|
||||
normalized = normalizeTier(raw)
|
||||
return raw, normalized
|
||||
}
|
||||
|
||||
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
|
||||
func normalizeTier(raw string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
lower := strings.ToLower(raw)
|
||||
switch {
|
||||
case strings.Contains(lower, "ultra"):
|
||||
return "ULTRA"
|
||||
case strings.Contains(lower, "pro"):
|
||||
return "PRO"
|
||||
case strings.Contains(lower, "free"):
|
||||
return "FREE"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
||||
AntigravityQuotaDetails: make(map[string]*AntigravityModelDetail),
|
||||
SubscriptionTier: tierNormalized,
|
||||
SubscriptionTierRaw: tierRaw,
|
||||
}
|
||||
|
||||
// 遍历所有模型,填充 AntigravityQuota 和 AntigravityQuotaDetails
|
||||
for modelName, modelInfo := range modelsResp.Models {
|
||||
if modelInfo.QuotaInfo == nil {
|
||||
continue
|
||||
@@ -73,6 +148,27 @@ func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAv
|
||||
Utilization: utilization,
|
||||
ResetTime: modelInfo.QuotaInfo.ResetTime,
|
||||
}
|
||||
|
||||
// 填充模型详细能力信息
|
||||
detail := &AntigravityModelDetail{
|
||||
DisplayName: modelInfo.DisplayName,
|
||||
SupportsImages: modelInfo.SupportsImages,
|
||||
SupportsThinking: modelInfo.SupportsThinking,
|
||||
ThinkingBudget: modelInfo.ThinkingBudget,
|
||||
Recommended: modelInfo.Recommended,
|
||||
MaxTokens: modelInfo.MaxTokens,
|
||||
MaxOutputTokens: modelInfo.MaxOutputTokens,
|
||||
SupportedMimeTypes: modelInfo.SupportedMimeTypes,
|
||||
}
|
||||
info.AntigravityQuotaDetails[modelName] = detail
|
||||
}
|
||||
|
||||
// 废弃模型转发规则
|
||||
if len(modelsResp.DeprecatedModelIDs) > 0 {
|
||||
info.ModelForwardingRules = make(map[string]string, len(modelsResp.DeprecatedModelIDs))
|
||||
for oldID, deprecated := range modelsResp.DeprecatedModelIDs {
|
||||
info.ModelForwardingRules[oldID] = deprecated.NewModelID
|
||||
}
|
||||
}
|
||||
|
||||
// 同时设置 FiveHour 用于兼容展示(取主要模型)
|
||||
@@ -108,3 +204,58 @@ func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Acco
|
||||
}
|
||||
return proxy.URL()
|
||||
}
|
||||
|
||||
// classifyForbiddenType 根据 403 响应体判断禁止类型
|
||||
func classifyForbiddenType(body string) string {
|
||||
lower := strings.ToLower(body)
|
||||
switch {
|
||||
case strings.Contains(lower, "validation_required") ||
|
||||
strings.Contains(lower, "verify your account") ||
|
||||
strings.Contains(lower, "validation_url"):
|
||||
return forbiddenTypeValidation
|
||||
case strings.Contains(lower, "terms of service") ||
|
||||
strings.Contains(lower, "violation"):
|
||||
return forbiddenTypeViolation
|
||||
default:
|
||||
return forbiddenTypeForbidden
|
||||
}
|
||||
}
|
||||
|
||||
// urlPattern 用于从 403 响应体中提取 URL(降级方案)
|
||||
var urlPattern = regexp.MustCompile(`https://[^\s"'\\]+`)
|
||||
|
||||
// extractValidationURL 从 403 响应 JSON 中提取验证/申诉链接
|
||||
func extractValidationURL(body string) string {
|
||||
// 1. 尝试结构化 JSON 提取: /error/details[*]/metadata/validation_url 或 appeal_url
|
||||
var parsed struct {
|
||||
Error struct {
|
||||
Details []struct {
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
} `json:"details"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if json.Unmarshal([]byte(body), &parsed) == nil {
|
||||
for _, detail := range parsed.Error.Details {
|
||||
if u := detail.Metadata["validation_url"]; u != "" {
|
||||
return u
|
||||
}
|
||||
if u := detail.Metadata["appeal_url"]; u != "" {
|
||||
return u
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 降级:正则匹配 URL
|
||||
lower := strings.ToLower(body)
|
||||
if !strings.Contains(lower, "validation") &&
|
||||
!strings.Contains(lower, "verify") &&
|
||||
!strings.Contains(lower, "appeal") {
|
||||
return ""
|
||||
}
|
||||
// 先解码常见转义再匹配
|
||||
normalized := strings.ReplaceAll(body, `\u0026`, "&")
|
||||
if m := urlPattern.FindString(normalized); m != "" {
|
||||
return m
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
497
backend/internal/service/antigravity_quota_fetcher_test.go
Normal file
497
backend/internal/service/antigravity_quota_fetcher_test.go
Normal file
@@ -0,0 +1,497 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// normalizeTier
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNormalizeTier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
expected string
|
||||
}{
|
||||
{name: "empty string", raw: "", expected: ""},
|
||||
{name: "free-tier", raw: "free-tier", expected: "FREE"},
|
||||
{name: "g1-pro-tier", raw: "g1-pro-tier", expected: "PRO"},
|
||||
{name: "g1-ultra-tier", raw: "g1-ultra-tier", expected: "ULTRA"},
|
||||
{name: "unknown-something", raw: "unknown-something", expected: "UNKNOWN"},
|
||||
{name: "Google AI Pro contains pro keyword", raw: "Google AI Pro", expected: "PRO"},
|
||||
{name: "case insensitive FREE", raw: "FREE-TIER", expected: "FREE"},
|
||||
{name: "case insensitive Ultra", raw: "Ultra Plan", expected: "ULTRA"},
|
||||
{name: "arbitrary unrecognized string", raw: "enterprise-custom", expected: "UNKNOWN"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeTier(tt.raw)
|
||||
require.Equal(t, tt.expected, got, "normalizeTier(%q)", tt.raw)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// buildUsageInfo
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func aqfBoolPtr(v bool) *bool { return &v }
|
||||
func aqfIntPtr(v int) *int { return &v }
|
||||
|
||||
func TestBuildUsageInfo_BasicModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.75,
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
DisplayName: "Claude Sonnet 4",
|
||||
SupportsImages: aqfBoolPtr(true),
|
||||
SupportsThinking: aqfBoolPtr(false),
|
||||
ThinkingBudget: aqfIntPtr(0),
|
||||
Recommended: aqfBoolPtr(true),
|
||||
MaxTokens: aqfIntPtr(200000),
|
||||
MaxOutputTokens: aqfIntPtr(16384),
|
||||
SupportedMimeTypes: map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
},
|
||||
},
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
ResetTime: "2026-03-08T15:00:00Z",
|
||||
},
|
||||
DisplayName: "Gemini 2.5 Pro",
|
||||
MaxTokens: aqfIntPtr(1000000),
|
||||
MaxOutputTokens: aqfIntPtr(65536),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO")
|
||||
|
||||
// 基本字段
|
||||
require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set")
|
||||
require.Equal(t, "PRO", info.SubscriptionTier)
|
||||
require.Equal(t, "g1-pro-tier", info.SubscriptionTierRaw)
|
||||
|
||||
// AntigravityQuota
|
||||
require.Len(t, info.AntigravityQuota, 2)
|
||||
|
||||
sonnetQuota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, sonnetQuota)
|
||||
require.Equal(t, 25, sonnetQuota.Utilization) // (1 - 0.75) * 100 = 25
|
||||
require.Equal(t, "2026-03-08T12:00:00Z", sonnetQuota.ResetTime)
|
||||
|
||||
geminiQuota := info.AntigravityQuota["gemini-2.5-pro"]
|
||||
require.NotNil(t, geminiQuota)
|
||||
require.Equal(t, 50, geminiQuota.Utilization) // (1 - 0.50) * 100 = 50
|
||||
require.Equal(t, "2026-03-08T15:00:00Z", geminiQuota.ResetTime)
|
||||
|
||||
// AntigravityQuotaDetails
|
||||
require.Len(t, info.AntigravityQuotaDetails, 2)
|
||||
|
||||
sonnetDetail := info.AntigravityQuotaDetails["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, sonnetDetail)
|
||||
require.Equal(t, "Claude Sonnet 4", sonnetDetail.DisplayName)
|
||||
require.Equal(t, aqfBoolPtr(true), sonnetDetail.SupportsImages)
|
||||
require.Equal(t, aqfBoolPtr(false), sonnetDetail.SupportsThinking)
|
||||
require.Equal(t, aqfIntPtr(0), sonnetDetail.ThinkingBudget)
|
||||
require.Equal(t, aqfBoolPtr(true), sonnetDetail.Recommended)
|
||||
require.Equal(t, aqfIntPtr(200000), sonnetDetail.MaxTokens)
|
||||
require.Equal(t, aqfIntPtr(16384), sonnetDetail.MaxOutputTokens)
|
||||
require.Equal(t, map[string]bool{"image/png": true, "image/jpeg": true}, sonnetDetail.SupportedMimeTypes)
|
||||
|
||||
geminiDetail := info.AntigravityQuotaDetails["gemini-2.5-pro"]
|
||||
require.NotNil(t, geminiDetail)
|
||||
require.Equal(t, "Gemini 2.5 Pro", geminiDetail.DisplayName)
|
||||
require.Nil(t, geminiDetail.SupportsImages)
|
||||
require.Nil(t, geminiDetail.SupportsThinking)
|
||||
require.Equal(t, aqfIntPtr(1000000), geminiDetail.MaxTokens)
|
||||
require.Equal(t, aqfIntPtr(65536), geminiDetail.MaxOutputTokens)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_DeprecatedModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 1.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
DeprecatedModelIDs: map[string]antigravity.DeprecatedModelInfo{
|
||||
"claude-3-sonnet-20240229": {NewModelID: "claude-sonnet-4-20250514"},
|
||||
"claude-3-haiku-20240307": {NewModelID: "claude-haiku-3.5-latest"},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.Len(t, info.ModelForwardingRules, 2)
|
||||
require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"])
|
||||
require.Equal(t, "claude-haiku-3.5-latest", info.ModelForwardingRules["claude-3-haiku-20240307"])
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"some-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{RemainingFraction: 0.9},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_EmptyModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.NotNil(t, info.AntigravityQuota)
|
||||
require.Empty(t, info.AntigravityQuota)
|
||||
require.NotNil(t, info.AntigravityQuotaDetails)
|
||||
require.Empty(t, info.AntigravityQuotaDetails)
|
||||
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"model-without-quota": {
|
||||
DisplayName: "No Quota Model",
|
||||
// QuotaInfo is nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped")
|
||||
require.Empty(t, info.AntigravityQuotaDetails, "models with nil QuotaInfo should be skipped from details too")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// priorityModels = ["claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"]
|
||||
// When the first priority model exists, it should be used for FiveHour
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.40,
|
||||
ResetTime: "2026-03-08T18:00:00Z",
|
||||
},
|
||||
},
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.80,
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists")
|
||||
// claude-sonnet-4-20250514 is first in priority list, so it should be used
|
||||
expectedUtilization := (1.0 - 0.80) * 100 // 20
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
require.NotNil(t, info.FiveHour.ResetsAt, "ResetsAt should be parsed from ResetTime")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// Only claude-sonnet-4 exists (second in priority list), not claude-sonnet-4-20250514
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.60,
|
||||
ResetTime: "2026-03-08T14:00:00Z",
|
||||
},
|
||||
},
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.30,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
expectedUtilization := (1.0 - 0.60) * 100 // 40
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// Only gemini-2.5-pro exists (third in priority list)
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.30,
|
||||
},
|
||||
},
|
||||
"other-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.90,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
expectedUtilization := (1.0 - 0.30) * 100 // 70
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// None of the priority models exist
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"some-other-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
ResetTime: "", // empty reset time
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty")
|
||||
require.Equal(t, 0, info.FiveHour.RemainingSeconds)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FullUtilization(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.0, // fully used
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, quota)
|
||||
require.Equal(t, 100, quota.Utilization)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_ZeroUtilization(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 1.0, // fully available
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "")
|
||||
|
||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, quota)
|
||||
require.Equal(t, 0, quota.Utilization)
|
||||
}
|
||||
|
||||
func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) {
|
||||
// 模拟 FetchQuota 遇到 403 时的行为:
|
||||
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true
|
||||
forbiddenErr := &antigravity.ForbiddenError{
|
||||
StatusCode: 403,
|
||||
Body: "Access denied",
|
||||
}
|
||||
|
||||
// 验证 ForbiddenError 满足 errors.As
|
||||
var target *antigravity.ForbiddenError
|
||||
require.True(t, errors.As(forbiddenErr, &target))
|
||||
require.Equal(t, 403, target.StatusCode)
|
||||
require.Equal(t, "Access denied", target.Body)
|
||||
require.Contains(t, forbiddenErr.Error(), "403")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// classifyForbiddenType
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestClassifyForbiddenType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "VALIDATION_REQUIRED keyword",
|
||||
body: `{"error":{"message":"VALIDATION_REQUIRED"}}`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "verify your account",
|
||||
body: `Please verify your account to continue`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "contains validation_url field",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://..."}}]}}`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "terms of service violation",
|
||||
body: `Your account has been suspended for Terms of Service violation`,
|
||||
expected: "violation",
|
||||
},
|
||||
{
|
||||
name: "violation keyword",
|
||||
body: `Account suspended due to policy violation`,
|
||||
expected: "violation",
|
||||
},
|
||||
{
|
||||
name: "generic 403",
|
||||
body: `Access denied`,
|
||||
expected: "forbidden",
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
expected: "forbidden",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := classifyForbiddenType(tt.body)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// extractValidationURL
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractValidationURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "structured validation_url",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://accounts.google.com/verify?token=abc"}}]}}`,
|
||||
expected: "https://accounts.google.com/verify?token=abc",
|
||||
},
|
||||
{
|
||||
name: "structured appeal_url",
|
||||
body: `{"error":{"details":[{"metadata":{"appeal_url":"https://support.google.com/appeal/123"}}]}}`,
|
||||
expected: "https://support.google.com/appeal/123",
|
||||
},
|
||||
{
|
||||
name: "validation_url takes priority over appeal_url",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://v.com","appeal_url":"https://a.com"}}]}}`,
|
||||
expected: "https://v.com",
|
||||
},
|
||||
{
|
||||
name: "fallback regex with verify keyword",
|
||||
body: `Please verify your account at https://accounts.google.com/verify`,
|
||||
expected: "https://accounts.google.com/verify",
|
||||
},
|
||||
{
|
||||
name: "no URL in generic forbidden",
|
||||
body: `Access denied`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "URL present but no validation keywords",
|
||||
body: `Error at https://example.com/something`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "unicode escaped ampersand",
|
||||
body: `validation required: https://accounts.google.com/verify?a=1\u0026b=2`,
|
||||
expected: "https://accounts.google.com/verify?a=1&b=2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractValidationURL(tt.body)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -22,8 +22,9 @@ const (
|
||||
)
|
||||
|
||||
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
|
||||
// A nil windowStart is treated as expired — no initialized window means any accumulated usage is stale.
|
||||
func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool {
|
||||
return windowStart != nil && time.Since(*windowStart) >= duration
|
||||
return windowStart == nil || time.Since(*windowStart) >= duration
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
|
||||
@@ -15,10 +15,10 @@ func TestIsWindowExpired(t *testing.T) {
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil window start",
|
||||
name: "nil window start (treated as expired)",
|
||||
start: nil,
|
||||
duration: RateLimitWindow5h,
|
||||
want: false,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "active window (started 1h ago, 5h window)",
|
||||
@@ -113,7 +113,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "nil window starts return raw usage",
|
||||
name: "nil window starts return 0 (stale usage reset)",
|
||||
key: APIKey{
|
||||
Usage5h: 5.0,
|
||||
Usage1d: 10.0,
|
||||
@@ -122,9 +122,9 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
|
||||
Window1dStart: nil,
|
||||
Window7dStart: nil,
|
||||
},
|
||||
want5h: 5.0,
|
||||
want1d: 10.0,
|
||||
want7d: 50.0,
|
||||
want5h: 0,
|
||||
want1d: 0,
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "mixed: 5h expired, 1d active, 7d nil",
|
||||
@@ -138,7 +138,7 @@ func TestAPIKey_EffectiveUsage(t *testing.T) {
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 10.0,
|
||||
want7d: 50.0,
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "zero usage with active windows",
|
||||
@@ -210,7 +210,7 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "nil window starts return raw usage",
|
||||
name: "nil window starts return 0 (stale usage reset)",
|
||||
data: APIKeyRateLimitData{
|
||||
Usage5h: 3.0,
|
||||
Usage1d: 8.0,
|
||||
@@ -219,9 +219,9 @@ func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
|
||||
Window1dStart: nil,
|
||||
Window7dStart: nil,
|
||||
},
|
||||
want5h: 3.0,
|
||||
want1d: 8.0,
|
||||
want7d: 40.0,
|
||||
want5h: 0,
|
||||
want1d: 0,
|
||||
want7d: 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1087,6 +1087,12 @@ type TokenPair struct {
|
||||
ExpiresIn int `json:"expires_in"` // Access Token有效期(秒)
|
||||
}
|
||||
|
||||
// TokenPairWithUser extends TokenPair with user role for backend mode checks
|
||||
type TokenPairWithUser struct {
|
||||
TokenPair
|
||||
UserRole string
|
||||
}
|
||||
|
||||
// GenerateTokenPair 生成Access Token和Refresh Token对
|
||||
// familyID: 可选的Token家族ID,用于Token轮转时保持家族关系
|
||||
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
|
||||
@@ -1168,7 +1174,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
|
||||
|
||||
// RefreshTokenPair 使用Refresh Token刷新Token对
|
||||
// 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效
|
||||
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
||||
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPairWithUser, error) {
|
||||
// 检查 refreshTokenCache 是否可用
|
||||
if s.refreshTokenCache == nil {
|
||||
return nil, ErrRefreshTokenInvalid
|
||||
@@ -1233,7 +1239,14 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
|
||||
}
|
||||
|
||||
// 生成新的Token对,保持同一个家族ID
|
||||
return s.GenerateTokenPair(ctx, user, data.FamilyID)
|
||||
pair, err := s.GenerateTokenPair(ctx, user, data.FamilyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenPairWithUser{
|
||||
TokenPair: *pair,
|
||||
UserRole: user.Role,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RevokeRefreshToken 撤销单个Refresh Token
|
||||
|
||||
770
backend/internal/service/backup_service.go
Normal file
770
backend/internal/service/backup_service.go
Normal file
@@ -0,0 +1,770 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/robfig/cron/v3"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
settingKeyBackupS3Config = "backup_s3_config"
|
||||
settingKeyBackupSchedule = "backup_schedule"
|
||||
settingKeyBackupRecords = "backup_records"
|
||||
|
||||
maxBackupRecords = 100
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBackupS3NotConfigured = infraerrors.BadRequest("BACKUP_S3_NOT_CONFIGURED", "backup S3 storage is not configured")
|
||||
ErrBackupNotFound = infraerrors.NotFound("BACKUP_NOT_FOUND", "backup record not found")
|
||||
ErrBackupInProgress = infraerrors.Conflict("BACKUP_IN_PROGRESS", "a backup is already in progress")
|
||||
ErrRestoreInProgress = infraerrors.Conflict("RESTORE_IN_PROGRESS", "a restore is already in progress")
|
||||
ErrBackupRecordsCorrupt = infraerrors.InternalServer("BACKUP_RECORDS_CORRUPT", "backup records data is corrupted")
|
||||
ErrBackupS3ConfigCorrupt = infraerrors.InternalServer("BACKUP_S3_CONFIG_CORRUPT", "backup S3 config data is corrupted")
|
||||
)
|
||||
|
||||
// ─── 接口定义 ───
|
||||
|
||||
// DBDumper abstracts database dump/restore operations
|
||||
type DBDumper interface {
|
||||
Dump(ctx context.Context) (io.ReadCloser, error)
|
||||
Restore(ctx context.Context, data io.Reader) error
|
||||
}
|
||||
|
||||
// BackupObjectStore abstracts object storage for backup files
|
||||
type BackupObjectStore interface {
|
||||
Upload(ctx context.Context, key string, body io.Reader, contentType string) (sizeBytes int64, err error)
|
||||
Download(ctx context.Context, key string) (io.ReadCloser, error)
|
||||
Delete(ctx context.Context, key string) error
|
||||
PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error)
|
||||
HeadBucket(ctx context.Context) error
|
||||
}
|
||||
|
||||
// BackupObjectStoreFactory creates an object store from S3 config
|
||||
type BackupObjectStoreFactory func(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error)
|
||||
|
||||
// ─── 数据模型 ───
|
||||
|
||||
// BackupS3Config S3 兼容存储配置(支持 Cloudflare R2)
|
||||
type BackupS3Config struct {
|
||||
Endpoint string `json:"endpoint"` // e.g. https://<account_id>.r2.cloudflarestorage.com
|
||||
Region string `json:"region"` // R2 用 "auto"
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key,omitempty"` //nolint:revive // field name follows AWS convention
|
||||
Prefix string `json:"prefix"` // S3 key 前缀,如 "backups/"
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
}
|
||||
|
||||
// IsConfigured 检查必要字段是否已配置
|
||||
func (c *BackupS3Config) IsConfigured() bool {
|
||||
return c.Bucket != "" && c.AccessKeyID != "" && c.SecretAccessKey != ""
|
||||
}
|
||||
|
||||
// BackupScheduleConfig 定时备份配置
|
||||
type BackupScheduleConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
CronExpr string `json:"cron_expr"` // cron 表达式,如 "0 2 * * *" 每天凌晨2点
|
||||
RetainDays int `json:"retain_days"` // 备份文件过期天数,默认14,0=不自动清理
|
||||
RetainCount int `json:"retain_count"` // 最多保留份数,0=不限制
|
||||
}
|
||||
|
||||
// BackupRecord 备份记录
|
||||
type BackupRecord struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"` // pending, running, completed, failed
|
||||
BackupType string `json:"backup_type"` // postgres
|
||||
FileName string `json:"file_name"`
|
||||
S3Key string `json:"s3_key"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
TriggeredBy string `json:"triggered_by"` // manual, scheduled
|
||||
ErrorMsg string `json:"error_message,omitempty"`
|
||||
StartedAt string `json:"started_at"`
|
||||
FinishedAt string `json:"finished_at,omitempty"`
|
||||
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
|
||||
}
|
||||
|
||||
// BackupService 数据库备份恢复服务
|
||||
type BackupService struct {
|
||||
settingRepo SettingRepository
|
||||
dbCfg *config.DatabaseConfig
|
||||
encryptor SecretEncryptor
|
||||
storeFactory BackupObjectStoreFactory
|
||||
dumper DBDumper
|
||||
|
||||
mu sync.Mutex
|
||||
store BackupObjectStore
|
||||
s3Cfg *BackupS3Config
|
||||
backingUp bool
|
||||
restoring bool
|
||||
|
||||
recordsMu sync.Mutex // 保护 records 的 load/save 操作
|
||||
|
||||
cronMu sync.Mutex
|
||||
cronSched *cron.Cron
|
||||
cronEntryID cron.EntryID
|
||||
}
|
||||
|
||||
func NewBackupService(
|
||||
settingRepo SettingRepository,
|
||||
cfg *config.Config,
|
||||
encryptor SecretEncryptor,
|
||||
storeFactory BackupObjectStoreFactory,
|
||||
dumper DBDumper,
|
||||
) *BackupService {
|
||||
return &BackupService{
|
||||
settingRepo: settingRepo,
|
||||
dbCfg: &cfg.Database,
|
||||
encryptor: encryptor,
|
||||
storeFactory: storeFactory,
|
||||
dumper: dumper,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动定时备份调度器
|
||||
func (s *BackupService) Start() {
|
||||
s.cronSched = cron.New()
|
||||
s.cronSched.Start()
|
||||
|
||||
// 加载已有的定时配置
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
schedule, err := s.GetSchedule(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 加载定时备份配置失败: %v", err)
|
||||
return
|
||||
}
|
||||
if schedule.Enabled && schedule.CronExpr != "" {
|
||||
if err := s.applyCronSchedule(schedule); err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 应用定时备份配置失败: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止定时备份
|
||||
func (s *BackupService) Stop() {
|
||||
s.cronMu.Lock()
|
||||
defer s.cronMu.Unlock()
|
||||
if s.cronSched != nil {
|
||||
s.cronSched.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// ─── S3 配置管理 ───
|
||||
|
||||
func (s *BackupService) GetS3Config(ctx context.Context) (*BackupS3Config, error) {
|
||||
cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg == nil {
|
||||
return &BackupS3Config{}, nil
|
||||
}
|
||||
// 脱敏返回
|
||||
cfg.SecretAccessKey = ""
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) (*BackupS3Config, error) {
|
||||
// 如果没提供 secret,保留原有值
|
||||
if cfg.SecretAccessKey == "" {
|
||||
old, _ := s.loadS3Config(ctx)
|
||||
if old != nil {
|
||||
cfg.SecretAccessKey = old.SecretAccessKey
|
||||
}
|
||||
} else {
|
||||
// 加密 SecretAccessKey
|
||||
encrypted, err := s.encryptor.Encrypt(cfg.SecretAccessKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encrypt secret: %w", err)
|
||||
}
|
||||
cfg.SecretAccessKey = encrypted
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal s3 config: %w", err)
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, settingKeyBackupS3Config, string(data)); err != nil {
|
||||
return nil, fmt.Errorf("save s3 config: %w", err)
|
||||
}
|
||||
|
||||
// 清除缓存的 S3 客户端
|
||||
s.mu.Lock()
|
||||
s.store = nil
|
||||
s.s3Cfg = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
cfg.SecretAccessKey = ""
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) TestS3Connection(ctx context.Context, cfg BackupS3Config) error {
|
||||
// 如果没提供 secret,用已保存的
|
||||
if cfg.SecretAccessKey == "" {
|
||||
old, _ := s.loadS3Config(ctx)
|
||||
if old != nil {
|
||||
cfg.SecretAccessKey = old.SecretAccessKey
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
|
||||
return fmt.Errorf("incomplete S3 config: bucket, access_key_id, secret_access_key are required")
|
||||
}
|
||||
|
||||
store, err := s.storeFactory(ctx, &cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return store.HeadBucket(ctx)
|
||||
}
|
||||
|
||||
// ─── 定时备份管理 ───
|
||||
|
||||
func (s *BackupService) GetSchedule(ctx context.Context) (*BackupScheduleConfig, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupSchedule)
|
||||
if err != nil || raw == "" {
|
||||
return &BackupScheduleConfig{}, nil
|
||||
}
|
||||
var cfg BackupScheduleConfig
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||
return &BackupScheduleConfig{}, nil
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) UpdateSchedule(ctx context.Context, cfg BackupScheduleConfig) (*BackupScheduleConfig, error) {
|
||||
if cfg.Enabled && cfg.CronExpr == "" {
|
||||
return nil, infraerrors.BadRequest("INVALID_CRON", "cron expression is required when schedule is enabled")
|
||||
}
|
||||
// 验证 cron 表达式
|
||||
if cfg.CronExpr != "" {
|
||||
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||
if _, err := parser.Parse(cfg.CronExpr); err != nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("invalid cron expression: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal schedule config: %w", err)
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, settingKeyBackupSchedule, string(data)); err != nil {
|
||||
return nil, fmt.Errorf("save schedule config: %w", err)
|
||||
}
|
||||
|
||||
// 应用或停止定时任务
|
||||
if cfg.Enabled {
|
||||
if err := s.applyCronSchedule(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
s.removeCronSchedule()
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) applyCronSchedule(cfg *BackupScheduleConfig) error {
|
||||
s.cronMu.Lock()
|
||||
defer s.cronMu.Unlock()
|
||||
|
||||
if s.cronSched == nil {
|
||||
return fmt.Errorf("cron scheduler not initialized")
|
||||
}
|
||||
|
||||
// 移除旧任务
|
||||
if s.cronEntryID != 0 {
|
||||
s.cronSched.Remove(s.cronEntryID)
|
||||
s.cronEntryID = 0
|
||||
}
|
||||
|
||||
entryID, err := s.cronSched.AddFunc(cfg.CronExpr, func() {
|
||||
s.runScheduledBackup()
|
||||
})
|
||||
if err != nil {
|
||||
return infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("failed to schedule: %v", err))
|
||||
}
|
||||
s.cronEntryID = entryID
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已启用: %s", cfg.CronExpr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupService) removeCronSchedule() {
|
||||
s.cronMu.Lock()
|
||||
defer s.cronMu.Unlock()
|
||||
if s.cronSched != nil && s.cronEntryID != 0 {
|
||||
s.cronSched.Remove(s.cronEntryID)
|
||||
s.cronEntryID = 0
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已停用")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupService) runScheduledBackup() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// 读取定时备份配置中的过期天数
|
||||
schedule, _ := s.GetSchedule(ctx)
|
||||
expireDays := 14 // 默认14天过期
|
||||
if schedule != nil && schedule.RetainDays > 0 {
|
||||
expireDays = schedule.RetainDays
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
|
||||
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
|
||||
|
||||
// 清理过期备份(复用已加载的 schedule)
|
||||
if schedule == nil {
|
||||
return
|
||||
}
|
||||
if err := s.cleanupOldBackups(ctx, schedule); err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 清理过期备份失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── 备份/恢复核心 ───
|
||||
|
||||
// CreateBackup 创建全量数据库备份并上传到 S3(流式处理)
|
||||
// expireDays: 备份过期天数,0=永不过期,默认14天
|
||||
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
|
||||
s.mu.Lock()
|
||||
if s.backingUp {
|
||||
s.mu.Unlock()
|
||||
return nil, ErrBackupInProgress
|
||||
}
|
||||
s.backingUp = true
|
||||
s.mu.Unlock()
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.backingUp = false
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s3Cfg == nil || !s3Cfg.IsConfigured() {
|
||||
return nil, ErrBackupS3NotConfigured
|
||||
}
|
||||
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init object store: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
backupID := uuid.New().String()[:8]
|
||||
fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405"))
|
||||
s3Key := s.buildS3Key(s3Cfg, fileName)
|
||||
|
||||
var expiresAt string
|
||||
if expireDays > 0 {
|
||||
expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339)
|
||||
}
|
||||
|
||||
record := &BackupRecord{
|
||||
ID: backupID,
|
||||
Status: "running",
|
||||
BackupType: "postgres",
|
||||
FileName: fileName,
|
||||
S3Key: s3Key,
|
||||
TriggeredBy: triggeredBy,
|
||||
StartedAt: now.Format(time.RFC3339),
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
|
||||
// 流式执行: pg_dump -> gzip -> S3 upload
|
||||
dumpReader, err := s.dumper.Dump(ctx)
|
||||
if err != nil {
|
||||
record.Status = "failed"
|
||||
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
_ = s.saveRecord(ctx, record)
|
||||
return record, fmt.Errorf("pg_dump: %w", err)
|
||||
}
|
||||
|
||||
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
|
||||
pr, pw := io.Pipe()
|
||||
var gzipErr error
|
||||
go func() {
|
||||
gzWriter := gzip.NewWriter(pw)
|
||||
_, gzipErr = io.Copy(gzWriter, dumpReader)
|
||||
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
|
||||
gzipErr = closeErr
|
||||
}
|
||||
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
|
||||
gzipErr = closeErr
|
||||
}
|
||||
if gzipErr != nil {
|
||||
_ = pw.CloseWithError(gzipErr)
|
||||
} else {
|
||||
_ = pw.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
contentType := "application/gzip"
|
||||
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
|
||||
if err != nil {
|
||||
record.Status = "failed"
|
||||
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
|
||||
if gzipErr != nil {
|
||||
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
|
||||
}
|
||||
record.ErrorMsg = errMsg
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
_ = s.saveRecord(ctx, record)
|
||||
return record, fmt.Errorf("backup upload: %w", err)
|
||||
}
|
||||
|
||||
record.SizeBytes = sizeBytes
|
||||
record.Status = "completed"
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
if err := s.saveRecord(ctx, record); err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
|
||||
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
|
||||
s.mu.Lock()
|
||||
if s.restoring {
|
||||
s.mu.Unlock()
|
||||
return ErrRestoreInProgress
|
||||
}
|
||||
s.restoring = true
|
||||
s.mu.Unlock()
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.restoring = false
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
record, err := s.GetBackupRecord(ctx, backupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if record.Status != "completed" {
|
||||
return infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup")
|
||||
}
|
||||
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init object store: %w", err)
|
||||
}
|
||||
|
||||
// 从 S3 流式下载
|
||||
body, err := objectStore.Download(ctx, record.S3Key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("S3 download failed: %w", err)
|
||||
}
|
||||
defer func() { _ = body.Close() }()
|
||||
|
||||
// 流式解压 gzip -> psql(不将全部数据加载到内存)
|
||||
gzReader, err := gzip.NewReader(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gzip reader: %w", err)
|
||||
}
|
||||
defer func() { _ = gzReader.Close() }()
|
||||
|
||||
// 流式恢复
|
||||
if err := s.dumper.Restore(ctx, gzReader); err != nil {
|
||||
return fmt.Errorf("pg restore: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ─── 备份记录管理 ───
|
||||
|
||||
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
|
||||
records, err := s.loadRecords(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 倒序返回(最新在前)
|
||||
sort.Slice(records, func(i, j int) bool {
|
||||
return records[i].StartedAt > records[j].StartedAt
|
||||
})
|
||||
return records, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackupRecord(ctx context.Context, backupID string) (*BackupRecord, error) {
|
||||
records, err := s.loadRecords(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range records {
|
||||
if records[i].ID == backupID {
|
||||
return &records[i], nil
|
||||
}
|
||||
}
|
||||
return nil, ErrBackupNotFound
|
||||
}
|
||||
|
||||
func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, err := s.loadRecordsLocked(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var found *BackupRecord
|
||||
var remaining []BackupRecord
|
||||
for i := range records {
|
||||
if records[i].ID == backupID {
|
||||
found = &records[i]
|
||||
} else {
|
||||
remaining = append(remaining, records[i])
|
||||
}
|
||||
}
|
||||
if found == nil {
|
||||
return ErrBackupNotFound
|
||||
}
|
||||
|
||||
// 从 S3 删除
|
||||
if found.S3Key != "" && found.Status == "completed" {
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err == nil && s3Cfg != nil && s3Cfg.IsConfigured() {
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err == nil {
|
||||
_ = objectStore.Delete(ctx, found.S3Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.saveRecordsLocked(ctx, remaining)
|
||||
}
|
||||
|
||||
// GetBackupDownloadURL 获取备份文件预签名下载 URL
|
||||
func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID string) (string, error) {
|
||||
record, err := s.GetBackupRecord(ctx, backupID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if record.Status != "completed" {
|
||||
return "", infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "backup is not completed")
|
||||
}
|
||||
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
url, err := objectStore.PresignURL(ctx, record.S3Key, 1*time.Hour)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// ─── 内部方法 ───
|
||||
|
||||
func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config)
|
||||
if err != nil || raw == "" {
|
||||
return nil, nil //nolint:nilnil // no config is a valid state
|
||||
}
|
||||
var cfg BackupS3Config
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||
return nil, ErrBackupS3ConfigCorrupt
|
||||
}
|
||||
// 解密 SecretAccessKey
|
||||
if cfg.SecretAccessKey != "" {
|
||||
decrypted, err := s.encryptor.Decrypt(cfg.SecretAccessKey)
|
||||
if err != nil {
|
||||
// 兼容未加密的旧数据:如果解密失败,保持原值
|
||||
logger.LegacyPrintf("service.backup", "[Backup] S3 SecretAccessKey 解密失败(可能是旧的未加密数据): %v", err)
|
||||
} else {
|
||||
cfg.SecretAccessKey = decrypted
|
||||
}
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.store != nil && s.s3Cfg != nil {
|
||||
return s.store, nil
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
return nil, ErrBackupS3NotConfigured
|
||||
}
|
||||
|
||||
store, err := s.storeFactory(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.store = store
|
||||
s.s3Cfg = cfg
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string {
|
||||
prefix := strings.TrimRight(cfg.Prefix, "/")
|
||||
if prefix == "" {
|
||||
prefix = "backups"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s", prefix, time.Now().Format("2006/01/02"), fileName)
|
||||
}
|
||||
|
||||
// loadRecords 加载备份记录,区分"无数据"和"数据损坏"
|
||||
func (s *BackupService) loadRecords(ctx context.Context) ([]BackupRecord, error) {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
return s.loadRecordsLocked(ctx)
|
||||
}
|
||||
|
||||
// loadRecordsLocked 在已持有 recordsMu 锁的情况下加载记录
|
||||
func (s *BackupService) loadRecordsLocked(ctx context.Context) ([]BackupRecord, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupRecords)
|
||||
if err != nil || raw == "" {
|
||||
return nil, nil //nolint:nilnil // no records is a valid state
|
||||
}
|
||||
var records []BackupRecord
|
||||
if err := json.Unmarshal([]byte(raw), &records); err != nil {
|
||||
return nil, ErrBackupRecordsCorrupt
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// saveRecordsLocked 在已持有 recordsMu 锁的情况下保存记录
|
||||
func (s *BackupService) saveRecordsLocked(ctx context.Context, records []BackupRecord) error {
|
||||
data, err := json.Marshal(records)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.settingRepo.Set(ctx, settingKeyBackupRecords, string(data))
|
||||
}
|
||||
|
||||
// saveRecord 保存单条记录(带互斥锁保护)
|
||||
func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) error {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, _ := s.loadRecordsLocked(ctx)
|
||||
|
||||
// 更新已有记录或追加
|
||||
found := false
|
||||
for i := range records {
|
||||
if records[i].ID == record.ID {
|
||||
records[i] = *record
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
records = append(records, *record)
|
||||
}
|
||||
|
||||
// 限制记录数量
|
||||
if len(records) > maxBackupRecords {
|
||||
records = records[len(records)-maxBackupRecords:]
|
||||
}
|
||||
|
||||
return s.saveRecordsLocked(ctx, records)
|
||||
}
|
||||
|
||||
func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupScheduleConfig) error {
|
||||
if schedule == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, err := s.loadRecordsLocked(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 按时间倒序
|
||||
sort.Slice(records, func(i, j int) bool {
|
||||
return records[i].StartedAt > records[j].StartedAt
|
||||
})
|
||||
|
||||
var toDelete []BackupRecord
|
||||
var toKeep []BackupRecord
|
||||
|
||||
for i, r := range records {
|
||||
shouldDelete := false
|
||||
|
||||
// 按保留份数清理
|
||||
if schedule.RetainCount > 0 && i >= schedule.RetainCount {
|
||||
shouldDelete = true
|
||||
}
|
||||
|
||||
// 按保留天数清理
|
||||
if schedule.RetainDays > 0 && r.StartedAt != "" {
|
||||
startedAt, err := time.Parse(time.RFC3339, r.StartedAt)
|
||||
if err == nil && time.Since(startedAt) > time.Duration(schedule.RetainDays)*24*time.Hour {
|
||||
shouldDelete = true
|
||||
}
|
||||
}
|
||||
|
||||
if shouldDelete && r.Status == "completed" {
|
||||
toDelete = append(toDelete, r)
|
||||
} else {
|
||||
toKeep = append(toKeep, r)
|
||||
}
|
||||
}
|
||||
|
||||
// 删除 S3 上的文件
|
||||
for _, r := range toDelete {
|
||||
if r.S3Key != "" {
|
||||
_ = s.deleteS3Object(ctx, r.S3Key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toDelete) > 0 {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 自动清理了 %d 个过期备份", len(toDelete))
|
||||
return s.saveRecordsLocked(ctx, toKeep)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteS3Object(ctx context.Context, key string) error {
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil || s3Cfg == nil {
|
||||
return nil
|
||||
}
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return objectStore.Delete(ctx, key)
|
||||
}
|
||||
528
backend/internal/service/backup_service_test.go
Normal file
528
backend/internal/service/backup_service_test.go
Normal file
@@ -0,0 +1,528 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// ─── Mocks ───
|
||||
|
||||
type mockSettingRepo struct {
|
||||
mu sync.Mutex
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
func newMockSettingRepo() *mockSettingRepo {
|
||||
return &mockSettingRepo{data: make(map[string]string)}
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) Get(_ context.Context, key string) (*Setting, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
v, ok := m.data[key]
|
||||
if !ok {
|
||||
return nil, ErrSettingNotFound
|
||||
}
|
||||
return &Setting{Key: key, Value: v}, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
v, ok := m.data[key]
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) Set(_ context.Context, key, value string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.data[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
result := make(map[string]string)
|
||||
for _, k := range keys {
|
||||
if v, ok := m.data[k]; ok {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for k, v := range settings {
|
||||
m.data[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
result := make(map[string]string, len(m.data))
|
||||
for k, v := range m.data {
|
||||
result[k] = v
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) Delete(_ context.Context, key string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.data, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// plainEncryptor 仅做 base64-like 包装,用于测试
|
||||
type plainEncryptor struct{}
|
||||
|
||||
func (e *plainEncryptor) Encrypt(plaintext string) (string, error) {
|
||||
return "ENC:" + plaintext, nil
|
||||
}
|
||||
|
||||
func (e *plainEncryptor) Decrypt(ciphertext string) (string, error) {
|
||||
if strings.HasPrefix(ciphertext, "ENC:") {
|
||||
return strings.TrimPrefix(ciphertext, "ENC:"), nil
|
||||
}
|
||||
return ciphertext, fmt.Errorf("not encrypted")
|
||||
}
|
||||
|
||||
type mockDumper struct {
|
||||
dumpData []byte
|
||||
dumpErr error
|
||||
restored []byte
|
||||
restErr error
|
||||
}
|
||||
|
||||
func (m *mockDumper) Dump(_ context.Context) (io.ReadCloser, error) {
|
||||
if m.dumpErr != nil {
|
||||
return nil, m.dumpErr
|
||||
}
|
||||
return io.NopCloser(bytes.NewReader(m.dumpData)), nil
|
||||
}
|
||||
|
||||
func (m *mockDumper) Restore(_ context.Context, data io.Reader) error {
|
||||
if m.restErr != nil {
|
||||
return m.restErr
|
||||
}
|
||||
d, err := io.ReadAll(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.restored = d
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockObjectStore struct {
|
||||
objects map[string][]byte
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMockObjectStore() *mockObjectStore {
|
||||
return &mockObjectStore{objects: make(map[string][]byte)}
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) Upload(_ context.Context, key string, body io.Reader, _ string) (int64, error) {
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.objects[key] = data
|
||||
m.mu.Unlock()
|
||||
return int64(len(data)), nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) Download(_ context.Context, key string) (io.ReadCloser, error) {
|
||||
m.mu.Lock()
|
||||
data, ok := m.objects[key]
|
||||
m.mu.Unlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("not found: %s", key)
|
||||
}
|
||||
return io.NopCloser(bytes.NewReader(data)), nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) Delete(_ context.Context, key string) error {
|
||||
m.mu.Lock()
|
||||
delete(m.objects, key)
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) PresignURL(_ context.Context, key string, _ time.Duration) (string, error) {
|
||||
return "https://presigned.example.com/" + key, nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) HeadBucket(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "test",
|
||||
DBName: "testdb",
|
||||
},
|
||||
}
|
||||
factory := func(_ context.Context, _ *BackupS3Config) (BackupObjectStore, error) {
|
||||
return store, nil
|
||||
}
|
||||
return NewBackupService(repo, cfg, &plainEncryptor{}, factory, dumper)
|
||||
}
|
||||
|
||||
func seedS3Config(t *testing.T, repo *mockSettingRepo) {
|
||||
t.Helper()
|
||||
cfg := BackupS3Config{
|
||||
Bucket: "test-bucket",
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "ENC:secret123",
|
||||
Prefix: "backups",
|
||||
}
|
||||
data, _ := json.Marshal(cfg)
|
||||
require.NoError(t, repo.Set(context.Background(), settingKeyBackupS3Config, string(data)))
|
||||
}
|
||||
|
||||
// ─── Tests ───
|
||||
|
||||
func TestBackupService_S3ConfigEncryption(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
// 保存配置 -> SecretAccessKey 应被加密
|
||||
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||
Bucket: "my-bucket",
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "my-secret",
|
||||
Prefix: "backups",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 直接读取数据库中存储的值,应该是加密后的
|
||||
raw, _ := repo.GetValue(context.Background(), settingKeyBackupS3Config)
|
||||
var stored BackupS3Config
|
||||
require.NoError(t, json.Unmarshal([]byte(raw), &stored))
|
||||
require.Equal(t, "ENC:my-secret", stored.SecretAccessKey)
|
||||
|
||||
// 通过 GetS3Config 获取应该脱敏
|
||||
cfg, err := svc.GetS3Config(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, cfg.SecretAccessKey)
|
||||
require.Equal(t, "my-bucket", cfg.Bucket)
|
||||
|
||||
// loadS3Config 内部应解密
|
||||
internal, err := svc.loadS3Config(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "my-secret", internal.SecretAccessKey)
|
||||
}
|
||||
|
||||
func TestBackupService_S3ConfigKeepExistingSecret(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
// 先保存一个有 secret 的配置
|
||||
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||
Bucket: "my-bucket",
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "original-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 再更新时不提供 secret,应保留原值
|
||||
_, err = svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||
Bucket: "my-bucket",
|
||||
AccessKeyID: "AKID-NEW",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
internal, err := svc.loadS3Config(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "original-secret", internal.SecretAccessKey)
|
||||
require.Equal(t, "AKID-NEW", internal.AccessKeyID)
|
||||
}
|
||||
|
||||
func TestBackupService_SaveRecordConcurrency(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
n := 20
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
record := &BackupRecord{
|
||||
ID: fmt.Sprintf("rec-%d", idx),
|
||||
Status: "completed",
|
||||
StartedAt: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
_ = svc.saveRecord(context.Background(), record)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
records, err := svc.loadRecords(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, records, n)
|
||||
}
|
||||
|
||||
func TestBackupService_LoadRecords_Empty(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
records, err := svc.loadRecords(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, records) // 无数据时返回 nil
|
||||
}
|
||||
|
||||
func TestBackupService_LoadRecords_Corrupted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
_ = repo.Set(context.Background(), settingKeyBackupRecords, "not valid json{{{")
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
records, err := svc.loadRecords(context.Background())
|
||||
require.Error(t, err) // 损坏数据应返回错误
|
||||
require.Nil(t, records)
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_Streaming(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
|
||||
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "completed", record.Status)
|
||||
require.Greater(t, record.SizeBytes, int64(0))
|
||||
require.NotEmpty(t, record.S3Key)
|
||||
|
||||
// 验证 S3 上确实有文件
|
||||
store.mu.Lock()
|
||||
require.Len(t, store.objects, 1)
|
||||
store.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_DumpFailure(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumper := &mockDumper{dumpErr: fmt.Errorf("pg_dump failed")}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "failed", record.Status)
|
||||
require.Contains(t, record.ErrorMsg, "pg_dump")
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_NoS3Config(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.ErrorIs(t, err, ErrBackupS3NotConfigured)
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
// 使用一个慢速 dumper 来模拟正在进行的备份
|
||||
dumper := &mockDumper{dumpData: []byte("data")}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
// 手动设置 backingUp 标志
|
||||
svc.mu.Lock()
|
||||
svc.backingUp = true
|
||||
svc.mu.Unlock()
|
||||
|
||||
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.ErrorIs(t, err, ErrBackupInProgress)
|
||||
}
|
||||
|
||||
func TestBackupService_RestoreBackup_Streaming(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
|
||||
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
// 先创建一个备份
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 恢复
|
||||
err = svc.RestoreBackup(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证 psql 收到的数据是否与原始 dump 内容一致
|
||||
require.Equal(t, dumpContent, string(dumper.restored))
|
||||
}
|
||||
|
||||
func TestBackupService_RestoreBackup_NotCompleted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
// 手动插入一条 failed 记录
|
||||
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||
ID: "fail-1",
|
||||
Status: "failed",
|
||||
})
|
||||
|
||||
err := svc.RestoreBackup(context.Background(), "fail-1")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_DeleteBackup(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumpContent := "data"
|
||||
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
|
||||
// S3 中应有文件
|
||||
store.mu.Lock()
|
||||
require.Len(t, store.objects, 1)
|
||||
store.mu.Unlock()
|
||||
|
||||
// 删除
|
||||
err = svc.DeleteBackup(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// S3 中文件应被删除
|
||||
store.mu.Lock()
|
||||
require.Len(t, store.objects, 0)
|
||||
store.mu.Unlock()
|
||||
|
||||
// 记录应不存在
|
||||
_, err = svc.GetBackupRecord(context.Background(), record.ID)
|
||||
require.ErrorIs(t, err, ErrBackupNotFound)
|
||||
}
|
||||
|
||||
func TestBackupService_GetDownloadURL(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumper := &mockDumper{dumpData: []byte("data")}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
|
||||
url, err := svc.GetBackupDownloadURL(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, url, "https://presigned.example.com/")
|
||||
}
|
||||
|
||||
func TestBackupService_ListBackups_Sorted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
now := time.Now()
|
||||
for i := 0; i < 3; i++ {
|
||||
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||
ID: fmt.Sprintf("rec-%d", i),
|
||||
Status: "completed",
|
||||
StartedAt: now.Add(time.Duration(i) * time.Hour).Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
records, err := svc.ListBackups(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, records, 3)
|
||||
// 最新在前
|
||||
require.Equal(t, "rec-2", records[0].ID)
|
||||
require.Equal(t, "rec-0", records[2].ID)
|
||||
}
|
||||
|
||||
func TestBackupService_TestS3Connection(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, store)
|
||||
|
||||
err := svc.TestS3Connection(context.Background(), BackupS3Config{
|
||||
Bucket: "test",
|
||||
AccessKeyID: "ak",
|
||||
SecretAccessKey: "sk",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_TestS3Connection_Incomplete(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
err := svc.TestS3Connection(context.Background(), BackupS3Config{
|
||||
Bucket: "test",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "incomplete")
|
||||
}
|
||||
|
||||
func TestBackupService_Schedule_CronValidation(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
svc.cronSched = nil // 未初始化 cron
|
||||
|
||||
// 启用但 cron 为空
|
||||
_, err := svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
|
||||
Enabled: true,
|
||||
CronExpr: "",
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
// 无效的 cron 表达式
|
||||
_, err = svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
|
||||
Enabled: true,
|
||||
CronExpr: "invalid",
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_LoadS3Config_Corrupted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
_ = repo.Set(context.Background(), settingKeyBackupS3Config, "not json!!!!")
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
cfg, err := svc.loadS3Config(context.Background())
|
||||
require.Error(t, err)
|
||||
require.Nil(t, cfg)
|
||||
}
|
||||
@@ -29,12 +29,11 @@ const (
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock)
|
||||
AccountTypeBedrockAPIKey = domain.AccountTypeBedrockAPIKey // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock)
|
||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@@ -81,6 +80,7 @@ const (
|
||||
SettingKeyRegistrationEmailSuffixWhitelist = "registration_email_suffix_whitelist" // 注册邮箱后缀白名单(JSON 数组)
|
||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||
SettingKeyFrontendURL = "frontend_url" // 前端基础URL,用于生成邮件中的重置密码链接
|
||||
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
||||
|
||||
// 邮件服务设置
|
||||
@@ -221,6 +221,9 @@ const (
|
||||
|
||||
// SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false:未分组 Key 返回 403)
|
||||
SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling"
|
||||
|
||||
// SettingKeyBackendModeEnabled Backend 模式:禁用用户注册和自助服务,仅管理员可登录
|
||||
SettingKeyBackendModeEnabled = "backend_mode_enabled"
|
||||
)
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
|
||||
@@ -110,7 +110,9 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_401_second_hit_upgrades_to_none",
|
||||
// Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制),
|
||||
// second hit 仍然返回 TempUnscheduled。
|
||||
name: "temp_unschedulable_401_second_hit_antigravity_stays_temp",
|
||||
account: &Account{
|
||||
ID: 15,
|
||||
Type: AccountTypeOAuth,
|
||||
@@ -129,7 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyNone,
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_body_miss_returns_none",
|
||||
|
||||
@@ -440,7 +440,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountForModelWithPlatform_AllExcluded 测试所有账户被排除
|
||||
@@ -1073,7 +1073,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_NoAccounts(t *testing.
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformAnthropic)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||
}
|
||||
|
||||
func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
|
||||
@@ -1734,7 +1734,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||
})
|
||||
|
||||
t.Run("混合调度-不支持模型返回错误", func(t *testing.T) {
|
||||
@@ -2290,7 +2290,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
require.ErrorIs(t, err, ErrNoAvailableAccounts)
|
||||
})
|
||||
|
||||
t.Run("过滤不可调度账号-限流账号被跳过", func(t *testing.T) {
|
||||
|
||||
@@ -369,3 +369,54 @@ func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 0, usageRepo.calls)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_ReasoningEffortPersisted(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
effort := "max"
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "effort_test",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
},
|
||||
Model: "claude-opus-4-6",
|
||||
Duration: time.Second,
|
||||
ReasoningEffort: &effort,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1},
|
||||
User: &User{ID: 1},
|
||||
Account: &Account{ID: 1},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.NotNil(t, usageRepo.lastLog.ReasoningEffort)
|
||||
require.Equal(t, "max", *usageRepo.lastLog.ReasoningEffort)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_ReasoningEffortNil(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "no_effort_test",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1},
|
||||
User: &User{ID: 1},
|
||||
Account: &Account{ID: 1},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Nil(t, usageRepo.lastLog.ReasoningEffort)
|
||||
}
|
||||
|
||||
@@ -60,6 +60,7 @@ type ParsedRequest struct {
|
||||
Messages []any // messages 数组
|
||||
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
||||
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
|
||||
OutputEffort string // output_config.effort(Claude API 的推理强度控制)
|
||||
MaxTokens int // max_tokens 值(用于探测请求拦截)
|
||||
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
|
||||
|
||||
@@ -116,6 +117,9 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
||||
parsed.ThinkingEnabled = true
|
||||
}
|
||||
|
||||
// output_config.effort: Claude API 的推理强度控制参数
|
||||
parsed.OutputEffort = strings.TrimSpace(gjson.Get(jsonStr, "output_config.effort").String())
|
||||
|
||||
// max_tokens: 仅接受整数值
|
||||
maxTokensResult := gjson.Get(jsonStr, "max_tokens")
|
||||
if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number {
|
||||
@@ -747,6 +751,21 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
|
||||
return newBody
|
||||
}
|
||||
|
||||
// NormalizeClaudeOutputEffort normalizes Claude's output_config.effort value.
|
||||
// Returns nil for empty or unrecognized values.
|
||||
func NormalizeClaudeOutputEffort(raw string) *string {
|
||||
value := strings.ToLower(strings.TrimSpace(raw))
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
switch value {
|
||||
case "low", "medium", "high", "max":
|
||||
return &value
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// =========================
|
||||
// Thinking Budget Rectifier
|
||||
// =========================
|
||||
|
||||
@@ -972,6 +972,76 @@ func BenchmarkParseGatewayRequest_Old_Large(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_OutputEffort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantEffort string
|
||||
}{
|
||||
{
|
||||
name: "output_config.effort present",
|
||||
body: `{"model":"claude-opus-4-6","output_config":{"effort":"medium"},"messages":[]}`,
|
||||
wantEffort: "medium",
|
||||
},
|
||||
{
|
||||
name: "output_config.effort max",
|
||||
body: `{"model":"claude-opus-4-6","output_config":{"effort":"max"},"messages":[]}`,
|
||||
wantEffort: "max",
|
||||
},
|
||||
{
|
||||
name: "output_config without effort",
|
||||
body: `{"model":"claude-opus-4-6","output_config":{},"messages":[]}`,
|
||||
wantEffort: "",
|
||||
},
|
||||
{
|
||||
name: "no output_config",
|
||||
body: `{"model":"claude-opus-4-6","messages":[]}`,
|
||||
wantEffort: "",
|
||||
},
|
||||
{
|
||||
name: "effort with whitespace trimmed",
|
||||
body: `{"model":"claude-opus-4-6","output_config":{"effort":" high "},"messages":[]}`,
|
||||
wantEffort: "high",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parsed, err := ParseGatewayRequest([]byte(tt.body), "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantEffort, parsed.OutputEffort)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeOutputEffort(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want *string
|
||||
}{
|
||||
{"low", strPtr("low")},
|
||||
{"medium", strPtr("medium")},
|
||||
{"high", strPtr("high")},
|
||||
{"max", strPtr("max")},
|
||||
{"LOW", strPtr("low")},
|
||||
{"Max", strPtr("max")},
|
||||
{" medium ", strPtr("medium")},
|
||||
{"", nil},
|
||||
{"unknown", nil},
|
||||
{"xhigh", nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := NormalizeClaudeOutputEffort(tt.input)
|
||||
if tt.want == nil {
|
||||
require.Nil(t, got)
|
||||
} else {
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, *tt.want, *got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseGatewayRequest_New_Large(b *testing.B) {
|
||||
data := buildLargeJSON()
|
||||
b.SetBytes(int64(len(data)))
|
||||
|
||||
@@ -346,6 +346,9 @@ var systemBlockFilterPrefixes = []string{
|
||||
"x-anthropic-billing-header",
|
||||
}
|
||||
|
||||
// ErrNoAvailableAccounts 表示没有可用的账号
|
||||
var ErrNoAvailableAccounts = errors.New("no available accounts")
|
||||
|
||||
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
||||
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
|
||||
|
||||
@@ -492,6 +495,7 @@ type ForwardResult struct {
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int // 首字时间(流式请求)
|
||||
ClientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||
ReasoningEffort *string
|
||||
|
||||
// 图片生成计费字段(图片生成模型使用)
|
||||
ImageCount int // 生成的图片数量
|
||||
@@ -1204,7 +1208,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return nil, err
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
@@ -1552,7 +1556,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
|
||||
@@ -1641,7 +1645,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||
@@ -2173,10 +2177,10 @@ func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []
|
||||
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
|
||||
}
|
||||
|
||||
// isAccountSchedulableForQuota 检查 API Key 账号是否在配额限制内
|
||||
// 仅适用于配置了 quota_limit 的 apikey 类型账号
|
||||
// isAccountSchedulableForQuota 检查账号是否在配额限制内
|
||||
// 适用于配置了 quota_limit 的 apikey 和 bedrock 类型账号
|
||||
func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool {
|
||||
if account.Type != AccountTypeAPIKey {
|
||||
if !account.IsAPIKeyOrBedrock() {
|
||||
return true
|
||||
}
|
||||
return !account.IsQuotaExceeded()
|
||||
@@ -2851,9 +2855,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if selected == nil {
|
||||
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false)
|
||||
if requestedModel != "" {
|
||||
return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats))
|
||||
return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats))
|
||||
}
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
@@ -3089,9 +3093,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if selected == nil {
|
||||
stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, nativePlatform, accounts, excludedIDs, true)
|
||||
if requestedModel != "" {
|
||||
return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats))
|
||||
return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats))
|
||||
}
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
@@ -3532,9 +3536,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
||||
}
|
||||
return apiKey, "apikey", nil
|
||||
case AccountTypeBedrock:
|
||||
return "", "bedrock", nil // Bedrock 使用 SigV4 签名,不需要 token
|
||||
case AccountTypeBedrockAPIKey:
|
||||
return "", "bedrock-apikey", nil // Bedrock API Key 使用 Bearer Token,由 forwardBedrock 处理
|
||||
return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理
|
||||
default:
|
||||
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
|
||||
}
|
||||
@@ -5186,7 +5188,7 @@ func (s *GatewayService) forwardBedrock(
|
||||
if account.IsBedrockAPIKey() {
|
||||
bedrockAPIKey = account.GetCredential("api_key")
|
||||
if bedrockAPIKey == "" {
|
||||
return nil, fmt.Errorf("api_key not found in bedrock-apikey credentials")
|
||||
return nil, fmt.Errorf("api_key not found in bedrock credentials")
|
||||
}
|
||||
} else {
|
||||
signer, err = NewBedrockSignerFromAccount(account)
|
||||
@@ -5375,8 +5377,9 @@ func (s *GatewayService) handleBedrockUpstreamErrors(
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
@@ -5398,8 +5401,9 @@ func (s *GatewayService) handleBedrockUpstreamErrors(
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5808,9 +5812,10 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri
|
||||
return betaPolicyResult{}
|
||||
}
|
||||
isOAuth := account.IsOAuth()
|
||||
isBedrock := account.IsBedrock()
|
||||
var result betaPolicyResult
|
||||
for _, rule := range settings.Rules {
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth) {
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
||||
continue
|
||||
}
|
||||
switch rule.Action {
|
||||
@@ -5870,14 +5875,16 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont
|
||||
}
|
||||
|
||||
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
|
||||
func betaPolicyScopeMatches(scope string, isOAuth bool) bool {
|
||||
func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool {
|
||||
switch scope {
|
||||
case BetaPolicyScopeAll:
|
||||
return true
|
||||
case BetaPolicyScopeOAuth:
|
||||
return isOAuth
|
||||
case BetaPolicyScopeAPIKey:
|
||||
return !isOAuth
|
||||
return !isOAuth && !isBedrock
|
||||
case BetaPolicyScopeBedrock:
|
||||
return isBedrock
|
||||
default:
|
||||
return true // unknown scope → match all (fail-open)
|
||||
}
|
||||
@@ -5959,12 +5966,13 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke
|
||||
return nil
|
||||
}
|
||||
isOAuth := account.IsOAuth()
|
||||
isBedrock := account.IsBedrock()
|
||||
tokenSet := buildBetaTokenSet(tokens)
|
||||
for _, rule := range settings.Rules {
|
||||
if rule.Action != BetaPolicyActionBlock {
|
||||
continue
|
||||
}
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth) {
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
||||
continue
|
||||
}
|
||||
if _, present := tokenSet[rule.BetaToken]; present {
|
||||
@@ -6125,6 +6133,29 @@ func extractUpstreamErrorMessage(body []byte) string {
|
||||
return gjson.GetBytes(body, "message").String()
|
||||
}
|
||||
|
||||
func extractUpstreamErrorCode(body []byte) string {
|
||||
if code := strings.TrimSpace(gjson.GetBytes(body, "error.code").String()); code != "" {
|
||||
return code
|
||||
}
|
||||
|
||||
inner := strings.TrimSpace(gjson.GetBytes(body, "error.message").String())
|
||||
if !strings.HasPrefix(inner, "{") {
|
||||
return ""
|
||||
}
|
||||
|
||||
if code := strings.TrimSpace(gjson.Get(inner, "error.code").String()); code != "" {
|
||||
return code
|
||||
}
|
||||
|
||||
if lastBrace := strings.LastIndex(inner, "}"); lastBrace >= 0 {
|
||||
if code := strings.TrimSpace(gjson.Get(inner[:lastBrace+1], "error.code").String()); code != "" {
|
||||
return code
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func isCountTokensUnsupported404(statusCode int, body []byte) bool {
|
||||
if statusCode != http.StatusNotFound {
|
||||
return false
|
||||
@@ -7176,7 +7207,7 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
||||
}
|
||||
|
||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||
if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
|
||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
|
||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||
@@ -7264,7 +7295,7 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
|
||||
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
|
||||
}
|
||||
if p.Cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||
if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
|
||||
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
|
||||
}
|
||||
|
||||
@@ -7496,6 +7527,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
@@ -7672,6 +7704,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
|
||||
@@ -3235,7 +3235,7 @@ func cleanToolSchema(schema any) any {
|
||||
for key, value := range v {
|
||||
// 跳过不支持的字段
|
||||
if key == "$schema" || key == "$id" || key == "$ref" ||
|
||||
key == "additionalProperties" || key == "minLength" ||
|
||||
key == "additionalProperties" || key == "patternProperties" || key == "minLength" ||
|
||||
key == "maxLength" || key == "minItems" || key == "maxItems" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -725,7 +725,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
}
|
||||
|
||||
return nil, len(candidates), topK, loadSkew, errors.New("no available accounts")
|
||||
return nil, len(candidates), topK, loadSkew, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
|
||||
|
||||
@@ -226,6 +226,41 @@ func TestOpenAIGatewayServiceRecordUsage_UsesUserSpecificGroupRate(t *testing.T)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_IncludesEndpointMetadata(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
rateRepo := &openAIUserGroupRateRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_endpoint_metadata",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 2,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 1002,
|
||||
Group: &Group{RateMultiplier: 1},
|
||||
},
|
||||
User: &User{ID: 2002},
|
||||
Account: &Account{ID: 3002},
|
||||
InboundEndpoint: " /v1/chat/completions ",
|
||||
UpstreamEndpoint: " /v1/responses ",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.NotNil(t, usageRepo.lastLog.InboundEndpoint)
|
||||
require.Equal(t, "/v1/chat/completions", *usageRepo.lastLog.InboundEndpoint)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamEndpoint)
|
||||
require.Equal(t, "/v1/responses", *usageRepo.lastLog.UpstreamEndpoint)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateOnResolverError(t *testing.T) {
|
||||
groupID := int64(12)
|
||||
groupRate := 1.6
|
||||
|
||||
@@ -480,6 +480,7 @@ func classifyOpenAIWSReconnectReason(err error) (string, bool) {
|
||||
"upgrade_required",
|
||||
"ws_unsupported",
|
||||
"auth_failed",
|
||||
"invalid_encrypted_content",
|
||||
"previous_response_not_found":
|
||||
return reason, false
|
||||
}
|
||||
@@ -530,6 +531,14 @@ func resolveOpenAIWSFallbackErrorResponse(err error) (statusCode int, errType st
|
||||
}
|
||||
|
||||
switch reason {
|
||||
case "invalid_encrypted_content":
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusBadRequest
|
||||
}
|
||||
errType = "invalid_request_error"
|
||||
if upstreamMessage == "" {
|
||||
upstreamMessage = "encrypted content could not be verified"
|
||||
}
|
||||
case "previous_response_not_found":
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusBadRequest
|
||||
@@ -1303,7 +1312,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
return nil, err
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
isExcluded := func(accountID int64) bool {
|
||||
@@ -1373,7 +1382,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
|
||||
@@ -1480,7 +1489,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("no available accounts")
|
||||
return nil, ErrNoAvailableAccounts
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
|
||||
@@ -1924,6 +1933,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
var wsErr error
|
||||
wsLastFailureReason := ""
|
||||
wsPrevResponseRecoveryTried := false
|
||||
wsInvalidEncryptedContentRecoveryTried := false
|
||||
recoverPrevResponseNotFound := func(attempt int) bool {
|
||||
if wsPrevResponseRecoveryTried {
|
||||
return false
|
||||
@@ -1956,6 +1966,37 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
)
|
||||
return true
|
||||
}
|
||||
recoverInvalidEncryptedContent := func(attempt int) bool {
|
||||
if wsInvalidEncryptedContentRecoveryTried {
|
||||
return false
|
||||
}
|
||||
removedReasoningItems := trimOpenAIEncryptedReasoningItems(wsReqBody)
|
||||
if !removedReasoningItems {
|
||||
logOpenAIWSModeInfo(
|
||||
"reconnect_invalid_encrypted_content_recovery_skip account_id=%d attempt=%d reason=missing_encrypted_reasoning_items",
|
||||
account.ID,
|
||||
attempt,
|
||||
)
|
||||
return false
|
||||
}
|
||||
previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id")
|
||||
hasFunctionCallOutput := HasFunctionCallOutput(wsReqBody)
|
||||
if previousResponseID != "" && !hasFunctionCallOutput {
|
||||
delete(wsReqBody, "previous_response_id")
|
||||
}
|
||||
wsInvalidEncryptedContentRecoveryTried = true
|
||||
logOpenAIWSModeInfo(
|
||||
"reconnect_invalid_encrypted_content_recovery account_id=%d attempt=%d action=drop_encrypted_reasoning_items retry=1 previous_response_id_present=%v previous_response_id=%s previous_response_id_kind=%s has_function_call_output=%v dropped_previous_response_id=%v",
|
||||
account.ID,
|
||||
attempt,
|
||||
previousResponseID != "",
|
||||
truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
|
||||
normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)),
|
||||
hasFunctionCallOutput,
|
||||
previousResponseID != "" && !hasFunctionCallOutput,
|
||||
)
|
||||
return true
|
||||
}
|
||||
retryBudget := s.openAIWSRetryTotalBudget()
|
||||
retryStartedAt := time.Now()
|
||||
wsRetryLoop:
|
||||
@@ -1992,6 +2033,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
if reason == "previous_response_not_found" && recoverPrevResponseNotFound(attempt) {
|
||||
continue
|
||||
}
|
||||
if reason == "invalid_encrypted_content" && recoverInvalidEncryptedContent(attempt) {
|
||||
continue
|
||||
}
|
||||
if retryable && attempt < maxAttempts {
|
||||
backoff := s.openAIWSRetryBackoff(attempt)
|
||||
if retryBudget > 0 && time.Since(retryStartedAt)+backoff > retryBudget {
|
||||
@@ -2075,126 +2119,143 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
return nil, wsErr
|
||||
}
|
||||
|
||||
// Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpInvalidEncryptedContentRetryTried := false
|
||||
for {
|
||||
// Build upstream request
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// Send request
|
||||
upstreamStart := time.Now()
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
||||
if err != nil {
|
||||
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
// Send request
|
||||
upstreamStart := time.Now()
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
||||
if err != nil {
|
||||
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||
// Handle error response
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamCode := extractUpstreamErrorCode(respBody)
|
||||
if !httpInvalidEncryptedContentRetryTried && resp.StatusCode == http.StatusBadRequest && upstreamCode == "invalid_encrypted_content" {
|
||||
if trimOpenAIEncryptedReasoningItems(reqBody) {
|
||||
body, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialize invalid_encrypted_content retry body: %w", err)
|
||||
}
|
||||
setOpsUpstreamRequestBody(c, body)
|
||||
httpInvalidEncryptedContentRetryTried = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Retrying non-WSv2 request once after invalid_encrypted_content (account: %s)", account.Name)
|
||||
continue
|
||||
}
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Skip non-WSv2 invalid_encrypted_content retry because encrypted reasoning items are missing (account: %s)", account.Name)
|
||||
}
|
||||
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, body)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// Handle normal response
|
||||
var usage *OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, body)
|
||||
}
|
||||
|
||||
// Handle normal response
|
||||
var usage *OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||
if account.Type == AccountTypeOAuth {
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||
}
|
||||
}
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
if usage == nil {
|
||||
usage = &OpenAIUsage{}
|
||||
}
|
||||
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
|
||||
serviceTier := extractOpenAIServiceTier(reqBody)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
ServiceTier: serviceTier,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: reqStream,
|
||||
OpenAIWSMode: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||
if account.Type == AccountTypeOAuth {
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||
}
|
||||
}
|
||||
|
||||
if usage == nil {
|
||||
usage = &OpenAIUsage{}
|
||||
}
|
||||
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
|
||||
serviceTier := extractOpenAIServiceTier(reqBody)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
ServiceTier: serviceTier,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: reqStream,
|
||||
OpenAIWSMode: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
@@ -3756,6 +3817,109 @@ func buildOpenAIResponsesURL(base string) string {
|
||||
return normalized + "/v1/responses"
|
||||
}
|
||||
|
||||
func trimOpenAIEncryptedReasoningItems(reqBody map[string]any) bool {
|
||||
if len(reqBody) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
inputValue, has := reqBody["input"]
|
||||
if !has {
|
||||
return false
|
||||
}
|
||||
|
||||
switch input := inputValue.(type) {
|
||||
case []any:
|
||||
filtered := input[:0]
|
||||
changed := false
|
||||
for _, item := range input {
|
||||
nextItem, itemChanged, keep := sanitizeEncryptedReasoningInputItem(item)
|
||||
if itemChanged {
|
||||
changed = true
|
||||
}
|
||||
if !keep {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, nextItem)
|
||||
}
|
||||
if !changed {
|
||||
return false
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
delete(reqBody, "input")
|
||||
return true
|
||||
}
|
||||
reqBody["input"] = filtered
|
||||
return true
|
||||
case []map[string]any:
|
||||
filtered := input[:0]
|
||||
changed := false
|
||||
for _, item := range input {
|
||||
nextItem, itemChanged, keep := sanitizeEncryptedReasoningInputItem(item)
|
||||
if itemChanged {
|
||||
changed = true
|
||||
}
|
||||
if !keep {
|
||||
continue
|
||||
}
|
||||
nextMap, ok := nextItem.(map[string]any)
|
||||
if !ok {
|
||||
filtered = append(filtered, item)
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, nextMap)
|
||||
}
|
||||
if !changed {
|
||||
return false
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
delete(reqBody, "input")
|
||||
return true
|
||||
}
|
||||
reqBody["input"] = filtered
|
||||
return true
|
||||
case map[string]any:
|
||||
nextItem, changed, keep := sanitizeEncryptedReasoningInputItem(input)
|
||||
if !changed {
|
||||
return false
|
||||
}
|
||||
if !keep {
|
||||
delete(reqBody, "input")
|
||||
return true
|
||||
}
|
||||
nextMap, ok := nextItem.(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
reqBody["input"] = nextMap
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeEncryptedReasoningInputItem(item any) (next any, changed bool, keep bool) {
|
||||
inputItem, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
return item, false, true
|
||||
}
|
||||
|
||||
itemType, _ := inputItem["type"].(string)
|
||||
if strings.TrimSpace(itemType) != "reasoning" {
|
||||
return item, false, true
|
||||
}
|
||||
|
||||
_, hasEncryptedContent := inputItem["encrypted_content"]
|
||||
if !hasEncryptedContent {
|
||||
return item, false, true
|
||||
}
|
||||
|
||||
delete(inputItem, "encrypted_content")
|
||||
if len(inputItem) == 1 {
|
||||
return nil, true, false
|
||||
}
|
||||
return inputItem, true, true
|
||||
}
|
||||
|
||||
func IsOpenAIResponsesCompactPathForTest(c *gin.Context) bool {
|
||||
return isOpenAIResponsesCompactPath(c)
|
||||
}
|
||||
@@ -3864,6 +4028,8 @@ type OpenAIRecordUsageInput struct {
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
InboundEndpoint string
|
||||
UpstreamEndpoint string
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string
|
||||
@@ -3942,6 +4108,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
Model: billingModel,
|
||||
ServiceTier: result.ServiceTier,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
@@ -3961,7 +4129,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 添加 UserAgent
|
||||
if input.UserAgent != "" {
|
||||
usageLog.UserAgent = &input.UserAgent
|
||||
@@ -4504,3 +4671,11 @@ func normalizeOpenAIReasoningEffort(raw string) string {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func optionalTrimmedStringPtr(raw string) *string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
return &trimmed
|
||||
}
|
||||
|
||||
@@ -3922,6 +3922,8 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri
|
||||
return "ws_unsupported", true
|
||||
case "websocket_connection_limit_reached":
|
||||
return "ws_connection_limit_reached", true
|
||||
case "invalid_encrypted_content":
|
||||
return "invalid_encrypted_content", true
|
||||
case "previous_response_not_found":
|
||||
return "previous_response_not_found", true
|
||||
}
|
||||
@@ -3940,6 +3942,10 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri
|
||||
if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") {
|
||||
return "ws_connection_limit_reached", true
|
||||
}
|
||||
if strings.Contains(msg, "invalid_encrypted_content") ||
|
||||
(strings.Contains(msg, "encrypted content") && strings.Contains(msg, "could not be verified")) {
|
||||
return "invalid_encrypted_content", true
|
||||
}
|
||||
if strings.Contains(msg, "previous_response_not_found") ||
|
||||
(strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) {
|
||||
return "previous_response_not_found", true
|
||||
@@ -3964,6 +3970,7 @@ func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
|
||||
case strings.Contains(errType, "invalid_request"),
|
||||
strings.Contains(code, "invalid_request"),
|
||||
strings.Contains(code, "bad_request"),
|
||||
code == "invalid_encrypted_content",
|
||||
code == "previous_response_not_found":
|
||||
return http.StatusBadRequest
|
||||
case strings.Contains(errType, "authentication"),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
@@ -19,6 +20,47 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type httpUpstreamSequenceRecorder struct {
|
||||
mu sync.Mutex
|
||||
bodies [][]byte
|
||||
reqs []*http.Request
|
||||
|
||||
responses []*http.Response
|
||||
errs []error
|
||||
callCount int
|
||||
}
|
||||
|
||||
func (u *httpUpstreamSequenceRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
idx := u.callCount
|
||||
u.callCount++
|
||||
u.reqs = append(u.reqs, req)
|
||||
if req != nil && req.Body != nil {
|
||||
b, _ := io.ReadAll(req.Body)
|
||||
u.bodies = append(u.bodies, b)
|
||||
_ = req.Body.Close()
|
||||
req.Body = io.NopCloser(bytes.NewReader(b))
|
||||
} else {
|
||||
u.bodies = append(u.bodies, nil)
|
||||
}
|
||||
if idx < len(u.errs) && u.errs[idx] != nil {
|
||||
return nil, u.errs[idx]
|
||||
}
|
||||
if idx < len(u.responses) {
|
||||
return u.responses[idx], nil
|
||||
}
|
||||
if len(u.responses) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return u.responses[len(u.responses)-1], nil
|
||||
}
|
||||
|
||||
func (u *httpUpstreamSequenceRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -143,6 +185,176 @@ func TestOpenAIGatewayService_Forward_HTTPIngressStaysHTTPWhenWSEnabled(t *testi
|
||||
require.Equal(t, "client_protocol_http", reason)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_HTTPIngressRetriesInvalidEncryptedContentOnce(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer wsFallbackServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "custom-client/1.0")
|
||||
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
|
||||
|
||||
upstream := &httpUpstreamSequenceRecorder{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`{"error":{"code":"invalid_encrypted_content","type":"invalid_request_error","message":"The encrypted content could not be verified."}}`,
|
||||
)),
|
||||
},
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`{"id":"resp_http_retry_ok","usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
|
||||
)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: upstream,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsFallbackServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_retry","input":[{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me"}]},{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发")
|
||||
require.Equal(t, 2, upstream.callCount, "命中 invalid_encrypted_content 后应只在 HTTP 路径重试一次")
|
||||
require.Len(t, upstream.bodies, 2)
|
||||
|
||||
firstBody := upstream.bodies[0]
|
||||
secondBody := upstream.bodies[1]
|
||||
require.False(t, gjson.GetBytes(firstBody, "previous_response_id").Exists(), "HTTP 首次请求仍应沿用原逻辑移除 previous_response_id")
|
||||
require.True(t, gjson.GetBytes(firstBody, "input.0.encrypted_content").Exists(), "首次请求不应做发送前预清理")
|
||||
require.Equal(t, "keep me", gjson.GetBytes(firstBody, "input.0.summary.0.text").String())
|
||||
|
||||
require.False(t, gjson.GetBytes(secondBody, "previous_response_id").Exists(), "HTTP 精确重试不应重新带回 previous_response_id")
|
||||
require.False(t, gjson.GetBytes(secondBody, "input.0.encrypted_content").Exists(), "精确重试应移除 reasoning.encrypted_content")
|
||||
require.Equal(t, "keep me", gjson.GetBytes(secondBody, "input.0.summary.0.text").String(), "精确重试应保留有效 reasoning summary")
|
||||
require.Equal(t, "input_text", gjson.GetBytes(secondBody, "input.1.type").String(), "非 reasoning input 应保持原样")
|
||||
|
||||
decision, _ := c.Get("openai_ws_transport_decision")
|
||||
reason, _ := c.Get("openai_ws_transport_reason")
|
||||
require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision)
|
||||
require.Equal(t, "client_protocol_http", reason)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_HTTPIngressRetriesWrappedInvalidEncryptedContentOnce(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer wsFallbackServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "custom-client/1.0")
|
||||
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
|
||||
|
||||
upstream := &httpUpstreamSequenceRecorder{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`{"error":{"code":null,"message":"{\"error\":{\"message\":\"The encrypted content could not be verified.\",\"type\":\"invalid_request_error\",\"param\":null,\"code\":\"invalid_encrypted_content\"}}(traceid: fb7ad1dbc7699c18f8a02f258f1af5ab)","param":null,"type":"invalid_request_error"}}`,
|
||||
)),
|
||||
},
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"x-request-id": []string{"req_http_retry_wrapped_ok"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`{"id":"resp_http_retry_wrapped_ok","usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
|
||||
)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: upstream,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Name: "openai-apikey-wrapped",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsFallbackServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_retry_wrapped","input":[{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me too"}]},{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发")
|
||||
require.Equal(t, 2, upstream.callCount, "wrapped invalid_encrypted_content 也应只在 HTTP 路径重试一次")
|
||||
require.Len(t, upstream.bodies, 2)
|
||||
|
||||
firstBody := upstream.bodies[0]
|
||||
secondBody := upstream.bodies[1]
|
||||
require.True(t, gjson.GetBytes(firstBody, "input.0.encrypted_content").Exists(), "首次请求不应做发送前预清理")
|
||||
require.False(t, gjson.GetBytes(secondBody, "input.0.encrypted_content").Exists(), "wrapped exact retry 应移除 reasoning.encrypted_content")
|
||||
require.Equal(t, "keep me too", gjson.GetBytes(secondBody, "input.0.summary.0.text").String(), "wrapped exact retry 应保留有效 reasoning summary")
|
||||
|
||||
decision, _ := c.Get("openai_ws_transport_decision")
|
||||
reason, _ := c.Get("openai_ws_transport_reason")
|
||||
require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision)
|
||||
require.Equal(t, "client_protocol_http", reason)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -1218,3 +1430,460 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOn
|
||||
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应包含 previous_response_id")
|
||||
require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentRecoversOnce(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var wsAttempts atomic.Int32
|
||||
var wsRequestPayloads [][]byte
|
||||
var wsRequestMu sync.Mutex
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempt := wsAttempts.Add(1)
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
var req map[string]any
|
||||
if err := conn.ReadJSON(&req); err != nil {
|
||||
t.Errorf("read ws request failed: %v", err)
|
||||
return
|
||||
}
|
||||
reqRaw, _ := json.Marshal(req)
|
||||
wsRequestMu.Lock()
|
||||
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
|
||||
wsRequestMu.Unlock()
|
||||
if attempt == 1 {
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"code": "invalid_encrypted_content",
|
||||
"type": "invalid_request_error",
|
||||
"message": "The encrypted content could not be verified.",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "response.completed",
|
||||
"response": map[string]any{
|
||||
"id": "resp_ws_invalid_encrypted_content_recover_ok",
|
||||
"model": "gpt-5.3-codex",
|
||||
"usage": map[string]any{
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
"input_tokens_details": map[string]any{
|
||||
"cached_tokens": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "custom-client/1.0")
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: upstream,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 95,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_encrypted","input":[{"type":"reasoning","encrypted_content":"gAAA"},{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "resp_ws_invalid_encrypted_content_recover_ok", result.RequestID)
|
||||
require.Nil(t, upstream.lastReq, "invalid_encrypted_content 不应回退 HTTP")
|
||||
require.Equal(t, int32(2), wsAttempts.Load(), "invalid_encrypted_content 应触发一次清洗后重试")
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "resp_ws_invalid_encrypted_content_recover_ok", gjson.Get(rec.Body.String(), "id").String())
|
||||
|
||||
wsRequestMu.Lock()
|
||||
requests := append([][]byte(nil), wsRequestPayloads...)
|
||||
wsRequestMu.Unlock()
|
||||
require.Len(t, requests, 2)
|
||||
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id")
|
||||
require.True(t, gjson.GetBytes(requests[0], `input.0.encrypted_content`).Exists(), "首轮请求应保留 encrypted reasoning")
|
||||
require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
|
||||
require.False(t, gjson.GetBytes(requests[1], `input.0.encrypted_content`).Exists(), "恢复重试应移除 encrypted reasoning item")
|
||||
require.Equal(t, "input_text", gjson.GetBytes(requests[1], `input.0.type`).String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentSkipsRecoveryWithoutReasoningItem(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var wsAttempts atomic.Int32
|
||||
var wsRequestPayloads [][]byte
|
||||
var wsRequestMu sync.Mutex
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
wsAttempts.Add(1)
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
var req map[string]any
|
||||
if err := conn.ReadJSON(&req); err != nil {
|
||||
t.Errorf("read ws request failed: %v", err)
|
||||
return
|
||||
}
|
||||
reqRaw, _ := json.Marshal(req)
|
||||
wsRequestMu.Lock()
|
||||
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
|
||||
wsRequestMu.Unlock()
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"code": "invalid_encrypted_content",
|
||||
"type": "invalid_request_error",
|
||||
"message": "The encrypted content could not be verified.",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "custom-client/1.0")
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: upstream,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 96,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_encrypted","input":[{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Nil(t, upstream.lastReq, "invalid_encrypted_content 不应回退 HTTP")
|
||||
require.Equal(t, int32(1), wsAttempts.Load(), "缺少 reasoning encrypted item 时应跳过自动恢复重试")
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
require.Contains(t, strings.ToLower(rec.Body.String()), "encrypted content")
|
||||
|
||||
wsRequestMu.Lock()
|
||||
requests := append([][]byte(nil), wsRequestPayloads...)
|
||||
wsRequestMu.Unlock()
|
||||
require.Len(t, requests, 1)
|
||||
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists())
|
||||
require.False(t, gjson.GetBytes(requests[0], `input.0.encrypted_content`).Exists())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentRecoversSingleObjectInputAndKeepsSummary(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var wsAttempts atomic.Int32
|
||||
var wsRequestPayloads [][]byte
|
||||
var wsRequestMu sync.Mutex
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempt := wsAttempts.Add(1)
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
var req map[string]any
|
||||
if err := conn.ReadJSON(&req); err != nil {
|
||||
t.Errorf("read ws request failed: %v", err)
|
||||
return
|
||||
}
|
||||
reqRaw, _ := json.Marshal(req)
|
||||
wsRequestMu.Lock()
|
||||
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
|
||||
wsRequestMu.Unlock()
|
||||
if attempt == 1 {
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"code": "invalid_encrypted_content",
|
||||
"type": "invalid_request_error",
|
||||
"message": "The encrypted content could not be verified.",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "response.completed",
|
||||
"response": map[string]any{
|
||||
"id": "resp_ws_invalid_encrypted_content_object_ok",
|
||||
"model": "gpt-5.3-codex",
|
||||
"usage": map[string]any{
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
"input_tokens_details": map[string]any{
|
||||
"cached_tokens": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "custom-client/1.0")
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: upstream,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 97,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_encrypted","input":{"type":"reasoning","encrypted_content":"gAAA","summary":[{"type":"summary_text","text":"keep me"}]}}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "resp_ws_invalid_encrypted_content_object_ok", result.RequestID)
|
||||
require.Nil(t, upstream.lastReq, "invalid_encrypted_content 单对象 input 不应回退 HTTP")
|
||||
require.Equal(t, int32(2), wsAttempts.Load(), "单对象 reasoning input 也应触发一次清洗后重试")
|
||||
|
||||
wsRequestMu.Lock()
|
||||
requests := append([][]byte(nil), wsRequestPayloads...)
|
||||
wsRequestMu.Unlock()
|
||||
require.Len(t, requests, 2)
|
||||
require.True(t, gjson.GetBytes(requests[0], `input.encrypted_content`).Exists(), "首轮单对象应保留 encrypted_content")
|
||||
require.True(t, gjson.GetBytes(requests[1], `input.summary.0.text`).Exists(), "恢复重试应保留 reasoning summary")
|
||||
require.False(t, gjson.GetBytes(requests[1], `input.encrypted_content`).Exists(), "恢复重试只应移除 encrypted_content")
|
||||
require.Equal(t, "reasoning", gjson.GetBytes(requests[1], `input.type`).String())
|
||||
require.False(t, gjson.GetBytes(requests[1], `previous_response_id`).Exists(), "恢复重试应移除 previous_response_id")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2InvalidEncryptedContentKeepsPreviousResponseIDForFunctionCallOutput(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var wsAttempts atomic.Int32
|
||||
var wsRequestPayloads [][]byte
|
||||
var wsRequestMu sync.Mutex
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempt := wsAttempts.Add(1)
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
var req map[string]any
|
||||
if err := conn.ReadJSON(&req); err != nil {
|
||||
t.Errorf("read ws request failed: %v", err)
|
||||
return
|
||||
}
|
||||
reqRaw, _ := json.Marshal(req)
|
||||
wsRequestMu.Lock()
|
||||
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
|
||||
wsRequestMu.Unlock()
|
||||
if attempt == 1 {
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"code": "invalid_encrypted_content",
|
||||
"type": "invalid_request_error",
|
||||
"message": "The encrypted content could not be verified.",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "response.completed",
|
||||
"response": map[string]any{
|
||||
"id": "resp_ws_invalid_encrypted_content_function_call_output_ok",
|
||||
"model": "gpt-5.3-codex",
|
||||
"usage": map[string]any{
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
"input_tokens_details": map[string]any{
|
||||
"cached_tokens": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "custom-client/1.0")
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_reasoning","usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: upstream,
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 98,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": wsServer.URL,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_function_call","input":[{"type":"reasoning","encrypted_content":"gAAA"},{"type":"function_call_output","call_id":"call_123","output":"ok"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "resp_ws_invalid_encrypted_content_function_call_output_ok", result.RequestID)
|
||||
require.Nil(t, upstream.lastReq, "function_call_output + invalid_encrypted_content 不应回退 HTTP")
|
||||
require.Equal(t, int32(2), wsAttempts.Load(), "应只做一次保锚点的清洗后重试")
|
||||
|
||||
wsRequestMu.Lock()
|
||||
requests := append([][]byte(nil), wsRequestPayloads...)
|
||||
wsRequestMu.Unlock()
|
||||
require.Len(t, requests, 2)
|
||||
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id")
|
||||
require.True(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "function_call_output 恢复重试不应移除 previous_response_id")
|
||||
require.False(t, gjson.GetBytes(requests[1], `input.0.encrypted_content`).Exists(), "恢复重试应移除 reasoning encrypted_content")
|
||||
require.Equal(t, "function_call_output", gjson.GetBytes(requests[1], `input.0.type`).String(), "清洗后应保留 function_call_output 作为首个输入项")
|
||||
require.Equal(t, "call_123", gjson.GetBytes(requests[1], `input.0.call_id`).String())
|
||||
require.Equal(t, "ok", gjson.GetBytes(requests[1], `input.0.output`).String())
|
||||
require.Equal(t, "resp_prev_function_call", gjson.GetBytes(requests[1], "previous_response_id").String())
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ const (
|
||||
opsAggDailyInterval = 1 * time.Hour
|
||||
|
||||
// Keep in sync with ops retention target (vNext default 30d).
|
||||
opsAggBackfillWindow = 30 * 24 * time.Hour
|
||||
opsAggBackfillWindow = 1 * time.Hour
|
||||
|
||||
// Recompute overlap to absorb late-arriving rows near boundaries.
|
||||
opsAggHourlyOverlap = 2 * time.Hour
|
||||
@@ -36,7 +36,7 @@ const (
|
||||
// that may still receive late inserts.
|
||||
opsAggSafeDelay = 5 * time.Minute
|
||||
|
||||
opsAggMaxQueryTimeout = 3 * time.Second
|
||||
opsAggMaxQueryTimeout = 5 * time.Second
|
||||
opsAggHourlyTimeout = 5 * time.Minute
|
||||
opsAggDailyTimeout = 2 * time.Minute
|
||||
|
||||
|
||||
@@ -467,7 +467,7 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: selErr.Error()}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "no available accounts"}
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: ErrNoAvailableAccounts.Error()}
|
||||
}
|
||||
|
||||
account := selection.Account
|
||||
|
||||
@@ -368,11 +368,14 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
|
||||
Aggregation: OpsAggregationSettings{
|
||||
AggregationEnabled: false,
|
||||
},
|
||||
IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略
|
||||
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
|
||||
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
|
||||
AutoRefreshEnabled: false,
|
||||
AutoRefreshIntervalSec: 30,
|
||||
IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略
|
||||
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
|
||||
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
|
||||
IgnoreInsufficientBalanceErrors: false, // 默认不忽略,余额不足可能需要关注
|
||||
DisplayOpenAITokenStats: false,
|
||||
DisplayAlertEvents: true,
|
||||
AutoRefreshEnabled: false,
|
||||
AutoRefreshIntervalSec: 30,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -438,7 +441,7 @@ func (s *OpsService) GetOpsAdvancedSettings(ctx context.Context) (*OpsAdvancedSe
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &OpsAdvancedSettings{}
|
||||
cfg := defaultOpsAdvancedSettings()
|
||||
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
|
||||
return defaultCfg, nil
|
||||
}
|
||||
|
||||
97
backend/internal/service/ops_settings_advanced_test.go
Normal file
97
backend/internal/service/ops_settings_advanced_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetOpsAdvancedSettings_DefaultHidesOpenAITokenStats(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
svc := &OpsService{settingRepo: repo}
|
||||
|
||||
cfg, err := svc.GetOpsAdvancedSettings(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetOpsAdvancedSettings() error = %v", err)
|
||||
}
|
||||
if cfg.DisplayOpenAITokenStats {
|
||||
t.Fatalf("DisplayOpenAITokenStats = true, want false by default")
|
||||
}
|
||||
if !cfg.DisplayAlertEvents {
|
||||
t.Fatalf("DisplayAlertEvents = false, want true by default")
|
||||
}
|
||||
if repo.setCalls != 1 {
|
||||
t.Fatalf("expected defaults to be persisted once, got %d", repo.setCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateOpsAdvancedSettings_PersistsOpenAITokenStatsVisibility(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
svc := &OpsService{settingRepo: repo}
|
||||
|
||||
cfg := defaultOpsAdvancedSettings()
|
||||
cfg.DisplayOpenAITokenStats = true
|
||||
cfg.DisplayAlertEvents = false
|
||||
|
||||
updated, err := svc.UpdateOpsAdvancedSettings(context.Background(), cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateOpsAdvancedSettings() error = %v", err)
|
||||
}
|
||||
if !updated.DisplayOpenAITokenStats {
|
||||
t.Fatalf("DisplayOpenAITokenStats = false, want true")
|
||||
}
|
||||
if updated.DisplayAlertEvents {
|
||||
t.Fatalf("DisplayAlertEvents = true, want false")
|
||||
}
|
||||
|
||||
reloaded, err := svc.GetOpsAdvancedSettings(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetOpsAdvancedSettings() after update error = %v", err)
|
||||
}
|
||||
if !reloaded.DisplayOpenAITokenStats {
|
||||
t.Fatalf("reloaded DisplayOpenAITokenStats = false, want true")
|
||||
}
|
||||
if reloaded.DisplayAlertEvents {
|
||||
t.Fatalf("reloaded DisplayAlertEvents = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOpsAdvancedSettings_BackfillsNewDisplayFlagsFromDefaults(t *testing.T) {
|
||||
repo := newRuntimeSettingRepoStub()
|
||||
svc := &OpsService{settingRepo: repo}
|
||||
|
||||
legacyCfg := map[string]any{
|
||||
"data_retention": map[string]any{
|
||||
"cleanup_enabled": false,
|
||||
"cleanup_schedule": "0 2 * * *",
|
||||
"error_log_retention_days": 30,
|
||||
"minute_metrics_retention_days": 30,
|
||||
"hourly_metrics_retention_days": 30,
|
||||
},
|
||||
"aggregation": map[string]any{
|
||||
"aggregation_enabled": false,
|
||||
},
|
||||
"ignore_count_tokens_errors": true,
|
||||
"ignore_context_canceled": true,
|
||||
"ignore_no_available_accounts": false,
|
||||
"ignore_invalid_api_key_errors": false,
|
||||
"auto_refresh_enabled": false,
|
||||
"auto_refresh_interval_seconds": 30,
|
||||
}
|
||||
raw, err := json.Marshal(legacyCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal legacy config: %v", err)
|
||||
}
|
||||
repo.values[SettingKeyOpsAdvancedSettings] = string(raw)
|
||||
|
||||
cfg, err := svc.GetOpsAdvancedSettings(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetOpsAdvancedSettings() error = %v", err)
|
||||
}
|
||||
if cfg.DisplayOpenAITokenStats {
|
||||
t.Fatalf("DisplayOpenAITokenStats = true, want false default backfill")
|
||||
}
|
||||
if !cfg.DisplayAlertEvents {
|
||||
t.Fatalf("DisplayAlertEvents = false, want true default backfill")
|
||||
}
|
||||
}
|
||||
@@ -92,14 +92,17 @@ type OpsAlertRuntimeSettings struct {
|
||||
|
||||
// OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation).
|
||||
type OpsAdvancedSettings struct {
|
||||
DataRetention OpsDataRetentionSettings `json:"data_retention"`
|
||||
Aggregation OpsAggregationSettings `json:"aggregation"`
|
||||
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
|
||||
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
|
||||
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
|
||||
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
|
||||
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
|
||||
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
|
||||
DataRetention OpsDataRetentionSettings `json:"data_retention"`
|
||||
Aggregation OpsAggregationSettings `json:"aggregation"`
|
||||
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
|
||||
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
|
||||
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
|
||||
IgnoreInvalidApiKeyErrors bool `json:"ignore_invalid_api_key_errors"`
|
||||
IgnoreInsufficientBalanceErrors bool `json:"ignore_insufficient_balance_errors"`
|
||||
DisplayOpenAITokenStats bool `json:"display_openai_token_stats"`
|
||||
DisplayAlertEvents bool `json:"display_alert_events"`
|
||||
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
|
||||
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
|
||||
}
|
||||
|
||||
type OpsDataRetentionSettings struct {
|
||||
|
||||
@@ -149,8 +149,9 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
}
|
||||
// 其他 400 错误(如参数问题)不处理,不禁用账号
|
||||
case 401:
|
||||
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
|
||||
if account.Type == AccountTypeOAuth {
|
||||
// OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为。
|
||||
// Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制。
|
||||
if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity {
|
||||
// 1. 失效缓存
|
||||
if s.tokenCacheInvalidator != nil {
|
||||
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
@@ -182,7 +183,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
}
|
||||
shouldDisable = true
|
||||
} else {
|
||||
// 非 OAuth 账号(APIKey):保持原有 SetError 行为
|
||||
// 非 OAuth / Antigravity OAuth:保持 SetError 行为
|
||||
msg := "Authentication failed (401): invalid or expired credentials"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Authentication failed (401): " + upstreamMsg
|
||||
@@ -199,11 +200,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
shouldDisable = true
|
||||
case 403:
|
||||
// 禁止访问:停止调度,记录错误
|
||||
msg := "Access forbidden (403): account may be suspended or lack permissions"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Access forbidden (403): " + upstreamMsg
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"service.ratelimit",
|
||||
"[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s",
|
||||
@@ -215,8 +211,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
upstreamMsg,
|
||||
truncateForLog(responseBody, 1024),
|
||||
)
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
shouldDisable = true
|
||||
shouldDisable = s.handle403(ctx, account, upstreamMsg, responseBody)
|
||||
case 429:
|
||||
s.handle429(ctx, account, headers, responseBody)
|
||||
shouldDisable = false
|
||||
@@ -621,6 +616,62 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account
|
||||
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
|
||||
}
|
||||
|
||||
// handle403 处理 403 Forbidden 错误
|
||||
// Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用;
|
||||
// 其他平台保持原有 SetError 行为。
|
||||
func (s *RateLimitService) handle403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody)
|
||||
}
|
||||
// 非 Antigravity 平台:保持原有行为
|
||||
msg := "Access forbidden (403): account may be suspended or lack permissions"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Access forbidden (403): " + upstreamMsg
|
||||
}
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
// handleAntigravity403 处理 Antigravity 平台的 403 错误
|
||||
// validation(需要验证)→ 永久 SetError(需人工去 Google 验证后恢复)
|
||||
// violation(违规封号)→ 永久 SetError(需人工处理)
|
||||
// generic(通用禁止)→ 永久 SetError
|
||||
func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
|
||||
fbType := classifyForbiddenType(string(responseBody))
|
||||
|
||||
switch fbType {
|
||||
case forbiddenTypeValidation:
|
||||
// VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复
|
||||
msg := "Validation required (403): account needs Google verification"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Validation required (403): " + upstreamMsg
|
||||
}
|
||||
if validationURL := extractValidationURL(string(responseBody)); validationURL != "" {
|
||||
msg += " | validation_url: " + validationURL
|
||||
}
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
return true
|
||||
|
||||
case forbiddenTypeViolation:
|
||||
// 违规封号: 永久禁用,需人工处理
|
||||
msg := "Account violation (403): terms of service violation"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Account violation (403): " + upstreamMsg
|
||||
}
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
return true
|
||||
|
||||
default:
|
||||
// 通用 403: 保持原有行为
|
||||
msg := "Access forbidden (403): account may be suspended or lack permissions"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Access forbidden (403): " + upstreamMsg
|
||||
}
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// handleCustomErrorCode 处理自定义错误码,停止账号调度
|
||||
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
|
||||
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
|
||||
@@ -1213,7 +1264,8 @@ func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Ac
|
||||
}
|
||||
// 401 首次命中可临时不可调度(给 token 刷新窗口);
|
||||
// 若历史上已因 401 进入过临时不可调度,则本次应升级为 error(返回 false 交由默认错误逻辑处理)。
|
||||
if statusCode == http.StatusUnauthorized {
|
||||
// Antigravity 跳过:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制,无需升级逻辑。
|
||||
if statusCode == http.StatusUnauthorized && account.Platform != PlatformAntigravity {
|
||||
reason := account.TempUnschedulableReason
|
||||
// 缓存可能没有 reason,从 DB 回退读取
|
||||
if reason == "" {
|
||||
|
||||
@@ -27,34 +27,68 @@ func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, e
|
||||
|
||||
func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
|
||||
// Scenario: cache account has empty TempUnschedulableReason (cache miss),
|
||||
// but DB account has a previous 401 record → should escalate to ErrorPolicyNone.
|
||||
repo := &dbFallbackRepoStub{
|
||||
dbAccount: &Account{
|
||||
ID: 20,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
},
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
// but DB account has a previous 401 record.
|
||||
// Non-Antigravity: should escalate to ErrorPolicyNone (second 401 = permanent error).
|
||||
// Antigravity: skips escalation logic (401 handled by applyErrorPolicy rules).
|
||||
t.Run("gemini_escalates", func(t *testing.T) {
|
||||
repo := &dbFallbackRepoStub{
|
||||
dbAccount: &Account{
|
||||
ID: 20,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
},
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
TempUnschedulableReason: "", // cache miss — reason is empty
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformGemini,
|
||||
TempUnschedulableReason: "",
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||
require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone")
|
||||
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||
require.Equal(t, ErrorPolicyNone, result, "gemini 401 with DB fallback showing previous 401 should escalate")
|
||||
})
|
||||
|
||||
t.Run("antigravity_stays_temp", func(t *testing.T) {
|
||||
repo := &dbFallbackRepoStub{
|
||||
dbAccount: &Account{
|
||||
ID: 20,
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
},
|
||||
}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
TempUnschedulableReason: "",
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
|
||||
require.Equal(t, ErrorPolicyTempUnscheduled, result, "antigravity 401 skips escalation, stays temp-unscheduled")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) {
|
||||
|
||||
@@ -42,45 +42,56 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
}{
|
||||
{name: "gemini", platform: PlatformGemini},
|
||||
{name: "antigravity", platform: PlatformAntigravity},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: tt.platform,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": 401,
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": 30,
|
||||
"description": "custom rule",
|
||||
},
|
||||
t.Run("gemini", func(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": 401,
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": 30,
|
||||
"description": "custom rule",
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 0, repo.setErrorCalls)
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
})
|
||||
}
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 0, repo.setErrorCalls)
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
})
|
||||
|
||||
t.Run("antigravity_401_uses_SetError", func(t *testing.T) {
|
||||
// Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制,
|
||||
// HandleUpstreamError 中走 SetError 路径。
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Equal(t, 0, repo.tempCalls)
|
||||
require.Empty(t, invalidator.accounts)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
|
||||
|
||||
@@ -65,6 +65,19 @@ const minVersionErrorTTL = 5 * time.Second
|
||||
// minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context
|
||||
const minVersionDBTimeout = 5 * time.Second
|
||||
|
||||
// cachedBackendMode Backend Mode cache (in-process, 60s TTL)
|
||||
type cachedBackendMode struct {
|
||||
value bool
|
||||
expiresAt int64 // unix nano
|
||||
}
|
||||
|
||||
var backendModeCache atomic.Value // *cachedBackendMode
|
||||
var backendModeSF singleflight.Group
|
||||
|
||||
const backendModeCacheTTL = 60 * time.Second
|
||||
const backendModeErrorTTL = 5 * time.Second
|
||||
const backendModeDBTimeout = 5 * time.Second
|
||||
|
||||
// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
|
||||
type DefaultSubscriptionGroupReader interface {
|
||||
GetByID(ctx context.Context, id int64) (*Group, error)
|
||||
@@ -103,6 +116,15 @@ func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, e
|
||||
return s.parseSettings(settings), nil
|
||||
}
|
||||
|
||||
// GetFrontendURL 获取前端基础URL(数据库优先,fallback 到配置文件)
|
||||
func (s *SettingService) GetFrontendURL(ctx context.Context) string {
|
||||
val, err := s.settingRepo.GetValue(ctx, SettingKeyFrontendURL)
|
||||
if err == nil && strings.TrimSpace(val) != "" {
|
||||
return strings.TrimSpace(val)
|
||||
}
|
||||
return s.cfg.Server.FrontendURL
|
||||
}
|
||||
|
||||
// GetPublicSettings 获取公开设置(无需登录)
|
||||
func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings, error) {
|
||||
keys := []string{
|
||||
@@ -128,6 +150,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeySoraClientEnabled,
|
||||
SettingKeyCustomMenuItems,
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
SettingKeyBackendModeEnabled,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -172,6 +195,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
||||
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -223,6 +247,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
Version string `json:"version,omitempty"`
|
||||
}{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
@@ -247,6 +272,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
Version: s.version,
|
||||
}, nil
|
||||
}
|
||||
@@ -384,6 +410,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON)
|
||||
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
|
||||
updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled)
|
||||
updates[SettingKeyFrontendURL] = settings.FrontendURL
|
||||
updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled)
|
||||
updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled)
|
||||
|
||||
@@ -461,6 +488,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
// 分组隔离
|
||||
updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling)
|
||||
|
||||
// Backend Mode
|
||||
updates[SettingKeyBackendModeEnabled] = strconv.FormatBool(settings.BackendModeEnabled)
|
||||
|
||||
err = s.settingRepo.SetMultiple(ctx, updates)
|
||||
if err == nil {
|
||||
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
|
||||
@@ -469,6 +499,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
value: settings.MinClaudeCodeVersion,
|
||||
expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(),
|
||||
})
|
||||
backendModeSF.Forget("backend_mode")
|
||||
backendModeCache.Store(&cachedBackendMode{
|
||||
value: settings.BackendModeEnabled,
|
||||
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
|
||||
})
|
||||
if s.onUpdate != nil {
|
||||
s.onUpdate() // Invalidate cache after settings update
|
||||
}
|
||||
@@ -525,6 +560,52 @@ func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// IsBackendModeEnabled checks if backend mode is enabled
|
||||
// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path
|
||||
func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool {
|
||||
if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil {
|
||||
if time.Now().UnixNano() < cached.expiresAt {
|
||||
return cached.value
|
||||
}
|
||||
}
|
||||
result, _, _ := backendModeSF.Do("backend_mode", func() (any, error) {
|
||||
if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil {
|
||||
if time.Now().UnixNano() < cached.expiresAt {
|
||||
return cached.value, nil
|
||||
}
|
||||
}
|
||||
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), backendModeDBTimeout)
|
||||
defer cancel()
|
||||
value, err := s.settingRepo.GetValue(dbCtx, SettingKeyBackendModeEnabled)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
// Setting not yet created (fresh install) - default to disabled with full TTL
|
||||
backendModeCache.Store(&cachedBackendMode{
|
||||
value: false,
|
||||
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
|
||||
})
|
||||
return false, nil
|
||||
}
|
||||
slog.Warn("failed to get backend_mode_enabled setting", "error", err)
|
||||
backendModeCache.Store(&cachedBackendMode{
|
||||
value: false,
|
||||
expiresAt: time.Now().Add(backendModeErrorTTL).UnixNano(),
|
||||
})
|
||||
return false, nil
|
||||
}
|
||||
enabled := value == "true"
|
||||
backendModeCache.Store(&cachedBackendMode{
|
||||
value: enabled,
|
||||
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
|
||||
})
|
||||
return enabled, nil
|
||||
})
|
||||
if val, ok := result.(bool); ok {
|
||||
return val
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsEmailVerifyEnabled 检查是否开启邮件验证
|
||||
func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled)
|
||||
@@ -696,6 +777,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
RegistrationEmailSuffixWhitelist: ParseRegistrationEmailSuffixWhitelist(settings[SettingKeyRegistrationEmailSuffixWhitelist]),
|
||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||
PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true",
|
||||
FrontendURL: settings[SettingKeyFrontendURL],
|
||||
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
|
||||
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
||||
SMTPHost: settings[SettingKeySMTPHost],
|
||||
@@ -719,6 +801,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
||||
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
||||
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
|
||||
}
|
||||
|
||||
// 解析整数类型
|
||||
@@ -1278,7 +1361,7 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be
|
||||
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
|
||||
}
|
||||
validScopes := map[string]bool{
|
||||
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true,
|
||||
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true,
|
||||
}
|
||||
|
||||
for i, rule := range settings.Rules {
|
||||
|
||||
199
backend/internal/service/setting_service_backend_mode_test.go
Normal file
199
backend/internal/service/setting_service_backend_mode_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type bmRepoStub struct {
|
||||
getValueFn func(ctx context.Context, key string) (string, error)
|
||||
calls int
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
s.calls++
|
||||
if s.getValueFn == nil {
|
||||
panic("unexpected GetValue call")
|
||||
}
|
||||
return s.getValueFn(ctx, key)
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
panic("unexpected SetMultiple call")
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *bmRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
type bmUpdateRepoStub struct {
|
||||
updates map[string]string
|
||||
getValueFn func(ctx context.Context, key string) (string, error)
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
if s.getValueFn == nil {
|
||||
panic("unexpected GetValue call")
|
||||
}
|
||||
return s.getValueFn(ctx, key)
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
s.updates = make(map[string]string, len(settings))
|
||||
for k, v := range settings {
|
||||
s.updates[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *bmUpdateRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func resetBackendModeTestCache(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
backendModeCache.Store((*cachedBackendMode)(nil))
|
||||
t.Cleanup(func() {
|
||||
backendModeCache.Store((*cachedBackendMode)(nil))
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsBackendModeEnabled_ReturnsTrue(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
repo := &bmRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "true", nil
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
require.True(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
|
||||
func TestIsBackendModeEnabled_ReturnsFalse(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
repo := &bmRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "false", nil
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
require.False(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
|
||||
func TestIsBackendModeEnabled_ReturnsFalseOnNotFound(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
repo := &bmRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "", ErrSettingNotFound
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
require.False(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
|
||||
func TestIsBackendModeEnabled_ReturnsFalseOnDBError(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
repo := &bmRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "", errors.New("db down")
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
require.False(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
|
||||
func TestIsBackendModeEnabled_CachesResult(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
repo := &bmRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "true", nil
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
require.True(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.True(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
require.Equal(t, 1, repo.calls)
|
||||
}
|
||||
|
||||
func TestUpdateSettings_InvalidatesBackendModeCache(t *testing.T) {
|
||||
resetBackendModeTestCache(t)
|
||||
|
||||
backendModeCache.Store(&cachedBackendMode{
|
||||
value: true,
|
||||
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
|
||||
})
|
||||
|
||||
repo := &bmUpdateRepoStub{
|
||||
getValueFn: func(ctx context.Context, key string) (string, error) {
|
||||
require.Equal(t, SettingKeyBackendModeEnabled, key)
|
||||
return "true", nil
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
err := svc.UpdateSettings(context.Background(), &SystemSettings{
|
||||
BackendModeEnabled: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "false", repo.updates[SettingKeyBackendModeEnabled])
|
||||
require.False(t, svc.IsBackendModeEnabled(context.Background()))
|
||||
}
|
||||
@@ -6,6 +6,7 @@ type SystemSettings struct {
|
||||
RegistrationEmailSuffixWhitelist []string
|
||||
PromoCodeEnabled bool
|
||||
PasswordResetEnabled bool
|
||||
FrontendURL string
|
||||
InvitationCodeEnabled bool
|
||||
TotpEnabled bool // TOTP 双因素认证
|
||||
|
||||
@@ -69,6 +70,9 @@ type SystemSettings struct {
|
||||
|
||||
// 分组隔离:允许未分组 Key 调度(默认 false → 403)
|
||||
AllowUngroupedKeyScheduling bool
|
||||
|
||||
// Backend 模式:禁用用户注册和自助服务,仅管理员可登录
|
||||
BackendModeEnabled bool
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -101,6 +105,7 @@ type PublicSettings struct {
|
||||
CustomMenuItems string // JSON array of custom menu items
|
||||
|
||||
LinuxDoOAuthEnabled bool
|
||||
BackendModeEnabled bool
|
||||
Version string
|
||||
}
|
||||
|
||||
@@ -198,16 +203,17 @@ const (
|
||||
BetaPolicyActionFilter = "filter" // 过滤,从 beta header 中移除该 token
|
||||
BetaPolicyActionBlock = "block" // 拦截,直接返回错误
|
||||
|
||||
BetaPolicyScopeAll = "all" // 所有账号类型
|
||||
BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号
|
||||
BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号
|
||||
BetaPolicyScopeAll = "all" // 所有账号类型
|
||||
BetaPolicyScopeOAuth = "oauth" // 仅 OAuth 账号
|
||||
BetaPolicyScopeAPIKey = "apikey" // 仅 API Key 账号
|
||||
BetaPolicyScopeBedrock = "bedrock" // 仅 AWS Bedrock 账号
|
||||
)
|
||||
|
||||
// BetaPolicyRule 单条 Beta 策略规则
|
||||
type BetaPolicyRule struct {
|
||||
BetaToken string `json:"beta_token"` // beta token 值
|
||||
Action string `json:"action"` // "pass" | "filter" | "block"
|
||||
Scope string `json:"scope"` // "all" | "oauth" | "apikey"
|
||||
Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock"
|
||||
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
|
||||
}
|
||||
|
||||
|
||||
@@ -100,9 +100,14 @@ type UsageLog struct {
|
||||
Model string
|
||||
// ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex".
|
||||
ServiceTier *string
|
||||
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API),
|
||||
// e.g. "low" / "medium" / "high" / "xhigh". Nil means not provided / not applicable.
|
||||
// ReasoningEffort is the request's reasoning effort level.
|
||||
// OpenAI: "low" / "medium" / "high" / "xhigh"; Claude: "low" / "medium" / "high" / "max".
|
||||
// Nil means not provided / not applicable.
|
||||
ReasoningEffort *string
|
||||
// InboundEndpoint is the client-facing API endpoint path, e.g. /v1/chat/completions.
|
||||
InboundEndpoint *string
|
||||
// UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses.
|
||||
UpstreamEndpoint *string
|
||||
|
||||
GroupID *int64
|
||||
SubscriptionID *int64
|
||||
|
||||
@@ -322,6 +322,19 @@ func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthC
|
||||
return apiKeyService
|
||||
}
|
||||
|
||||
// ProvideBackupService creates and starts BackupService
|
||||
func ProvideBackupService(
|
||||
settingRepo SettingRepository,
|
||||
cfg *config.Config,
|
||||
encryptor SecretEncryptor,
|
||||
storeFactory BackupObjectStoreFactory,
|
||||
dumper DBDumper,
|
||||
) *BackupService {
|
||||
svc := NewBackupService(settingRepo, cfg, encryptor, storeFactory, dumper)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideSettingService wires SettingService with group reader for default subscription validation.
|
||||
func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService {
|
||||
svc := NewSettingService(settingRepo, cfg)
|
||||
@@ -373,6 +386,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAccountTestService,
|
||||
ProvideSettingService,
|
||||
NewDataManagementService,
|
||||
ProvideBackupService,
|
||||
ProvideOpsSystemLogSink,
|
||||
NewOpsService,
|
||||
ProvideOpsMetricsCollector,
|
||||
|
||||
5
backend/migrations/074_add_usage_log_endpoints.sql
Normal file
5
backend/migrations/074_add_usage_log_endpoints.sql
Normal file
@@ -0,0 +1,5 @@
|
||||
-- Add endpoint tracking fields to usage_logs.
|
||||
-- inbound_endpoint: client-facing API route (e.g. /v1/chat/completions, /v1/messages, /v1/responses)
|
||||
-- upstream_endpoint: normalized upstream route (e.g. /v1/responses)
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS inbound_endpoint VARCHAR(128);
|
||||
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS upstream_endpoint VARCHAR(128);
|
||||
105
deploy/docker-compose.dev.yml
Normal file
105
deploy/docker-compose.dev.yml
Normal file
@@ -0,0 +1,105 @@
|
||||
# =============================================================================
|
||||
# Sub2API Docker Compose - Local Development Build
|
||||
# =============================================================================
|
||||
# Build from local source code for testing changes.
|
||||
#
|
||||
# Usage:
|
||||
# cd deploy
|
||||
# docker compose -f docker-compose.dev.yml up --build
|
||||
# =============================================================================
|
||||
|
||||
services:
|
||||
sub2api:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: Dockerfile
|
||||
container_name: sub2api-dev
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "${BIND_HOST:-127.0.0.1}:${SERVER_PORT:-8080}:8080"
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
environment:
|
||||
- AUTO_SETUP=true
|
||||
- SERVER_HOST=0.0.0.0
|
||||
- SERVER_PORT=8080
|
||||
- SERVER_MODE=debug
|
||||
- RUN_MODE=${RUN_MODE:-standard}
|
||||
- DATABASE_HOST=postgres
|
||||
- DATABASE_PORT=5432
|
||||
- DATABASE_USER=${POSTGRES_USER:-sub2api}
|
||||
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
|
||||
- DATABASE_SSLMODE=disable
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=${REDIS_PASSWORD:-}
|
||||
- REDIS_DB=${REDIS_DB:-0}
|
||||
- ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local}
|
||||
- ADMIN_PASSWORD=${ADMIN_PASSWORD:-}
|
||||
- JWT_SECRET=${JWT_SECRET:-}
|
||||
- TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-}
|
||||
- TZ=${TZ:-Asia/Shanghai}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- sub2api-network
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 30s
|
||||
|
||||
postgres:
|
||||
image: postgres:18-alpine
|
||||
container_name: sub2api-postgres-dev
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./postgres_data:/var/lib/postgresql/data
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
|
||||
- PGDATA=/var/lib/postgresql/data
|
||||
- TZ=${TZ:-Asia/Shanghai}
|
||||
networks:
|
||||
- sub2api-network
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 10s
|
||||
|
||||
redis:
|
||||
image: redis:8-alpine
|
||||
container_name: sub2api-redis-dev
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./redis_data:/data
|
||||
command: >
|
||||
sh -c '
|
||||
redis-server
|
||||
--save 60 1
|
||||
--appendonly yes
|
||||
--appendfsync everysec
|
||||
${REDIS_PASSWORD:+--requirepass "$REDIS_PASSWORD"}'
|
||||
environment:
|
||||
- TZ=${TZ:-Asia/Shanghai}
|
||||
- REDISCLI_AUTH=${REDIS_PASSWORD:-}
|
||||
networks:
|
||||
- sub2api-network
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 5s
|
||||
|
||||
networks:
|
||||
sub2api-network:
|
||||
driver: bridge
|
||||
114
frontend/src/api/admin/backup.ts
Normal file
114
frontend/src/api/admin/backup.ts
Normal file
@@ -0,0 +1,114 @@
|
||||
import { apiClient } from '../client'
|
||||
|
||||
export interface BackupS3Config {
|
||||
endpoint: string
|
||||
region: string
|
||||
bucket: string
|
||||
access_key_id: string
|
||||
secret_access_key?: string
|
||||
prefix: string
|
||||
force_path_style: boolean
|
||||
}
|
||||
|
||||
export interface BackupScheduleConfig {
|
||||
enabled: boolean
|
||||
cron_expr: string
|
||||
retain_days: number
|
||||
retain_count: number
|
||||
}
|
||||
|
||||
export interface BackupRecord {
|
||||
id: string
|
||||
status: 'pending' | 'running' | 'completed' | 'failed'
|
||||
backup_type: string
|
||||
file_name: string
|
||||
s3_key: string
|
||||
size_bytes: number
|
||||
triggered_by: string
|
||||
error_message?: string
|
||||
started_at: string
|
||||
finished_at?: string
|
||||
expires_at?: string
|
||||
}
|
||||
|
||||
export interface CreateBackupRequest {
|
||||
expire_days?: number
|
||||
}
|
||||
|
||||
export interface TestS3Response {
|
||||
ok: boolean
|
||||
message: string
|
||||
}
|
||||
|
||||
// S3 Config
|
||||
export async function getS3Config(): Promise<BackupS3Config> {
|
||||
const { data } = await apiClient.get<BackupS3Config>('/admin/backups/s3-config')
|
||||
return data
|
||||
}
|
||||
|
||||
export async function updateS3Config(config: BackupS3Config): Promise<BackupS3Config> {
|
||||
const { data } = await apiClient.put<BackupS3Config>('/admin/backups/s3-config', config)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function testS3Connection(config: BackupS3Config): Promise<TestS3Response> {
|
||||
const { data } = await apiClient.post<TestS3Response>('/admin/backups/s3-config/test', config)
|
||||
return data
|
||||
}
|
||||
|
||||
// Schedule
|
||||
export async function getSchedule(): Promise<BackupScheduleConfig> {
|
||||
const { data } = await apiClient.get<BackupScheduleConfig>('/admin/backups/schedule')
|
||||
return data
|
||||
}
|
||||
|
||||
export async function updateSchedule(config: BackupScheduleConfig): Promise<BackupScheduleConfig> {
|
||||
const { data } = await apiClient.put<BackupScheduleConfig>('/admin/backups/schedule', config)
|
||||
return data
|
||||
}
|
||||
|
||||
// Backup operations
|
||||
export async function createBackup(req?: CreateBackupRequest): Promise<BackupRecord> {
|
||||
const { data } = await apiClient.post<BackupRecord>('/admin/backups', req || {}, { timeout: 600000 })
|
||||
return data
|
||||
}
|
||||
|
||||
export async function listBackups(): Promise<{ items: BackupRecord[] }> {
|
||||
const { data } = await apiClient.get<{ items: BackupRecord[] }>('/admin/backups')
|
||||
return data
|
||||
}
|
||||
|
||||
export async function getBackup(id: string): Promise<BackupRecord> {
|
||||
const { data } = await apiClient.get<BackupRecord>(`/admin/backups/${id}`)
|
||||
return data
|
||||
}
|
||||
|
||||
export async function deleteBackup(id: string): Promise<void> {
|
||||
await apiClient.delete(`/admin/backups/${id}`)
|
||||
}
|
||||
|
||||
export async function getDownloadURL(id: string): Promise<{ url: string }> {
|
||||
const { data } = await apiClient.get<{ url: string }>(`/admin/backups/${id}/download-url`)
|
||||
return data
|
||||
}
|
||||
|
||||
// Restore
|
||||
export async function restoreBackup(id: string, password: string): Promise<void> {
|
||||
await apiClient.post(`/admin/backups/${id}/restore`, { password }, { timeout: 600000 })
|
||||
}
|
||||
|
||||
export const backupAPI = {
|
||||
getS3Config,
|
||||
updateS3Config,
|
||||
testS3Connection,
|
||||
getSchedule,
|
||||
updateSchedule,
|
||||
createBackup,
|
||||
listBackups,
|
||||
getBackup,
|
||||
deleteBackup,
|
||||
getDownloadURL,
|
||||
restoreBackup,
|
||||
}
|
||||
|
||||
export default backupAPI
|
||||
@@ -23,6 +23,7 @@ import errorPassthroughAPI from './errorPassthrough'
|
||||
import dataManagementAPI from './dataManagement'
|
||||
import apiKeysAPI from './apiKeys'
|
||||
import scheduledTestsAPI from './scheduledTests'
|
||||
import backupAPI from './backup'
|
||||
|
||||
/**
|
||||
* Unified admin API object for convenient access
|
||||
@@ -47,7 +48,8 @@ export const adminAPI = {
|
||||
errorPassthrough: errorPassthroughAPI,
|
||||
dataManagement: dataManagementAPI,
|
||||
apiKeys: apiKeysAPI,
|
||||
scheduledTests: scheduledTestsAPI
|
||||
scheduledTests: scheduledTestsAPI,
|
||||
backup: backupAPI
|
||||
}
|
||||
|
||||
export {
|
||||
@@ -70,7 +72,8 @@ export {
|
||||
errorPassthroughAPI,
|
||||
dataManagementAPI,
|
||||
apiKeysAPI,
|
||||
scheduledTestsAPI
|
||||
scheduledTestsAPI,
|
||||
backupAPI
|
||||
}
|
||||
|
||||
export default adminAPI
|
||||
|
||||
@@ -841,6 +841,9 @@ export interface OpsAdvancedSettings {
|
||||
ignore_context_canceled: boolean
|
||||
ignore_no_available_accounts: boolean
|
||||
ignore_invalid_api_key_errors: boolean
|
||||
ignore_insufficient_balance_errors: boolean
|
||||
display_openai_token_stats: boolean
|
||||
display_alert_events: boolean
|
||||
auto_refresh_enabled: boolean
|
||||
auto_refresh_interval_seconds: number
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ export interface SystemSettings {
|
||||
registration_email_suffix_whitelist: string[]
|
||||
promo_code_enabled: boolean
|
||||
password_reset_enabled: boolean
|
||||
frontend_url: string
|
||||
invitation_code_enabled: boolean
|
||||
totp_enabled: boolean // TOTP 双因素认证
|
||||
totp_encryption_key_configured: boolean // TOTP 加密密钥是否已配置
|
||||
@@ -40,6 +41,7 @@ export interface SystemSettings {
|
||||
purchase_subscription_enabled: boolean
|
||||
purchase_subscription_url: string
|
||||
sora_client_enabled: boolean
|
||||
backend_mode_enabled: boolean
|
||||
custom_menu_items: CustomMenuItem[]
|
||||
// SMTP settings
|
||||
smtp_host: string
|
||||
@@ -90,6 +92,7 @@ export interface UpdateSettingsRequest {
|
||||
registration_email_suffix_whitelist?: string[]
|
||||
promo_code_enabled?: boolean
|
||||
password_reset_enabled?: boolean
|
||||
frontend_url?: string
|
||||
invitation_code_enabled?: boolean
|
||||
totp_enabled?: boolean // TOTP 双因素认证
|
||||
default_balance?: number
|
||||
@@ -106,6 +109,7 @@ export interface UpdateSettingsRequest {
|
||||
purchase_subscription_enabled?: boolean
|
||||
purchase_subscription_url?: string
|
||||
sora_client_enabled?: boolean
|
||||
backend_mode_enabled?: boolean
|
||||
custom_menu_items?: CustomMenuItem[]
|
||||
smtp_host?: string
|
||||
smtp_port?: number
|
||||
@@ -316,7 +320,7 @@ export async function updateRectifierSettings(
|
||||
export interface BetaPolicyRule {
|
||||
beta_token: string
|
||||
action: 'pass' | 'filter' | 'block'
|
||||
scope: 'all' | 'oauth' | 'apikey'
|
||||
scope: 'all' | 'oauth' | 'apikey' | 'bedrock'
|
||||
error_message?: string
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
import { apiClient } from '../client'
|
||||
import type { AdminUsageLog, UsageQueryParams, PaginatedResponse, UsageRequestType } from '@/types'
|
||||
import type { EndpointStat } from '@/types'
|
||||
|
||||
// ==================== Types ====================
|
||||
|
||||
@@ -18,6 +19,9 @@ export interface AdminUsageStatsResponse {
|
||||
total_actual_cost: number
|
||||
total_account_cost?: number
|
||||
average_duration_ms: number
|
||||
endpoints?: EndpointStat[]
|
||||
upstream_endpoints?: EndpointStat[]
|
||||
endpoint_paths?: EndpointStat[]
|
||||
}
|
||||
|
||||
export interface SimpleUser {
|
||||
|
||||
@@ -292,17 +292,19 @@ const rpmTooltip = computed(() => {
|
||||
}
|
||||
})
|
||||
|
||||
// 是否显示各维度配额(仅 apikey 类型)
|
||||
// 是否显示各维度配额(apikey / bedrock 类型)
|
||||
const isQuotaEligible = computed(() => props.account.type === 'apikey' || props.account.type === 'bedrock')
|
||||
|
||||
const showDailyQuota = computed(() => {
|
||||
return props.account.type === 'apikey' && (props.account.quota_daily_limit ?? 0) > 0
|
||||
return isQuotaEligible.value && (props.account.quota_daily_limit ?? 0) > 0
|
||||
})
|
||||
|
||||
const showWeeklyQuota = computed(() => {
|
||||
return props.account.type === 'apikey' && (props.account.quota_weekly_limit ?? 0) > 0
|
||||
return isQuotaEligible.value && (props.account.quota_weekly_limit ?? 0) > 0
|
||||
})
|
||||
|
||||
const showTotalQuota = computed(() => {
|
||||
return props.account.type === 'apikey' && (props.account.quota_limit ?? 0) > 0
|
||||
return isQuotaEligible.value && (props.account.quota_limit ?? 0) > 0
|
||||
})
|
||||
|
||||
// 格式化费用显示
|
||||
|
||||
@@ -446,6 +446,18 @@
|
||||
|
||||
<!-- Model Distribution -->
|
||||
<ModelDistributionChart :model-stats="stats.models" :loading="false" />
|
||||
|
||||
<EndpointDistributionChart
|
||||
:endpoint-stats="stats.endpoints || []"
|
||||
:loading="false"
|
||||
:title="t('usage.inboundEndpoint')"
|
||||
/>
|
||||
|
||||
<EndpointDistributionChart
|
||||
:endpoint-stats="stats.upstream_endpoints || []"
|
||||
:loading="false"
|
||||
:title="t('usage.upstreamEndpoint')"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<!-- No Data State -->
|
||||
@@ -489,6 +501,7 @@ import { Line } from 'vue-chartjs'
|
||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||
import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
|
||||
import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'
|
||||
import EndpointDistributionChart from '@/components/charts/EndpointDistributionChart.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { Account, AccountUsageStatsResponse } from '@/types'
|
||||
|
||||
@@ -36,6 +36,10 @@
|
||||
|
||||
<!-- Usage data -->
|
||||
<div v-else-if="usageInfo" class="space-y-1">
|
||||
<!-- API error (degraded response) -->
|
||||
<div v-if="usageInfo.error" class="text-xs text-amber-600 dark:text-amber-400 truncate max-w-[200px]" :title="usageInfo.error">
|
||||
{{ usageInfo.error }}
|
||||
</div>
|
||||
<!-- 5h Window -->
|
||||
<UsageProgressBar
|
||||
v-if="usageInfo.five_hour"
|
||||
@@ -189,8 +193,53 @@
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Forbidden state (403) -->
|
||||
<div v-if="isForbidden" class="space-y-1">
|
||||
<span
|
||||
:class="[
|
||||
'inline-block rounded px-1.5 py-0.5 text-[10px] font-medium',
|
||||
forbiddenBadgeClass
|
||||
]"
|
||||
>
|
||||
{{ forbiddenLabel }}
|
||||
</span>
|
||||
<div v-if="validationURL" class="flex items-center gap-1">
|
||||
<a
|
||||
:href="validationURL"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
class="text-[10px] text-blue-600 hover:text-blue-800 hover:underline dark:text-blue-400 dark:hover:text-blue-300"
|
||||
:title="t('admin.accounts.openVerification')"
|
||||
>
|
||||
{{ t('admin.accounts.openVerification') }}
|
||||
</a>
|
||||
<button
|
||||
type="button"
|
||||
class="text-[10px] text-gray-500 hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-200"
|
||||
:title="t('admin.accounts.copyLink')"
|
||||
@click="copyValidationURL"
|
||||
>
|
||||
{{ linkCopied ? t('admin.accounts.linkCopied') : t('admin.accounts.copyLink') }}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Needs reauth (401) -->
|
||||
<div v-else-if="needsReauth" class="space-y-1">
|
||||
<span class="inline-block rounded px-1.5 py-0.5 text-[10px] font-medium bg-orange-100 text-orange-700 dark:bg-orange-900/40 dark:text-orange-300">
|
||||
{{ t('admin.accounts.needsReauth') }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Degraded error (non-403, non-401) -->
|
||||
<div v-else-if="usageInfo?.error" class="space-y-1">
|
||||
<span class="inline-block rounded px-1.5 py-0.5 text-[10px] font-medium bg-amber-100 text-amber-700 dark:bg-amber-900/40 dark:text-amber-300">
|
||||
{{ usageErrorLabel }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Loading state -->
|
||||
<div v-if="loading" class="space-y-1.5">
|
||||
<div v-else-if="loading" class="space-y-1.5">
|
||||
<div class="flex items-center gap-1">
|
||||
<div class="h-3 w-[32px] animate-pulse rounded bg-gray-200 dark:bg-gray-700"></div>
|
||||
<div class="h-1.5 w-8 animate-pulse rounded-full bg-gray-200 dark:bg-gray-700"></div>
|
||||
@@ -816,6 +865,51 @@ const hasIneligibleTiers = computed(() => {
|
||||
return Array.isArray(ineligibleTiers) && ineligibleTiers.length > 0
|
||||
})
|
||||
|
||||
// Antigravity 403 forbidden 状态
|
||||
const isForbidden = computed(() => !!usageInfo.value?.is_forbidden)
|
||||
const forbiddenType = computed(() => usageInfo.value?.forbidden_type || 'forbidden')
|
||||
const validationURL = computed(() => usageInfo.value?.validation_url || '')
|
||||
|
||||
// 需要重新授权(401)
|
||||
const needsReauth = computed(() => !!usageInfo.value?.needs_reauth)
|
||||
|
||||
// 降级错误标签(rate_limited / network_error)
|
||||
const usageErrorLabel = computed(() => {
|
||||
const code = usageInfo.value?.error_code
|
||||
if (code === 'rate_limited') return t('admin.accounts.rateLimited')
|
||||
return t('admin.accounts.usageError')
|
||||
})
|
||||
|
||||
const forbiddenLabel = computed(() => {
|
||||
switch (forbiddenType.value) {
|
||||
case 'validation':
|
||||
return t('admin.accounts.forbiddenValidation')
|
||||
case 'violation':
|
||||
return t('admin.accounts.forbiddenViolation')
|
||||
default:
|
||||
return t('admin.accounts.forbidden')
|
||||
}
|
||||
})
|
||||
|
||||
const forbiddenBadgeClass = computed(() => {
|
||||
if (forbiddenType.value === 'validation') {
|
||||
return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/40 dark:text-yellow-300'
|
||||
}
|
||||
return 'bg-red-100 text-red-700 dark:bg-red-900/40 dark:text-red-300'
|
||||
})
|
||||
|
||||
const linkCopied = ref(false)
|
||||
const copyValidationURL = async () => {
|
||||
if (!validationURL.value) return
|
||||
try {
|
||||
await navigator.clipboard.writeText(validationURL.value)
|
||||
linkCopied.value = true
|
||||
setTimeout(() => { linkCopied.value = false }, 2000)
|
||||
} catch {
|
||||
// fallback: ignore
|
||||
}
|
||||
}
|
||||
|
||||
const loadUsage = async () => {
|
||||
if (!shouldFetchUsage.value) return
|
||||
|
||||
@@ -848,18 +942,30 @@ const makeQuotaBar = (
|
||||
let resetsAt: string | null = null
|
||||
if (startKey) {
|
||||
const extra = props.account.extra as Record<string, unknown> | undefined
|
||||
const startStr = extra?.[startKey] as string | undefined
|
||||
if (startStr) {
|
||||
const startDate = new Date(startStr)
|
||||
const periodMs = startKey.includes('daily') ? 24 * 60 * 60 * 1000 : 7 * 24 * 60 * 60 * 1000
|
||||
resetsAt = new Date(startDate.getTime() + periodMs).toISOString()
|
||||
const isDaily = startKey.includes('daily')
|
||||
const mode = isDaily
|
||||
? (extra?.quota_daily_reset_mode as string) || 'rolling'
|
||||
: (extra?.quota_weekly_reset_mode as string) || 'rolling'
|
||||
|
||||
if (mode === 'fixed') {
|
||||
// Use pre-computed next reset time for fixed mode
|
||||
const resetAtKey = isDaily ? 'quota_daily_reset_at' : 'quota_weekly_reset_at'
|
||||
resetsAt = (extra?.[resetAtKey] as string) || null
|
||||
} else {
|
||||
// Rolling mode: compute from start + period
|
||||
const startStr = extra?.[startKey] as string | undefined
|
||||
if (startStr) {
|
||||
const startDate = new Date(startStr)
|
||||
const periodMs = isDaily ? 24 * 60 * 60 * 1000 : 7 * 24 * 60 * 60 * 1000
|
||||
resetsAt = new Date(startDate.getTime() + periodMs).toISOString()
|
||||
}
|
||||
}
|
||||
}
|
||||
return { utilization, resetsAt }
|
||||
}
|
||||
|
||||
const hasApiKeyQuota = computed(() => {
|
||||
if (props.account.type !== 'apikey') return false
|
||||
if (props.account.type !== 'apikey' && props.account.type !== 'bedrock') return false
|
||||
return (
|
||||
(props.account.quota_daily_limit ?? 0) > 0 ||
|
||||
(props.account.quota_weekly_limit ?? 0) > 0 ||
|
||||
|
||||
@@ -164,27 +164,10 @@
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Checkbox List -->
|
||||
<div class="mb-3 grid grid-cols-2 gap-2">
|
||||
<label
|
||||
v-for="model in filteredModels"
|
||||
:key="model.value"
|
||||
class="flex cursor-pointer items-center rounded-lg border p-3 transition-all hover:bg-gray-50 dark:border-dark-600 dark:hover:bg-dark-700"
|
||||
:class="
|
||||
allowedModels.includes(model.value)
|
||||
? 'border-primary-500 bg-primary-50 dark:bg-primary-900/20'
|
||||
: 'border-gray-200'
|
||||
"
|
||||
>
|
||||
<input
|
||||
v-model="allowedModels"
|
||||
type="checkbox"
|
||||
:value="model.value"
|
||||
class="mr-2 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ model.label }}</span>
|
||||
</label>
|
||||
</div>
|
||||
<ModelWhitelistSelector
|
||||
v-model="allowedModels"
|
||||
:platforms="selectedPlatforms"
|
||||
/>
|
||||
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||
@@ -832,8 +815,12 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import { buildModelMappingObject as buildModelMappingPayload } from '@/composables/useModelWhitelist'
|
||||
import {
|
||||
buildModelMappingObject as buildModelMappingPayload,
|
||||
getPresetMappingsByPlatform
|
||||
} from '@/composables/useModelWhitelist'
|
||||
|
||||
interface Props {
|
||||
show: boolean
|
||||
@@ -865,26 +852,20 @@ const allAnthropicOAuthOrSetupToken = computed(() => {
|
||||
)
|
||||
})
|
||||
|
||||
const platformModelPrefix: Record<string, string[]> = {
|
||||
anthropic: ['claude-'],
|
||||
antigravity: ['claude-', 'gemini-', 'gpt-oss-', 'tab_'],
|
||||
openai: ['gpt-'],
|
||||
gemini: ['gemini-'],
|
||||
sora: []
|
||||
}
|
||||
|
||||
const filteredModels = computed(() => {
|
||||
if (props.selectedPlatforms.length === 0) return allModels
|
||||
const prefixes = [...new Set(props.selectedPlatforms.flatMap(p => platformModelPrefix[p] || []))]
|
||||
if (prefixes.length === 0) return allModels
|
||||
return allModels.filter(m => prefixes.some(prefix => m.value.startsWith(prefix)))
|
||||
})
|
||||
|
||||
const filteredPresets = computed(() => {
|
||||
if (props.selectedPlatforms.length === 0) return presetMappings
|
||||
const prefixes = [...new Set(props.selectedPlatforms.flatMap(p => platformModelPrefix[p] || []))]
|
||||
if (prefixes.length === 0) return presetMappings
|
||||
return presetMappings.filter(m => prefixes.some(prefix => m.from.startsWith(prefix)))
|
||||
if (props.selectedPlatforms.length === 0) return []
|
||||
|
||||
const dedupedPresets = new Map<string, ReturnType<typeof getPresetMappingsByPlatform>[number]>()
|
||||
for (const platform of props.selectedPlatforms) {
|
||||
for (const preset of getPresetMappingsByPlatform(platform)) {
|
||||
const key = `${preset.from}=>${preset.to}`
|
||||
if (!dedupedPresets.has(key)) {
|
||||
dedupedPresets.set(key, preset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(dedupedPresets.values())
|
||||
})
|
||||
|
||||
// Model mapping type
|
||||
@@ -937,204 +918,6 @@ const umqModeOptions = computed(() => [
|
||||
{ value: 'serialize', label: t('admin.accounts.quotaControl.rpmLimit.umqModeSerialize') },
|
||||
])
|
||||
|
||||
// All models list (combined Anthropic + OpenAI + Gemini)
|
||||
const allModels = [
|
||||
{ value: 'claude-opus-4-6', label: 'Claude Opus 4.6' },
|
||||
{ value: 'claude-sonnet-4-6', label: 'Claude Sonnet 4.6' },
|
||||
{ value: 'claude-opus-4-5-20251101', label: 'Claude Opus 4.5' },
|
||||
{ value: 'claude-sonnet-4-20250514', label: 'Claude Sonnet 4' },
|
||||
{ value: 'claude-sonnet-4-5-20250929', label: 'Claude Sonnet 4.5' },
|
||||
{ value: 'claude-3-5-haiku-20241022', label: 'Claude 3.5 Haiku' },
|
||||
{ value: 'claude-haiku-4-5-20251001', label: 'Claude Haiku 4.5' },
|
||||
{ value: 'claude-3-opus-20240229', label: 'Claude 3 Opus' },
|
||||
{ value: 'claude-3-5-sonnet-20241022', label: 'Claude 3.5 Sonnet' },
|
||||
{ value: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku' },
|
||||
{ value: 'gpt-5.3-codex', label: 'GPT-5.3 Codex' },
|
||||
{ value: 'gpt-5.3-codex-spark', label: 'GPT-5.3 Codex Spark' },
|
||||
{ value: 'gpt-5.4', label: 'GPT-5.4' },
|
||||
{ value: 'gpt-5.2-2025-12-11', label: 'GPT-5.2' },
|
||||
{ value: 'gpt-5.2-codex', label: 'GPT-5.2 Codex' },
|
||||
{ value: 'gpt-5.1-codex-max', label: 'GPT-5.1 Codex Max' },
|
||||
{ value: 'gpt-5.1-codex', label: 'GPT-5.1 Codex' },
|
||||
{ value: 'gpt-5.1-2025-11-13', label: 'GPT-5.1' },
|
||||
{ value: 'gpt-5.1-codex-mini', label: 'GPT-5.1 Codex Mini' },
|
||||
{ value: 'gpt-5-2025-08-07', label: 'GPT-5' },
|
||||
{ value: 'gemini-3.1-flash-image', label: 'Gemini 3.1 Flash Image' },
|
||||
{ value: 'gemini-2.5-flash-image', label: 'Gemini 2.5 Flash Image' },
|
||||
{ value: 'gemini-2.0-flash', label: 'Gemini 2.0 Flash' },
|
||||
{ value: 'gemini-2.5-flash', label: 'Gemini 2.5 Flash' },
|
||||
{ value: 'gemini-2.5-pro', label: 'Gemini 2.5 Pro' },
|
||||
{ value: 'gemini-3-pro-image', label: 'Gemini 3 Pro Image (Legacy)' },
|
||||
{ value: 'gemini-3-flash-preview', label: 'Gemini 3 Flash Preview' },
|
||||
{ value: 'gemini-3-pro-preview', label: 'Gemini 3 Pro Preview' }
|
||||
]
|
||||
|
||||
// Preset mappings (combined Anthropic + OpenAI + Gemini)
|
||||
const presetMappings = [
|
||||
{
|
||||
label: 'Sonnet 4',
|
||||
from: 'claude-sonnet-4-20250514',
|
||||
to: 'claude-sonnet-4-20250514',
|
||||
color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400'
|
||||
},
|
||||
{
|
||||
label: 'Sonnet 4.5',
|
||||
from: 'claude-sonnet-4-5-20250929',
|
||||
to: 'claude-sonnet-4-5-20250929',
|
||||
color:
|
||||
'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400'
|
||||
},
|
||||
{
|
||||
label: 'Opus 4.5',
|
||||
from: 'claude-opus-4-5-20251101',
|
||||
to: 'claude-opus-4-5-20251101',
|
||||
color:
|
||||
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
},
|
||||
{
|
||||
label: 'Opus 4.6',
|
||||
from: 'claude-opus-4-6',
|
||||
to: 'claude-opus-4-6-thinking',
|
||||
color:
|
||||
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
},
|
||||
{
|
||||
label: 'Opus 4.6-thinking',
|
||||
from: 'claude-opus-4-6-thinking',
|
||||
to: 'claude-opus-4-6-thinking',
|
||||
color:
|
||||
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
},
|
||||
{
|
||||
label: 'Sonnet 4.6',
|
||||
from: 'claude-sonnet-4-6',
|
||||
to: 'claude-sonnet-4-6',
|
||||
color:
|
||||
'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
},
|
||||
{
|
||||
label: 'Sonnet4→4.6',
|
||||
from: 'claude-sonnet-4-20250514',
|
||||
to: 'claude-sonnet-4-6',
|
||||
color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400'
|
||||
},
|
||||
{
|
||||
label: 'Sonnet4.5→4.6',
|
||||
from: 'claude-sonnet-4-5-20250929',
|
||||
to: 'claude-sonnet-4-6',
|
||||
color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400'
|
||||
},
|
||||
{
|
||||
label: 'Sonnet3.5→4.6',
|
||||
from: 'claude-3-5-sonnet-20241022',
|
||||
to: 'claude-sonnet-4-6',
|
||||
color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400'
|
||||
},
|
||||
{
|
||||
label: 'Opus4.5→4.6',
|
||||
from: 'claude-opus-4-5-20251101',
|
||||
to: 'claude-opus-4-6-thinking',
|
||||
color:
|
||||
'bg-violet-100 text-violet-700 hover:bg-violet-200 dark:bg-violet-900/30 dark:text-violet-400'
|
||||
},
|
||||
{
|
||||
label: 'Opus->Sonnet',
|
||||
from: 'claude-opus-4-5-20251101',
|
||||
to: 'claude-sonnet-4-5-20250929',
|
||||
color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400'
|
||||
},
|
||||
{
|
||||
label: 'Gemini 2.5 Image',
|
||||
from: 'gemini-2.5-flash-image',
|
||||
to: 'gemini-2.5-flash-image',
|
||||
color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400'
|
||||
},
|
||||
{
|
||||
label: 'Gemini 3.1 Image',
|
||||
from: 'gemini-3.1-flash-image',
|
||||
to: 'gemini-3.1-flash-image',
|
||||
color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400'
|
||||
},
|
||||
{
|
||||
label: 'G3 Image→3.1',
|
||||
from: 'gemini-3-pro-image',
|
||||
to: 'gemini-3.1-flash-image',
|
||||
color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400'
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.3 Codex',
|
||||
from: 'gpt-5.3-codex',
|
||||
to: 'gpt-5.3-codex',
|
||||
color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400'
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.3 Spark',
|
||||
from: 'gpt-5.3-codex-spark',
|
||||
to: 'gpt-5.3-codex-spark',
|
||||
color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400'
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.4',
|
||||
from: 'gpt-5.4',
|
||||
to: 'gpt-5.4',
|
||||
color: 'bg-rose-100 text-rose-700 hover:bg-rose-200 dark:bg-rose-900/30 dark:text-rose-400'
|
||||
},
|
||||
{
|
||||
label: '5.2→5.3',
|
||||
from: 'gpt-5.2-codex',
|
||||
to: 'gpt-5.3-codex',
|
||||
color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400'
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.2',
|
||||
from: 'gpt-5.2-2025-12-11',
|
||||
to: 'gpt-5.2-2025-12-11',
|
||||
color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400'
|
||||
},
|
||||
{
|
||||
label: 'GPT-5.2 Codex',
|
||||
from: 'gpt-5.2-codex',
|
||||
to: 'gpt-5.2-codex',
|
||||
color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400'
|
||||
},
|
||||
{
|
||||
label: 'Max->Codex',
|
||||
from: 'gpt-5.1-codex-max',
|
||||
to: 'gpt-5.1-codex',
|
||||
color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400'
|
||||
},
|
||||
{
|
||||
label: '3-Pro-Preview→3.1-Pro-High',
|
||||
from: 'gemini-3-pro-preview',
|
||||
to: 'gemini-3.1-pro-high',
|
||||
color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400'
|
||||
},
|
||||
{
|
||||
label: '3-Pro-High→3.1-Pro-High',
|
||||
from: 'gemini-3-pro-high',
|
||||
to: 'gemini-3.1-pro-high',
|
||||
color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400'
|
||||
},
|
||||
{
|
||||
label: '3-Pro-Low→3.1-Pro-Low',
|
||||
from: 'gemini-3-pro-low',
|
||||
to: 'gemini-3.1-pro-low',
|
||||
color: 'bg-yellow-100 text-yellow-700 hover:bg-yellow-200 dark:bg-yellow-900/30 dark:text-yellow-400'
|
||||
},
|
||||
{
|
||||
label: '3-Flash透传',
|
||||
from: 'gemini-3-flash',
|
||||
to: 'gemini-3-flash',
|
||||
color: 'bg-lime-100 text-lime-700 hover:bg-lime-200 dark:bg-lime-900/30 dark:text-lime-400'
|
||||
},
|
||||
{
|
||||
label: '2.5-Flash-Lite透传',
|
||||
from: 'gemini-2.5-flash-lite',
|
||||
to: 'gemini-2.5-flash-lite',
|
||||
color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400'
|
||||
}
|
||||
]
|
||||
|
||||
// Common HTTP error codes
|
||||
const commonErrorCodes = [
|
||||
{ value: 401, label: 'Unauthorized' },
|
||||
|
||||
@@ -323,35 +323,6 @@
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
@click="accountCategory = 'bedrock-apikey'"
|
||||
:class="[
|
||||
'flex items-center gap-3 rounded-lg border-2 p-3 text-left transition-all',
|
||||
accountCategory === 'bedrock-apikey'
|
||||
? 'border-amber-500 bg-amber-50 dark:bg-amber-900/20'
|
||||
: 'border-gray-200 hover:border-amber-300 dark:border-dark-600 dark:hover:border-amber-700'
|
||||
]"
|
||||
>
|
||||
<div
|
||||
:class="[
|
||||
'flex h-8 w-8 shrink-0 items-center justify-center rounded-lg',
|
||||
accountCategory === 'bedrock-apikey'
|
||||
? 'bg-amber-500 text-white'
|
||||
: 'bg-gray-100 text-gray-500 dark:bg-dark-600 dark:text-gray-400'
|
||||
]"
|
||||
>
|
||||
<Icon name="key" size="sm" />
|
||||
</div>
|
||||
<div>
|
||||
<span class="block text-sm font-medium text-gray-900 dark:text-white">{{
|
||||
t('admin.accounts.bedrockApiKeyLabel')
|
||||
}}</span>
|
||||
<span class="text-xs text-gray-500 dark:text-gray-400">{{
|
||||
t('admin.accounts.bedrockApiKeyDesc')
|
||||
}}</span>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -956,7 +927,7 @@
|
||||
</div>
|
||||
|
||||
<!-- API Key input (only for apikey type, excluding Antigravity which has its own fields) -->
|
||||
<div v-if="form.type === 'apikey' && form.platform !== 'antigravity' && accountCategory !== 'bedrock-apikey'" class="space-y-4">
|
||||
<div v-if="form.type === 'apikey' && form.platform !== 'antigravity'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.baseUrl') }}</label>
|
||||
<input
|
||||
@@ -1341,34 +1312,75 @@
|
||||
|
||||
<!-- Bedrock credentials (only for Anthropic Bedrock type) -->
|
||||
<div v-if="form.platform === 'anthropic' && accountCategory === 'bedrock'" class="space-y-4">
|
||||
<!-- Auth Mode Radio -->
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label>
|
||||
<input
|
||||
v-model="bedrockAccessKeyId"
|
||||
type="text"
|
||||
required
|
||||
class="input font-mono"
|
||||
placeholder="AKIA..."
|
||||
/>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockAuthMode') }}</label>
|
||||
<div class="mt-2 flex gap-4">
|
||||
<label class="flex cursor-pointer items-center">
|
||||
<input
|
||||
v-model="bedrockAuthMode"
|
||||
type="radio"
|
||||
value="sigv4"
|
||||
class="mr-2 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockAuthModeSigv4') }}</span>
|
||||
</label>
|
||||
<label class="flex cursor-pointer items-center">
|
||||
<input
|
||||
v-model="bedrockAuthMode"
|
||||
type="radio"
|
||||
value="apikey"
|
||||
class="mr-2 text-primary-600 focus:ring-primary-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockAuthModeApikey') }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label>
|
||||
|
||||
<!-- SigV4 fields -->
|
||||
<template v-if="bedrockAuthMode === 'sigv4'">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label>
|
||||
<input
|
||||
v-model="bedrockAccessKeyId"
|
||||
type="text"
|
||||
required
|
||||
class="input font-mono"
|
||||
placeholder="AKIA..."
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label>
|
||||
<input
|
||||
v-model="bedrockSecretAccessKey"
|
||||
type="password"
|
||||
required
|
||||
class="input font-mono"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label>
|
||||
<input
|
||||
v-model="bedrockSessionToken"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- API Key field -->
|
||||
<div v-if="bedrockAuthMode === 'apikey'">
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
|
||||
<input
|
||||
v-model="bedrockSecretAccessKey"
|
||||
v-model="bedrockApiKeyValue"
|
||||
type="password"
|
||||
required
|
||||
class="input font-mono"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label>
|
||||
<input
|
||||
v-model="bedrockSessionToken"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Shared: Region -->
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
|
||||
<select v-model="bedrockRegion" class="input">
|
||||
@@ -1408,6 +1420,8 @@
|
||||
</select>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Shared: Force Global -->
|
||||
<div>
|
||||
<label class="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
@@ -1488,142 +1502,62 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Bedrock API Key credentials (only for Anthropic Bedrock API Key type) -->
|
||||
<div v-if="form.platform === 'anthropic' && accountCategory === 'bedrock-apikey'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
|
||||
<input
|
||||
v-model="bedrockApiKeyValue"
|
||||
type="password"
|
||||
required
|
||||
class="input font-mono"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
|
||||
<select v-model="bedrockApiKeyRegion" class="input">
|
||||
<optgroup label="US">
|
||||
<option value="us-east-1">us-east-1 (N. Virginia)</option>
|
||||
<option value="us-east-2">us-east-2 (Ohio)</option>
|
||||
<option value="us-west-1">us-west-1 (N. California)</option>
|
||||
<option value="us-west-2">us-west-2 (Oregon)</option>
|
||||
<option value="us-gov-east-1">us-gov-east-1 (GovCloud US-East)</option>
|
||||
<option value="us-gov-west-1">us-gov-west-1 (GovCloud US-West)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Europe">
|
||||
<option value="eu-west-1">eu-west-1 (Ireland)</option>
|
||||
<option value="eu-west-2">eu-west-2 (London)</option>
|
||||
<option value="eu-west-3">eu-west-3 (Paris)</option>
|
||||
<option value="eu-central-1">eu-central-1 (Frankfurt)</option>
|
||||
<option value="eu-central-2">eu-central-2 (Zurich)</option>
|
||||
<option value="eu-south-1">eu-south-1 (Milan)</option>
|
||||
<option value="eu-south-2">eu-south-2 (Spain)</option>
|
||||
<option value="eu-north-1">eu-north-1 (Stockholm)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Asia Pacific">
|
||||
<option value="ap-northeast-1">ap-northeast-1 (Tokyo)</option>
|
||||
<option value="ap-northeast-2">ap-northeast-2 (Seoul)</option>
|
||||
<option value="ap-northeast-3">ap-northeast-3 (Osaka)</option>
|
||||
<option value="ap-south-1">ap-south-1 (Mumbai)</option>
|
||||
<option value="ap-south-2">ap-south-2 (Hyderabad)</option>
|
||||
<option value="ap-southeast-1">ap-southeast-1 (Singapore)</option>
|
||||
<option value="ap-southeast-2">ap-southeast-2 (Sydney)</option>
|
||||
</optgroup>
|
||||
<optgroup label="Canada">
|
||||
<option value="ca-central-1">ca-central-1 (Canada)</option>
|
||||
</optgroup>
|
||||
<optgroup label="South America">
|
||||
<option value="sa-east-1">sa-east-1 (São Paulo)</option>
|
||||
</optgroup>
|
||||
</select>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
v-model="bedrockApiKeyForceGlobal"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockForceGlobal') }}</span>
|
||||
</label>
|
||||
<p class="input-hint mt-1">{{ t('admin.accounts.bedrockForceGlobalHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Restriction Section for Bedrock API Key -->
|
||||
<!-- Pool Mode Section for Bedrock -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
|
||||
<!-- Mode Toggle -->
|
||||
<div class="mb-4 flex gap-2">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.poolMode') }}</label>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.poolModeHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'whitelist'"
|
||||
@click="poolModeEnabled = !poolModeEnabled"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'whitelist'
|
||||
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||
poolModeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelWhitelist') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'mapping'"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'mapping'
|
||||
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelMapping') }}
|
||||
<span
|
||||
:class="[
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
poolModeEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Whitelist Mode -->
|
||||
<div v-if="modelRestrictionMode === 'whitelist'">
|
||||
<ModelWhitelistSelector v-model="allowedModels" platform="anthropic" />
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||
<span v-if="allowedModels.length === 0">{{ t('admin.accounts.supportsAllModels') }}</span>
|
||||
<div v-if="poolModeEnabled" class="rounded-lg bg-blue-50 p-3 dark:bg-blue-900/20">
|
||||
<p class="text-xs text-blue-700 dark:text-blue-400">
|
||||
<Icon name="exclamationCircle" size="sm" class="mr-1 inline" :stroke-width="2" />
|
||||
{{ t('admin.accounts.poolModeInfo') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Mapping Mode -->
|
||||
<div v-else class="space-y-3">
|
||||
<div v-for="(mapping, index) in modelMappings" :key="index" class="flex items-center gap-2">
|
||||
<input v-model="mapping.from" type="text" class="input flex-1" :placeholder="t('admin.accounts.fromModel')" />
|
||||
<span class="text-gray-400">→</span>
|
||||
<input v-model="mapping.to" type="text" class="input flex-1" :placeholder="t('admin.accounts.toModel')" />
|
||||
<button type="button" @click="modelMappings.splice(index, 1)" class="text-red-500 hover:text-red-700">
|
||||
<Icon name="trash" size="sm" />
|
||||
</button>
|
||||
</div>
|
||||
<button type="button" @click="modelMappings.push({ from: '', to: '' })" class="btn btn-secondary text-sm">
|
||||
+ {{ t('admin.accounts.addMapping') }}
|
||||
</button>
|
||||
<!-- Bedrock Preset Mappings -->
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="preset in bedrockPresets"
|
||||
:key="preset.from"
|
||||
type="button"
|
||||
@click="addPresetMapping(preset.from, preset.to)"
|
||||
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
|
||||
>
|
||||
+ {{ preset.label }}
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="poolModeEnabled" class="mt-3">
|
||||
<label class="input-label">{{ t('admin.accounts.poolModeRetryCount') }}</label>
|
||||
<input
|
||||
v-model.number="poolModeRetryCount"
|
||||
type="number"
|
||||
min="0"
|
||||
:max="MAX_POOL_MODE_RETRY_COUNT"
|
||||
step="1"
|
||||
class="input"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{
|
||||
t('admin.accounts.poolModeRetryCountHint', {
|
||||
default: DEFAULT_POOL_MODE_RETRY_COUNT,
|
||||
max: MAX_POOL_MODE_RETRY_COUNT
|
||||
})
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- API Key 账号配额限制 -->
|
||||
<div v-if="form.type === 'apikey'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
|
||||
<!-- API Key / Bedrock 账号配额限制 -->
|
||||
<div v-if="form.type === 'apikey' || form.type === 'bedrock'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
|
||||
<div class="mb-3">
|
||||
<h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaLimit') }}</h3>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
@@ -1634,9 +1568,21 @@
|
||||
:totalLimit="editQuotaLimit"
|
||||
:dailyLimit="editQuotaDailyLimit"
|
||||
:weeklyLimit="editQuotaWeeklyLimit"
|
||||
:dailyResetMode="editDailyResetMode"
|
||||
:dailyResetHour="editDailyResetHour"
|
||||
:weeklyResetMode="editWeeklyResetMode"
|
||||
:weeklyResetDay="editWeeklyResetDay"
|
||||
:weeklyResetHour="editWeeklyResetHour"
|
||||
:resetTimezone="editResetTimezone"
|
||||
@update:totalLimit="editQuotaLimit = $event"
|
||||
@update:dailyLimit="editQuotaDailyLimit = $event"
|
||||
@update:weeklyLimit="editQuotaWeeklyLimit = $event"
|
||||
@update:dailyResetMode="editDailyResetMode = $event"
|
||||
@update:dailyResetHour="editDailyResetHour = $event"
|
||||
@update:weeklyResetMode="editWeeklyResetMode = $event"
|
||||
@update:weeklyResetDay="editWeeklyResetDay = $event"
|
||||
@update:weeklyResetHour="editWeeklyResetHour = $event"
|
||||
@update:resetTimezone="editResetTimezone = $event"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -3014,13 +2960,19 @@ interface TempUnschedRuleForm {
|
||||
// State
|
||||
const step = ref(1)
|
||||
const submitting = ref(false)
|
||||
const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'bedrock-apikey'>('oauth-based') // UI selection for account category
|
||||
const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock'>('oauth-based') // UI selection for account category
|
||||
const addMethod = ref<AddMethod>('oauth') // For oauth-based: 'oauth' or 'setup-token'
|
||||
const apiKeyBaseUrl = ref('https://api.anthropic.com')
|
||||
const apiKeyValue = ref('')
|
||||
const editQuotaLimit = ref<number | null>(null)
|
||||
const editQuotaDailyLimit = ref<number | null>(null)
|
||||
const editQuotaWeeklyLimit = ref<number | null>(null)
|
||||
const editDailyResetMode = ref<'rolling' | 'fixed' | null>(null)
|
||||
const editDailyResetHour = ref<number | null>(null)
|
||||
const editWeeklyResetMode = ref<'rolling' | 'fixed' | null>(null)
|
||||
const editWeeklyResetDay = ref<number | null>(null)
|
||||
const editWeeklyResetHour = ref<number | null>(null)
|
||||
const editResetTimezone = ref<string | null>(null)
|
||||
const modelMappings = ref<ModelMapping[]>([])
|
||||
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||
const allowedModels = ref<string[]>([])
|
||||
@@ -3050,16 +3002,13 @@ const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('an
|
||||
const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock'))
|
||||
|
||||
// Bedrock credentials
|
||||
const bedrockAuthMode = ref<'sigv4' | 'apikey'>('sigv4')
|
||||
const bedrockAccessKeyId = ref('')
|
||||
const bedrockSecretAccessKey = ref('')
|
||||
const bedrockSessionToken = ref('')
|
||||
const bedrockRegion = ref('us-east-1')
|
||||
const bedrockForceGlobal = ref(false)
|
||||
|
||||
// Bedrock API Key credentials
|
||||
const bedrockApiKeyValue = ref('')
|
||||
const bedrockApiKeyRegion = ref('us-east-1')
|
||||
const bedrockApiKeyForceGlobal = ref(false)
|
||||
const tempUnschedEnabled = ref(false)
|
||||
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
|
||||
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping')
|
||||
@@ -3343,7 +3292,8 @@ watch(
|
||||
bedrockSessionToken.value = ''
|
||||
bedrockRegion.value = 'us-east-1'
|
||||
bedrockForceGlobal.value = false
|
||||
bedrockApiKeyForceGlobal.value = false
|
||||
bedrockAuthMode.value = 'sigv4'
|
||||
bedrockApiKeyValue.value = ''
|
||||
// Reset Anthropic/Antigravity-specific settings when switching to other platforms
|
||||
if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') {
|
||||
interceptWarmupRequests.value = false
|
||||
@@ -3719,6 +3669,12 @@ const resetForm = () => {
|
||||
editQuotaLimit.value = null
|
||||
editQuotaDailyLimit.value = null
|
||||
editQuotaWeeklyLimit.value = null
|
||||
editDailyResetMode.value = null
|
||||
editDailyResetHour.value = null
|
||||
editWeeklyResetMode.value = null
|
||||
editWeeklyResetDay.value = null
|
||||
editWeeklyResetHour.value = null
|
||||
editResetTimezone.value = null
|
||||
modelMappings.value = []
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
allowedModels.value = [...claudeModels] // Default fill related models
|
||||
@@ -3919,27 +3875,34 @@ const handleSubmit = async () => {
|
||||
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
|
||||
return
|
||||
}
|
||||
if (!bedrockAccessKeyId.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockAccessKeyIdRequired'))
|
||||
return
|
||||
}
|
||||
if (!bedrockSecretAccessKey.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockSecretAccessKeyRequired'))
|
||||
return
|
||||
}
|
||||
if (!bedrockRegion.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockRegionRequired'))
|
||||
return
|
||||
}
|
||||
|
||||
const credentials: Record<string, unknown> = {
|
||||
aws_access_key_id: bedrockAccessKeyId.value.trim(),
|
||||
aws_secret_access_key: bedrockSecretAccessKey.value.trim(),
|
||||
aws_region: bedrockRegion.value.trim(),
|
||||
auth_mode: bedrockAuthMode.value,
|
||||
aws_region: bedrockRegion.value.trim() || 'us-east-1',
|
||||
}
|
||||
if (bedrockSessionToken.value.trim()) {
|
||||
credentials.aws_session_token = bedrockSessionToken.value.trim()
|
||||
|
||||
if (bedrockAuthMode.value === 'sigv4') {
|
||||
if (!bedrockAccessKeyId.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockAccessKeyIdRequired'))
|
||||
return
|
||||
}
|
||||
if (!bedrockSecretAccessKey.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockSecretAccessKeyRequired'))
|
||||
return
|
||||
}
|
||||
credentials.aws_access_key_id = bedrockAccessKeyId.value.trim()
|
||||
credentials.aws_secret_access_key = bedrockSecretAccessKey.value.trim()
|
||||
if (bedrockSessionToken.value.trim()) {
|
||||
credentials.aws_session_token = bedrockSessionToken.value.trim()
|
||||
}
|
||||
} else {
|
||||
if (!bedrockApiKeyValue.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockApiKeyRequired'))
|
||||
return
|
||||
}
|
||||
credentials.api_key = bedrockApiKeyValue.value.trim()
|
||||
}
|
||||
|
||||
if (bedrockForceGlobal.value) {
|
||||
credentials.aws_force_global = 'true'
|
||||
}
|
||||
@@ -3952,45 +3915,18 @@ const handleSubmit = async () => {
|
||||
credentials.model_mapping = modelMapping
|
||||
}
|
||||
|
||||
// Pool mode
|
||||
if (poolModeEnabled.value) {
|
||||
credentials.pool_mode = true
|
||||
credentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value)
|
||||
}
|
||||
|
||||
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||
|
||||
await createAccountAndFinish('anthropic', 'bedrock' as AccountType, credentials)
|
||||
return
|
||||
}
|
||||
|
||||
// For Bedrock API Key type, create directly
|
||||
if (form.platform === 'anthropic' && accountCategory.value === 'bedrock-apikey') {
|
||||
if (!form.name.trim()) {
|
||||
appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
|
||||
return
|
||||
}
|
||||
if (!bedrockApiKeyValue.value.trim()) {
|
||||
appStore.showError(t('admin.accounts.bedrockApiKeyRequired'))
|
||||
return
|
||||
}
|
||||
|
||||
const credentials: Record<string, unknown> = {
|
||||
api_key: bedrockApiKeyValue.value.trim(),
|
||||
aws_region: bedrockApiKeyRegion.value.trim() || 'us-east-1',
|
||||
}
|
||||
if (bedrockApiKeyForceGlobal.value) {
|
||||
credentials.aws_force_global = 'true'
|
||||
}
|
||||
|
||||
// Model mapping
|
||||
const modelMapping = buildModelMappingObject(
|
||||
modelRestrictionMode.value, allowedModels.value, modelMappings.value
|
||||
)
|
||||
if (modelMapping) {
|
||||
credentials.model_mapping = modelMapping
|
||||
}
|
||||
|
||||
applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create')
|
||||
|
||||
await createAccountAndFinish('anthropic', 'bedrock-apikey' as AccountType, credentials)
|
||||
return
|
||||
}
|
||||
|
||||
// For Antigravity upstream type, create directly
|
||||
if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') {
|
||||
if (!form.name.trim()) {
|
||||
@@ -4233,9 +4169,9 @@ const createAccountAndFinish = async (
|
||||
if (!applyTempUnschedConfig(credentials)) {
|
||||
return
|
||||
}
|
||||
// Inject quota limits for apikey accounts
|
||||
// Inject quota limits for apikey/bedrock accounts
|
||||
let finalExtra = extra
|
||||
if (type === 'apikey') {
|
||||
if (type === 'apikey' || type === 'bedrock') {
|
||||
const quotaExtra: Record<string, unknown> = { ...(extra || {}) }
|
||||
if (editQuotaLimit.value != null && editQuotaLimit.value > 0) {
|
||||
quotaExtra.quota_limit = editQuotaLimit.value
|
||||
@@ -4246,6 +4182,19 @@ const createAccountAndFinish = async (
|
||||
if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) {
|
||||
quotaExtra.quota_weekly_limit = editQuotaWeeklyLimit.value
|
||||
}
|
||||
// Quota reset mode config
|
||||
if (editDailyResetMode.value === 'fixed') {
|
||||
quotaExtra.quota_daily_reset_mode = 'fixed'
|
||||
quotaExtra.quota_daily_reset_hour = editDailyResetHour.value ?? 0
|
||||
}
|
||||
if (editWeeklyResetMode.value === 'fixed') {
|
||||
quotaExtra.quota_weekly_reset_mode = 'fixed'
|
||||
quotaExtra.quota_weekly_reset_day = editWeeklyResetDay.value ?? 1
|
||||
quotaExtra.quota_weekly_reset_hour = editWeeklyResetHour.value ?? 0
|
||||
}
|
||||
if (editDailyResetMode.value === 'fixed' || editWeeklyResetMode.value === 'fixed') {
|
||||
quotaExtra.quota_reset_timezone = editResetTimezone.value || 'UTC'
|
||||
}
|
||||
if (Object.keys(quotaExtra).length > 0) {
|
||||
finalExtra = quotaExtra
|
||||
}
|
||||
|
||||
@@ -563,37 +563,54 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Bedrock fields (only for bedrock type) -->
|
||||
<!-- Bedrock fields (for bedrock type, both SigV4 and API Key modes) -->
|
||||
<div v-if="account.type === 'bedrock'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label>
|
||||
<!-- SigV4 fields -->
|
||||
<template v-if="!isBedrockAPIKeyMode">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockAccessKeyId') }}</label>
|
||||
<input
|
||||
v-model="editBedrockAccessKeyId"
|
||||
type="text"
|
||||
class="input font-mono"
|
||||
placeholder="AKIA..."
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label>
|
||||
<input
|
||||
v-model="editBedrockSecretAccessKey"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockSecretKeyLeaveEmpty') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label>
|
||||
<input
|
||||
v-model="editBedrockSessionToken"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- API Key field -->
|
||||
<div v-if="isBedrockAPIKeyMode">
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
|
||||
<input
|
||||
v-model="editBedrockAccessKeyId"
|
||||
type="text"
|
||||
class="input font-mono"
|
||||
placeholder="AKIA..."
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSecretAccessKey') }}</label>
|
||||
<input
|
||||
v-model="editBedrockSecretAccessKey"
|
||||
v-model="editBedrockApiKeyValue"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')"
|
||||
:placeholder="t('admin.accounts.bedrockApiKeyLeaveEmpty')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockSecretKeyLeaveEmpty') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockSessionToken') }}</label>
|
||||
<input
|
||||
v-model="editBedrockSessionToken"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
:placeholder="t('admin.accounts.bedrockSecretKeyLeaveEmpty')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockSessionTokenHint') }}</p>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockApiKeyLeaveEmpty') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Shared: Region -->
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
|
||||
<input
|
||||
@@ -604,6 +621,8 @@
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Shared: Force Global -->
|
||||
<div>
|
||||
<label class="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
@@ -684,108 +703,56 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Bedrock API Key fields (only for bedrock-apikey type) -->
|
||||
<div v-if="account.type === 'bedrock-apikey'" class="space-y-4">
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockApiKeyInput') }}</label>
|
||||
<input
|
||||
v-model="editBedrockApiKeyValue"
|
||||
type="password"
|
||||
class="input font-mono"
|
||||
:placeholder="t('admin.accounts.bedrockApiKeyLeaveEmpty')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockApiKeyLeaveEmpty') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="input-label">{{ t('admin.accounts.bedrockRegion') }}</label>
|
||||
<input
|
||||
v-model="editBedrockApiKeyRegion"
|
||||
type="text"
|
||||
class="input"
|
||||
placeholder="us-east-1"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.accounts.bedrockRegionHint') }}</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
v-model="editBedrockApiKeyForceGlobal"
|
||||
type="checkbox"
|
||||
class="rounded border-gray-300 text-primary-600 focus:ring-primary-500 dark:border-dark-500"
|
||||
/>
|
||||
<span class="text-sm text-gray-700 dark:text-gray-300">{{ t('admin.accounts.bedrockForceGlobal') }}</span>
|
||||
</label>
|
||||
<p class="input-hint mt-1">{{ t('admin.accounts.bedrockForceGlobalHint') }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Restriction for Bedrock API Key -->
|
||||
<!-- Pool Mode Section for Bedrock -->
|
||||
<div class="border-t border-gray-200 pt-4 dark:border-dark-600">
|
||||
<label class="input-label">{{ t('admin.accounts.modelRestriction') }}</label>
|
||||
|
||||
<!-- Mode Toggle -->
|
||||
<div class="mb-4 flex gap-2">
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<div>
|
||||
<label class="input-label mb-0">{{ t('admin.accounts.poolMode') }}</label>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.poolModeHint') }}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'whitelist'"
|
||||
@click="poolModeEnabled = !poolModeEnabled"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'whitelist'
|
||||
? 'bg-primary-100 text-primary-700 dark:bg-primary-900/30 dark:text-primary-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
|
||||
poolModeEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelWhitelist') }}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
@click="modelRestrictionMode = 'mapping'"
|
||||
:class="[
|
||||
'flex-1 rounded-lg px-4 py-2 text-sm font-medium transition-all',
|
||||
modelRestrictionMode === 'mapping'
|
||||
? 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
|
||||
: 'bg-gray-100 text-gray-600 hover:bg-gray-200 dark:bg-dark-600 dark:text-gray-400 dark:hover:bg-dark-500'
|
||||
]"
|
||||
>
|
||||
{{ t('admin.accounts.modelMapping') }}
|
||||
<span
|
||||
:class="[
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
poolModeEnabled ? 'translate-x-5' : 'translate-x-0'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Whitelist Mode -->
|
||||
<div v-if="modelRestrictionMode === 'whitelist'">
|
||||
<ModelWhitelistSelector v-model="allowedModels" platform="anthropic" />
|
||||
<p class="text-xs text-gray-500 dark:text-gray-400">
|
||||
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
|
||||
<span v-if="allowedModels.length === 0">{{ t('admin.accounts.supportsAllModels') }}</span>
|
||||
<div v-if="poolModeEnabled" class="rounded-lg bg-blue-50 p-3 dark:bg-blue-900/20">
|
||||
<p class="text-xs text-blue-700 dark:text-blue-400">
|
||||
<Icon name="exclamationCircle" size="sm" class="mr-1 inline" :stroke-width="2" />
|
||||
{{ t('admin.accounts.poolModeInfo') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Mapping Mode -->
|
||||
<div v-else class="space-y-3">
|
||||
<div v-for="(mapping, index) in modelMappings" :key="getModelMappingKey(mapping)" class="flex items-center gap-2">
|
||||
<input v-model="mapping.from" type="text" class="input flex-1" :placeholder="t('admin.accounts.fromModel')" />
|
||||
<span class="text-gray-400">→</span>
|
||||
<input v-model="mapping.to" type="text" class="input flex-1" :placeholder="t('admin.accounts.toModel')" />
|
||||
<button type="button" @click="modelMappings.splice(index, 1)" class="text-red-500 hover:text-red-700">
|
||||
<Icon name="trash" size="sm" />
|
||||
</button>
|
||||
</div>
|
||||
<button type="button" @click="modelMappings.push({ from: '', to: '' })" class="btn btn-secondary text-sm">
|
||||
+ {{ t('admin.accounts.addMapping') }}
|
||||
</button>
|
||||
<!-- Bedrock Preset Mappings -->
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<button
|
||||
v-for="preset in bedrockPresets"
|
||||
:key="preset.from"
|
||||
type="button"
|
||||
@click="modelMappings.push({ from: preset.from, to: preset.to })"
|
||||
:class="['rounded-lg px-3 py-1 text-xs transition-colors', preset.color]"
|
||||
>
|
||||
+ {{ preset.label }}
|
||||
</button>
|
||||
</div>
|
||||
<div v-if="poolModeEnabled" class="mt-3">
|
||||
<label class="input-label">{{ t('admin.accounts.poolModeRetryCount') }}</label>
|
||||
<input
|
||||
v-model.number="poolModeRetryCount"
|
||||
type="number"
|
||||
min="0"
|
||||
:max="MAX_POOL_MODE_RETRY_COUNT"
|
||||
step="1"
|
||||
class="input"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
{{
|
||||
t('admin.accounts.poolModeRetryCountHint', {
|
||||
default: DEFAULT_POOL_MODE_RETRY_COUNT,
|
||||
max: MAX_POOL_MODE_RETRY_COUNT
|
||||
})
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -1182,8 +1149,8 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- API Key 账号配额限制 -->
|
||||
<div v-if="account?.type === 'apikey'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
|
||||
<!-- API Key / Bedrock 账号配额限制 -->
|
||||
<div v-if="account?.type === 'apikey' || account?.type === 'bedrock'" class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4">
|
||||
<div class="mb-3">
|
||||
<h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaLimit') }}</h3>
|
||||
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
|
||||
@@ -1194,9 +1161,21 @@
|
||||
:totalLimit="editQuotaLimit"
|
||||
:dailyLimit="editQuotaDailyLimit"
|
||||
:weeklyLimit="editQuotaWeeklyLimit"
|
||||
:dailyResetMode="editDailyResetMode"
|
||||
:dailyResetHour="editDailyResetHour"
|
||||
:weeklyResetMode="editWeeklyResetMode"
|
||||
:weeklyResetDay="editWeeklyResetDay"
|
||||
:weeklyResetHour="editWeeklyResetHour"
|
||||
:resetTimezone="editResetTimezone"
|
||||
@update:totalLimit="editQuotaLimit = $event"
|
||||
@update:dailyLimit="editQuotaDailyLimit = $event"
|
||||
@update:weeklyLimit="editQuotaWeeklyLimit = $event"
|
||||
@update:dailyResetMode="editDailyResetMode = $event"
|
||||
@update:dailyResetHour="editDailyResetHour = $event"
|
||||
@update:weeklyResetMode="editWeeklyResetMode = $event"
|
||||
@update:weeklyResetDay="editWeeklyResetDay = $event"
|
||||
@update:weeklyResetHour="editWeeklyResetHour = $event"
|
||||
@update:resetTimezone="editResetTimezone = $event"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -1781,11 +1760,11 @@ const editBedrockSecretAccessKey = ref('')
|
||||
const editBedrockSessionToken = ref('')
|
||||
const editBedrockRegion = ref('')
|
||||
const editBedrockForceGlobal = ref(false)
|
||||
|
||||
// Bedrock API Key credentials
|
||||
const editBedrockApiKeyValue = ref('')
|
||||
const editBedrockApiKeyRegion = ref('')
|
||||
const editBedrockApiKeyForceGlobal = ref(false)
|
||||
const isBedrockAPIKeyMode = computed(() =>
|
||||
props.account?.type === 'bedrock' &&
|
||||
(props.account?.credentials as Record<string, unknown>)?.auth_mode === 'apikey'
|
||||
)
|
||||
const modelMappings = ref<ModelMapping[]>([])
|
||||
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||
const allowedModels = ref<string[]>([])
|
||||
@@ -1847,6 +1826,12 @@ const anthropicPassthroughEnabled = ref(false)
|
||||
const editQuotaLimit = ref<number | null>(null)
|
||||
const editQuotaDailyLimit = ref<number | null>(null)
|
||||
const editQuotaWeeklyLimit = ref<number | null>(null)
|
||||
const editDailyResetMode = ref<'rolling' | 'fixed' | null>(null)
|
||||
const editDailyResetHour = ref<number | null>(null)
|
||||
const editWeeklyResetMode = ref<'rolling' | 'fixed' | null>(null)
|
||||
const editWeeklyResetDay = ref<number | null>(null)
|
||||
const editWeeklyResetHour = ref<number | null>(null)
|
||||
const editResetTimezone = ref<string | null>(null)
|
||||
const openAIWSModeOptions = computed(() => [
|
||||
{ value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') },
|
||||
// TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复
|
||||
@@ -2026,18 +2011,31 @@ watch(
|
||||
anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true
|
||||
}
|
||||
|
||||
// Load quota limit for apikey accounts
|
||||
if (newAccount.type === 'apikey') {
|
||||
// Load quota limit for apikey/bedrock accounts (bedrock quota is also loaded in its own branch above)
|
||||
if (newAccount.type === 'apikey' || newAccount.type === 'bedrock') {
|
||||
const quotaVal = extra?.quota_limit as number | undefined
|
||||
editQuotaLimit.value = (quotaVal && quotaVal > 0) ? quotaVal : null
|
||||
const dailyVal = extra?.quota_daily_limit as number | undefined
|
||||
editQuotaDailyLimit.value = (dailyVal && dailyVal > 0) ? dailyVal : null
|
||||
const weeklyVal = extra?.quota_weekly_limit as number | undefined
|
||||
editQuotaWeeklyLimit.value = (weeklyVal && weeklyVal > 0) ? weeklyVal : null
|
||||
// Load quota reset mode config
|
||||
editDailyResetMode.value = (extra?.quota_daily_reset_mode as 'rolling' | 'fixed') || null
|
||||
editDailyResetHour.value = (extra?.quota_daily_reset_hour as number) ?? null
|
||||
editWeeklyResetMode.value = (extra?.quota_weekly_reset_mode as 'rolling' | 'fixed') || null
|
||||
editWeeklyResetDay.value = (extra?.quota_weekly_reset_day as number) ?? null
|
||||
editWeeklyResetHour.value = (extra?.quota_weekly_reset_hour as number) ?? null
|
||||
editResetTimezone.value = (extra?.quota_reset_timezone as string) || null
|
||||
} else {
|
||||
editQuotaLimit.value = null
|
||||
editQuotaDailyLimit.value = null
|
||||
editQuotaWeeklyLimit.value = null
|
||||
editDailyResetMode.value = null
|
||||
editDailyResetHour.value = null
|
||||
editWeeklyResetMode.value = null
|
||||
editWeeklyResetDay.value = null
|
||||
editWeeklyResetHour.value = null
|
||||
editResetTimezone.value = null
|
||||
}
|
||||
|
||||
// Load antigravity model mapping (Antigravity 只支持映射模式)
|
||||
@@ -2130,11 +2128,28 @@ watch(
|
||||
}
|
||||
} else if (newAccount.type === 'bedrock' && newAccount.credentials) {
|
||||
const bedrockCreds = newAccount.credentials as Record<string, unknown>
|
||||
editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || ''
|
||||
const authMode = (bedrockCreds.auth_mode as string) || 'sigv4'
|
||||
editBedrockRegion.value = (bedrockCreds.aws_region as string) || ''
|
||||
editBedrockForceGlobal.value = (bedrockCreds.aws_force_global as string) === 'true'
|
||||
editBedrockSecretAccessKey.value = ''
|
||||
editBedrockSessionToken.value = ''
|
||||
|
||||
if (authMode === 'apikey') {
|
||||
editBedrockApiKeyValue.value = ''
|
||||
} else {
|
||||
editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || ''
|
||||
editBedrockSecretAccessKey.value = ''
|
||||
editBedrockSessionToken.value = ''
|
||||
}
|
||||
|
||||
// Load pool mode for bedrock
|
||||
poolModeEnabled.value = bedrockCreds.pool_mode === true
|
||||
const retryCount = bedrockCreds.pool_mode_retry_count
|
||||
poolModeRetryCount.value = (typeof retryCount === 'number' && retryCount >= 0) ? retryCount : DEFAULT_POOL_MODE_RETRY_COUNT
|
||||
|
||||
// Load quota limits for bedrock
|
||||
const bedrockExtra = (newAccount.extra as Record<string, unknown>) || {}
|
||||
editQuotaLimit.value = typeof bedrockExtra.quota_limit === 'number' ? bedrockExtra.quota_limit : null
|
||||
editQuotaDailyLimit.value = typeof bedrockExtra.quota_daily_limit === 'number' ? bedrockExtra.quota_daily_limit : null
|
||||
editQuotaWeeklyLimit.value = typeof bedrockExtra.quota_weekly_limit === 'number' ? bedrockExtra.quota_weekly_limit : null
|
||||
|
||||
// Load model mappings for bedrock
|
||||
const existingMappings = bedrockCreds.model_mapping as Record<string, string> | undefined
|
||||
@@ -2155,31 +2170,6 @@ watch(
|
||||
modelMappings.value = []
|
||||
allowedModels.value = []
|
||||
}
|
||||
} else if (newAccount.type === 'bedrock-apikey' && newAccount.credentials) {
|
||||
const bedrockApiKeyCreds = newAccount.credentials as Record<string, unknown>
|
||||
editBedrockApiKeyRegion.value = (bedrockApiKeyCreds.aws_region as string) || 'us-east-1'
|
||||
editBedrockApiKeyForceGlobal.value = (bedrockApiKeyCreds.aws_force_global as string) === 'true'
|
||||
editBedrockApiKeyValue.value = ''
|
||||
|
||||
// Load model mappings for bedrock-apikey
|
||||
const existingMappings = bedrockApiKeyCreds.model_mapping as Record<string, string> | undefined
|
||||
if (existingMappings && typeof existingMappings === 'object') {
|
||||
const entries = Object.entries(existingMappings)
|
||||
const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
|
||||
if (isWhitelistMode) {
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
allowedModels.value = entries.map(([from]) => from)
|
||||
modelMappings.value = []
|
||||
} else {
|
||||
modelRestrictionMode.value = 'mapping'
|
||||
modelMappings.value = entries.map(([from, to]) => ({ from, to }))
|
||||
allowedModels.value = []
|
||||
}
|
||||
} else {
|
||||
modelRestrictionMode.value = 'whitelist'
|
||||
modelMappings.value = []
|
||||
allowedModels.value = []
|
||||
}
|
||||
} else if (newAccount.type === 'upstream' && newAccount.credentials) {
|
||||
const credentials = newAccount.credentials as Record<string, unknown>
|
||||
editBaseUrl.value = (credentials.base_url as string) || ''
|
||||
@@ -2727,7 +2717,6 @@ const handleSubmit = async () => {
|
||||
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
|
||||
const newCredentials: Record<string, unknown> = { ...currentCredentials }
|
||||
|
||||
newCredentials.aws_access_key_id = editBedrockAccessKeyId.value.trim()
|
||||
newCredentials.aws_region = editBedrockRegion.value.trim()
|
||||
if (editBedrockForceGlobal.value) {
|
||||
newCredentials.aws_force_global = 'true'
|
||||
@@ -2735,42 +2724,29 @@ const handleSubmit = async () => {
|
||||
delete newCredentials.aws_force_global
|
||||
}
|
||||
|
||||
// Only update secrets if user provided new values
|
||||
if (editBedrockSecretAccessKey.value.trim()) {
|
||||
newCredentials.aws_secret_access_key = editBedrockSecretAccessKey.value.trim()
|
||||
}
|
||||
if (editBedrockSessionToken.value.trim()) {
|
||||
newCredentials.aws_session_token = editBedrockSessionToken.value.trim()
|
||||
}
|
||||
|
||||
// Model mapping
|
||||
const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
|
||||
if (modelMapping) {
|
||||
newCredentials.model_mapping = modelMapping
|
||||
if (isBedrockAPIKeyMode.value) {
|
||||
// API Key mode: only update api_key if user provided new value
|
||||
if (editBedrockApiKeyValue.value.trim()) {
|
||||
newCredentials.api_key = editBedrockApiKeyValue.value.trim()
|
||||
}
|
||||
} else {
|
||||
delete newCredentials.model_mapping
|
||||
// SigV4 mode
|
||||
newCredentials.aws_access_key_id = editBedrockAccessKeyId.value.trim()
|
||||
if (editBedrockSecretAccessKey.value.trim()) {
|
||||
newCredentials.aws_secret_access_key = editBedrockSecretAccessKey.value.trim()
|
||||
}
|
||||
if (editBedrockSessionToken.value.trim()) {
|
||||
newCredentials.aws_session_token = editBedrockSessionToken.value.trim()
|
||||
}
|
||||
}
|
||||
|
||||
applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
|
||||
if (!applyTempUnschedConfig(newCredentials)) {
|
||||
return
|
||||
}
|
||||
|
||||
updatePayload.credentials = newCredentials
|
||||
} else if (props.account.type === 'bedrock-apikey') {
|
||||
const currentCredentials = (props.account.credentials as Record<string, unknown>) || {}
|
||||
const newCredentials: Record<string, unknown> = { ...currentCredentials }
|
||||
|
||||
newCredentials.aws_region = editBedrockApiKeyRegion.value.trim() || 'us-east-1'
|
||||
if (editBedrockApiKeyForceGlobal.value) {
|
||||
newCredentials.aws_force_global = 'true'
|
||||
// Pool mode
|
||||
if (poolModeEnabled.value) {
|
||||
newCredentials.pool_mode = true
|
||||
newCredentials.pool_mode_retry_count = normalizePoolModeRetryCount(poolModeRetryCount.value)
|
||||
} else {
|
||||
delete newCredentials.aws_force_global
|
||||
}
|
||||
|
||||
// Only update API key if user provided new value
|
||||
if (editBedrockApiKeyValue.value.trim()) {
|
||||
newCredentials.api_key = editBedrockApiKeyValue.value.trim()
|
||||
delete newCredentials.pool_mode
|
||||
delete newCredentials.pool_mode_retry_count
|
||||
}
|
||||
|
||||
// Model mapping
|
||||
@@ -2980,8 +2956,8 @@ const handleSubmit = async () => {
|
||||
updatePayload.extra = newExtra
|
||||
}
|
||||
|
||||
// For apikey accounts, handle quota_limit in extra
|
||||
if (props.account.type === 'apikey') {
|
||||
// For apikey/bedrock accounts, handle quota_limit in extra
|
||||
if (props.account.type === 'apikey' || props.account.type === 'bedrock') {
|
||||
const currentExtra = (updatePayload.extra as Record<string, unknown>) ||
|
||||
(props.account.extra as Record<string, unknown>) || {}
|
||||
const newExtra: Record<string, unknown> = { ...currentExtra }
|
||||
@@ -3000,6 +2976,28 @@ const handleSubmit = async () => {
|
||||
} else {
|
||||
delete newExtra.quota_weekly_limit
|
||||
}
|
||||
// Quota reset mode config
|
||||
if (editDailyResetMode.value === 'fixed') {
|
||||
newExtra.quota_daily_reset_mode = 'fixed'
|
||||
newExtra.quota_daily_reset_hour = editDailyResetHour.value ?? 0
|
||||
} else {
|
||||
delete newExtra.quota_daily_reset_mode
|
||||
delete newExtra.quota_daily_reset_hour
|
||||
}
|
||||
if (editWeeklyResetMode.value === 'fixed') {
|
||||
newExtra.quota_weekly_reset_mode = 'fixed'
|
||||
newExtra.quota_weekly_reset_day = editWeeklyResetDay.value ?? 1
|
||||
newExtra.quota_weekly_reset_hour = editWeeklyResetHour.value ?? 0
|
||||
} else {
|
||||
delete newExtra.quota_weekly_reset_mode
|
||||
delete newExtra.quota_weekly_reset_day
|
||||
delete newExtra.quota_weekly_reset_hour
|
||||
}
|
||||
if (editDailyResetMode.value === 'fixed' || editWeeklyResetMode.value === 'fixed') {
|
||||
newExtra.quota_reset_timezone = editResetTimezone.value || 'UTC'
|
||||
} else {
|
||||
delete newExtra.quota_reset_timezone
|
||||
}
|
||||
updatePayload.extra = newExtra
|
||||
}
|
||||
|
||||
|
||||
@@ -131,7 +131,8 @@ const { t } = useI18n()
|
||||
|
||||
const props = defineProps<{
|
||||
modelValue: string[]
|
||||
platform: string
|
||||
platform?: string
|
||||
platforms?: string[]
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
@@ -144,11 +145,36 @@ const showDropdown = ref(false)
|
||||
const searchQuery = ref('')
|
||||
const customModel = ref('')
|
||||
const isComposing = ref(false)
|
||||
const normalizedPlatforms = computed(() => {
|
||||
const rawPlatforms =
|
||||
props.platforms && props.platforms.length > 0
|
||||
? props.platforms
|
||||
: props.platform
|
||||
? [props.platform]
|
||||
: []
|
||||
|
||||
return Array.from(
|
||||
new Set(
|
||||
rawPlatforms
|
||||
.map(platform => platform?.trim())
|
||||
.filter((platform): platform is string => Boolean(platform))
|
||||
)
|
||||
)
|
||||
})
|
||||
|
||||
const availableOptions = computed(() => {
|
||||
if (props.platform === 'sora') {
|
||||
return getModelsByPlatform('sora').map(m => ({ value: m, label: m }))
|
||||
if (normalizedPlatforms.value.length === 0) {
|
||||
return allModels
|
||||
}
|
||||
return allModels
|
||||
|
||||
const allowedModels = new Set<string>()
|
||||
for (const platform of normalizedPlatforms.value) {
|
||||
for (const model of getModelsByPlatform(platform)) {
|
||||
allowedModels.add(model)
|
||||
}
|
||||
}
|
||||
|
||||
return allModels.filter(model => allowedModels.has(model.value))
|
||||
})
|
||||
|
||||
const filteredModels = computed(() => {
|
||||
@@ -192,10 +218,13 @@ const handleEnter = () => {
|
||||
}
|
||||
|
||||
const fillRelated = () => {
|
||||
const models = getModelsByPlatform(props.platform)
|
||||
const newModels = [...props.modelValue]
|
||||
for (const model of models) {
|
||||
if (!newModels.includes(model)) newModels.push(model)
|
||||
for (const platform of normalizedPlatforms.value) {
|
||||
for (const model of getModelsByPlatform(platform)) {
|
||||
if (!newModels.includes(model)) {
|
||||
newModels.push(model)
|
||||
}
|
||||
}
|
||||
}
|
||||
emit('update:modelValue', newModels)
|
||||
}
|
||||
|
||||
@@ -8,12 +8,24 @@ const props = defineProps<{
|
||||
totalLimit: number | null
|
||||
dailyLimit: number | null
|
||||
weeklyLimit: number | null
|
||||
dailyResetMode: 'rolling' | 'fixed' | null
|
||||
dailyResetHour: number | null
|
||||
weeklyResetMode: 'rolling' | 'fixed' | null
|
||||
weeklyResetDay: number | null
|
||||
weeklyResetHour: number | null
|
||||
resetTimezone: string | null
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:totalLimit': [value: number | null]
|
||||
'update:dailyLimit': [value: number | null]
|
||||
'update:weeklyLimit': [value: number | null]
|
||||
'update:dailyResetMode': [value: 'rolling' | 'fixed' | null]
|
||||
'update:dailyResetHour': [value: number | null]
|
||||
'update:weeklyResetMode': [value: 'rolling' | 'fixed' | null]
|
||||
'update:weeklyResetDay': [value: number | null]
|
||||
'update:weeklyResetHour': [value: number | null]
|
||||
'update:resetTimezone': [value: string | null]
|
||||
}>()
|
||||
|
||||
const enabled = computed(() =>
|
||||
@@ -35,9 +47,56 @@ watch(localEnabled, (val) => {
|
||||
emit('update:totalLimit', null)
|
||||
emit('update:dailyLimit', null)
|
||||
emit('update:weeklyLimit', null)
|
||||
emit('update:dailyResetMode', null)
|
||||
emit('update:dailyResetHour', null)
|
||||
emit('update:weeklyResetMode', null)
|
||||
emit('update:weeklyResetDay', null)
|
||||
emit('update:weeklyResetHour', null)
|
||||
emit('update:resetTimezone', null)
|
||||
}
|
||||
})
|
||||
|
||||
// Whether any fixed mode is active (to show timezone selector)
|
||||
const hasFixedMode = computed(() =>
|
||||
props.dailyResetMode === 'fixed' || props.weeklyResetMode === 'fixed'
|
||||
)
|
||||
|
||||
// Common timezone options
|
||||
const timezoneOptions = [
|
||||
'UTC',
|
||||
'Asia/Shanghai',
|
||||
'Asia/Tokyo',
|
||||
'Asia/Seoul',
|
||||
'Asia/Singapore',
|
||||
'Asia/Kolkata',
|
||||
'Asia/Dubai',
|
||||
'Europe/London',
|
||||
'Europe/Paris',
|
||||
'Europe/Berlin',
|
||||
'Europe/Moscow',
|
||||
'America/New_York',
|
||||
'America/Chicago',
|
||||
'America/Denver',
|
||||
'America/Los_Angeles',
|
||||
'America/Sao_Paulo',
|
||||
'Australia/Sydney',
|
||||
'Pacific/Auckland',
|
||||
]
|
||||
|
||||
// Hours for dropdown (0-23)
|
||||
const hourOptions = Array.from({ length: 24 }, (_, i) => i)
|
||||
|
||||
// Day of week options
|
||||
const dayOptions = [
|
||||
{ value: 1, key: 'monday' },
|
||||
{ value: 2, key: 'tuesday' },
|
||||
{ value: 3, key: 'wednesday' },
|
||||
{ value: 4, key: 'thursday' },
|
||||
{ value: 5, key: 'friday' },
|
||||
{ value: 6, key: 'saturday' },
|
||||
{ value: 0, key: 'sunday' },
|
||||
]
|
||||
|
||||
const onTotalInput = (e: Event) => {
|
||||
const raw = (e.target as HTMLInputElement).valueAsNumber
|
||||
emit('update:totalLimit', Number.isNaN(raw) ? null : raw)
|
||||
@@ -50,6 +109,25 @@ const onWeeklyInput = (e: Event) => {
|
||||
const raw = (e.target as HTMLInputElement).valueAsNumber
|
||||
emit('update:weeklyLimit', Number.isNaN(raw) ? null : raw)
|
||||
}
|
||||
|
||||
const onDailyModeChange = (e: Event) => {
|
||||
const val = (e.target as HTMLSelectElement).value as 'rolling' | 'fixed'
|
||||
emit('update:dailyResetMode', val)
|
||||
if (val === 'fixed') {
|
||||
if (props.dailyResetHour == null) emit('update:dailyResetHour', 0)
|
||||
if (!props.resetTimezone) emit('update:resetTimezone', 'UTC')
|
||||
}
|
||||
}
|
||||
|
||||
const onWeeklyModeChange = (e: Event) => {
|
||||
const val = (e.target as HTMLSelectElement).value as 'rolling' | 'fixed'
|
||||
emit('update:weeklyResetMode', val)
|
||||
if (val === 'fixed') {
|
||||
if (props.weeklyResetDay == null) emit('update:weeklyResetDay', 1)
|
||||
if (props.weeklyResetHour == null) emit('update:weeklyResetHour', 0)
|
||||
if (!props.resetTimezone) emit('update:resetTimezone', 'UTC')
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
@@ -94,7 +172,37 @@ const onWeeklyInput = (e: Event) => {
|
||||
:placeholder="t('admin.accounts.quotaLimitPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
<p class="input-hint">{{ t('admin.accounts.quotaDailyLimitHint') }}</p>
|
||||
<!-- 日配额重置模式 -->
|
||||
<div class="mt-2 flex items-center gap-2">
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400 whitespace-nowrap">{{ t('admin.accounts.quotaResetMode') }}</label>
|
||||
<select
|
||||
:value="dailyResetMode || 'rolling'"
|
||||
@change="onDailyModeChange"
|
||||
class="input py-1 text-xs"
|
||||
>
|
||||
<option value="rolling">{{ t('admin.accounts.quotaResetModeRolling') }}</option>
|
||||
<option value="fixed">{{ t('admin.accounts.quotaResetModeFixed') }}</option>
|
||||
</select>
|
||||
</div>
|
||||
<!-- 固定模式:小时选择 -->
|
||||
<div v-if="dailyResetMode === 'fixed'" class="mt-2 flex items-center gap-2">
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400 whitespace-nowrap">{{ t('admin.accounts.quotaResetHour') }}</label>
|
||||
<select
|
||||
:value="dailyResetHour ?? 0"
|
||||
@change="emit('update:dailyResetHour', Number(($event.target as HTMLSelectElement).value))"
|
||||
class="input py-1 text-xs w-24"
|
||||
>
|
||||
<option v-for="h in hourOptions" :key="h" :value="h">{{ String(h).padStart(2, '0') }}:00</option>
|
||||
</select>
|
||||
</div>
|
||||
<p class="input-hint">
|
||||
<template v-if="dailyResetMode === 'fixed'">
|
||||
{{ t('admin.accounts.quotaDailyLimitHintFixed', { hour: String(dailyResetHour ?? 0).padStart(2, '0'), timezone: resetTimezone || 'UTC' }) }}
|
||||
</template>
|
||||
<template v-else>
|
||||
{{ t('admin.accounts.quotaDailyLimitHint') }}
|
||||
</template>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 周配额 -->
|
||||
@@ -112,7 +220,57 @@ const onWeeklyInput = (e: Event) => {
|
||||
:placeholder="t('admin.accounts.quotaLimitPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
<p class="input-hint">{{ t('admin.accounts.quotaWeeklyLimitHint') }}</p>
|
||||
<!-- 周配额重置模式 -->
|
||||
<div class="mt-2 flex items-center gap-2">
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400 whitespace-nowrap">{{ t('admin.accounts.quotaResetMode') }}</label>
|
||||
<select
|
||||
:value="weeklyResetMode || 'rolling'"
|
||||
@change="onWeeklyModeChange"
|
||||
class="input py-1 text-xs"
|
||||
>
|
||||
<option value="rolling">{{ t('admin.accounts.quotaResetModeRolling') }}</option>
|
||||
<option value="fixed">{{ t('admin.accounts.quotaResetModeFixed') }}</option>
|
||||
</select>
|
||||
</div>
|
||||
<!-- 固定模式:星期几 + 小时 -->
|
||||
<div v-if="weeklyResetMode === 'fixed'" class="mt-2 flex items-center gap-2 flex-wrap">
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400 whitespace-nowrap">{{ t('admin.accounts.quotaWeeklyResetDay') }}</label>
|
||||
<select
|
||||
:value="weeklyResetDay ?? 1"
|
||||
@change="emit('update:weeklyResetDay', Number(($event.target as HTMLSelectElement).value))"
|
||||
class="input py-1 text-xs w-28"
|
||||
>
|
||||
<option v-for="d in dayOptions" :key="d.value" :value="d.value">{{ t('admin.accounts.dayOfWeek.' + d.key) }}</option>
|
||||
</select>
|
||||
<label class="text-xs text-gray-500 dark:text-gray-400 whitespace-nowrap">{{ t('admin.accounts.quotaResetHour') }}</label>
|
||||
<select
|
||||
:value="weeklyResetHour ?? 0"
|
||||
@change="emit('update:weeklyResetHour', Number(($event.target as HTMLSelectElement).value))"
|
||||
class="input py-1 text-xs w-24"
|
||||
>
|
||||
<option v-for="h in hourOptions" :key="h" :value="h">{{ String(h).padStart(2, '0') }}:00</option>
|
||||
</select>
|
||||
</div>
|
||||
<p class="input-hint">
|
||||
<template v-if="weeklyResetMode === 'fixed'">
|
||||
{{ t('admin.accounts.quotaWeeklyLimitHintFixed', { day: t('admin.accounts.dayOfWeek.' + (dayOptions.find(d => d.value === (weeklyResetDay ?? 1))?.key || 'monday')), hour: String(weeklyResetHour ?? 0).padStart(2, '0'), timezone: resetTimezone || 'UTC' }) }}
|
||||
</template>
|
||||
<template v-else>
|
||||
{{ t('admin.accounts.quotaWeeklyLimitHint') }}
|
||||
</template>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 时区选择(当任一维度使用固定模式时显示) -->
|
||||
<div v-if="hasFixedMode">
|
||||
<label class="input-label">{{ t('admin.accounts.quotaResetTimezone') }}</label>
|
||||
<select
|
||||
:value="resetTimezone || 'UTC'"
|
||||
@change="emit('update:resetTimezone', ($event.target as HTMLSelectElement).value)"
|
||||
class="input text-sm"
|
||||
>
|
||||
<option v-for="tz in timezoneOptions" :key="tz" :value="tz">{{ tz }}</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<!-- 总配额 -->
|
||||
|
||||
@@ -1,5 +1,29 @@
|
||||
<template>
|
||||
<div>
|
||||
<!-- Window stats row (above progress bar) -->
|
||||
<div
|
||||
v-if="windowStats"
|
||||
class="mb-0.5 flex items-center"
|
||||
>
|
||||
<div class="flex items-center gap-1.5 text-[9px] text-gray-500 dark:text-gray-400">
|
||||
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800">
|
||||
{{ formatRequests }} req
|
||||
</span>
|
||||
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800">
|
||||
{{ formatTokens }}
|
||||
</span>
|
||||
<span class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800">
|
||||
A ${{ formatAccountCost }}
|
||||
</span>
|
||||
<span
|
||||
v-if="windowStats?.user_cost != null"
|
||||
class="rounded bg-gray-100 px-1.5 py-0.5 dark:bg-gray-800"
|
||||
>
|
||||
U ${{ formatUserCost }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Progress bar row -->
|
||||
<div class="flex items-center gap-1">
|
||||
<!-- Label badge (fixed width for alignment) -->
|
||||
@@ -108,4 +132,32 @@ const formatResetTime = computed(() => {
|
||||
}
|
||||
})
|
||||
|
||||
// Window stats formatters
|
||||
const formatRequests = computed(() => {
|
||||
if (!props.windowStats) return ''
|
||||
const r = props.windowStats.requests
|
||||
if (r >= 1000000) return `${(r / 1000000).toFixed(1)}M`
|
||||
if (r >= 1000) return `${(r / 1000).toFixed(1)}K`
|
||||
return r.toString()
|
||||
})
|
||||
|
||||
const formatTokens = computed(() => {
|
||||
if (!props.windowStats) return ''
|
||||
const t = props.windowStats.tokens
|
||||
if (t >= 1000000000) return `${(t / 1000000000).toFixed(1)}B`
|
||||
if (t >= 1000000) return `${(t / 1000000).toFixed(1)}M`
|
||||
if (t >= 1000) return `${(t / 1000).toFixed(1)}K`
|
||||
return t.toString()
|
||||
})
|
||||
|
||||
const formatAccountCost = computed(() => {
|
||||
if (!props.windowStats) return '0.00'
|
||||
return props.windowStats.cost.toFixed(2)
|
||||
})
|
||||
|
||||
const formatUserCost = computed(() => {
|
||||
if (!props.windowStats || props.windowStats.user_cost == null) return '0.00'
|
||||
return props.windowStats.user_cost.toFixed(2)
|
||||
})
|
||||
|
||||
</script>
|
||||
|
||||
@@ -76,7 +76,7 @@ const hasRecoverableState = computed(() => {
|
||||
return props.account?.status === 'error' || Boolean(isRateLimited.value) || Boolean(isOverloaded.value) || Boolean(isTempUnschedulable.value)
|
||||
})
|
||||
const hasQuotaLimit = computed(() => {
|
||||
return props.account?.type === 'apikey' && (
|
||||
return (props.account?.type === 'apikey' || props.account?.type === 'bedrock') && (
|
||||
(props.account?.quota_limit ?? 0) > 0 ||
|
||||
(props.account?.quota_daily_limit ?? 0) > 0 ||
|
||||
(props.account?.quota_weekly_limit ?? 0) > 0
|
||||
|
||||
@@ -410,6 +410,18 @@
|
||||
|
||||
<!-- Model Distribution -->
|
||||
<ModelDistributionChart :model-stats="stats.models" :loading="false" />
|
||||
|
||||
<EndpointDistributionChart
|
||||
:endpoint-stats="stats.endpoints || []"
|
||||
:loading="false"
|
||||
:title="t('usage.inboundEndpoint')"
|
||||
/>
|
||||
|
||||
<EndpointDistributionChart
|
||||
:endpoint-stats="stats.upstream_endpoints || []"
|
||||
:loading="false"
|
||||
:title="t('usage.upstreamEndpoint')"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<!-- No Data State -->
|
||||
@@ -453,6 +465,7 @@ import { Line } from 'vue-chartjs'
|
||||
import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||
import LoadingSpinner from '@/components/common/LoadingSpinner.vue'
|
||||
import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'
|
||||
import EndpointDistributionChart from '@/components/charts/EndpointDistributionChart.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { Account, AccountUsageStatsResponse } from '@/types'
|
||||
|
||||
@@ -24,7 +24,7 @@ const updateType = (value: string | number | boolean | null) => { emit('update:f
|
||||
const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) }
|
||||
const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) }
|
||||
const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }, { value: 'sora', label: 'Sora' }])
|
||||
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }])
|
||||
const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }])
|
||||
const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }])
|
||||
const gOpts = computed(() => [{ value: '', label: t('admin.accounts.allGroups') }, ...(props.groups || []).map(g => ({ value: String(g.id), label: g.name }))])
|
||||
</script>
|
||||
|
||||
@@ -35,6 +35,19 @@
|
||||
</span>
|
||||
</template>
|
||||
|
||||
<template #cell-endpoint="{ row }">
|
||||
<div class="max-w-[320px] space-y-1 text-xs">
|
||||
<div class="break-all text-gray-700 dark:text-gray-300">
|
||||
<span class="font-medium text-gray-500 dark:text-gray-400">{{ t('usage.inbound') }}:</span>
|
||||
<span class="ml-1">{{ row.inbound_endpoint?.trim() || '-' }}</span>
|
||||
</div>
|
||||
<div class="break-all text-gray-700 dark:text-gray-300">
|
||||
<span class="font-medium text-gray-500 dark:text-gray-400">{{ t('usage.upstream') }}:</span>
|
||||
<span class="ml-1">{{ row.upstream_endpoint?.trim() || '-' }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<template #cell-group="{ row }">
|
||||
<span v-if="row.group" class="inline-flex items-center rounded px-2 py-0.5 text-xs font-medium bg-indigo-100 text-indigo-800 dark:bg-indigo-900 dark:text-indigo-200">
|
||||
{{ row.group.name }}
|
||||
@@ -328,6 +341,7 @@ const getRequestTypeBadgeClass = (row: AdminUsageLog): string => {
|
||||
if (requestType === 'sync') return 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-200'
|
||||
return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200'
|
||||
}
|
||||
|
||||
const formatCacheTokens = (tokens: number): string => {
|
||||
if (tokens >= 1000000) return `${(tokens / 1000000).toFixed(1)}M`
|
||||
if (tokens >= 1000) return `${(tokens / 1000).toFixed(1)}K`
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user