merge upstream main

This commit is contained in:
song
2026-02-03 15:36:17 +08:00
67 changed files with 3014 additions and 283 deletions

View File

@@ -33,7 +33,7 @@ func main() {
}() }()
userRepo := repository.NewUserRepository(client, sqlDB) userRepo := repository.NewUserRepository(client, sqlDB)
authService := service.NewAuthService(userRepo, cfg, nil, nil, nil, nil, nil) authService := service.NewAuthService(userRepo, nil, cfg, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()

View File

@@ -43,6 +43,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
return nil, err return nil, err
} }
userRepository := repository.NewUserRepository(client, db) userRepository := repository.NewUserRepository(client, db)
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
settingRepository := repository.NewSettingRepository(client) settingRepository := repository.NewSettingRepository(client)
settingService := service.NewSettingService(settingRepository, configConfig) settingService := service.NewSettingService(settingRepository, configConfig)
redisClient := repository.ProvideRedis(configConfig) redisClient := repository.ProvideRedis(configConfig)
@@ -61,24 +62,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) authService := service.NewAuthService(userRepository, redeemCodeRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
secretEncryptor, err := repository.NewAESEncryptor(configConfig) secretEncryptor, err := repository.NewAESEncryptor(configConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
totpCache := repository.NewTotpCache(redisClient) totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService) totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, totpService) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
userHandler := handler.NewUserHandler(userService) userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db) usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
redeemHandler := handler.NewRedeemHandler(redeemService) redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
announcementRepository := repository.NewAnnouncementRepository(client) announcementRepository := repository.NewAnnouncementRepository(client)

View File

@@ -37,6 +37,7 @@ const (
RedeemTypeBalance = "balance" RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency" RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription" RedeemTypeSubscription = "subscription"
RedeemTypeInvitation = "invitation"
) )
// PromoCode status constants // PromoCode status constants

View File

@@ -47,6 +47,8 @@ type CreateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"` MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"` SupportedModelScopes []string `json:"supported_model_scopes"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
} }
// UpdateGroupRequest represents update group request // UpdateGroupRequest represents update group request
@@ -74,6 +76,8 @@ type UpdateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"` MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"` SupportedModelScopes *[]string `json:"supported_model_scopes"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
} }
// List handles listing all groups with pagination // List handles listing all groups with pagination
@@ -182,6 +186,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled, ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject, MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes, SupportedModelScopes: req.SupportedModelScopes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
@@ -227,6 +232,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled, ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject, MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes, SupportedModelScopes: req.SupportedModelScopes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)

View File

@@ -29,7 +29,7 @@ func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
// GenerateRedeemCodesRequest represents generate redeem codes request // GenerateRedeemCodesRequest represents generate redeem codes request
type GenerateRedeemCodesRequest struct { type GenerateRedeemCodesRequest struct {
Count int `json:"count" binding:"required,min=1,max=100"` Count int `json:"count" binding:"required,min=1,max=100"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription"` Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value" binding:"min=0"` Value float64 `json:"value" binding:"min=0"`
GroupID *int64 `json:"group_id"` // 订阅类型必填 GroupID *int64 `json:"group_id"` // 订阅类型必填
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用默认30天最大100年 ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用默认30天最大100年

View File

@@ -49,6 +49,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled, PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled, PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled, TotpEnabled: settings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: settings.SMTPHost, SMTPHost: settings.SMTPHost,
@@ -98,6 +99,7 @@ type UpdateSettingsRequest struct {
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
// 邮件服务设置 // 邮件服务设置
@@ -291,6 +293,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EmailVerifyEnabled: req.EmailVerifyEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled, PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled, PasswordResetEnabled: req.PasswordResetEnabled,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled, TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost, SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort, SMTPPort: req.SMTPPort,
@@ -370,6 +373,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled, PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled, PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
TotpEnabled: updatedSettings.TotpEnabled, TotpEnabled: updatedSettings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: updatedSettings.SMTPHost, SMTPHost: updatedSettings.SMTPHost,

View File

@@ -20,17 +20,19 @@ type AuthHandler struct {
userService *service.UserService userService *service.UserService
settingSvc *service.SettingService settingSvc *service.SettingService
promoService *service.PromoService promoService *service.PromoService
redeemService *service.RedeemService
totpService *service.TotpService totpService *service.TotpService
} }
// NewAuthHandler creates a new AuthHandler // NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler { func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler {
return &AuthHandler{ return &AuthHandler{
cfg: cfg, cfg: cfg,
authService: authService, authService: authService,
userService: userService, userService: userService,
settingSvc: settingService, settingSvc: settingService,
promoService: promoService, promoService: promoService,
redeemService: redeemService,
totpService: totpService, totpService: totpService,
} }
} }
@@ -42,6 +44,7 @@ type RegisterRequest struct {
VerifyCode string `json:"verify_code"` VerifyCode string `json:"verify_code"`
TurnstileToken string `json:"turnstile_token"` TurnstileToken string `json:"turnstile_token"`
PromoCode string `json:"promo_code"` // 注册优惠码 PromoCode string `json:"promo_code"` // 注册优惠码
InvitationCode string `json:"invitation_code"` // 邀请码
} }
// SendVerifyCodeRequest 发送验证码请求 // SendVerifyCodeRequest 发送验证码请求
@@ -87,7 +90,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
} }
} }
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode) token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -346,6 +349,66 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
}) })
} }
// ValidateInvitationCodeRequest 验证邀请码请求
type ValidateInvitationCodeRequest struct {
Code string `json:"code" binding:"required"`
}
// ValidateInvitationCodeResponse 验证邀请码响应
type ValidateInvitationCodeResponse struct {
Valid bool `json:"valid"`
ErrorCode string `json:"error_code,omitempty"`
}
// ValidateInvitationCode 验证邀请码(公开接口,注册前调用)
// POST /api/v1/auth/validate-invitation-code
func (h *AuthHandler) ValidateInvitationCode(c *gin.Context) {
// 检查邀请码功能是否启用
if h.settingSvc == nil || !h.settingSvc.IsInvitationCodeEnabled(c.Request.Context()) {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_DISABLED",
})
return
}
var req ValidateInvitationCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 验证邀请码
redeemCode, err := h.redeemService.GetByCode(c.Request.Context(), req.Code)
if err != nil {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_NOT_FOUND",
})
return
}
// 检查类型和状态
if redeemCode.Type != service.RedeemTypeInvitation {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_INVALID",
})
return
}
if redeemCode.Status != service.StatusUnused {
response.Success(c, ValidateInvitationCodeResponse{
Valid: false,
ErrorCode: "INVITATION_CODE_USED",
})
return
}
response.Success(c, ValidateInvitationCodeResponse{
Valid: true,
})
}
// ForgotPasswordRequest 忘记密码请求 // ForgotPasswordRequest 忘记密码请求
type ForgotPasswordRequest struct { type ForgotPasswordRequest struct {
Email string `json:"email" binding:"required,email"` Email string `json:"email" binding:"required,email"`

View File

@@ -381,6 +381,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
AccountID: l.AccountID, AccountID: l.AccountID,
RequestID: l.RequestID, RequestID: l.RequestID,
Model: l.Model, Model: l.Model,
ReasoningEffort: l.ReasoningEffort,
GroupID: l.GroupID, GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID, SubscriptionID: l.SubscriptionID,
InputTokens: l.InputTokens, InputTokens: l.InputTokens,

View File

@@ -6,6 +6,7 @@ type SystemSettings struct {
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置 TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
@@ -63,6 +64,7 @@ type PublicSettings struct {
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TurnstileEnabled bool `json:"turnstile_enabled"` TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSiteKey string `json:"turnstile_site_key"`

View File

@@ -237,6 +237,9 @@ type UsageLog struct {
AccountID int64 `json:"account_id"` AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"` RequestID string `json:"request_id"`
Model string `json:"model"` Model string `json:"model"`
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API).
// nil means not provided / not applicable.
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
SubscriptionID *int64 `json:"subscription_id"` SubscriptionID *int64 `json:"subscription_id"`

View File

@@ -596,7 +596,6 @@ func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service
cloned.Group = group cloned.Group = group
return &cloned return &cloned
} }
// Usage handles getting account balance and usage statistics for CC Switch integration // Usage handles getting account balance and usage statistics for CC Switch integration
// GET /v1/usage // GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) { func (h *GatewayHandler) Usage(c *gin.Context) {
@@ -849,6 +848,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return return
} }
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
setOpsRequestContext(c, "", false, body) setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body) parsedReq, err := service.ParseGatewayRequest(body)

View File

@@ -371,11 +371,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
// 6) record usage async // 6) record usage async (Gemini 使用长上下文双倍计费)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) { go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
@@ -383,6 +384,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: ua,
IPAddress: ip, IPAddress: ip,
LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }

View File

@@ -36,6 +36,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled, PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled, PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled, TotpEnabled: settings.TotpEnabled,
TurnstileEnabled: settings.TurnstileEnabled, TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey, TurnstileSiteKey: settings.TurnstileSiteKey,

View File

@@ -81,7 +81,6 @@ func ForwardBaseURLs() []string {
} }
return reordered return reordered
} }
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级) // URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级)
type URLAvailability struct { type URLAvailability struct {
mu sync.RWMutex mu sync.RWMutex

View File

@@ -9,11 +9,26 @@ const (
BetaClaudeCode = "claude-code-20250219" BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14" BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14" BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting = "token-counting-2024-11-01"
) )
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header // DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
//
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
// even if the request doesn't use tools, otherwise upstream may reject the
// request as a non-Claude-Code API request.
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta // HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultHeaders 是 Claude Code 客户端默认请求头。 // DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{ var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)", // Keep these in sync with recent Claude CLI traffic to reduce the chance
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
"User-Agent": "claude-cli/2.1.22 (external, cli)",
"X-Stainless-Lang": "js", "X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.52.0", "X-Stainless-Package-Version": "0.70.0",
"X-Stainless-OS": "Linux", "X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64", "X-Stainless-Arch": "arm64",
"X-Stainless-Runtime": "node", "X-Stainless-Runtime": "node",
"X-Stainless-Runtime-Version": "v22.14.0", "X-Stainless-Runtime-Version": "v24.13.0",
"X-Stainless-Retry-Count": "0", "X-Stainless-Retry-Count": "0",
"X-Stainless-Timeout": "60", "X-Stainless-Timeout": "600",
"X-App": "cli", "X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true", "Anthropic-Dangerous-Direct-Browser-Access": "true",
} }
@@ -79,3 +96,39 @@ func DefaultModelIDs() []string {
// DefaultTestModel 测试时使用的默认模型 // DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929" const DefaultTestModel = "claude-sonnet-4-5-20250929"
// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
var ModelIDOverrides = map[string]string{
"claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
"claude-opus-4-5": "claude-opus-4-5-20251101",
"claude-haiku-4-5": "claude-haiku-4-5-20251001",
}
// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
var ModelIDReverseOverrides = map[string]string{
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
"claude-opus-4-5-20251101": "claude-opus-4-5",
"claude-haiku-4-5-20251001": "claude-haiku-4-5",
}
// NormalizeModelID 根据 Claude OAuth 规则映射模型
func NormalizeModelID(id string) string {
if id == "" {
return id
}
if mapped, ok := ModelIDOverrides[id]; ok {
return mapped
}
return id
}
// DenormalizeModelID 将上游模型 ID 转换为短名
func DenormalizeModelID(id string) string {
if id == "" {
return id
}
if mapped, ok := ModelIDReverseOverrides[id]; ok {
return mapped
}
return id
}

View File

@@ -439,3 +439,61 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
return counts, nil return counts, nil
} }
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID去重
func (r *groupRepository) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
if len(groupIDs) == 0 {
return nil, nil
}
rows, err := r.sql.QueryContext(
ctx,
"SELECT DISTINCT account_id FROM account_groups WHERE group_id = ANY($1) ORDER BY account_id",
pq.Array(groupIDs),
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
var accountIDs []int64
for rows.Next() {
var accountID int64
if err := rows.Scan(&accountID); err != nil {
return nil, err
}
accountIDs = append(accountIDs, accountID)
}
if err := rows.Err(); err != nil {
return nil, err
}
return accountIDs, nil
}
// BindAccountsToGroup 将多个账号绑定到指定分组(批量插入,忽略已存在的绑定)
func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
if len(accountIDs) == 0 {
return nil
}
// 使用 INSERT ... ON CONFLICT DO NOTHING 忽略已存在的绑定
_, err := r.sql.ExecContext(
ctx,
`INSERT INTO account_groups (account_id, group_id, priority, created_at)
SELECT unnest($1::bigint[]), $2, 50, NOW()
ON CONFLICT (account_id, group_id) DO NOTHING`,
pq.Array(accountIDs),
groupID,
)
if err != nil {
return err
}
// 发送调度器事件
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
}
return nil
}

View File

@@ -22,7 +22,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at"
type usageLogRepository struct { type usageLogRepository struct {
client *dbent.Client client *dbent.Client
@@ -114,6 +114,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
ip_address, ip_address,
image_count, image_count,
image_size, image_size,
reasoning_effort,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $1, $2, $3, $4, $5,
@@ -121,7 +122,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11, $8, $9, $10, $11,
$12, $13, $12, $13,
$14, $15, $16, $17, $18, $19, $14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30 $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
@@ -134,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
userAgent := nullString(log.UserAgent) userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress) ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize) imageSize := nullString(log.ImageSize)
reasoningEffort := nullString(log.ReasoningEffort)
var requestIDArg any var requestIDArg any
if requestID != "" { if requestID != "" {
@@ -170,6 +172,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
ipAddress, ipAddress,
log.ImageCount, log.ImageCount,
imageSize, imageSize,
reasoningEffort,
createdAt, createdAt,
} }
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
@@ -2090,6 +2093,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
ipAddress sql.NullString ipAddress sql.NullString
imageCount int imageCount int
imageSize sql.NullString imageSize sql.NullString
reasoningEffort sql.NullString
createdAt time.Time createdAt time.Time
) )
@@ -2124,6 +2128,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&ipAddress, &ipAddress,
&imageCount, &imageCount,
&imageSize, &imageSize,
&reasoningEffort,
&createdAt, &createdAt,
); err != nil { ); err != nil {
return nil, err return nil, err
@@ -2183,6 +2188,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if imageSize.Valid { if imageSize.Valid {
log.ImageSize = &imageSize.String log.ImageSize = &imageSize.String
} }
if reasoningEffort.Valid {
log.ReasoningEffort = &reasoningEffort.String
}
return log, nil return log, nil
} }

View File

@@ -488,6 +488,7 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o", "fallback_model_openai": "gpt-4o",
"enable_identity_patch": true, "enable_identity_patch": true,
"identity_patch_prompt": "", "identity_patch_prompt": "",
"invitation_code_enabled": false,
"home_content": "", "home_content": "",
"hide_ccs_import_button": false, "hide_ccs_import_button": false,
"purchase_subscription_enabled": false, "purchase_subscription_enabled": false,
@@ -600,7 +601,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingService := service.NewSettingService(settingRepo, cfg) settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
@@ -880,6 +881,14 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
return 0, errors.New("not implemented") return 0, errors.New("not implemented")
} }
func (stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return errors.New("not implemented")
}
func (stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
type stubAccountRepo struct { type stubAccountRepo struct {
bulkUpdateIDs []int64 bulkUpdateIDs []int64
} }

View File

@@ -32,6 +32,10 @@ func RegisterAuthRoutes(
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{ auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose, FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ValidatePromoCode) }), h.Auth.ValidatePromoCode)
// 邀请码验证接口添加速率限制:每分钟最多 10 次Redis 故障时 fail-close
auth.POST("/validate-invitation-code", rateLimiter.LimitWithOptions("validate-invitation", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ValidateInvitationCode)
// 忘记密码接口添加速率限制:每分钟最多 5 次Redis 故障时 fail-close // 忘记密码接口添加速率限制:每分钟最多 5 次Redis 故障时 fail-close
auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{ auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose, FailureMode: middleware.RateLimitFailClose,

View File

@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string {
return "" return ""
} }
func (a *Account) GetClaudeUserID() string {
if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" {
return v
}
return ""
}
func (a *Account) IsCustomErrorCodesEnabled() bool { func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil { if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false return false

View File

@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
"system": []map[string]any{ "system": []map[string]any{
{ {
"type": "text", "type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.", "text": claudeCodeSystemPrompt,
"cache_control": map[string]string{ "cache_control": map[string]string{
"type": "ephemeral", "type": "ephemeral",
}, },

View File

@@ -115,6 +115,8 @@ type CreateGroupInput struct {
MCPXMLInject *bool MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string SupportedModelScopes []string
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
} }
type UpdateGroupInput struct { type UpdateGroupInput struct {
@@ -142,6 +144,8 @@ type UpdateGroupInput struct {
MCPXMLInject *bool MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string SupportedModelScopes *[]string
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
} }
type CreateAccountInput struct { type CreateAccountInput struct {
@@ -598,6 +602,38 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
mcpXMLInject = *input.MCPXMLInject mcpXMLInject = *input.MCPXMLInject
} }
// 如果指定了复制账号的源分组,先获取账号 ID 列表
var accountIDsToCopy []int64
if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs
seen := make(map[int64]struct{})
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
if _, exists := seen[srcGroupID]; !exists {
seen[srcGroupID] = struct{}{}
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
}
}
// 校验源分组的平台是否与新分组一致
for _, srcGroupID := range uniqueSourceGroupIDs {
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
if err != nil {
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
}
if srcGroup.Platform != platform {
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, platform, srcGroup.Platform)
}
}
// 获取所有源分组的账号(去重)
var err error
accountIDsToCopy, err = s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
if err != nil {
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
}
}
group := &Group{ group := &Group{
Name: input.Name, Name: input.Name,
Description: input.Description, Description: input.Description,
@@ -622,6 +658,15 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
} }
// 如果有需要复制的账号,绑定到新分组
if len(accountIDsToCopy) > 0 {
if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil {
return nil, fmt.Errorf("failed to bind accounts to new group: %w", err)
}
group.AccountCount = int64(len(accountIDsToCopy))
}
return group, nil return group, nil
} }
@@ -810,6 +855,54 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err return nil, err
} }
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs
seen := make(map[int64]struct{})
uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs))
for _, srcGroupID := range input.CopyAccountsFromGroupIDs {
// 校验:源分组不能是自身
if srcGroupID == id {
return nil, fmt.Errorf("cannot copy accounts from self")
}
// 去重
if _, exists := seen[srcGroupID]; !exists {
seen[srcGroupID] = struct{}{}
uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID)
}
}
// 校验源分组的平台是否与当前分组一致
for _, srcGroupID := range uniqueSourceGroupIDs {
srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID)
if err != nil {
return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err)
}
if srcGroup.Platform != group.Platform {
return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, group.Platform, srcGroup.Platform)
}
}
// 获取所有源分组的账号(去重)
accountIDsToCopy, err := s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs)
if err != nil {
return nil, fmt.Errorf("failed to get accounts from source groups: %w", err)
}
// 先清空当前分组的所有账号绑定
if _, err := s.groupRepo.DeleteAccountGroupsByGroupID(ctx, id); err != nil {
return nil, fmt.Errorf("failed to clear existing account bindings: %w", err)
}
// 再绑定源分组的账号
if len(accountIDsToCopy) > 0 {
if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil {
return nil, fmt.Errorf("failed to bind accounts to group: %w", err)
}
}
}
if s.authCacheInvalidator != nil { if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
} }

View File

@@ -164,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
panic("unexpected DeleteAccountGroupsByGroupID call") panic("unexpected DeleteAccountGroupsByGroupID call")
} }
func (s *groupRepoStub) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
type proxyRepoStub struct { type proxyRepoStub struct {
deleteErr error deleteErr error
countErr error countErr error

View File

@@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context,
panic("unexpected DeleteAccountGroupsByGroupID call") panic("unexpected DeleteAccountGroupsByGroupID call")
} }
func (s *groupRepoStubForAdmin) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递 // TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) { func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{} repo := &groupRepoStubForAdmin{}
@@ -379,6 +387,14 @@ func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.C
panic("unexpected DeleteAccountGroupsByGroupID call") panic("unexpected DeleteAccountGroupsByGroupID call")
} }
func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
type groupRepoStubForInvalidRequestFallback struct { type groupRepoStubForInvalidRequestFallback struct {
groups map[int64]*Group groups map[int64]*Group
created *Group created *Group

View File

@@ -302,13 +302,11 @@ func logPrefix(sessionID, accountName string) string {
} }
// Antigravity 直接支持的模型(精确匹配透传) // Antigravity 直接支持的模型(精确匹配透传)
// 注意gemini-2.5 系列已移除,统一映射到 gemini-3 系列
var antigravitySupportedModels = map[string]bool{ var antigravitySupportedModels = map[string]bool{
"claude-opus-4-5-thinking": true, "claude-opus-4-5-thinking": true,
"claude-sonnet-4-5": true, "claude-sonnet-4-5": true,
"claude-sonnet-4-5-thinking": true, "claude-sonnet-4-5-thinking": true,
"gemini-2.5-flash": true,
"gemini-2.5-flash-lite": true,
"gemini-2.5-flash-thinking": true,
"gemini-3-flash": true, "gemini-3-flash": true,
"gemini-3-pro-low": true, "gemini-3-pro-low": true,
"gemini-3-pro-high": true, "gemini-3-pro-high": true,
@@ -317,14 +315,24 @@ var antigravitySupportedModels = map[string]bool{
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先) // Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀) // 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
// gemini-2.5 系列统一映射到 gemini-3 系列Antigravity 上游不再支持 2.5
var antigravityPrefixMapping = []struct { var antigravityPrefixMapping = []struct {
prefix string prefix string
target string target string
}{ }{
// 长前缀优先 // gemini-2.5 → gemini-3 映射(长前缀优先
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image {"gemini-2.5-flash-thinking", "gemini-3-flash"}, // gemini-2.5-flash-thinking → gemini-3-flash
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → gemini-3-pro-image
{"gemini-2.5-flash-lite", "gemini-3-flash"}, // gemini-2.5-flash-lite → gemini-3-flash
{"gemini-2.5-flash", "gemini-3-flash"}, // gemini-2.5-flash → gemini-3-flash
{"gemini-2.5-pro-preview", "gemini-3-pro-high"}, // gemini-2.5-pro-preview → gemini-3-pro-high
{"gemini-2.5-pro-exp", "gemini-3-pro-high"}, // gemini-2.5-pro-exp → gemini-3-pro-high
{"gemini-2.5-pro", "gemini-3-pro-high"}, // gemini-2.5-pro → gemini-3-pro-high
// gemini-3 前缀映射
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等 {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
{"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash {"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
// Claude 映射
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet {"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
@@ -333,7 +341,6 @@ var antigravityPrefixMapping = []struct {
{"claude-sonnet-4", "claude-sonnet-4-5"}, {"claude-sonnet-4", "claude-sonnet-4-5"},
{"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet {"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet
{"claude-opus-4", "claude-opus-4-5-thinking"}, {"claude-opus-4", "claude-opus-4-5-thinking"},
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
} }
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发 // AntigravityGatewayService 处理 Antigravity 平台的 API 转发

View File

@@ -103,6 +103,10 @@ func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.
return s.resp, s.err return s.resp, s.err
} }
func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
return s.resp, s.err
}
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder() writer := httptest.NewRecorder()

View File

@@ -134,18 +134,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "claude-sonnet-4-5", expected: "claude-sonnet-4-5",
}, },
// 3. Gemini 透传 // 3. Gemini 2.5 → 3 映射
{ {
name: "Gemini透传 - gemini-2.5-flash", name: "Gemini映射 - gemini-2.5-flash → gemini-3-flash",
requestedModel: "gemini-2.5-flash", requestedModel: "gemini-2.5-flash",
accountMapping: nil, accountMapping: nil,
expected: "gemini-2.5-flash", expected: "gemini-3-flash",
}, },
{ {
name: "Gemini透传 - gemini-2.5-pro", name: "Gemini映射 - gemini-2.5-pro → gemini-3-pro-high",
requestedModel: "gemini-2.5-pro", requestedModel: "gemini-2.5-pro",
accountMapping: nil, accountMapping: nil,
expected: "gemini-2.5-pro", expected: "gemini-3-pro-high",
}, },
{ {
name: "Gemini透传 - gemini-future-model", name: "Gemini透传 - gemini-future-model",

View File

@@ -30,6 +30,8 @@ var (
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code")
) )
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。 // maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
@@ -47,6 +49,7 @@ type JWTClaims struct {
// AuthService 认证服务 // AuthService 认证服务
type AuthService struct { type AuthService struct {
userRepo UserRepository userRepo UserRepository
redeemRepo RedeemCodeRepository
cfg *config.Config cfg *config.Config
settingService *SettingService settingService *SettingService
emailService *EmailService emailService *EmailService
@@ -58,6 +61,7 @@ type AuthService struct {
// NewAuthService 创建认证服务实例 // NewAuthService 创建认证服务实例
func NewAuthService( func NewAuthService(
userRepo UserRepository, userRepo UserRepository,
redeemRepo RedeemCodeRepository,
cfg *config.Config, cfg *config.Config,
settingService *SettingService, settingService *SettingService,
emailService *EmailService, emailService *EmailService,
@@ -67,6 +71,7 @@ func NewAuthService(
) *AuthService { ) *AuthService {
return &AuthService{ return &AuthService{
userRepo: userRepo, userRepo: userRepo,
redeemRepo: redeemRepo,
cfg: cfg, cfg: cfg,
settingService: settingService, settingService: settingService,
emailService: emailService, emailService: emailService,
@@ -78,11 +83,11 @@ func NewAuthService(
// Register 用户注册返回token和用户 // Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "") return s.RegisterWithVerification(ctx, email, password, "", "", "")
} }
// RegisterWithVerification 用户注册(支持邮件验证优惠码返回token和用户 // RegisterWithVerification 用户注册(支持邮件验证优惠码和邀请码返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) { func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
// 检查是否开放注册默认关闭settingService 未配置时不允许注册) // 检查是否开放注册默认关闭settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled return "", nil, ErrRegDisabled
@@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrEmailReserved return "", nil, ErrEmailReserved
} }
// 检查是否需要邀请码
var invitationRedeemCode *RedeemCode
if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) {
if invitationCode == "" {
return "", nil, ErrInvitationCodeRequired
}
// 验证邀请码
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
if err != nil {
log.Printf("[Auth] Invalid invitation code: %s, error: %v", invitationCode, err)
return "", nil, ErrInvitationCodeInvalid
}
// 检查类型和状态
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
log.Printf("[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status)
return "", nil, ErrInvitationCodeInvalid
}
invitationRedeemCode = redeemCode
}
// 检查是否需要邮件验证 // 检查是否需要邮件验证
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) { if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
// 如果邮件验证已开启但邮件服务未配置,拒绝注册 // 如果邮件验证已开启但邮件服务未配置,拒绝注册
@@ -153,6 +178,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable return "", nil, ErrServiceUnavailable
} }
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
// 邀请码标记失败不影响注册,只记录日志
log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err)
}
}
// 应用优惠码(如果提供且功能已启用) // 应用优惠码(如果提供且功能已启用)
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) { if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil { if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {

View File

@@ -115,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
return NewAuthService( return NewAuthService(
repo, repo,
nil, // redeemRepo
cfg, cfg,
settingService, settingService,
emailService, emailService,
@@ -152,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
}, nil) }, nil)
// 应返回服务不可用错误,而不是允许绕过验证 // 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "") _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
require.ErrorIs(t, err, ErrServiceUnavailable) require.ErrorIs(t, err, ErrServiceUnavailable)
} }
@@ -164,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true", SettingKeyEmailVerifyEnabled: "true",
}, cache) }, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "") _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired) require.ErrorIs(t, err, ErrEmailVerifyRequired)
} }
@@ -178,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true", SettingKeyEmailVerifyEnabled: "true",
}, cache) }, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "") _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode) require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code") require.ErrorContains(t, err, "verify code")
} }

View File

@@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
return s.CalculateCost(model, tokens, multiplier) return s.CalculateCost(model, tokens, multiplier)
} }
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
// threshold: 阈值(如 200000超过此值的部分按 extraMultiplier 倍计费
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
//
// 示例:缓存 210k + 输入 10k = 220k阈值 200k倍率 2.0
// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k)
// 范围内正常计费,范围外 × 2 计费
func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) {
// 未启用长上下文计费,直接走正常计费
if threshold <= 0 || extraMultiplier <= 1 {
return s.CalculateCost(model, tokens, rateMultiplier)
}
// 计算总输入 token缓存读取 + 新输入)
total := tokens.CacheReadTokens + tokens.InputTokens
if total <= threshold {
return s.CalculateCost(model, tokens, rateMultiplier)
}
// 拆分成范围内和范围外
var inRangeCacheTokens, inRangeInputTokens int
var outRangeCacheTokens, outRangeInputTokens int
if tokens.CacheReadTokens >= threshold {
// 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入
inRangeCacheTokens = threshold
inRangeInputTokens = 0
outRangeCacheTokens = tokens.CacheReadTokens - threshold
outRangeInputTokens = tokens.InputTokens
} else {
// 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入
inRangeCacheTokens = tokens.CacheReadTokens
inRangeInputTokens = threshold - tokens.CacheReadTokens
outRangeCacheTokens = 0
outRangeInputTokens = tokens.InputTokens - inRangeInputTokens
}
// 范围内部分:正常计费
inRangeTokens := UsageTokens{
InputTokens: inRangeInputTokens,
OutputTokens: tokens.OutputTokens, // 输出只算一次
CacheCreationTokens: tokens.CacheCreationTokens,
CacheReadTokens: inRangeCacheTokens,
}
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
if err != nil {
return nil, err
}
// 范围外部分:× extraMultiplier 计费
outRangeTokens := UsageTokens{
InputTokens: outRangeInputTokens,
CacheReadTokens: outRangeCacheTokens,
}
outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier)
if err != nil {
return inRangeCost, nil // 出错时返回范围内成本
}
// 合并成本
return &CostBreakdown{
InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
OutputCost: inRangeCost.OutputCost,
CacheCreationCost: inRangeCost.CacheCreationCost,
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
ActualCost: inRangeCost.ActualCost + outRangeCost.ActualCost,
}, nil
}
// ListSupportedModels 列出所有支持的模型现在总是返回true因为有模糊匹配 // ListSupportedModels 列出所有支持的模型现在总是返回true因为有模糊匹配
func (s *BillingService) ListSupportedModels() []string { func (s *BillingService) ListSupportedModels() []string {
models := make([]string, 0) models := make([]string, 0)

View File

@@ -39,6 +39,7 @@ const (
RedeemTypeBalance = domain.RedeemTypeBalance RedeemTypeBalance = domain.RedeemTypeBalance
RedeemTypeConcurrency = domain.RedeemTypeConcurrency RedeemTypeConcurrency = domain.RedeemTypeConcurrency
RedeemTypeSubscription = domain.RedeemTypeSubscription RedeemTypeSubscription = domain.RedeemTypeSubscription
RedeemTypeInvitation = domain.RedeemTypeInvitation
) )
// PromoCode status constants // PromoCode status constants
@@ -76,6 +77,7 @@ const (
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能 SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证) SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
// 邮件服务设置 // 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址

View File

@@ -0,0 +1,23 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestMergeAnthropicBeta(t *testing.T) {
got := mergeAnthropicBeta(
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
"foo, oauth-2025-04-20,bar, foo",
)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar", got)
}
func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
got := mergeAnthropicBeta(
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
"",
)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
}

View File

@@ -266,6 +266,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte
return 0, nil return 0, nil
} }
func (m *mockGroupRepoForGateway) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return nil
}
func (m *mockGroupRepoForGateway) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, nil
}
func ptr[T any](v T) *T { func ptr[T any](v T) *T {
return &v return &v
} }

View File

@@ -0,0 +1,62 @@
package service
import (
"regexp"
"testing"
"github.com/stretchr/testify/require"
)
func TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
Model: "claude-sonnet-4-5",
Stream: true,
MetadataUserID: "",
System: nil,
Messages: nil,
}
account := &Account{
ID: 123,
Type: AccountTypeOAuth,
Extra: map[string]any{}, // intentionally missing account_uuid / claude_user_id
}
fp := &Fingerprint{ClientID: "deadbeef"} // should be used as user id in legacy format
got := svc.buildOAuthMetadataUserID(parsed, account, fp)
require.NotEmpty(t, got)
// Legacy format: user_{client}_account__session_{uuid}
re := regexp.MustCompile(`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`)
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
}
func TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
Model: "claude-sonnet-4-5",
Stream: true,
MetadataUserID: "",
}
account := &Account{
ID: 123,
Type: AccountTypeOAuth,
Extra: map[string]any{
"account_uuid": "acc-uuid",
"claude_user_id": "clientid123",
"anthropic_user_id": "",
},
}
got := svc.buildOAuthMetadataUserID(parsed, account, nil)
require.NotEmpty(t, got)
// New format: user_{client}_account_{account_uuid}_session_{uuid}
re := regexp.MustCompile(`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`)
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
}

View File

@@ -2,6 +2,7 @@ package service
import ( import (
"encoding/json" "encoding/json"
"strings"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
} }
func TestInjectClaudeCodePrompt(t *testing.T) { func TestInjectClaudeCodePrompt(t *testing.T) {
claudePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
tests := []struct { tests := []struct {
name string name string
body string body string
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
system: "Custom prompt", system: "Custom prompt",
wantSystemLen: 2, wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt, wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Custom prompt", wantSecondText: claudePrefix + "\n\nCustom prompt",
}, },
{ {
name: "string system equals Claude Code prompt", name: "string system equals Claude Code prompt",
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code + Custom = 2 // Claude Code + Custom = 2
wantSystemLen: 2, wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt, wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Custom", wantSecondText: claudePrefix + "\n\nCustom",
}, },
{ {
name: "array system with existing Claude Code prompt (should dedupe)", name: "array system with existing Claude Code prompt (should dedupe)",
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code at start + Other = 2 (deduped) // Claude Code at start + Other = 2 (deduped)
wantSystemLen: 2, wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt, wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Other", wantSecondText: claudePrefix + "\n\nOther",
}, },
{ {
name: "empty array", name: "empty array",

View File

@@ -0,0 +1,21 @@
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) {
in := "You are OpenCode, the best coding agent on the planet."
got := sanitizeSystemText(in)
require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got)
}
func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) {
in := "OpenCode and opencode are mentioned."
got := sanitizeToolDescription(in)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require.Equal(t, in, got)
}

File diff suppressed because it is too large Load Diff

View File

@@ -36,6 +36,11 @@ const (
geminiRetryMaxDelay = 16 * time.Second geminiRetryMaxDelay = 16 * time.Second
) )
// Gemini tool calling now requires `thoughtSignature` in parts that include `functionCall`.
// Many clients don't send it; we inject a known dummy signature to satisfy the validator.
// Ref: https://ai.google.dev/gemini-api/docs/thought-signatures
const geminiDummyThoughtSignature = "skip_thought_signature_validator"
type GeminiMessagesCompatService struct { type GeminiMessagesCompatService struct {
accountRepo AccountRepository accountRepo AccountRepository
groupRepo GroupRepository groupRepo GroupRepository
@@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if err != nil { if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
} }
geminiReq = ensureGeminiFunctionCallThoughtSignatures(geminiReq)
originalClaudeBody := body originalClaudeBody := body
proxyURL := "" proxyURL := ""
@@ -983,6 +989,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action) return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
} }
// Some Gemini upstreams validate tool call parts strictly; ensure any `functionCall` part includes a
// `thoughtSignature` to avoid frequent INVALID_ARGUMENT 400s.
body = ensureGeminiFunctionCallThoughtSignatures(body)
mappedModel := originalModel mappedModel := originalModel
if account.Type == AccountTypeAPIKey { if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel) mappedModel = account.GetMappedModel(originalModel)
@@ -2662,6 +2672,58 @@ func nextGeminiDailyResetUnix() *int64 {
return &ts return &ts
} }
func ensureGeminiFunctionCallThoughtSignatures(body []byte) []byte {
// Fast path: only run when functionCall is present.
if !bytes.Contains(body, []byte(`"functionCall"`)) {
return body
}
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return body
}
contentsAny, ok := payload["contents"].([]any)
if !ok || len(contentsAny) == 0 {
return body
}
modified := false
for _, c := range contentsAny {
cm, ok := c.(map[string]any)
if !ok {
continue
}
partsAny, ok := cm["parts"].([]any)
if !ok || len(partsAny) == 0 {
continue
}
for _, p := range partsAny {
pm, ok := p.(map[string]any)
if !ok || pm == nil {
continue
}
if fc, ok := pm["functionCall"].(map[string]any); !ok || fc == nil {
continue
}
ts, _ := pm["thoughtSignature"].(string)
if strings.TrimSpace(ts) == "" {
pm["thoughtSignature"] = geminiDummyThoughtSignature
modified = true
}
}
}
if !modified {
return body
}
b, err := json.Marshal(payload)
if err != nil {
return body
}
return b
}
func extractGeminiFinishReason(geminiResp map[string]any) string { func extractGeminiFinishReason(geminiResp map[string]any) string {
if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 { if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok { if cand, ok := candidates[0].(map[string]any); ok {
@@ -2861,7 +2923,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" { if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" {
toolUseIDToName[id] = name toolUseIDToName[id] = name
} }
signature, _ := bm["signature"].(string)
signature = strings.TrimSpace(signature)
if signature == "" {
signature = geminiDummyThoughtSignature
}
parts = append(parts, map[string]any{ parts = append(parts, map[string]any{
"thoughtSignature": signature,
"functionCall": map[string]any{ "functionCall": map[string]any{
"name": name, "name": name,
"args": bm["input"], "args": bm["input"],

View File

@@ -1,6 +1,8 @@
package service package service
import ( import (
"encoding/json"
"strings"
"testing" "testing"
) )
@@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
}) })
} }
} }
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
claudeReq := map[string]any{
"model": "claude-haiku-4-5-20251001",
"max_tokens": 10,
"messages": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "hi"},
},
},
map[string]any{
"role": "assistant",
"content": []any{
map[string]any{"type": "text", "text": "ok"},
map[string]any{
"type": "tool_use",
"id": "toolu_123",
"name": "default_api:write_file",
"input": map[string]any{"path": "a.txt", "content": "x"},
// no signature on purpose
},
},
},
},
"tools": []any{
map[string]any{
"name": "default_api:write_file",
"description": "write file",
"input_schema": map[string]any{
"type": "object",
"properties": map[string]any{"path": map[string]any{"type": "string"}},
},
},
},
}
b, _ := json.Marshal(claudeReq)
out, err := convertClaudeMessagesToGeminiGenerateContent(b)
if err != nil {
t.Fatalf("convert failed: %v", err)
}
s := string(out)
if !strings.Contains(s, "\"functionCall\"") {
t.Fatalf("expected functionCall in output, got: %s", s)
}
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
}
}
func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing.T) {
geminiReq := map[string]any{
"contents": []any{
map[string]any{
"role": "user",
"parts": []any{
map[string]any{
"functionCall": map[string]any{
"name": "default_api:write_file",
"args": map[string]any{"path": "a.txt"},
},
},
},
},
},
}
b, _ := json.Marshal(geminiReq)
out := ensureGeminiFunctionCallThoughtSignatures(b)
s := string(out)
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
}
}

View File

@@ -218,6 +218,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex
return 0, nil return 0, nil
} }
func (m *mockGroupRepoForGemini) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
return nil
}
func (m *mockGroupRepoForGemini) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
return nil, nil
}
var _ GroupRepository = (*mockGroupRepoForGemini)(nil) var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock // mockGatewayCacheForGemini Gemini 测试用的 cache mock

View File

@@ -29,6 +29,10 @@ type GroupRepository interface {
ExistsByName(ctx context.Context, name string) (bool, error) ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error) GetAccountCount(ctx context.Context, groupID int64) (int64, error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID去重
GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error)
// BindAccountsToGroup 将多个账号绑定到指定分组
BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error
} }
// CreateGroupRequest 创建分组请求 // CreateGroupRequest 创建分组请求

View File

@@ -26,13 +26,13 @@ var (
// 默认指纹值(当客户端未提供时使用) // 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{ var defaultFingerprint = Fingerprint{
UserAgent: "claude-cli/2.0.62 (external, cli)", UserAgent: "claude-cli/2.1.22 (external, cli)",
StainlessLang: "js", StainlessLang: "js",
StainlessPackageVersion: "0.52.0", StainlessPackageVersion: "0.70.0",
StainlessOS: "Linux", StainlessOS: "Linux",
StainlessArch: "x64", StainlessArch: "arm64",
StainlessRuntime: "node", StainlessRuntime: "node",
StainlessRuntimeVersion: "v22.14.0", StainlessRuntimeVersion: "v24.13.0",
} }
// Fingerprint represents account fingerprint data // Fingerprint represents account fingerprint data
@@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string {
} }
// parseUserAgentVersion 解析user-agent版本号 // parseUserAgentVersion 解析user-agent版本号
// 例如claude-cli/2.0.62 -> (2, 0, 62) // 例如claude-cli/2.1.2 -> (2, 1, 2)
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) { func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
// 匹配 xxx/x.y.z 格式 // 匹配 xxx/x.y.z 格式
matches := userAgentVersionRegex.FindStringSubmatch(ua) matches := userAgentVersionRegex.FindStringSubmatch(ua)

View File

@@ -159,6 +159,9 @@ type OpenAIForwardResult struct {
RequestID string RequestID string
Usage OpenAIUsage Usage OpenAIUsage
Model string Model string
// ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix.
// Stored for usage records display; nil means not provided / not applicable.
ReasoningEffort *string
Stream bool Stream bool
Duration time.Duration Duration time.Duration
FirstTokenMs *int FirstTokenMs *int
@@ -958,10 +961,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
} }
} }
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
return &OpenAIForwardResult{ return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: originalModel, Model: originalModel,
ReasoningEffort: reasoningEffort,
Stream: reqStream, Stream: reqStream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
@@ -1260,16 +1266,30 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率 // 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt := time.Now() lastDataAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) // 仅发送一次错误事件,避免多次写入导致协议混乱
// 注意OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema
// 否则下游 SDK例如 OpenCode会因为类型校验失败而报错。
errorEventSent := false errorEventSent := false
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
sendErrorEvent := func(reason string) { sendErrorEvent := func(reason string) {
if errorEventSent { if errorEventSent || clientDisconnected {
return return
} }
errorEventSent = true errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) payload := map[string]any{
"type": "error",
"sequence_number": 0,
"error": map[string]any{
"type": "upstream_error",
"message": reason,
"code": reason,
},
}
if b, err := json.Marshal(payload); err == nil {
_, _ = fmt.Fprintf(w, "data: %s\n\n", b)
flusher.Flush() flusher.Flush()
} }
}
needModelReplace := originalModel != mappedModel needModelReplace := originalModel != mappedModel
@@ -1280,6 +1300,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
} }
if ev.err != nil { if ev.err != nil {
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event避免下游 SDK 解析失败。
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
log.Printf("Context canceled during streaming, returning collected usage")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if clientDisconnected {
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if errors.Is(ev.err, bufio.ErrTooLong) { if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large") sendErrorEvent("response_too_large")
@@ -1303,15 +1334,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// Correct Codex tool calls if needed (apply_patch -> edit, etc.) // Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected { if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
data = correctedData
line = "data: " + correctedData line = "data: " + correctedData
} }
// Forward line // 写入客户端(客户端断开后继续 drain 上游)
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed") clientDisconnected = true
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} } else {
flusher.Flush() flusher.Flush()
}
}
// Record first token time // Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" { if firstTokenMs == nil && data != "" && data != "[DONE]" {
@@ -1321,18 +1356,25 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
s.parseSSEUsage(data, usage) s.parseSSEUsage(data, usage)
} else { } else {
// Forward non-data lines as-is // Forward non-data lines as-is
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed") clientDisconnected = true
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} } else {
flusher.Flush() flusher.Flush()
} }
}
}
case <-intervalCh: case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval { if time.Since(lastRead) < streamInterval {
continue continue
} }
if clientDisconnected {
log.Printf("Upstream timeout after client disconnect, returning collected usage")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态 // 处理流超时,可能标记账户为临时不可调度或错误状态
if s.rateLimitService != nil { if s.rateLimitService != nil {
@@ -1342,11 +1384,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh: case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval { if time.Since(lastDataAt) < keepaliveInterval {
continue continue
} }
if _, err := fmt.Fprint(w, ":\n\n"); err != nil { if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
continue
} }
flusher.Flush() flusher.Flush()
} }
@@ -1687,6 +1734,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
AccountID: account.ID, AccountID: account.ID,
RequestID: result.RequestID, RequestID: result.RequestID,
Model: result.Model, Model: result.Model,
ReasoningEffort: result.ReasoningEffort,
InputTokens: actualInputTokens, InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
@@ -1881,3 +1929,86 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}() }()
} }
func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) {
if reqBody == nil {
return "", false
}
// Primary: reasoning.effort
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
if effort, ok := reasoning["effort"].(string); ok {
return normalizeOpenAIReasoningEffort(effort), true
}
}
// Fallback: some clients may use a flat field.
if effort, ok := reqBody["reasoning_effort"].(string); ok {
return normalizeOpenAIReasoningEffort(effort), true
}
return "", false
}
func deriveOpenAIReasoningEffortFromModel(model string) string {
if strings.TrimSpace(model) == "" {
return ""
}
modelID := strings.TrimSpace(model)
if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool {
switch r {
case '-', '_', ' ':
return true
default:
return false
}
})
if len(parts) == 0 {
return ""
}
return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
}
func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
if value == "" {
return nil
}
return &value
}
value := deriveOpenAIReasoningEffortFromModel(requestedModel)
if value == "" {
return nil
}
return &value
}
func normalizeOpenAIReasoningEffort(raw string) string {
value := strings.ToLower(strings.TrimSpace(raw))
if value == "" {
return ""
}
// Normalize separators for "x-high"/"x_high" variants.
value = strings.NewReplacer("-", "", "_", "", " ", "").Replace(value)
switch value {
case "none", "minimal":
return ""
case "low", "medium", "high":
return value
case "xhigh", "extrahigh":
return "xhigh"
default:
// Only store known effort levels for now to keep UI consistent.
return ""
}
}

View File

@@ -59,6 +59,25 @@ type stubConcurrencyCache struct {
skipDefaultLoad bool skipDefaultLoad bool
} }
type cancelReadCloser struct{}
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
func (c cancelReadCloser) Close() error { return nil }
type failingGinWriter struct {
gin.ResponseWriter
failAfter int
writes int
}
func (w *failingGinWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed")
}
w.writes++
return w.ResponseWriter.Write(p)
}
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
if c.acquireResults != nil { if c.acquireResults != nil {
if result, ok := c.acquireResults[accountID]; ok { if result, ok := c.acquireResults[accountID]; ok {
@@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") { if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
t.Fatalf("expected stream timeout error, got %v", err) t.Fatalf("expected stream timeout error, got %v", err)
} }
if !strings.Contains(rec.Body.String(), "stream_timeout") { if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "stream_timeout") {
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String()) t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
}
}
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: cancelReadCloser{},
Header: http.Header{},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
}
}
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
}()
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if result == nil || result.usage == nil {
t.Fatalf("expected usage result")
}
if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 {
t.Fatalf("unexpected usage: %+v", *result.usage)
}
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") {
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
} }
} }
@@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
if !errors.Is(err, bufio.ErrTooLong) { if !errors.Is(err, bufio.ErrTooLong) {
t.Fatalf("expected ErrTooLong, got %v", err) t.Fatalf("expected ErrTooLong, got %v", err)
} }
if !strings.Contains(rec.Body.String(), "response_too_large") { if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "response_too_large") {
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String()) t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
} }
} }

View File

@@ -126,7 +126,8 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
return nil, errors.New("count must be greater than 0") return nil, errors.New("count must be greater than 0")
} }
if req.Value <= 0 { // 邀请码类型不需要数值,其他类型需要
if req.Type != RedeemTypeInvitation && req.Value <= 0 {
return nil, errors.New("value must be greater than 0") return nil, errors.New("value must be greater than 0")
} }
@@ -139,6 +140,12 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
codeType = RedeemTypeBalance codeType = RedeemTypeBalance
} }
// 邀请码类型的 value 设为 0
value := req.Value
if codeType == RedeemTypeInvitation {
value = 0
}
codes := make([]RedeemCode, 0, req.Count) codes := make([]RedeemCode, 0, req.Count)
for i := 0; i < req.Count; i++ { for i := 0; i < req.Count; i++ {
code, err := s.GenerateRandomCode() code, err := s.GenerateRandomCode()
@@ -149,7 +156,7 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
codes = append(codes, RedeemCode{ codes = append(codes, RedeemCode{
Code: code, Code: code,
Type: codeType, Type: codeType,
Value: req.Value, Value: value,
Status: StatusUnused, Status: StatusUnused,
}) })
} }

View File

@@ -62,6 +62,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyEmailVerifyEnabled, SettingKeyEmailVerifyEnabled,
SettingKeyPromoCodeEnabled, SettingKeyPromoCodeEnabled,
SettingKeyPasswordResetEnabled, SettingKeyPasswordResetEnabled,
SettingKeyInvitationCodeEnabled,
SettingKeyTotpEnabled, SettingKeyTotpEnabled,
SettingKeyTurnstileEnabled, SettingKeyTurnstileEnabled,
SettingKeyTurnstileSiteKey, SettingKeyTurnstileSiteKey,
@@ -99,6 +100,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
EmailVerifyEnabled: emailVerifyEnabled, EmailVerifyEnabled: emailVerifyEnabled,
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
PasswordResetEnabled: passwordResetEnabled, PasswordResetEnabled: passwordResetEnabled,
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
TotpEnabled: settings[SettingKeyTotpEnabled] == "true", TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
@@ -141,6 +143,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` TotpEnabled bool `json:"totp_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"` TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
@@ -161,6 +164,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled, PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled, PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled, TotpEnabled: settings.TotpEnabled,
TurnstileEnabled: settings.TurnstileEnabled, TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey, TurnstileSiteKey: settings.TurnstileSiteKey,
@@ -188,6 +192,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled) updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled) updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled)
updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled)
updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled) updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled)
// 邮件服务设置(只有非空才更新密码) // 邮件服务设置(只有非空才更新密码)
@@ -286,6 +291,14 @@ func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
return value != "false" return value != "false"
} }
// IsInvitationCodeEnabled 检查是否启用邀请码注册功能
func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyInvitationCodeEnabled)
if err != nil {
return false // 默认关闭
}
return value == "true"
}
// IsPasswordResetEnabled 检查是否启用密码重置功能 // IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证 // 要求:必须同时开启邮件验证
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool { func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
@@ -401,6 +414,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
EmailVerifyEnabled: emailVerifyEnabled, EmailVerifyEnabled: emailVerifyEnabled,
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true", PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true",
InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true",
TotpEnabled: settings[SettingKeyTotpEnabled] == "true", TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
SMTPHost: settings[SettingKeySMTPHost], SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername], SMTPUsername: settings[SettingKeySMTPUsername],

View File

@@ -5,6 +5,7 @@ type SystemSettings struct {
EmailVerifyEnabled bool EmailVerifyEnabled bool
PromoCodeEnabled bool PromoCodeEnabled bool
PasswordResetEnabled bool PasswordResetEnabled bool
InvitationCodeEnabled bool
TotpEnabled bool // TOTP 双因素认证 TotpEnabled bool // TOTP 双因素认证
SMTPHost string SMTPHost string
@@ -65,6 +66,7 @@ type PublicSettings struct {
EmailVerifyEnabled bool EmailVerifyEnabled bool
PromoCodeEnabled bool PromoCodeEnabled bool
PasswordResetEnabled bool PasswordResetEnabled bool
InvitationCodeEnabled bool
TotpEnabled bool // TOTP 双因素认证 TotpEnabled bool // TOTP 双因素认证
TurnstileEnabled bool TurnstileEnabled bool
TurnstileSiteKey string TurnstileSiteKey string

View File

@@ -14,6 +14,9 @@ type UsageLog struct {
AccountID int64 AccountID int64
RequestID string RequestID string
Model string Model string
// ReasoningEffort is the request's reasoning effort level (OpenAI Responses API),
// e.g. "low" / "medium" / "high" / "xhigh". Nil means not provided / not applicable.
ReasoningEffort *string
GroupID *int64 GroupID *int64
SubscriptionID *int64 SubscriptionID *int64

View File

@@ -0,0 +1,4 @@
-- Add reasoning_effort field to usage_logs for OpenAI/Codex requests.
-- This stores the request's reasoning effort level (e.g. low/medium/high/xhigh).
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS reasoning_effort VARCHAR(20);

View File

@@ -24,4 +24,11 @@ WHERE filename = '044_add_group_mcp_xml_inject.sql'
SELECT 1 FROM schema_migrations WHERE filename = '044b_add_group_mcp_xml_inject.sql' SELECT 1 FROM schema_migrations WHERE filename = '044b_add_group_mcp_xml_inject.sql'
); );
UPDATE schema_migrations
SET filename = '046b_add_group_supported_model_scopes.sql'
WHERE filename = '046_add_group_supported_model_scopes.sql'
AND NOT EXISTS (
SELECT 1 FROM schema_migrations WHERE filename = '046b_add_group_supported_model_scopes.sql'
);
COMMIT; COMMIT;

View File

@@ -14,6 +14,7 @@ export interface SystemSettings {
email_verify_enabled: boolean email_verify_enabled: boolean
promo_code_enabled: boolean promo_code_enabled: boolean
password_reset_enabled: boolean password_reset_enabled: boolean
invitation_code_enabled: boolean
totp_enabled: boolean // TOTP 双因素认证 totp_enabled: boolean // TOTP 双因素认证
totp_encryption_key_configured: boolean // TOTP 加密密钥是否已配置 totp_encryption_key_configured: boolean // TOTP 加密密钥是否已配置
// Default settings // Default settings
@@ -72,6 +73,7 @@ export interface UpdateSettingsRequest {
email_verify_enabled?: boolean email_verify_enabled?: boolean
promo_code_enabled?: boolean promo_code_enabled?: boolean
password_reset_enabled?: boolean password_reset_enabled?: boolean
invitation_code_enabled?: boolean
totp_enabled?: boolean // TOTP 双因素认证 totp_enabled?: boolean // TOTP 双因素认证
default_balance?: number default_balance?: number
default_concurrency?: number default_concurrency?: number

View File

@@ -164,6 +164,24 @@ export async function validatePromoCode(code: string): Promise<ValidatePromoCode
return data return data
} }
/**
* Validate invitation code response
*/
export interface ValidateInvitationCodeResponse {
valid: boolean
error_code?: string
}
/**
* Validate invitation code (public endpoint, no auth required)
* @param code - Invitation code to validate
* @returns Validation result
*/
export async function validateInvitationCode(code: string): Promise<ValidateInvitationCodeResponse> {
const { data } = await apiClient.post<ValidateInvitationCodeResponse>('/auth/validate-invitation-code', { code })
return data
}
/** /**
* Forgot password request * Forgot password request
*/ */
@@ -229,6 +247,7 @@ export const authAPI = {
getPublicSettings, getPublicSettings,
sendVerifyCode, sendVerifyCode,
validatePromoCode, validatePromoCode,
validateInvitationCode,
forgotPassword, forgotPassword,
resetPassword resetPassword
} }

View File

@@ -21,6 +21,12 @@
<span class="font-medium text-gray-900 dark:text-white">{{ value }}</span> <span class="font-medium text-gray-900 dark:text-white">{{ value }}</span>
</template> </template>
<template #cell-reasoning_effort="{ row }">
<span class="text-sm text-gray-900 dark:text-white">
{{ formatReasoningEffort(row.reasoning_effort) }}
</span>
</template>
<template #cell-group="{ row }"> <template #cell-group="{ row }">
<span v-if="row.group" class="inline-flex items-center rounded px-2 py-0.5 text-xs font-medium bg-indigo-100 text-indigo-800 dark:bg-indigo-900 dark:text-indigo-200"> <span v-if="row.group" class="inline-flex items-center rounded px-2 py-0.5 text-xs font-medium bg-indigo-100 text-indigo-800 dark:bg-indigo-900 dark:text-indigo-200">
{{ row.group.name }} {{ row.group.name }}
@@ -235,7 +241,7 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed } from 'vue' import { ref, computed } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { formatDateTime } from '@/utils/format' import { formatDateTime, formatReasoningEffort } from '@/utils/format'
import DataTable from '@/components/common/DataTable.vue' import DataTable from '@/components/common/DataTable.vue'
import EmptyState from '@/components/common/EmptyState.vue' import EmptyState from '@/components/common/EmptyState.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
@@ -259,6 +265,7 @@ const cols = computed(() => [
{ key: 'api_key', label: t('usage.apiKeyFilter'), sortable: false }, { key: 'api_key', label: t('usage.apiKeyFilter'), sortable: false },
{ key: 'account', label: t('admin.usage.account'), sortable: false }, { key: 'account', label: t('admin.usage.account'), sortable: false },
{ key: 'model', label: t('usage.model'), sortable: true }, { key: 'model', label: t('usage.model'), sortable: true },
{ key: 'reasoning_effort', label: t('usage.reasoningEffort'), sortable: false },
{ key: 'group', label: t('admin.usage.group'), sortable: false }, { key: 'group', label: t('admin.usage.group'), sortable: false },
{ key: 'stream', label: t('usage.type'), sortable: false }, { key: 'stream', label: t('usage.type'), sortable: false },
{ key: 'tokens', label: t('usage.tokens'), sortable: false }, { key: 'tokens', label: t('usage.tokens'), sortable: false },

View File

@@ -265,6 +265,13 @@ export default {
promoCodeAlreadyUsed: 'You have already used this promo code', promoCodeAlreadyUsed: 'You have already used this promo code',
promoCodeValidating: 'Promo code is being validated, please wait', promoCodeValidating: 'Promo code is being validated, please wait',
promoCodeInvalidCannotRegister: 'Invalid promo code. Please check and try again or clear the promo code field', promoCodeInvalidCannotRegister: 'Invalid promo code. Please check and try again or clear the promo code field',
invitationCodeLabel: 'Invitation Code',
invitationCodePlaceholder: 'Enter invitation code',
invitationCodeRequired: 'Invitation code is required',
invitationCodeValid: 'Invitation code is valid',
invitationCodeInvalid: 'Invalid or used invitation code',
invitationCodeValidating: 'Validating invitation code...',
invitationCodeInvalidCannotRegister: 'Invalid invitation code. Please check and try again',
linuxdo: { linuxdo: {
signIn: 'Continue with Linux.do', signIn: 'Continue with Linux.do',
orContinue: 'or continue with email', orContinue: 'or continue with email',
@@ -495,6 +502,7 @@ export default {
exporting: 'Exporting...', exporting: 'Exporting...',
preparingExport: 'Preparing export...', preparingExport: 'Preparing export...',
model: 'Model', model: 'Model',
reasoningEffort: 'Reasoning Effort',
type: 'Type', type: 'Type',
tokens: 'Tokens', tokens: 'Tokens',
cost: 'Cost', cost: 'Cost',
@@ -1009,6 +1017,14 @@ export default {
hint: 'Triggered only when upstream explicitly returns prompt too long. Leave empty to disable fallback.', hint: 'Triggered only when upstream explicitly returns prompt too long. Leave empty to disable fallback.',
noFallback: 'No Fallback' noFallback: 'No Fallback'
}, },
copyAccounts: {
title: 'Copy Accounts from Groups',
tooltip: 'Select one or more groups of the same platform. After creation, all accounts from these groups will be automatically bound to the new group (deduplicated).',
tooltipEdit: 'Select one or more groups of the same platform. After saving, current group accounts will be replaced with accounts from these groups (deduplicated).',
selectPlaceholder: 'Select groups to copy accounts from...',
hint: 'Multiple groups can be selected, accounts will be deduplicated',
hintEdit: '⚠️ Warning: This will replace all existing account bindings'
},
modelRouting: { modelRouting: {
title: 'Model Routing', title: 'Model Routing',
tooltip: 'Configure specific model requests to be routed to designated accounts. Supports wildcard matching, e.g., claude-opus-* matches all opus models.', tooltip: 'Configure specific model requests to be routed to designated accounts. Supports wildcard matching, e.g., claude-opus-* matches all opus models.',
@@ -1922,6 +1938,8 @@ export default {
balance: 'Balance', balance: 'Balance',
concurrency: 'Concurrency', concurrency: 'Concurrency',
subscription: 'Subscription', subscription: 'Subscription',
invitation: 'Invitation',
invitationHint: 'Invitation codes are used to restrict user registration. They are automatically marked as used after use.',
unused: 'Unused', unused: 'Unused',
used: 'Used', used: 'Used',
columns: { columns: {
@@ -1968,6 +1986,7 @@ export default {
balance: 'Balance', balance: 'Balance',
concurrency: 'Concurrency', concurrency: 'Concurrency',
subscription: 'Subscription', subscription: 'Subscription',
invitation: 'Invitation',
// Admin adjustment types (created when admin modifies user balance/concurrency) // Admin adjustment types (created when admin modifies user balance/concurrency)
admin_balance: 'Balance (Admin)', admin_balance: 'Balance (Admin)',
admin_concurrency: 'Concurrency (Admin)' admin_concurrency: 'Concurrency (Admin)'
@@ -2925,6 +2944,8 @@ export default {
emailVerificationHint: 'Require email verification for new registrations', emailVerificationHint: 'Require email verification for new registrations',
promoCode: 'Promo Code', promoCode: 'Promo Code',
promoCodeHint: 'Allow users to use promo codes during registration', promoCodeHint: 'Allow users to use promo codes during registration',
invitationCode: 'Invitation Code Registration',
invitationCodeHint: 'When enabled, users must enter a valid invitation code to register',
passwordReset: 'Password Reset', passwordReset: 'Password Reset',
passwordResetHint: 'Allow users to reset their password via email', passwordResetHint: 'Allow users to reset their password via email',
totp: 'Two-Factor Authentication (2FA)', totp: 'Two-Factor Authentication (2FA)',

View File

@@ -262,6 +262,13 @@ export default {
promoCodeAlreadyUsed: '您已使用过此优惠码', promoCodeAlreadyUsed: '您已使用过此优惠码',
promoCodeValidating: '优惠码正在验证中,请稍候', promoCodeValidating: '优惠码正在验证中,请稍候',
promoCodeInvalidCannotRegister: '优惠码无效,请检查后重试或清空优惠码', promoCodeInvalidCannotRegister: '优惠码无效,请检查后重试或清空优惠码',
invitationCodeLabel: '邀请码',
invitationCodePlaceholder: '请输入邀请码',
invitationCodeRequired: '请输入邀请码',
invitationCodeValid: '邀请码有效',
invitationCodeInvalid: '邀请码无效或已被使用',
invitationCodeValidating: '正在验证邀请码...',
invitationCodeInvalidCannotRegister: '邀请码无效,请检查后重试',
linuxdo: { linuxdo: {
signIn: '使用 Linux.do 登录', signIn: '使用 Linux.do 登录',
orContinue: '或使用邮箱密码继续', orContinue: '或使用邮箱密码继续',
@@ -491,6 +498,7 @@ export default {
exporting: '导出中...', exporting: '导出中...',
preparingExport: '正在准备导出...', preparingExport: '正在准备导出...',
model: '模型', model: '模型',
reasoningEffort: '推理强度',
type: '类型', type: '类型',
tokens: 'Token', tokens: 'Token',
cost: '费用', cost: '费用',
@@ -1084,6 +1092,14 @@ export default {
hint: '仅当上游明确返回 prompt too long 时才会触发,留空表示不兜底', hint: '仅当上游明确返回 prompt too long 时才会触发,留空表示不兜底',
noFallback: '不兜底' noFallback: '不兜底'
}, },
copyAccounts: {
title: '从分组复制账号',
tooltip: '选择一个或多个相同平台的分组,创建后会自动将这些分组的所有账号绑定到新分组(去重)。',
tooltipEdit: '选择一个或多个相同平台的分组,保存后当前分组的账号会被替换为这些分组的账号(去重)。',
selectPlaceholder: '选择分组以复制其账号...',
hint: '可选多个分组,账号会自动去重',
hintEdit: '⚠️ 注意:这会替换当前分组的所有账号绑定'
},
modelRouting: { modelRouting: {
title: '模型路由配置', title: '模型路由配置',
tooltip: '配置特定模型请求优先路由到指定账号。支持通配符匹配,如 claude-opus-* 匹配所有 opus 模型。', tooltip: '配置特定模型请求优先路由到指定账号。支持通配符匹配,如 claude-opus-* 匹配所有 opus 模型。',
@@ -2045,6 +2061,7 @@ export default {
balance: '余额', balance: '余额',
concurrency: '并发数', concurrency: '并发数',
subscription: '订阅', subscription: '订阅',
invitation: '邀请码',
// 管理员在用户管理页面调整余额/并发时产生的记录 // 管理员在用户管理页面调整余额/并发时产生的记录
admin_balance: '余额(管理员)', admin_balance: '余额(管理员)',
admin_concurrency: '并发数(管理员)' admin_concurrency: '并发数(管理员)'
@@ -2053,6 +2070,8 @@ export default {
balance: '余额', balance: '余额',
concurrency: '并发数', concurrency: '并发数',
subscription: '订阅', subscription: '订阅',
invitation: '邀请码',
invitationHint: '邀请码用于限制用户注册,使用后自动标记为已使用。',
allTypes: '全部类型', allTypes: '全部类型',
allStatus: '全部状态', allStatus: '全部状态',
unused: '未使用', unused: '未使用',
@@ -3078,6 +3097,8 @@ export default {
emailVerificationHint: '新用户注册时需要验证邮箱', emailVerificationHint: '新用户注册时需要验证邮箱',
promoCode: '优惠码', promoCode: '优惠码',
promoCodeHint: '允许用户在注册时使用优惠码', promoCodeHint: '允许用户在注册时使用优惠码',
invitationCode: '邀请码注册',
invitationCodeHint: '开启后,用户注册时需要填写有效的邀请码',
passwordReset: '忘记密码', passwordReset: '忘记密码',
passwordResetHint: '允许用户通过邮箱重置密码', passwordResetHint: '允许用户通过邮箱重置密码',
totp: '双因素认证 (2FA)', totp: '双因素认证 (2FA)',

View File

@@ -314,6 +314,7 @@ export const useAppStore = defineStore('app', () => {
email_verify_enabled: false, email_verify_enabled: false,
promo_code_enabled: true, promo_code_enabled: true,
password_reset_enabled: false, password_reset_enabled: false,
invitation_code_enabled: false,
turnstile_enabled: false, turnstile_enabled: false,
turnstile_site_key: '', turnstile_site_key: '',
site_name: siteName.value, site_name: siteName.value,

View File

@@ -55,6 +55,7 @@ export interface RegisterRequest {
verify_code?: string verify_code?: string
turnstile_token?: string turnstile_token?: string
promo_code?: string promo_code?: string
invitation_code?: string
} }
export interface SendVerifyCodeRequest { export interface SendVerifyCodeRequest {
@@ -72,6 +73,7 @@ export interface PublicSettings {
email_verify_enabled: boolean email_verify_enabled: boolean
promo_code_enabled: boolean promo_code_enabled: boolean
password_reset_enabled: boolean password_reset_enabled: boolean
invitation_code_enabled: boolean
turnstile_enabled: boolean turnstile_enabled: boolean
turnstile_site_key: string turnstile_site_key: string
site_name: string site_name: string
@@ -419,7 +421,10 @@ export interface CreateGroupRequest {
claude_code_only?: boolean claude_code_only?: boolean
fallback_group_id?: number | null fallback_group_id?: number | null
fallback_group_id_on_invalid_request?: number | null fallback_group_id_on_invalid_request?: number | null
mcp_xml_inject?: boolean
supported_model_scopes?: string[] supported_model_scopes?: string[]
// 从指定分组复制账号
copy_accounts_from_group_ids?: number[]
} }
export interface UpdateGroupRequest { export interface UpdateGroupRequest {
@@ -439,7 +444,9 @@ export interface UpdateGroupRequest {
claude_code_only?: boolean claude_code_only?: boolean
fallback_group_id?: number | null fallback_group_id?: number | null
fallback_group_id_on_invalid_request?: number | null fallback_group_id_on_invalid_request?: number | null
mcp_xml_inject?: boolean
supported_model_scopes?: string[] supported_model_scopes?: string[]
copy_accounts_from_group_ids?: number[]
} }
// ==================== Account & Proxy Types ==================== // ==================== Account & Proxy Types ====================
@@ -712,7 +719,7 @@ export interface UpdateProxyRequest {
// ==================== Usage & Redeem Types ==================== // ==================== Usage & Redeem Types ====================
export type RedeemCodeType = 'balance' | 'concurrency' | 'subscription' export type RedeemCodeType = 'balance' | 'concurrency' | 'subscription' | 'invitation'
export interface UsageLog { export interface UsageLog {
id: number id: number
@@ -721,6 +728,7 @@ export interface UsageLog {
account_id: number | null account_id: number | null
request_id: string request_id: string
model: string model: string
reasoning_effort?: string | null
group_id: number | null group_id: number | null
subscription_id: number | null subscription_id: number | null

View File

@@ -174,6 +174,35 @@ export function parseDateTimeLocalInput(value: string): number | null {
return Math.floor(date.getTime() / 1000) return Math.floor(date.getTime() / 1000)
} }
/**
* 格式化 OpenAI reasoning effort用于使用记录展示
* @param effort 原始 effort如 "low" / "medium" / "high" / "xhigh"
* @returns 格式化后的字符串Low / Medium / High / Xhigh无值返回 "-"
*/
export function formatReasoningEffort(effort: string | null | undefined): string {
const raw = (effort ?? '').toString().trim()
if (!raw) return '-'
const normalized = raw.toLowerCase().replace(/[-_\s]/g, '')
switch (normalized) {
case 'low':
return 'Low'
case 'medium':
return 'Medium'
case 'high':
return 'High'
case 'xhigh':
case 'extrahigh':
return 'Xhigh'
case 'none':
case 'minimal':
return '-'
default:
// best-effort: Title-case first letter
return raw.length > 1 ? raw[0].toUpperCase() + raw.slice(1) : raw.toUpperCase()
}
}
/** /**
* 格式化时间(只显示时分) * 格式化时间(只显示时分)
* @param date 日期字符串或 Date 对象 * @param date 日期字符串或 Date 对象

View File

@@ -240,9 +240,73 @@
v-model="createForm.platform" v-model="createForm.platform"
:options="platformOptions" :options="platformOptions"
data-tour="group-form-platform" data-tour="group-form-platform"
@change="createForm.copy_accounts_from_group_ids = []"
/> />
<p class="input-hint">{{ t('admin.groups.platformHint') }}</p> <p class="input-hint">{{ t('admin.groups.platformHint') }}</p>
</div> </div>
<!-- 从分组复制账号 -->
<div v-if="copyAccountsGroupOptions.length > 0">
<div class="mb-1.5 flex items-center gap-1">
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.groups.copyAccounts.title') }}
</label>
<div class="group relative inline-flex">
<Icon
name="questionCircle"
size="sm"
:stroke-width="2"
class="cursor-help text-gray-400 transition-colors hover:text-primary-500 dark:text-gray-500 dark:hover:text-primary-400"
/>
<div class="pointer-events-none absolute bottom-full left-0 z-50 mb-2 w-72 opacity-0 transition-all duration-200 group-hover:pointer-events-auto group-hover:opacity-100">
<div class="rounded-lg bg-gray-900 p-3 text-white shadow-lg dark:bg-gray-800">
<p class="text-xs leading-relaxed text-gray-300">
{{ t('admin.groups.copyAccounts.tooltip') }}
</p>
<div class="absolute -bottom-1.5 left-3 h-3 w-3 rotate-45 bg-gray-900 dark:bg-gray-800"></div>
</div>
</div>
</div>
</div>
<!-- 已选分组标签 -->
<div v-if="createForm.copy_accounts_from_group_ids.length > 0" class="flex flex-wrap gap-1.5 mb-2">
<span
v-for="groupId in createForm.copy_accounts_from_group_ids"
:key="groupId"
class="inline-flex items-center gap-1 rounded-full bg-primary-100 px-2.5 py-1 text-xs font-medium text-primary-700 dark:bg-primary-900/30 dark:text-primary-300"
>
{{ copyAccountsGroupOptions.find(o => o.value === groupId)?.label || `#${groupId}` }}
<button
type="button"
@click="createForm.copy_accounts_from_group_ids = createForm.copy_accounts_from_group_ids.filter(id => id !== groupId)"
class="ml-0.5 text-primary-500 hover:text-primary-700 dark:hover:text-primary-200"
>
<Icon name="x" size="xs" />
</button>
</span>
</div>
<!-- 分组选择下拉 -->
<select
class="input"
@change="(e) => {
const val = Number((e.target as HTMLSelectElement).value)
if (val && !createForm.copy_accounts_from_group_ids.includes(val)) {
createForm.copy_accounts_from_group_ids.push(val)
}
(e.target as HTMLSelectElement).value = ''
}"
>
<option value="">{{ t('admin.groups.copyAccounts.selectPlaceholder') }}</option>
<option
v-for="opt in copyAccountsGroupOptions"
:key="opt.value"
:value="opt.value"
:disabled="createForm.copy_accounts_from_group_ids.includes(opt.value)"
>
{{ opt.label }}
</option>
</select>
<p class="input-hint">{{ t('admin.groups.copyAccounts.hint') }}</p>
</div>
<div> <div>
<label class="input-label">{{ t('admin.groups.form.rateMultiplier') }}</label> <label class="input-label">{{ t('admin.groups.form.rateMultiplier') }}</label>
<input <input
@@ -795,6 +859,69 @@
/> />
<p class="input-hint">{{ t('admin.groups.platformNotEditable') }}</p> <p class="input-hint">{{ t('admin.groups.platformNotEditable') }}</p>
</div> </div>
<!-- 从分组复制账号编辑时 -->
<div v-if="copyAccountsGroupOptionsForEdit.length > 0">
<div class="mb-1.5 flex items-center gap-1">
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.groups.copyAccounts.title') }}
</label>
<div class="group relative inline-flex">
<Icon
name="questionCircle"
size="sm"
:stroke-width="2"
class="cursor-help text-gray-400 transition-colors hover:text-primary-500 dark:text-gray-500 dark:hover:text-primary-400"
/>
<div class="pointer-events-none absolute bottom-full left-0 z-50 mb-2 w-72 opacity-0 transition-all duration-200 group-hover:pointer-events-auto group-hover:opacity-100">
<div class="rounded-lg bg-gray-900 p-3 text-white shadow-lg dark:bg-gray-800">
<p class="text-xs leading-relaxed text-gray-300">
{{ t('admin.groups.copyAccounts.tooltipEdit') }}
</p>
<div class="absolute -bottom-1.5 left-3 h-3 w-3 rotate-45 bg-gray-900 dark:bg-gray-800"></div>
</div>
</div>
</div>
</div>
<!-- 已选分组标签 -->
<div v-if="editForm.copy_accounts_from_group_ids.length > 0" class="flex flex-wrap gap-1.5 mb-2">
<span
v-for="groupId in editForm.copy_accounts_from_group_ids"
:key="groupId"
class="inline-flex items-center gap-1 rounded-full bg-primary-100 px-2.5 py-1 text-xs font-medium text-primary-700 dark:bg-primary-900/30 dark:text-primary-300"
>
{{ copyAccountsGroupOptionsForEdit.find(o => o.value === groupId)?.label || `#${groupId}` }}
<button
type="button"
@click="editForm.copy_accounts_from_group_ids = editForm.copy_accounts_from_group_ids.filter(id => id !== groupId)"
class="ml-0.5 text-primary-500 hover:text-primary-700 dark:hover:text-primary-200"
>
<Icon name="x" size="xs" />
</button>
</span>
</div>
<!-- 分组选择下拉 -->
<select
class="input"
@change="(e) => {
const val = Number((e.target as HTMLSelectElement).value)
if (val && !editForm.copy_accounts_from_group_ids.includes(val)) {
editForm.copy_accounts_from_group_ids.push(val)
}
(e.target as HTMLSelectElement).value = ''
}"
>
<option value="">{{ t('admin.groups.copyAccounts.selectPlaceholder') }}</option>
<option
v-for="opt in copyAccountsGroupOptionsForEdit"
:key="opt.value"
:value="opt.value"
:disabled="editForm.copy_accounts_from_group_ids.includes(opt.value)"
>
{{ opt.label }}
</option>
</select>
<p class="input-hint">{{ t('admin.groups.copyAccounts.hintEdit') }}</p>
</div>
<div> <div>
<label class="input-label">{{ t('admin.groups.form.rateMultiplier') }}</label> <label class="input-label">{{ t('admin.groups.form.rateMultiplier') }}</label>
<input <input
@@ -1470,6 +1597,29 @@ const invalidRequestFallbackOptionsForEdit = computed(() => {
return options return options
}) })
// 复制账号的源分组选项(创建时)- 仅包含相同平台且有账号的分组
const copyAccountsGroupOptions = computed(() => {
const eligibleGroups = groups.value.filter(
(g) => g.platform === createForm.platform && (g.account_count || 0) > 0
)
return eligibleGroups.map((g) => ({
value: g.id,
label: `${g.name} (${g.account_count || 0} 个账号)`
}))
})
// 复制账号的源分组选项(编辑时)- 仅包含相同平台且有账号的分组,排除自身
const copyAccountsGroupOptionsForEdit = computed(() => {
const currentId = editingGroup.value?.id
const eligibleGroups = groups.value.filter(
(g) => g.platform === editForm.platform && (g.account_count || 0) > 0 && g.id !== currentId
)
return eligibleGroups.map((g) => ({
value: g.id,
label: `${g.name} (${g.account_count || 0} 个账号)`
}))
})
const groups = ref<AdminGroup[]>([]) const groups = ref<AdminGroup[]>([])
const loading = ref(false) const loading = ref(false)
const searchQuery = ref('') const searchQuery = ref('')
@@ -1517,7 +1667,9 @@ const createForm = reactive({
// 支持的模型系列(仅 antigravity 平台) // 支持的模型系列(仅 antigravity 平台)
supported_model_scopes: ['claude', 'gemini_text', 'gemini_image'] as string[], supported_model_scopes: ['claude', 'gemini_text', 'gemini_image'] as string[],
// MCP XML 协议注入开关(仅 antigravity 平台) // MCP XML 协议注入开关(仅 antigravity 平台)
mcp_xml_inject: true mcp_xml_inject: true,
// 从分组复制账号
copy_accounts_from_group_ids: [] as number[]
}) })
// 简单账号类型(用于模型路由选择) // 简单账号类型(用于模型路由选择)
@@ -1713,7 +1865,9 @@ const editForm = reactive({
// 支持的模型系列(仅 antigravity 平台) // 支持的模型系列(仅 antigravity 平台)
supported_model_scopes: ['claude', 'gemini_text', 'gemini_image'] as string[], supported_model_scopes: ['claude', 'gemini_text', 'gemini_image'] as string[],
// MCP XML 协议注入开关(仅 antigravity 平台) // MCP XML 协议注入开关(仅 antigravity 平台)
mcp_xml_inject: true mcp_xml_inject: true,
// 从分组复制账号
copy_accounts_from_group_ids: [] as number[]
}) })
// 根据分组类型返回不同的删除确认消息 // 根据分组类型返回不同的删除确认消息
@@ -1798,6 +1952,7 @@ const closeCreateModal = () => {
createForm.fallback_group_id_on_invalid_request = null createForm.fallback_group_id_on_invalid_request = null
createForm.supported_model_scopes = ['claude', 'gemini_text', 'gemini_image'] createForm.supported_model_scopes = ['claude', 'gemini_text', 'gemini_image']
createForm.mcp_xml_inject = true createForm.mcp_xml_inject = true
createForm.copy_accounts_from_group_ids = []
createModelRoutingRules.value = [] createModelRoutingRules.value = []
} }
@@ -1851,6 +2006,7 @@ const handleEdit = async (group: AdminGroup) => {
editForm.model_routing_enabled = group.model_routing_enabled || false editForm.model_routing_enabled = group.model_routing_enabled || false
editForm.supported_model_scopes = group.supported_model_scopes || ['claude', 'gemini_text', 'gemini_image'] editForm.supported_model_scopes = group.supported_model_scopes || ['claude', 'gemini_text', 'gemini_image']
editForm.mcp_xml_inject = group.mcp_xml_inject ?? true editForm.mcp_xml_inject = group.mcp_xml_inject ?? true
editForm.copy_accounts_from_group_ids = [] // 复制账号字段每次编辑时重置为空
// 加载模型路由规则(异步加载账号名称) // 加载模型路由规则(异步加载账号名称)
editModelRoutingRules.value = await convertApiFormatToRoutingRules(group.model_routing) editModelRoutingRules.value = await convertApiFormatToRoutingRules(group.model_routing)
showEditModal.value = true showEditModal.value = true
@@ -1860,6 +2016,7 @@ const closeEditModal = () => {
showEditModal.value = false showEditModal.value = false
editingGroup.value = null editingGroup.value = null
editModelRoutingRules.value = [] editModelRoutingRules.value = []
editForm.copy_accounts_from_group_ids = []
} }
const handleUpdateGroup = async () => { const handleUpdateGroup = async () => {

View File

@@ -213,7 +213,7 @@
<Select v-model="generateForm.type" :options="typeOptions" /> <Select v-model="generateForm.type" :options="typeOptions" />
</div> </div>
<!-- 余额/并发类型显示数值输入 --> <!-- 余额/并发类型显示数值输入 -->
<div v-if="generateForm.type !== 'subscription'"> <div v-if="generateForm.type !== 'subscription' && generateForm.type !== 'invitation'">
<label class="input-label"> <label class="input-label">
{{ {{
generateForm.type === 'balance' generateForm.type === 'balance'
@@ -230,6 +230,12 @@
class="input" class="input"
/> />
</div> </div>
<!-- 邀请码类型显示提示信息 -->
<div v-if="generateForm.type === 'invitation'" class="rounded-lg bg-blue-50 p-3 dark:bg-blue-900/20">
<p class="text-sm text-blue-700 dark:text-blue-300">
{{ t('admin.redeem.invitationHint') }}
</p>
</div>
<!-- 订阅类型显示分组选择和有效天数 --> <!-- 订阅类型显示分组选择和有效天数 -->
<template v-if="generateForm.type === 'subscription'"> <template v-if="generateForm.type === 'subscription'">
<div> <div>
@@ -387,7 +393,7 @@
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue' import { ref, reactive, computed, onMounted, onUnmounted, watch } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app' import { useAppStore } from '@/stores/app'
import { useClipboard } from '@/composables/useClipboard' import { useClipboard } from '@/composables/useClipboard'
@@ -499,14 +505,16 @@ const columns = computed<Column[]>(() => [
const typeOptions = computed(() => [ const typeOptions = computed(() => [
{ value: 'balance', label: t('admin.redeem.balance') }, { value: 'balance', label: t('admin.redeem.balance') },
{ value: 'concurrency', label: t('admin.redeem.concurrency') }, { value: 'concurrency', label: t('admin.redeem.concurrency') },
{ value: 'subscription', label: t('admin.redeem.subscription') } { value: 'subscription', label: t('admin.redeem.subscription') },
{ value: 'invitation', label: t('admin.redeem.invitation') }
]) ])
const filterTypeOptions = computed(() => [ const filterTypeOptions = computed(() => [
{ value: '', label: t('admin.redeem.allTypes') }, { value: '', label: t('admin.redeem.allTypes') },
{ value: 'balance', label: t('admin.redeem.balance') }, { value: 'balance', label: t('admin.redeem.balance') },
{ value: 'concurrency', label: t('admin.redeem.concurrency') }, { value: 'concurrency', label: t('admin.redeem.concurrency') },
{ value: 'subscription', label: t('admin.redeem.subscription') } { value: 'subscription', label: t('admin.redeem.subscription') },
{ value: 'invitation', label: t('admin.redeem.invitation') }
]) ])
const filterStatusOptions = computed(() => [ const filterStatusOptions = computed(() => [
@@ -546,6 +554,18 @@ const generateForm = reactive({
validity_days: 30 validity_days: 30
}) })
// 监听类型变化,邀请码类型时自动设置 value 为 0
watch(
() => generateForm.type,
(newType) => {
if (newType === 'invitation') {
generateForm.value = 0
} else if (generateForm.value === 0) {
generateForm.value = 10
}
}
)
const loadCodes = async () => { const loadCodes = async () => {
if (abortController) { if (abortController) {
abortController.abort() abortController.abort()

View File

@@ -339,6 +339,20 @@
<Toggle v-model="form.promo_code_enabled" /> <Toggle v-model="form.promo_code_enabled" />
</div> </div>
<!-- Invitation Code -->
<div
class="flex items-center justify-between border-t border-gray-100 pt-4 dark:border-dark-700"
>
<div>
<label class="font-medium text-gray-900 dark:text-white">{{
t('admin.settings.registration.invitationCode')
}}</label>
<p class="text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.settings.registration.invitationCodeHint') }}
</p>
</div>
<Toggle v-model="form.invitation_code_enabled" />
</div>
<!-- Password Reset - Only show when email verification is enabled --> <!-- Password Reset - Only show when email verification is enabled -->
<div <div
v-if="form.email_verify_enabled" v-if="form.email_verify_enabled"
@@ -1115,6 +1129,7 @@ const form = reactive<SettingsForm>({
registration_enabled: true, registration_enabled: true,
email_verify_enabled: false, email_verify_enabled: false,
promo_code_enabled: true, promo_code_enabled: true,
invitation_code_enabled: false,
password_reset_enabled: false, password_reset_enabled: false,
totp_enabled: false, totp_enabled: false,
totp_encryption_key_configured: false, totp_encryption_key_configured: false,
@@ -1243,6 +1258,7 @@ async function saveSettings() {
registration_enabled: form.registration_enabled, registration_enabled: form.registration_enabled,
email_verify_enabled: form.email_verify_enabled, email_verify_enabled: form.email_verify_enabled,
promo_code_enabled: form.promo_code_enabled, promo_code_enabled: form.promo_code_enabled,
invitation_code_enabled: form.invitation_code_enabled,
password_reset_enabled: form.password_reset_enabled, password_reset_enabled: form.password_reset_enabled,
totp_enabled: form.totp_enabled, totp_enabled: form.totp_enabled,
default_balance: form.default_balance, default_balance: form.default_balance,

View File

@@ -37,6 +37,7 @@ import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { saveAs } from 'file-saver' import { saveAs } from 'file-saver'
import { useAppStore } from '@/stores/app'; import { adminAPI } from '@/api/admin'; import { adminUsageAPI } from '@/api/admin/usage' import { useAppStore } from '@/stores/app'; import { adminAPI } from '@/api/admin'; import { adminUsageAPI } from '@/api/admin/usage'
import { formatReasoningEffort } from '@/utils/format'
import AppLayout from '@/components/layout/AppLayout.vue'; import Pagination from '@/components/common/Pagination.vue'; import Select from '@/components/common/Select.vue' import AppLayout from '@/components/layout/AppLayout.vue'; import Pagination from '@/components/common/Pagination.vue'; import Select from '@/components/common/Select.vue'
import UsageStatsCards from '@/components/admin/usage/UsageStatsCards.vue'; import UsageFilters from '@/components/admin/usage/UsageFilters.vue' import UsageStatsCards from '@/components/admin/usage/UsageStatsCards.vue'; import UsageFilters from '@/components/admin/usage/UsageFilters.vue'
import UsageTable from '@/components/admin/usage/UsageTable.vue'; import UsageExportProgress from '@/components/admin/usage/UsageExportProgress.vue' import UsageTable from '@/components/admin/usage/UsageTable.vue'; import UsageExportProgress from '@/components/admin/usage/UsageExportProgress.vue'
@@ -104,7 +105,7 @@ const exportToExcel = async () => {
const XLSX = await import('xlsx') const XLSX = await import('xlsx')
const headers = [ const headers = [
t('usage.time'), t('admin.usage.user'), t('usage.apiKeyFilter'), t('usage.time'), t('admin.usage.user'), t('usage.apiKeyFilter'),
t('admin.usage.account'), t('usage.model'), t('admin.usage.group'), t('admin.usage.account'), t('usage.model'), t('usage.reasoningEffort'), t('admin.usage.group'),
t('usage.type'), t('usage.type'),
t('admin.usage.inputTokens'), t('admin.usage.outputTokens'), t('admin.usage.inputTokens'), t('admin.usage.outputTokens'),
t('admin.usage.cacheReadTokens'), t('admin.usage.cacheCreationTokens'), t('admin.usage.cacheReadTokens'), t('admin.usage.cacheCreationTokens'),
@@ -120,6 +121,7 @@ const exportToExcel = async () => {
log.api_key?.name || '', log.api_key?.name || '',
log.account?.name || '', log.account?.name || '',
log.model, log.model,
formatReasoningEffort(log.reasoning_effort),
log.group?.name || '', log.group?.name || '',
log.stream ? t('usage.stream') : t('usage.sync'), log.stream ? t('usage.stream') : t('usage.sync'),
log.input_tokens, log.input_tokens,

View File

@@ -201,6 +201,7 @@ const email = ref<string>('')
const password = ref<string>('') const password = ref<string>('')
const initialTurnstileToken = ref<string>('') const initialTurnstileToken = ref<string>('')
const promoCode = ref<string>('') const promoCode = ref<string>('')
const invitationCode = ref<string>('')
const hasRegisterData = ref<boolean>(false) const hasRegisterData = ref<boolean>(false)
// Public settings // Public settings
@@ -230,6 +231,7 @@ onMounted(async () => {
password.value = registerData.password || '' password.value = registerData.password || ''
initialTurnstileToken.value = registerData.turnstile_token || '' initialTurnstileToken.value = registerData.turnstile_token || ''
promoCode.value = registerData.promo_code || '' promoCode.value = registerData.promo_code || ''
invitationCode.value = registerData.invitation_code || ''
hasRegisterData.value = !!(email.value && password.value) hasRegisterData.value = !!(email.value && password.value)
} catch { } catch {
hasRegisterData.value = false hasRegisterData.value = false
@@ -384,7 +386,8 @@ async function handleVerify(): Promise<void> {
password: password.value, password: password.value,
verify_code: verifyCode.value.trim(), verify_code: verifyCode.value.trim(),
turnstile_token: initialTurnstileToken.value || undefined, turnstile_token: initialTurnstileToken.value || undefined,
promo_code: promoCode.value || undefined promo_code: promoCode.value || undefined,
invitation_code: invitationCode.value || undefined
}) })
// Clear session data // Clear session data

View File

@@ -95,6 +95,59 @@
</p> </p>
</div> </div>
<!-- Invitation Code Input (Required when enabled) -->
<div v-if="invitationCodeEnabled">
<label for="invitation_code" class="input-label">
{{ t('auth.invitationCodeLabel') }}
</label>
<div class="relative">
<div class="pointer-events-none absolute inset-y-0 left-0 flex items-center pl-3.5">
<Icon name="key" size="md" :class="invitationValidation.valid ? 'text-green-500' : 'text-gray-400 dark:text-dark-500'" />
</div>
<input
id="invitation_code"
v-model="formData.invitation_code"
type="text"
:disabled="isLoading"
class="input pl-11 pr-10"
:class="{
'border-green-500 focus:border-green-500 focus:ring-green-500': invitationValidation.valid,
'border-red-500 focus:border-red-500 focus:ring-red-500': invitationValidation.invalid || errors.invitation_code
}"
:placeholder="t('auth.invitationCodePlaceholder')"
@input="handleInvitationCodeInput"
/>
<!-- Validation indicator -->
<div v-if="invitationValidating" class="absolute inset-y-0 right-0 flex items-center pr-3.5">
<svg class="h-4 w-4 animate-spin text-gray-400" fill="none" viewBox="0 0 24 24">
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
</svg>
</div>
<div v-else-if="invitationValidation.valid" class="absolute inset-y-0 right-0 flex items-center pr-3.5">
<Icon name="checkCircle" size="md" class="text-green-500" />
</div>
<div v-else-if="invitationValidation.invalid || errors.invitation_code" class="absolute inset-y-0 right-0 flex items-center pr-3.5">
<Icon name="exclamationCircle" size="md" class="text-red-500" />
</div>
</div>
<!-- Invitation code validation result -->
<transition name="fade">
<div v-if="invitationValidation.valid" class="mt-2 flex items-center gap-2 rounded-lg bg-green-50 px-3 py-2 dark:bg-green-900/20">
<Icon name="checkCircle" size="sm" class="text-green-600 dark:text-green-400" />
<span class="text-sm text-green-700 dark:text-green-400">
{{ t('auth.invitationCodeValid') }}
</span>
</div>
<p v-else-if="invitationValidation.invalid" class="input-error-text">
{{ invitationValidation.message }}
</p>
<p v-else-if="errors.invitation_code" class="input-error-text">
{{ errors.invitation_code }}
</p>
</transition>
</div>
<!-- Promo Code Input (Optional) --> <!-- Promo Code Input (Optional) -->
<div v-if="promoCodeEnabled"> <div v-if="promoCodeEnabled">
<label for="promo_code" class="input-label"> <label for="promo_code" class="input-label">
@@ -239,7 +292,7 @@ import LinuxDoOAuthSection from '@/components/auth/LinuxDoOAuthSection.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import TurnstileWidget from '@/components/TurnstileWidget.vue' import TurnstileWidget from '@/components/TurnstileWidget.vue'
import { useAuthStore, useAppStore } from '@/stores' import { useAuthStore, useAppStore } from '@/stores'
import { getPublicSettings, validatePromoCode } from '@/api/auth' import { getPublicSettings, validatePromoCode, validateInvitationCode } from '@/api/auth'
const { t } = useI18n() const { t } = useI18n()
@@ -261,6 +314,7 @@ const showPassword = ref<boolean>(false)
const registrationEnabled = ref<boolean>(true) const registrationEnabled = ref<boolean>(true)
const emailVerifyEnabled = ref<boolean>(false) const emailVerifyEnabled = ref<boolean>(false)
const promoCodeEnabled = ref<boolean>(true) const promoCodeEnabled = ref<boolean>(true)
const invitationCodeEnabled = ref<boolean>(false)
const turnstileEnabled = ref<boolean>(false) const turnstileEnabled = ref<boolean>(false)
const turnstileSiteKey = ref<string>('') const turnstileSiteKey = ref<string>('')
const siteName = ref<string>('Sub2API') const siteName = ref<string>('Sub2API')
@@ -280,16 +334,27 @@ const promoValidation = reactive({
}) })
let promoValidateTimeout: ReturnType<typeof setTimeout> | null = null let promoValidateTimeout: ReturnType<typeof setTimeout> | null = null
// Invitation code validation
const invitationValidating = ref<boolean>(false)
const invitationValidation = reactive({
valid: false,
invalid: false,
message: ''
})
let invitationValidateTimeout: ReturnType<typeof setTimeout> | null = null
const formData = reactive({ const formData = reactive({
email: '', email: '',
password: '', password: '',
promo_code: '' promo_code: '',
invitation_code: ''
}) })
const errors = reactive({ const errors = reactive({
email: '', email: '',
password: '', password: '',
turnstile: '' turnstile: '',
invitation_code: ''
}) })
// ==================== Lifecycle ==================== // ==================== Lifecycle ====================
@@ -300,6 +365,7 @@ onMounted(async () => {
registrationEnabled.value = settings.registration_enabled registrationEnabled.value = settings.registration_enabled
emailVerifyEnabled.value = settings.email_verify_enabled emailVerifyEnabled.value = settings.email_verify_enabled
promoCodeEnabled.value = settings.promo_code_enabled promoCodeEnabled.value = settings.promo_code_enabled
invitationCodeEnabled.value = settings.invitation_code_enabled
turnstileEnabled.value = settings.turnstile_enabled turnstileEnabled.value = settings.turnstile_enabled
turnstileSiteKey.value = settings.turnstile_site_key || '' turnstileSiteKey.value = settings.turnstile_site_key || ''
siteName.value = settings.site_name || 'Sub2API' siteName.value = settings.site_name || 'Sub2API'
@@ -325,6 +391,9 @@ onUnmounted(() => {
if (promoValidateTimeout) { if (promoValidateTimeout) {
clearTimeout(promoValidateTimeout) clearTimeout(promoValidateTimeout)
} }
if (invitationValidateTimeout) {
clearTimeout(invitationValidateTimeout)
}
}) })
// ==================== Promo Code Validation ==================== // ==================== Promo Code Validation ====================
@@ -400,6 +469,70 @@ function getPromoErrorMessage(errorCode?: string): string {
} }
} }
// ==================== Invitation Code Validation ====================
function handleInvitationCodeInput(): void {
const code = formData.invitation_code.trim()
// Clear previous validation
invitationValidation.valid = false
invitationValidation.invalid = false
invitationValidation.message = ''
errors.invitation_code = ''
if (!code) {
return
}
// Debounce validation
if (invitationValidateTimeout) {
clearTimeout(invitationValidateTimeout)
}
invitationValidateTimeout = setTimeout(() => {
validateInvitationCodeDebounced(code)
}, 500)
}
async function validateInvitationCodeDebounced(code: string): Promise<void> {
invitationValidating.value = true
try {
const result = await validateInvitationCode(code)
if (result.valid) {
invitationValidation.valid = true
invitationValidation.invalid = false
invitationValidation.message = ''
} else {
invitationValidation.valid = false
invitationValidation.invalid = true
invitationValidation.message = getInvitationErrorMessage(result.error_code)
}
} catch {
invitationValidation.valid = false
invitationValidation.invalid = true
invitationValidation.message = t('auth.invitationCodeInvalid')
} finally {
invitationValidating.value = false
}
}
function getInvitationErrorMessage(errorCode?: string): string {
switch (errorCode) {
case 'INVITATION_CODE_NOT_FOUND':
return t('auth.invitationCodeInvalid')
case 'INVITATION_CODE_INVALID':
return t('auth.invitationCodeInvalid')
case 'INVITATION_CODE_USED':
return t('auth.invitationCodeInvalid')
case 'INVITATION_CODE_DISABLED':
return t('auth.invitationCodeInvalid')
default:
return t('auth.invitationCodeInvalid')
}
}
// ==================== Turnstile Handlers ==================== // ==================== Turnstile Handlers ====================
function onTurnstileVerify(token: string): void { function onTurnstileVerify(token: string): void {
@@ -429,6 +562,7 @@ function validateForm(): boolean {
errors.email = '' errors.email = ''
errors.password = '' errors.password = ''
errors.turnstile = '' errors.turnstile = ''
errors.invitation_code = ''
let isValid = true let isValid = true
@@ -450,6 +584,14 @@ function validateForm(): boolean {
isValid = false isValid = false
} }
// Invitation code validation (required when enabled)
if (invitationCodeEnabled.value) {
if (!formData.invitation_code.trim()) {
errors.invitation_code = t('auth.invitationCodeRequired')
isValid = false
}
}
// Turnstile validation // Turnstile validation
if (turnstileEnabled.value && !turnstileToken.value) { if (turnstileEnabled.value && !turnstileToken.value) {
errors.turnstile = t('auth.completeVerification') errors.turnstile = t('auth.completeVerification')
@@ -484,6 +626,30 @@ async function handleRegister(): Promise<void> {
} }
} }
// Check invitation code validation status (if enabled and code provided)
if (invitationCodeEnabled.value) {
// If still validating, wait
if (invitationValidating.value) {
errorMessage.value = t('auth.invitationCodeValidating')
return
}
// If invitation code is invalid, block submission
if (invitationValidation.invalid) {
errorMessage.value = t('auth.invitationCodeInvalidCannotRegister')
return
}
// If invitation code is required but not validated yet
if (formData.invitation_code.trim() && !invitationValidation.valid) {
errorMessage.value = t('auth.invitationCodeValidating')
// Trigger validation
await validateInvitationCodeDebounced(formData.invitation_code.trim())
if (!invitationValidation.valid) {
errorMessage.value = t('auth.invitationCodeInvalidCannotRegister')
return
}
}
}
isLoading.value = true isLoading.value = true
try { try {
@@ -496,7 +662,8 @@ async function handleRegister(): Promise<void> {
email: formData.email, email: formData.email,
password: formData.password, password: formData.password,
turnstile_token: turnstileToken.value, turnstile_token: turnstileToken.value,
promo_code: formData.promo_code || undefined promo_code: formData.promo_code || undefined,
invitation_code: formData.invitation_code || undefined
}) })
) )
@@ -510,7 +677,8 @@ async function handleRegister(): Promise<void> {
email: formData.email, email: formData.email,
password: formData.password, password: formData.password,
turnstile_token: turnstileEnabled.value ? turnstileToken.value : undefined, turnstile_token: turnstileEnabled.value ? turnstileToken.value : undefined,
promo_code: formData.promo_code || undefined promo_code: formData.promo_code || undefined,
invitation_code: formData.invitation_code || undefined
}) })
// Show success toast // Show success toast

View File

@@ -237,6 +237,18 @@
</div> </div>
</div> </div>
<div class="flex items-center justify-between rounded-xl border border-gray-200 p-3 dark:border-dark-700">
<div>
<p class="text-sm font-medium text-gray-900 dark:text-white">
{{ t("setup.redis.enableTls") }}
</p>
<p class="text-xs text-gray-500 dark:text-dark-400">
{{ t("setup.redis.enableTlsHint") }}
</p>
</div>
<Toggle v-model="formData.redis.enable_tls" />
</div>
<button <button
@click="testRedisConnection" @click="testRedisConnection"
:disabled="testingRedis" :disabled="testingRedis"
@@ -482,6 +494,7 @@ import { ref, reactive, computed } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { testDatabase, testRedis, install, type InstallRequest } from '@/api/setup' import { testDatabase, testRedis, install, type InstallRequest } from '@/api/setup'
import Select from '@/components/common/Select.vue' import Select from '@/components/common/Select.vue'
import Toggle from '@/components/common/Toggle.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
const { t } = useI18n() const { t } = useI18n()

View File

@@ -157,6 +157,12 @@
<span class="font-medium text-gray-900 dark:text-white">{{ value }}</span> <span class="font-medium text-gray-900 dark:text-white">{{ value }}</span>
</template> </template>
<template #cell-reasoning_effort="{ row }">
<span class="text-sm text-gray-900 dark:text-white">
{{ formatReasoningEffort(row.reasoning_effort) }}
</span>
</template>
<template #cell-stream="{ row }"> <template #cell-stream="{ row }">
<span <span
class="inline-flex items-center rounded px-2 py-0.5 text-xs font-medium" class="inline-flex items-center rounded px-2 py-0.5 text-xs font-medium"
@@ -443,7 +449,7 @@ import DateRangePicker from '@/components/common/DateRangePicker.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import type { UsageLog, ApiKey, UsageQueryParams, UsageStatsResponse } from '@/types' import type { UsageLog, ApiKey, UsageQueryParams, UsageStatsResponse } from '@/types'
import type { Column } from '@/components/common/types' import type { Column } from '@/components/common/types'
import { formatDateTime } from '@/utils/format' import { formatDateTime, formatReasoningEffort } from '@/utils/format'
const { t } = useI18n() const { t } = useI18n()
const appStore = useAppStore() const appStore = useAppStore()
@@ -466,6 +472,7 @@ const usageStats = ref<UsageStatsResponse | null>(null)
const columns = computed<Column[]>(() => [ const columns = computed<Column[]>(() => [
{ key: 'api_key', label: t('usage.apiKeyFilter'), sortable: false }, { key: 'api_key', label: t('usage.apiKeyFilter'), sortable: false },
{ key: 'model', label: t('usage.model'), sortable: true }, { key: 'model', label: t('usage.model'), sortable: true },
{ key: 'reasoning_effort', label: t('usage.reasoningEffort'), sortable: false },
{ key: 'stream', label: t('usage.type'), sortable: false }, { key: 'stream', label: t('usage.type'), sortable: false },
{ key: 'tokens', label: t('usage.tokens'), sortable: false }, { key: 'tokens', label: t('usage.tokens'), sortable: false },
{ key: 'cost', label: t('usage.cost'), sortable: false }, { key: 'cost', label: t('usage.cost'), sortable: false },
@@ -723,6 +730,7 @@ const exportToCSV = async () => {
'Time', 'Time',
'API Key Name', 'API Key Name',
'Model', 'Model',
'Reasoning Effort',
'Type', 'Type',
'Input Tokens', 'Input Tokens',
'Output Tokens', 'Output Tokens',
@@ -739,6 +747,7 @@ const exportToCSV = async () => {
log.created_at, log.created_at,
log.api_key?.name || '', log.api_key?.name || '',
log.model, log.model,
formatReasoningEffort(log.reasoning_effort),
log.stream ? 'Stream' : 'Sync', log.stream ? 'Stream' : 'Sync',
log.input_tokens, log.input_tokens,
log.output_tokens, log.output_tokens,