This commit is contained in:
yangjianbo
2026-02-06 06:56:23 +08:00
94 changed files with 10861 additions and 858 deletions

View File

@@ -145,12 +145,24 @@ type PricingConfig struct {
}
type ServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表CIDR/IP
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表CIDR/IP
MaxRequestBodySize int64 `mapstructure:"max_request_body_size"` // 全局最大请求体限制
H2C H2CConfig `mapstructure:"h2c"` // HTTP/2 Cleartext 配置
}
// H2CConfig HTTP/2 Cleartext 配置
type H2CConfig struct {
Enabled bool `mapstructure:"enabled"` // 是否启用 H2C
MaxConcurrentStreams uint32 `mapstructure:"max_concurrent_streams"` // 最大并发流数量
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲超时(秒)
MaxReadFrameSize int `mapstructure:"max_read_frame_size"` // 最大帧大小(字节)
MaxUploadBufferPerConnection int `mapstructure:"max_upload_buffer_per_connection"` // 每个连接的上传缓冲区(字节)
MaxUploadBufferPerStream int `mapstructure:"max_upload_buffer_per_stream"` // 每个流的上传缓冲区(字节)
}
type CORSConfig struct {
@@ -532,6 +544,13 @@ type OpsMetricsCollectorCacheConfig struct {
type JWTConfig struct {
Secret string `mapstructure:"secret"`
ExpireHour int `mapstructure:"expire_hour"`
// AccessTokenExpireMinutes: Access Token有效期分钟默认15分钟
// 短有效期减少被盗用风险配合Refresh Token实现无感续期
AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"`
// RefreshTokenExpireDays: Refresh Token有效期默认30天
RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"`
// RefreshWindowMinutes: 刷新窗口分钟在Access Token过期前多久开始允许刷新
RefreshWindowMinutes int `mapstructure:"refresh_window_minutes"`
}
// TotpConfig TOTP 双因素认证配置
@@ -752,6 +771,14 @@ func setDefaults() {
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.trusted_proxies", []string{})
viper.SetDefault("server.max_request_body_size", int64(100*1024*1024))
// H2C 默认配置
viper.SetDefault("server.h2c.enabled", false)
viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流
viper.SetDefault("server.h2c.idle_timeout", 75) // 75 秒
viper.SetDefault("server.h2c.max_read_frame_size", 1<<20) // 1MB够用
viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB
viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB
// CORS
viper.SetDefault("cors.allowed_origins", []string{})
@@ -848,6 +875,9 @@ func setDefaults() {
// JWT
viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
viper.SetDefault("jwt.access_token_expire_minutes", 360) // 6小时Access Token有效期
viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期
viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新
// TOTP
viper.SetDefault("totp.encryption_key", "")
@@ -1009,6 +1039,22 @@ func (c *Config) Validate() error {
if c.JWT.ExpireHour > 24 {
log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
}
// JWT Refresh Token配置验证
if c.JWT.AccessTokenExpireMinutes <= 0 {
return fmt.Errorf("jwt.access_token_expire_minutes must be positive")
}
if c.JWT.AccessTokenExpireMinutes > 720 {
log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes)
}
if c.JWT.RefreshTokenExpireDays <= 0 {
return fmt.Errorf("jwt.refresh_token_expire_days must be positive")
}
if c.JWT.RefreshTokenExpireDays > 90 {
log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays)
}
if c.JWT.RefreshWindowMinutes < 0 {
return fmt.Errorf("jwt.refresh_window_minutes must be non-negative")
}
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
}

View File

@@ -0,0 +1,273 @@
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ErrorPassthroughHandler 处理错误透传规则的 HTTP 请求
type ErrorPassthroughHandler struct {
service *service.ErrorPassthroughService
}
// NewErrorPassthroughHandler 创建错误透传规则处理器
func NewErrorPassthroughHandler(service *service.ErrorPassthroughService) *ErrorPassthroughHandler {
return &ErrorPassthroughHandler{service: service}
}
// CreateErrorPassthroughRuleRequest 创建规则请求
type CreateErrorPassthroughRuleRequest struct {
Name string `json:"name" binding:"required"`
Enabled *bool `json:"enabled"`
Priority int `json:"priority"`
ErrorCodes []int `json:"error_codes"`
Keywords []string `json:"keywords"`
MatchMode string `json:"match_mode"`
Platforms []string `json:"platforms"`
PassthroughCode *bool `json:"passthrough_code"`
ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"`
Description *string `json:"description"`
}
// UpdateErrorPassthroughRuleRequest 更新规则请求(部分更新,所有字段可选)
type UpdateErrorPassthroughRuleRequest struct {
Name *string `json:"name"`
Enabled *bool `json:"enabled"`
Priority *int `json:"priority"`
ErrorCodes []int `json:"error_codes"`
Keywords []string `json:"keywords"`
MatchMode *string `json:"match_mode"`
Platforms []string `json:"platforms"`
PassthroughCode *bool `json:"passthrough_code"`
ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"`
Description *string `json:"description"`
}
// List 获取所有规则
// GET /api/v1/admin/error-passthrough-rules
func (h *ErrorPassthroughHandler) List(c *gin.Context) {
rules, err := h.service.List(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, rules)
}
// GetByID 根据 ID 获取规则
// GET /api/v1/admin/error-passthrough-rules/:id
func (h *ErrorPassthroughHandler) GetByID(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid rule ID")
return
}
rule, err := h.service.GetByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
if rule == nil {
response.NotFound(c, "Rule not found")
return
}
response.Success(c, rule)
}
// Create 创建规则
// POST /api/v1/admin/error-passthrough-rules
func (h *ErrorPassthroughHandler) Create(c *gin.Context) {
var req CreateErrorPassthroughRuleRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
rule := &model.ErrorPassthroughRule{
Name: req.Name,
Priority: req.Priority,
ErrorCodes: req.ErrorCodes,
Keywords: req.Keywords,
Platforms: req.Platforms,
}
// 设置默认值
if req.Enabled != nil {
rule.Enabled = *req.Enabled
} else {
rule.Enabled = true
}
if req.MatchMode != "" {
rule.MatchMode = req.MatchMode
} else {
rule.MatchMode = model.MatchModeAny
}
if req.PassthroughCode != nil {
rule.PassthroughCode = *req.PassthroughCode
} else {
rule.PassthroughCode = true
}
if req.PassthroughBody != nil {
rule.PassthroughBody = *req.PassthroughBody
} else {
rule.PassthroughBody = true
}
rule.ResponseCode = req.ResponseCode
rule.CustomMessage = req.CustomMessage
rule.Description = req.Description
// 确保切片不为 nil
if rule.ErrorCodes == nil {
rule.ErrorCodes = []int{}
}
if rule.Keywords == nil {
rule.Keywords = []string{}
}
if rule.Platforms == nil {
rule.Platforms = []string{}
}
created, err := h.service.Create(c.Request.Context(), rule)
if err != nil {
if _, ok := err.(*model.ValidationError); ok {
response.BadRequest(c, err.Error())
return
}
response.ErrorFrom(c, err)
return
}
response.Success(c, created)
}
// Update 更新规则(支持部分更新)
// PUT /api/v1/admin/error-passthrough-rules/:id
func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid rule ID")
return
}
var req UpdateErrorPassthroughRuleRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 先获取现有规则
existing, err := h.service.GetByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
if existing == nil {
response.NotFound(c, "Rule not found")
return
}
// 部分更新:只更新请求中提供的字段
rule := &model.ErrorPassthroughRule{
ID: id,
Name: existing.Name,
Enabled: existing.Enabled,
Priority: existing.Priority,
ErrorCodes: existing.ErrorCodes,
Keywords: existing.Keywords,
MatchMode: existing.MatchMode,
Platforms: existing.Platforms,
PassthroughCode: existing.PassthroughCode,
ResponseCode: existing.ResponseCode,
PassthroughBody: existing.PassthroughBody,
CustomMessage: existing.CustomMessage,
Description: existing.Description,
}
// 应用请求中提供的更新
if req.Name != nil {
rule.Name = *req.Name
}
if req.Enabled != nil {
rule.Enabled = *req.Enabled
}
if req.Priority != nil {
rule.Priority = *req.Priority
}
if req.ErrorCodes != nil {
rule.ErrorCodes = req.ErrorCodes
}
if req.Keywords != nil {
rule.Keywords = req.Keywords
}
if req.MatchMode != nil {
rule.MatchMode = *req.MatchMode
}
if req.Platforms != nil {
rule.Platforms = req.Platforms
}
if req.PassthroughCode != nil {
rule.PassthroughCode = *req.PassthroughCode
}
if req.ResponseCode != nil {
rule.ResponseCode = req.ResponseCode
}
if req.PassthroughBody != nil {
rule.PassthroughBody = *req.PassthroughBody
}
if req.CustomMessage != nil {
rule.CustomMessage = req.CustomMessage
}
if req.Description != nil {
rule.Description = req.Description
}
// 确保切片不为 nil
if rule.ErrorCodes == nil {
rule.ErrorCodes = []int{}
}
if rule.Keywords == nil {
rule.Keywords = []string{}
}
if rule.Platforms == nil {
rule.Platforms = []string{}
}
updated, err := h.service.Update(c.Request.Context(), rule)
if err != nil {
if _, ok := err.(*model.ValidationError); ok {
response.BadRequest(c, err.Error())
return
}
response.ErrorFrom(c, err)
return
}
response.Success(c, updated)
}
// Delete 删除规则
// DELETE /api/v1/admin/error-passthrough-rules/:id
func (h *ErrorPassthroughHandler) Delete(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid rule ID")
return
}
if err := h.service.Delete(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Rule deleted successfully"})
}

View File

@@ -45,6 +45,9 @@ type UpdateUserRequest struct {
Concurrency *int `json:"concurrency"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups *[]int64 `json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
// map[groupID]*ratenil 表示删除该分组的专属倍率
GroupRates map[int64]*float64 `json:"group_rates"`
}
// UpdateBalanceRequest represents balance update request
@@ -183,6 +186,7 @@ func (h *UserHandler) Update(c *gin.Context) {
Concurrency: req.Concurrency,
Status: req.Status,
AllowedGroups: req.AllowedGroups,
GroupRates: req.GroupRates,
})
if err != nil {
response.ErrorFrom(c, err)

View File

@@ -243,3 +243,21 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
}
response.Success(c, out)
}
// GetUserGroupRates 获取当前用户的专属分组倍率配置
// GET /api/v1/groups/rates
func (h *APIKeyHandler) GetUserGroupRates(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
rates, err := h.apiKeyService.GetUserGroupRates(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, rates)
}

View File

@@ -68,9 +68,39 @@ type LoginRequest struct {
// AuthResponse 认证响应格式(匹配前端期望)
type AuthResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
User *dto.User `json:"user"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"` // 新增Refresh Token
ExpiresIn int `json:"expires_in,omitempty"` // 新增Access Token有效期
TokenType string `json:"token_type"`
User *dto.User `json:"user"`
}
// respondWithTokenPair 生成 Token 对并返回认证响应
// 如果 Token 对生成失败,回退到只返回 Access Token向后兼容
func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
if err != nil {
slog.Error("failed to generate token pair", "error", err, "user_id", user.ID)
// 回退到只返回Access Token
token, tokenErr := h.authService.GenerateToken(user)
if tokenErr != nil {
response.InternalError(c, "Failed to generate token")
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
return
}
response.Success(c, AuthResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
ExpiresIn: tokenPair.ExpiresIn,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
}
// Register handles user registration
@@ -90,17 +120,13 @@ func (h *AuthHandler) Register(c *gin.Context) {
}
}
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
_, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
h.respondWithTokenPair(c, user)
}
// SendVerifyCode 发送邮箱验证码
@@ -150,6 +176,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
_ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成
// Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
@@ -168,11 +195,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
h.respondWithTokenPair(c, user)
}
// TotpLoginResponse represents the response when 2FA is required
@@ -238,18 +261,7 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
return
}
// Generate the JWT token
token, err := h.authService.GenerateToken(user)
if err != nil {
response.InternalError(c, "Failed to generate token")
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: dto.UserFromService(user),
})
h.respondWithTokenPair(c, user)
}
// GetCurrentUser handles getting current authenticated user
@@ -491,3 +503,96 @@ func (h *AuthHandler) ResetPassword(c *gin.Context) {
Message: "Your password has been reset successfully. You can now log in with your new password.",
})
}
// ==================== Token Refresh Endpoints ====================
// RefreshTokenRequest 刷新Token请求
type RefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
}
// RefreshTokenResponse 刷新Token响应
type RefreshTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"` // Access Token有效期
TokenType string `json:"token_type"`
}
// RefreshToken 刷新Token
// POST /api/v1/auth/refresh
func (h *AuthHandler) RefreshToken(c *gin.Context) {
var req RefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, RefreshTokenResponse{
AccessToken: tokenPair.AccessToken,
RefreshToken: tokenPair.RefreshToken,
ExpiresIn: tokenPair.ExpiresIn,
TokenType: "Bearer",
})
}
// LogoutRequest 登出请求
type LogoutRequest struct {
RefreshToken string `json:"refresh_token,omitempty"` // 可选撤销指定的Refresh Token
}
// LogoutResponse 登出响应
type LogoutResponse struct {
Message string `json:"message"`
}
// Logout 用户登出
// POST /api/v1/auth/logout
func (h *AuthHandler) Logout(c *gin.Context) {
var req LogoutRequest
// 允许空请求体(向后兼容)
_ = c.ShouldBindJSON(&req)
// 如果提供了Refresh Token撤销它
if req.RefreshToken != "" {
if err := h.authService.RevokeRefreshToken(c.Request.Context(), req.RefreshToken); err != nil {
slog.Debug("failed to revoke refresh token", "error", err)
// 不影响登出流程
}
}
response.Success(c, LogoutResponse{
Message: "Logged out successfully",
})
}
// RevokeAllSessionsResponse 撤销所有会话响应
type RevokeAllSessionsResponse struct {
Message string `json:"message"`
}
// RevokeAllSessions 撤销当前用户的所有会话
// POST /api/v1/auth/revoke-all-sessions
func (h *AuthHandler) RevokeAllSessions(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil {
slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err)
response.InternalError(c, "Failed to revoke sessions")
return
}
response.Success(c, RevokeAllSessionsResponse{
Message: "All sessions have been revoked. Please log in again.",
})
}

View File

@@ -211,7 +211,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
email = linuxDoSyntheticEmail(subject)
}
jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username)
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username)
if err != nil {
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
@@ -219,7 +219,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
}
fragment := url.Values{}
fragment.Set("access_token", jwtToken)
fragment.Set("access_token", tokenPair.AccessToken)
fragment.Set("refresh_token", tokenPair.RefreshToken)
fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
fragment.Set("token_type", "Bearer")
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)

View File

@@ -58,8 +58,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
return nil
}
return &AdminUser{
User: *base,
Notes: u.Notes,
User: *base,
Notes: u.Notes,
GroupRates: u.GroupRates,
}
}

View File

@@ -29,6 +29,9 @@ type AdminUser struct {
User
Notes string `json:"notes"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
}
type APIKey struct {

View File

@@ -33,6 +33,7 @@ type GatewayHandler struct {
billingCacheService *service.BillingCacheService
usageService *service.UsageService
apiKeyService *service.APIKeyService
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
maxAccountSwitchesGemini int
@@ -49,6 +50,7 @@ func NewGatewayHandler(
billingCacheService *service.BillingCacheService,
usageService *service.UsageService,
apiKeyService *service.APIKeyService,
errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config,
) *GatewayHandler {
pingInterval := time.Duration(0)
@@ -71,6 +73,7 @@ func NewGatewayHandler(
billingCacheService: billingCacheService,
usageService: usageService,
apiKeyService: apiKeyService,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
@@ -203,7 +206,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
var lastFailoverErr *service.UpstreamFailoverError
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
@@ -212,7 +215,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
account := selection.Account
@@ -303,9 +310,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverStatus = failoverErr.StatusCode
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
return
}
switchCount++
@@ -354,7 +361,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
var lastFailoverErr *service.UpstreamFailoverError
retryWithFallback := false
for {
@@ -365,7 +372,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
account := selection.Account
@@ -489,9 +500,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverStatus = failoverErr.StatusCode
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
return
}
switchCount++
@@ -629,10 +640,10 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
return
}
// Best-effort: 获取用量统计,失败不影响基础响应
// Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应
var usageData gin.H
if h.usageService != nil {
dashStats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
dashStats, err := h.usageService.GetAPIKeyDashboardStats(c.Request.Context(), apiKey.ID)
if err == nil && dashStats != nil {
usageData = gin.H{
"today": gin.H{
@@ -768,7 +779,37 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {
if rule := h.errorPassthroughService.MatchRule(platform, statusCode, responseBody); rule != nil {
// 确定响应状态码
respCode := statusCode
if !rule.PassthroughCode && rule.ResponseCode != nil {
respCode = *rule.ResponseCode
}
// 确定响应消息
msg := service.ExtractUpstreamErrorMessage(responseBody)
if !rule.PassthroughBody && rule.CustomMessage != nil {
msg = *rule.CustomMessage
}
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
return
}
}
// 使用默认的错误映射
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}

View File

@@ -253,7 +253,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
var lastFailoverErr *service.UpstreamFailoverError
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
@@ -262,7 +262,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
handleGeminiFailoverExhausted(c, lastFailoverStatus)
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
return
}
account := selection.Account
@@ -353,11 +353,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
handleGeminiFailoverExhausted(c, lastFailoverStatus)
lastFailoverErr = failoverErr
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
return
}
lastFailoverStatus = failoverErr.StatusCode
lastFailoverErr = failoverErr
switchCount++
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
@@ -414,7 +414,36 @@ func parseGeminiModelAction(rest string) (model string, action string, err error
return "", "", &pathParseError{"invalid model action path"}
}
func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) {
func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError) {
if failoverErr == nil {
googleError(c, http.StatusBadGateway, "Upstream request failed")
return
}
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {
if rule := h.errorPassthroughService.MatchRule(service.PlatformGemini, statusCode, responseBody); rule != nil {
// 确定响应状态码
respCode := statusCode
if !rule.PassthroughCode && rule.ResponseCode != nil {
respCode = *rule.ResponseCode
}
// 确定响应消息
msg := service.ExtractUpstreamErrorMessage(responseBody)
if !rule.PassthroughBody && rule.CustomMessage != nil {
msg = *rule.CustomMessage
}
googleError(c, respCode, msg)
return
}
}
// 使用默认的错误映射
status, message := mapGeminiUpstreamError(statusCode)
googleError(c, status, message)
}

View File

@@ -24,6 +24,7 @@ type AdminHandlers struct {
Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler
ErrorPassthrough *admin.ErrorPassthroughHandler
}
// Handlers contains all HTTP handlers

View File

@@ -22,11 +22,12 @@ import (
// OpenAIGatewayHandler handles OpenAI API gateway requests
type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
@@ -35,6 +36,7 @@ func NewOpenAIGatewayHandler(
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
apiKeyService *service.APIKeyService,
errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config,
) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
@@ -46,11 +48,12 @@ func NewOpenAIGatewayHandler(
}
}
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
apiKeyService: apiKeyService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
gatewayService: gatewayService,
billingCacheService: billingCacheService,
apiKeyService: apiKeyService,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
}
}
@@ -201,7 +204,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
var lastFailoverErr *service.UpstreamFailoverError
for {
// Select account supporting the requested model
@@ -213,7 +216,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
account := selection.Account
@@ -278,12 +285,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
@@ -324,7 +330,37 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {
if rule := h.errorPassthroughService.MatchRule("openai", statusCode, responseBody); rule != nil {
// 确定响应状态码
respCode := statusCode
if !rule.PassthroughCode && rule.ResponseCode != nil {
respCode = *rule.ResponseCode
}
// 确定响应消息
msg := service.ExtractUpstreamErrorMessage(responseBody)
if !rule.PassthroughBody && rule.CustomMessage != nil {
msg = *rule.CustomMessage
}
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
return
}
}
// 使用默认的错误映射
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}

View File

@@ -27,6 +27,7 @@ func ProvideAdminHandlers(
subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler,
errorPassthroughHandler *admin.ErrorPassthroughHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
@@ -47,6 +48,7 @@ func ProvideAdminHandlers(
Subscription: subscriptionHandler,
Usage: usageHandler,
UserAttribute: userAttributeHandler,
ErrorPassthrough: errorPassthroughHandler,
}
}
@@ -128,6 +130,7 @@ var ProviderSet = wire.NewSet(
admin.NewSubscriptionHandler,
admin.NewUsageHandler,
admin.NewUserAttributeHandler,
admin.NewErrorPassthroughHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,

View File

@@ -0,0 +1,74 @@
// Package model 定义服务层使用的数据模型。
package model
import "time"
// ErrorPassthroughRule 全局错误透传规则
// 用于控制上游错误如何返回给客户端
type ErrorPassthroughRule struct {
ID int64 `json:"id"`
Name string `json:"name"` // 规则名称
Enabled bool `json:"enabled"` // 是否启用
Priority int `json:"priority"` // 优先级(数字越小优先级越高)
ErrorCodes []int `json:"error_codes"` // 匹配的错误码列表OR关系
Keywords []string `json:"keywords"` // 匹配的关键词列表OR关系
MatchMode string `json:"match_mode"` // "any"(任一条件) 或 "all"(所有条件)
Platforms []string `json:"platforms"` // 适用平台列表
PassthroughCode bool `json:"passthrough_code"` // 是否透传原始状态码
ResponseCode *int `json:"response_code"` // 自定义状态码passthrough_code=false 时使用)
PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息
CustomMessage *string `json:"custom_message"` // 自定义错误信息passthrough_body=false 时使用)
Description *string `json:"description"` // 规则描述
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// MatchModeAny 表示任一条件匹配即可
const MatchModeAny = "any"
// MatchModeAll 表示所有条件都必须匹配
const MatchModeAll = "all"
// 支持的平台常量
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
)
// AllPlatforms 返回所有支持的平台列表
func AllPlatforms() []string {
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity}
}
// Validate 验证规则配置的有效性
func (r *ErrorPassthroughRule) Validate() error {
if r.Name == "" {
return &ValidationError{Field: "name", Message: "name is required"}
}
if r.MatchMode != MatchModeAny && r.MatchMode != MatchModeAll {
return &ValidationError{Field: "match_mode", Message: "match_mode must be 'any' or 'all'"}
}
// 至少需要配置一个匹配条件(错误码或关键词)
if len(r.ErrorCodes) == 0 && len(r.Keywords) == 0 {
return &ValidationError{Field: "conditions", Message: "at least one error_code or keyword is required"}
}
if !r.PassthroughCode && (r.ResponseCode == nil || *r.ResponseCode <= 0) {
return &ValidationError{Field: "response_code", Message: "response_code is required when passthrough_code is false"}
}
if !r.PassthroughBody && (r.CustomMessage == nil || *r.CustomMessage == "") {
return &ValidationError{Field: "custom_message", Message: "custom_message is required when passthrough_body is false"}
}
return nil
}
// ValidationError 表示验证错误
type ValidationError struct {
Field string
Message string
}
func (e *ValidationError) Error() string {
return e.Field + ": " + e.Message
}

View File

@@ -0,0 +1,109 @@
// Package googleapi provides helpers for Google-style API responses.
package googleapi
import (
"encoding/json"
"fmt"
"strings"
)
// ErrorResponse represents a Google API error response
type ErrorResponse struct {
Error ErrorDetail `json:"error"`
}
// ErrorDetail contains the error details from Google API
type ErrorDetail struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
Details []json.RawMessage `json:"details,omitempty"`
}
// ErrorDetailInfo contains additional error information
type ErrorDetailInfo struct {
Type string `json:"@type"`
Reason string `json:"reason,omitempty"`
Domain string `json:"domain,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// ErrorHelp contains help links
type ErrorHelp struct {
Type string `json:"@type"`
Links []HelpLink `json:"links,omitempty"`
}
// HelpLink represents a help link
type HelpLink struct {
Description string `json:"description"`
URL string `json:"url"`
}
// ParseError parses a Google API error response and extracts key information
func ParseError(body string) (*ErrorResponse, error) {
var errResp ErrorResponse
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
return nil, fmt.Errorf("failed to parse error response: %w", err)
}
return &errResp, nil
}
// ExtractActivationURL extracts the API activation URL from error details
func ExtractActivationURL(body string) string {
var errResp ErrorResponse
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
return ""
}
// Check error details for activation URL
for _, detailRaw := range errResp.Error.Details {
// Parse as ErrorDetailInfo
var info ErrorDetailInfo
if err := json.Unmarshal(detailRaw, &info); err == nil {
if info.Metadata != nil {
if activationURL, ok := info.Metadata["activationUrl"]; ok && activationURL != "" {
return activationURL
}
}
}
// Parse as ErrorHelp
var help ErrorHelp
if err := json.Unmarshal(detailRaw, &help); err == nil {
for _, link := range help.Links {
if strings.Contains(link.Description, "activation") ||
strings.Contains(link.Description, "API activation") ||
strings.Contains(link.URL, "/apis/api/") {
return link.URL
}
}
}
}
return ""
}
// IsServiceDisabledError checks if the error is a SERVICE_DISABLED error
func IsServiceDisabledError(body string) bool {
var errResp ErrorResponse
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
return false
}
// Check if it's a 403 PERMISSION_DENIED with SERVICE_DISABLED reason
if errResp.Error.Code != 403 || errResp.Error.Status != "PERMISSION_DENIED" {
return false
}
for _, detailRaw := range errResp.Error.Details {
var info ErrorDetailInfo
if err := json.Unmarshal(detailRaw, &info); err == nil {
if info.Reason == "SERVICE_DISABLED" {
return true
}
}
}
return false
}

View File

@@ -0,0 +1,143 @@
package googleapi
import (
"testing"
)
func TestExtractActivationURL(t *testing.T) {
// Test case from the user's error message
errorBody := `{
"error": {
"code": 403,
"message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry.",
"status": "PERMISSION_DENIED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "SERVICE_DISABLED",
"domain": "googleapis.com",
"metadata": {
"service": "cloudaicompanion.googleapis.com",
"activationUrl": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843",
"consumer": "projects/project-6eca5881-ab73-4736-843",
"serviceTitle": "Gemini for Google Cloud API",
"containerInfo": "project-6eca5881-ab73-4736-843"
}
},
{
"@type": "type.googleapis.com/google.rpc.LocalizedMessage",
"locale": "en-US",
"message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry."
},
{
"@type": "type.googleapis.com/google.rpc.Help",
"links": [
{
"description": "Google developers console API activation",
"url": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843"
}
]
}
]
}
}`
activationURL := ExtractActivationURL(errorBody)
expectedURL := "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843"
if activationURL != expectedURL {
t.Errorf("Expected activation URL %s, got %s", expectedURL, activationURL)
}
}
func TestIsServiceDisabledError(t *testing.T) {
tests := []struct {
name string
body string
expected bool
}{
{
name: "SERVICE_DISABLED error",
body: `{
"error": {
"code": 403,
"status": "PERMISSION_DENIED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "SERVICE_DISABLED"
}
]
}
}`,
expected: true,
},
{
name: "Other 403 error",
body: `{
"error": {
"code": 403,
"status": "PERMISSION_DENIED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "OTHER_REASON"
}
]
}
}`,
expected: false,
},
{
name: "404 error",
body: `{
"error": {
"code": 404,
"status": "NOT_FOUND"
}
}`,
expected: false,
},
{
name: "Invalid JSON",
body: `invalid json`,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsServiceDisabledError(tt.body)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
func TestParseError(t *testing.T) {
errorBody := `{
"error": {
"code": 403,
"message": "API not enabled",
"status": "PERMISSION_DENIED"
}
}`
errResp, err := ParseError(errorBody)
if err != nil {
t.Fatalf("Failed to parse error: %v", err)
}
if errResp.Error.Code != 403 {
t.Errorf("Expected code 403, got %d", errResp.Error.Code)
}
if errResp.Error.Status != "PERMISSION_DENIED" {
t.Errorf("Expected status PERMISSION_DENIED, got %s", errResp.Error.Status)
}
if errResp.Error.Message != "API not enabled" {
t.Errorf("Expected message 'API not enabled', got %s", errResp.Error.Message)
}
}

View File

@@ -0,0 +1,128 @@
package repository
import (
"context"
"encoding/json"
"log"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
errorPassthroughCacheKey = "error_passthrough_rules"
errorPassthroughPubSubKey = "error_passthrough_rules_updated"
errorPassthroughCacheTTL = 24 * time.Hour
)
type errorPassthroughCache struct {
rdb *redis.Client
localCache []*model.ErrorPassthroughRule
localMu sync.RWMutex
}
// NewErrorPassthroughCache 创建错误透传规则缓存
func NewErrorPassthroughCache(rdb *redis.Client) service.ErrorPassthroughCache {
return &errorPassthroughCache{
rdb: rdb,
}
}
// Get 从缓存获取规则列表
func (c *errorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) {
// 先检查本地缓存
c.localMu.RLock()
if c.localCache != nil {
rules := c.localCache
c.localMu.RUnlock()
return rules, true
}
c.localMu.RUnlock()
// 从 Redis 获取
data, err := c.rdb.Get(ctx, errorPassthroughCacheKey).Bytes()
if err != nil {
if err != redis.Nil {
log.Printf("[ErrorPassthroughCache] Failed to get from Redis: %v", err)
}
return nil, false
}
var rules []*model.ErrorPassthroughRule
if err := json.Unmarshal(data, &rules); err != nil {
log.Printf("[ErrorPassthroughCache] Failed to unmarshal rules: %v", err)
return nil, false
}
// 更新本地缓存
c.localMu.Lock()
c.localCache = rules
c.localMu.Unlock()
return rules, true
}
// Set 设置缓存
func (c *errorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error {
data, err := json.Marshal(rules)
if err != nil {
return err
}
if err := c.rdb.Set(ctx, errorPassthroughCacheKey, data, errorPassthroughCacheTTL).Err(); err != nil {
return err
}
// 更新本地缓存
c.localMu.Lock()
c.localCache = rules
c.localMu.Unlock()
return nil
}
// Invalidate 使缓存失效
func (c *errorPassthroughCache) Invalidate(ctx context.Context) error {
// 清除本地缓存
c.localMu.Lock()
c.localCache = nil
c.localMu.Unlock()
// 清除 Redis 缓存
return c.rdb.Del(ctx, errorPassthroughCacheKey).Err()
}
// NotifyUpdate 通知其他实例刷新缓存
func (c *errorPassthroughCache) NotifyUpdate(ctx context.Context) error {
return c.rdb.Publish(ctx, errorPassthroughPubSubKey, "refresh").Err()
}
// SubscribeUpdates 订阅缓存更新通知
func (c *errorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) {
go func() {
sub := c.rdb.Subscribe(ctx, errorPassthroughPubSubKey)
defer func() { _ = sub.Close() }()
ch := sub.Channel()
for {
select {
case <-ctx.Done():
return
case msg := <-ch:
if msg == nil {
return
}
// 清除本地缓存,下次访问时会从 Redis 或数据库重新加载
c.localMu.Lock()
c.localCache = nil
c.localMu.Unlock()
// 调用处理函数
handler()
}
}
}()
}

View File

@@ -0,0 +1,178 @@
package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type errorPassthroughRepository struct {
client *ent.Client
}
// NewErrorPassthroughRepository 创建错误透传规则仓库
func NewErrorPassthroughRepository(client *ent.Client) service.ErrorPassthroughRepository {
return &errorPassthroughRepository{client: client}
}
// List 获取所有规则
func (r *errorPassthroughRepository) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
rules, err := r.client.ErrorPassthroughRule.Query().
Order(ent.Asc(errorpassthroughrule.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
result := make([]*model.ErrorPassthroughRule, len(rules))
for i, rule := range rules {
result[i] = r.toModel(rule)
}
return result, nil
}
// GetByID 根据 ID 获取规则
func (r *errorPassthroughRepository) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
rule, err := r.client.ErrorPassthroughRule.Get(ctx, id)
if err != nil {
if ent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
return r.toModel(rule), nil
}
// Create 创建规则
func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
builder := r.client.ErrorPassthroughRule.Create().
SetName(rule.Name).
SetEnabled(rule.Enabled).
SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody)
if len(rule.ErrorCodes) > 0 {
builder.SetErrorCodes(rule.ErrorCodes)
}
if len(rule.Keywords) > 0 {
builder.SetKeywords(rule.Keywords)
}
if len(rule.Platforms) > 0 {
builder.SetPlatforms(rule.Platforms)
}
if rule.ResponseCode != nil {
builder.SetResponseCode(*rule.ResponseCode)
}
if rule.CustomMessage != nil {
builder.SetCustomMessage(*rule.CustomMessage)
}
if rule.Description != nil {
builder.SetDescription(*rule.Description)
}
created, err := builder.Save(ctx)
if err != nil {
return nil, err
}
return r.toModel(created), nil
}
// Update 更新规则
func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
builder := r.client.ErrorPassthroughRule.UpdateOneID(rule.ID).
SetName(rule.Name).
SetEnabled(rule.Enabled).
SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody)
// 处理可选字段
if len(rule.ErrorCodes) > 0 {
builder.SetErrorCodes(rule.ErrorCodes)
} else {
builder.ClearErrorCodes()
}
if len(rule.Keywords) > 0 {
builder.SetKeywords(rule.Keywords)
} else {
builder.ClearKeywords()
}
if len(rule.Platforms) > 0 {
builder.SetPlatforms(rule.Platforms)
} else {
builder.ClearPlatforms()
}
if rule.ResponseCode != nil {
builder.SetResponseCode(*rule.ResponseCode)
} else {
builder.ClearResponseCode()
}
if rule.CustomMessage != nil {
builder.SetCustomMessage(*rule.CustomMessage)
} else {
builder.ClearCustomMessage()
}
if rule.Description != nil {
builder.SetDescription(*rule.Description)
} else {
builder.ClearDescription()
}
updated, err := builder.Save(ctx)
if err != nil {
return nil, err
}
return r.toModel(updated), nil
}
// Delete 删除规则
func (r *errorPassthroughRepository) Delete(ctx context.Context, id int64) error {
return r.client.ErrorPassthroughRule.DeleteOneID(id).Exec(ctx)
}
// toModel 将 Ent 实体转换为服务模型
func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model.ErrorPassthroughRule {
rule := &model.ErrorPassthroughRule{
ID: int64(e.ID),
Name: e.Name,
Enabled: e.Enabled,
Priority: e.Priority,
ErrorCodes: e.ErrorCodes,
Keywords: e.Keywords,
MatchMode: e.MatchMode,
Platforms: e.Platforms,
PassthroughCode: e.PassthroughCode,
PassthroughBody: e.PassthroughBody,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
}
if e.ResponseCode != nil {
rule.ResponseCode = e.ResponseCode
}
if e.CustomMessage != nil {
rule.CustomMessage = e.CustomMessage
}
if e.Description != nil {
rule.Description = e.Description
}
// 确保切片不为 nil
if rule.ErrorCodes == nil {
rule.ErrorCodes = []int{}
}
if rule.Keywords == nil {
rule.Keywords = []string{}
}
if rule.Platforms == nil {
rule.Platforms = []string{}
}
return rule
}

View File

@@ -6,6 +6,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
@@ -38,9 +39,20 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
body := geminicli.SanitizeBodyForLogs(resp.String())
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, body)
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, body)
body := resp.String()
sanitizedBody := geminicli.SanitizeBodyForLogs(body)
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
// Check if this is a SERVICE_DISABLED error and extract activation URL
if googleapi.IsServiceDisabledError(body) {
activationURL := googleapi.ExtractActivationURL(body)
if activationURL != "" {
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
}
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
}
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
}
fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out)
return &out, nil
@@ -67,9 +79,20 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
body := geminicli.SanitizeBodyForLogs(resp.String())
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, body)
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, body)
body := resp.String()
sanitizedBody := geminicli.SanitizeBodyForLogs(body)
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
// Check if this is a SERVICE_DISABLED error and extract activation URL
if googleapi.IsServiceDisabledError(body) {
activationURL := googleapi.ExtractActivationURL(body)
if activationURL != "" {
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
}
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
}
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
}
fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out)
return &out, nil

View File

@@ -0,0 +1,158 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
refreshTokenKeyPrefix = "refresh_token:"
userRefreshTokensPrefix = "user_refresh_tokens:"
tokenFamilyPrefix = "token_family:"
)
// refreshTokenKey generates the Redis key for a refresh token.
func refreshTokenKey(tokenHash string) string {
return refreshTokenKeyPrefix + tokenHash
}
// userRefreshTokensKey generates the Redis key for user's token set.
func userRefreshTokensKey(userID int64) string {
return fmt.Sprintf("%s%d", userRefreshTokensPrefix, userID)
}
// tokenFamilyKey generates the Redis key for token family set.
func tokenFamilyKey(familyID string) string {
return tokenFamilyPrefix + familyID
}
type refreshTokenCache struct {
rdb *redis.Client
}
// NewRefreshTokenCache creates a new RefreshTokenCache implementation.
func NewRefreshTokenCache(rdb *redis.Client) service.RefreshTokenCache {
return &refreshTokenCache{rdb: rdb}
}
func (c *refreshTokenCache) StoreRefreshToken(ctx context.Context, tokenHash string, data *service.RefreshTokenData, ttl time.Duration) error {
key := refreshTokenKey(tokenHash)
val, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("marshal refresh token data: %w", err)
}
return c.rdb.Set(ctx, key, val, ttl).Err()
}
func (c *refreshTokenCache) GetRefreshToken(ctx context.Context, tokenHash string) (*service.RefreshTokenData, error) {
key := refreshTokenKey(tokenHash)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return nil, service.ErrRefreshTokenNotFound
}
return nil, err
}
var data service.RefreshTokenData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, fmt.Errorf("unmarshal refresh token data: %w", err)
}
return &data, nil
}
func (c *refreshTokenCache) DeleteRefreshToken(ctx context.Context, tokenHash string) error {
key := refreshTokenKey(tokenHash)
return c.rdb.Del(ctx, key).Err()
}
func (c *refreshTokenCache) DeleteUserRefreshTokens(ctx context.Context, userID int64) error {
// Get all token hashes for this user
tokenHashes, err := c.GetUserTokenHashes(ctx, userID)
if err != nil && err != redis.Nil {
return fmt.Errorf("get user token hashes: %w", err)
}
if len(tokenHashes) == 0 {
return nil
}
// Build keys to delete
keys := make([]string, 0, len(tokenHashes)+1)
for _, hash := range tokenHashes {
keys = append(keys, refreshTokenKey(hash))
}
keys = append(keys, userRefreshTokensKey(userID))
// Delete all keys in a pipeline
pipe := c.rdb.Pipeline()
for _, key := range keys {
pipe.Del(ctx, key)
}
_, err = pipe.Exec(ctx)
return err
}
func (c *refreshTokenCache) DeleteTokenFamily(ctx context.Context, familyID string) error {
// Get all token hashes in this family
tokenHashes, err := c.GetFamilyTokenHashes(ctx, familyID)
if err != nil && err != redis.Nil {
return fmt.Errorf("get family token hashes: %w", err)
}
if len(tokenHashes) == 0 {
return nil
}
// Build keys to delete
keys := make([]string, 0, len(tokenHashes)+1)
for _, hash := range tokenHashes {
keys = append(keys, refreshTokenKey(hash))
}
keys = append(keys, tokenFamilyKey(familyID))
// Delete all keys in a pipeline
pipe := c.rdb.Pipeline()
for _, key := range keys {
pipe.Del(ctx, key)
}
_, err = pipe.Exec(ctx)
return err
}
func (c *refreshTokenCache) AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error {
key := userRefreshTokensKey(userID)
pipe := c.rdb.Pipeline()
pipe.SAdd(ctx, key, tokenHash)
pipe.Expire(ctx, key, ttl)
_, err := pipe.Exec(ctx)
return err
}
func (c *refreshTokenCache) AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error {
key := tokenFamilyKey(familyID)
pipe := c.rdb.Pipeline()
pipe.SAdd(ctx, key, tokenHash)
pipe.Expire(ctx, key, ttl)
_, err := pipe.Exec(ctx)
return err
}
func (c *refreshTokenCache) GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error) {
key := userRefreshTokensKey(userID)
return c.rdb.SMembers(ctx, key).Result()
}
func (c *refreshTokenCache) GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error) {
key := tokenFamilyKey(familyID)
return c.rdb.SMembers(ctx, key).Result()
}
func (c *refreshTokenCache) IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error) {
key := tokenFamilyKey(familyID)
return c.rdb.SIsMember(ctx, key, tokenHash).Result()
}

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"fmt"
"log"
"strconv"
"time"
@@ -153,6 +154,21 @@ func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) serv
if defaultIdleTimeoutMinutes <= 0 {
defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
}
// 预加载 Lua 脚本到 Redis避免 Pipeline 中出现 NOSCRIPT 错误
ctx := context.Background()
scripts := []*redis.Script{
registerSessionScript,
refreshSessionScript,
getActiveSessionCountScript,
isSessionActiveScript,
}
for _, script := range scripts {
if err := script.Load(ctx, rdb).Err(); err != nil {
log.Printf("[SessionLimitCache] Failed to preload Lua script: %v", err)
}
}
return &sessionLimitCache{
rdb: rdb,
defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,

View File

@@ -1128,6 +1128,107 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
return stats, nil
}
// getPerformanceStatsByAPIKey 获取指定 API Key 的 RPM 和 TPM近5分钟平均值
func (r *usageLogRepository) getPerformanceStatsByAPIKey(ctx context.Context, apiKeyID int64) (rpm, tpm int64, err error) {
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
query := `
SELECT
COUNT(*) as request_count,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as token_count
FROM usage_logs
WHERE created_at >= $1 AND api_key_id = $2`
args := []any{fiveMinutesAgo, apiKeyID}
var requestCount int64
var tokenCount int64
if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil {
return 0, 0, err
}
return requestCount / 5, tokenCount / 5, nil
}
// GetAPIKeyDashboardStats 获取指定 API Key 的仪表盘统计(按 api_key_id 过滤)
func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*UserDashboardStats, error) {
stats := &UserDashboardStats{}
today := timezone.Today()
// API Key 维度不需要统计 key 数量,设为 1
stats.TotalAPIKeys = 1
stats.ActiveAPIKeys = 1
// 累计 Token 统计
totalStatsQuery := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
WHERE api_key_id = $1
`
if err := scanSingleRow(
ctx,
r.sql,
totalStatsQuery,
[]any{apiKeyID},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheCreationTokens,
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
// 今日 Token 统计
todayStatsQuery := `
SELECT
COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as today_cost,
COALESCE(SUM(actual_cost), 0) as today_actual_cost
FROM usage_logs
WHERE api_key_id = $1 AND created_at >= $2
`
if err := scanSingleRow(
ctx,
r.sql,
todayStatsQuery,
[]any{apiKeyID, today},
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
&stats.TodayCacheCreationTokens,
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
return nil, err
}
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
// 性能指标RPM 和 TPM最近5分钟按 API Key 过滤)
rpm, tpm, err := r.getPerformanceStatsByAPIKey(ctx, apiKeyID)
if err != nil {
return nil, err
}
stats.Rpm = rpm
stats.Tpm = tpm
return stats, nil
}
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"

View File

@@ -0,0 +1,113 @@
package repository
import (
"context"
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type userGroupRateRepository struct {
sql sqlExecutor
}
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
return &userGroupRateRepository{sql: sqlDB}
}
// GetByUserID 获取用户的所有专属分组倍率
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
rows, err := r.sql.QueryContext(ctx, query, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
result := make(map[int64]float64)
for rows.Next() {
var groupID int64
var rate float64
if err := rows.Scan(&groupID, &rate); err != nil {
return nil, err
}
result[groupID] = rate
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// GetByUserAndGroup 获取用户在特定分组的专属倍率
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var rate float64
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &rate, nil
}
// SyncUserGroupRates 同步用户的分组专属倍率
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
if len(rates) == 0 {
// 如果传入空 map删除该用户的所有专属倍率
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err
}
// 分离需要删除和需要 upsert 的记录
var toDelete []int64
toUpsert := make(map[int64]float64)
for groupID, rate := range rates {
if rate == nil {
toDelete = append(toDelete, groupID)
} else {
toUpsert[groupID] = *rate
}
}
// 删除指定的记录
for _, groupID := range toDelete {
_, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`,
userID, groupID)
if err != nil {
return err
}
}
// Upsert 记录
now := time.Now()
for groupID, rate := range toUpsert {
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
VALUES ($1, $2, $3, $4, $4)
ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4
`, userID, groupID, rate, now)
if err != nil {
return err
}
}
return nil
}
// DeleteByGroupID 删除指定分组的所有用户专属倍率
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
return err
}
// DeleteByUserID 删除指定用户的所有专属倍率
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err
}

View File

@@ -67,6 +67,8 @@ var ProviderSet = wire.NewSet(
NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository,
NewUserAttributeValueRepository,
NewUserGroupRateRepository,
NewErrorPassthroughRepository,
// Cache implementations
NewGatewayCache,
@@ -86,6 +88,8 @@ var ProviderSet = wire.NewSet(
NewSchedulerOutboxRepository,
NewProxyLatencyCache,
NewTotpCache,
NewRefreshTokenCache,
NewErrorPassthroughCache,
// Encryptors
NewAESEncryptor,

View File

@@ -598,7 +598,7 @@ func newContractDeps(t *testing.T) *contractDeps {
}
userService := service.NewUserService(userRepo, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
@@ -612,7 +612,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
@@ -1619,6 +1619,10 @@ func (r *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented")
}

View File

@@ -58,10 +58,39 @@ func ProvideRouter(
// ProvideHTTPServer 提供 HTTP 服务器
func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
handler := h2c.NewHandler(router, &http2.Server{})
httpHandler := http.Handler(router)
globalMaxSize := cfg.Server.MaxRequestBodySize
if globalMaxSize <= 0 {
globalMaxSize = cfg.Gateway.MaxBodySize
}
if globalMaxSize > 0 {
httpHandler = http.MaxBytesHandler(httpHandler, globalMaxSize)
log.Printf("Global max request body size: %d bytes (%.2f MB)", globalMaxSize, float64(globalMaxSize)/(1<<20))
}
// 根据配置决定是否启用 H2C
if cfg.Server.H2C.Enabled {
h2cConfig := cfg.Server.H2C
httpHandler = h2c.NewHandler(router, &http2.Server{
MaxConcurrentStreams: h2cConfig.MaxConcurrentStreams,
IdleTimeout: time.Duration(h2cConfig.IdleTimeout) * time.Second,
MaxReadFrameSize: uint32(h2cConfig.MaxReadFrameSize),
MaxUploadBufferPerConnection: int32(h2cConfig.MaxUploadBufferPerConnection),
MaxUploadBufferPerStream: int32(h2cConfig.MaxUploadBufferPerStream),
})
log.Printf("HTTP/2 Cleartext (h2c) enabled: max_concurrent_streams=%d, idle_timeout=%ds, max_read_frame_size=%d, max_upload_buffer_per_connection=%d, max_upload_buffer_per_stream=%d",
h2cConfig.MaxConcurrentStreams,
h2cConfig.IdleTimeout,
h2cConfig.MaxReadFrameSize,
h2cConfig.MaxUploadBufferPerConnection,
h2cConfig.MaxUploadBufferPerStream,
)
}
return &http.Server{
Addr: cfg.Server.Address(),
Handler: handler,
Handler: httpHandler,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源

View File

@@ -93,6 +93,7 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService
nil, // userRepo (unused in GetByKey)
nil, // groupRepo
nil, // userSubRepo
nil, // userGroupRateRepo
nil, // cache
&config.Config{},
)
@@ -187,6 +188,7 @@ func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
nil,
nil,
nil,
nil,
&config.Config{RunMode: config.RunModeSimple},
)

View File

@@ -59,7 +59,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
@@ -73,7 +73,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
now := time.Now()
sub := &service.UserSubscription{
@@ -150,7 +150,7 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) {
@@ -208,7 +208,7 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))

View File

@@ -34,12 +34,16 @@ func Logger() gin.HandlerFunc {
// 客户端IP
clientIP := c.ClientIP()
// 日志格式: [时间] 状态码 | 延迟 | IP | 方法 路径
log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s",
// 协议版本
protocol := c.Request.Proto
// 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径
log.Printf("[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s",
endTime.Format("2006/01/02 - 15:04:05"),
statusCode,
latency,
clientIP,
protocol,
method,
path,
)

View File

@@ -67,6 +67,9 @@ func RegisterAdminRoutes(
// 用户属性管理
registerUserAttributeRoutes(admin, h)
// 错误透传规则管理
registerErrorPassthroughRoutes(admin, h)
}
}
@@ -387,3 +390,14 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
}
}
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
rules := admin.Group("/error-passthrough-rules")
{
rules.GET("", h.Admin.ErrorPassthrough.List)
rules.GET("/:id", h.Admin.ErrorPassthrough.GetByID)
rules.POST("", h.Admin.ErrorPassthrough.Create)
rules.PUT("/:id", h.Admin.ErrorPassthrough.Update)
rules.DELETE("/:id", h.Admin.ErrorPassthrough.Delete)
}
}

View File

@@ -28,6 +28,12 @@ func RegisterAuthRoutes(
auth.POST("/login", h.Auth.Login)
auth.POST("/login/2fa", h.Auth.Login2FA)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// Token刷新接口添加速率限制每分钟最多 30 次Redis 故障时 fail-close
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.RefreshToken)
// 登出接口公开允许未认证用户调用以撤销Refresh Token
auth.POST("/logout", h.Auth.Logout)
// 优惠码验证接口添加速率限制:每分钟最多 10 次Redis 故障时 fail-close
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
@@ -59,5 +65,7 @@ func RegisterAuthRoutes(
authenticated.Use(gin.HandlerFunc(jwtAuth))
{
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 撤销所有会话(需要认证)
authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions)
}
}

View File

@@ -49,6 +49,7 @@ func RegisterUserRoutes(
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
groups.GET("/rates", h.APIKey.GetUserGroupRates)
}
// 使用记录

View File

@@ -41,6 +41,7 @@ type UsageLogRepository interface {
// User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error)
GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error)
GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error)

View File

@@ -93,6 +93,9 @@ type UpdateUserInput struct {
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
// map[groupID]*ratenil 表示删除该分组的专属倍率
GroupRates map[int64]*float64
}
type CreateGroupInput struct {
@@ -304,6 +307,7 @@ type adminServiceImpl struct {
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
userGroupRateRepo UserGroupRateRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
@@ -319,6 +323,7 @@ func NewAdminService(
proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
userGroupRateRepo UserGroupRateRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
@@ -332,6 +337,7 @@ func NewAdminService(
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
userGroupRateRepo: userGroupRateRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
@@ -346,11 +352,35 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
if err != nil {
return nil, 0, err
}
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
}
users[i].GroupRates = rates
}
}
return users, result.Total, nil
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
return s.userRepo.GetByID(ctx, id)
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// 加载用户专属分组倍率
if s.userGroupRateRepo != nil {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", id, err)
} else {
user.GroupRates = rates
}
}
return user, nil
}
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
@@ -419,6 +449,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
// 同步用户专属分组倍率
if input.GroupRates != nil && s.userGroupRateRepo != nil {
if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil {
log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err)
}
}
if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
@@ -974,6 +1012,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
if err != nil {
return err
}
// 注意user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理
// 事务成功后,异步失效受影响用户的订阅缓存
if len(affectedUserIDs) > 0 && s.billingCacheService != nil {

View File

@@ -1106,7 +1106,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
@@ -1779,6 +1779,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
contentType := resp.Header.Get("Content-Type")
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body因此用内存副本重新包装。
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
@@ -1849,10 +1850,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps}
}
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}

View File

@@ -115,15 +115,16 @@ type UpdateAPIKeyRequest struct {
// APIKeyService API Key服务
type APIKeyService struct {
apiKeyRepo APIKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
cache APIKeyCache
cfg *config.Config
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
apiKeyRepo APIKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache APIKeyCache
cfg *config.Config
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
}
// NewAPIKeyService 创建API Key服务实例
@@ -132,16 +133,18 @@ func NewAPIKeyService(
userRepo UserRepository,
groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache APIKeyCache,
cfg *config.Config,
) *APIKeyService {
svc := &APIKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache,
cfg: cfg,
}
svc.initAuthCache(cfg)
return svc
@@ -627,6 +630,19 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
return keys, nil
}
// GetUserGroupRates 获取用户的专属分组倍率配置
// 返回 map[groupID]rateMultiplier
func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) {
if s.userGroupRateRepo == nil {
return nil, nil
}
rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user group rates: %w", err)
}
return rates, nil
}
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
// Returns nil if valid, error if invalid
func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error {

View File

@@ -167,7 +167,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{
@@ -223,7 +223,7 @@ func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{NotFound: true}, nil
}
@@ -256,7 +256,7 @@ func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
@@ -293,7 +293,7 @@ func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
L1TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
require.NotNil(t, svc.authCacheL1)
_, err := svc.GetByKey(context.Background(), "k-l1")
@@ -320,7 +320,7 @@ func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
require.Len(t, cache.deleteAuthKeys, 2)
@@ -338,7 +338,7 @@ func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
require.Len(t, cache.deleteAuthKeys, 2)
@@ -356,7 +356,7 @@ func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
require.Len(t, cache.deleteAuthKeys, 1)
@@ -375,7 +375,7 @@ func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
@@ -411,7 +411,7 @@ func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
Singleflight: true,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
start := make(chan struct{})
wg := sync.WaitGroup{}

View File

@@ -3,6 +3,7 @@ package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
@@ -25,8 +26,12 @@ var (
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
@@ -37,6 +42,9 @@ var (
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
const maxTokenLength = 8192
// refreshTokenPrefix is the prefix for refresh tokens to distinguish them from access tokens.
const refreshTokenPrefix = "rt_"
// JWTClaims JWT载荷数据
type JWTClaims struct {
UserID int64 `json:"user_id"`
@@ -50,6 +58,7 @@ type JWTClaims struct {
type AuthService struct {
userRepo UserRepository
redeemRepo RedeemCodeRepository
refreshTokenCache RefreshTokenCache
cfg *config.Config
settingService *SettingService
emailService *EmailService
@@ -62,6 +71,7 @@ type AuthService struct {
func NewAuthService(
userRepo UserRepository,
redeemRepo RedeemCodeRepository,
refreshTokenCache RefreshTokenCache,
cfg *config.Config,
settingService *SettingService,
emailService *EmailService,
@@ -72,6 +82,7 @@ func NewAuthService(
return &AuthService{
userRepo: userRepo,
redeemRepo: redeemRepo,
refreshTokenCache: refreshTokenCache,
cfg: cfg,
settingService: settingService,
emailService: emailService,
@@ -481,6 +492,100 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
return token, user, nil
}
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, nil, errors.New("refresh token cache not configured")
}
email = strings.TrimSpace(email)
if email == "" || len(email) > 255 {
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
if _, err := mail.ParseAddress(email); err != nil {
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
username = strings.TrimSpace(username)
if len([]rune(username)) > 100 {
username = string([]rune(username)[:100])
}
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// OAuth 首次登录视为注册
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return nil, nil, ErrRegDisabled
}
randomPassword, err := randomHexString(32)
if err != nil {
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
return nil, nil, ErrServiceUnavailable
}
hashedPassword, err := s.HashPassword(randomPassword)
if err != nil {
return nil, nil, fmt.Errorf("hash password: %w", err)
}
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
newUser := &User{
Email: email,
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Status: StatusActive,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) {
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
log.Printf("[Auth] Database error getting user after conflict: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
log.Printf("[Auth] Database error creating oauth user: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
user = newUser
}
} else {
log.Printf("[Auth] Database error during oauth login: %v", err)
return nil, nil, ErrServiceUnavailable
}
}
if !user.IsActive() {
return nil, nil, ErrUserNotActive
}
if user.Username == "" && username != "" {
user.Username = username
if err := s.userRepo.Update(ctx, user); err != nil {
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
}
}
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
return nil, nil, fmt.Errorf("generate token pair: %w", err)
}
return tokenPair, user, nil
}
// ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
// 先做长度校验,尽早拒绝异常超长 token降低 DoS 风险。
@@ -539,10 +644,17 @@ func isReservedEmail(email string) bool {
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT token
// GenerateToken 生成JWT access token
// 使用新的access_token_expire_minutes配置项如果配置了否则回退到expire_hour
func (s *AuthService) GenerateToken(user *User) (string, error) {
now := time.Now()
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
var expiresAt time.Time
if s.cfg.JWT.AccessTokenExpireMinutes > 0 {
expiresAt = now.Add(time.Duration(s.cfg.JWT.AccessTokenExpireMinutes) * time.Minute)
} else {
// 向后兼容使用旧的expire_hour配置
expiresAt = now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
}
claims := &JWTClaims{
UserID: user.ID,
@@ -565,6 +677,15 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
return tokenString, nil
}
// GetAccessTokenExpiresIn 返回Access Token的有效期
// 用于前端设置刷新定时器
func (s *AuthService) GetAccessTokenExpiresIn() int {
if s.cfg.JWT.AccessTokenExpireMinutes > 0 {
return s.cfg.JWT.AccessTokenExpireMinutes * 60
}
return s.cfg.JWT.ExpireHour * 3600
}
// HashPassword 使用bcrypt加密密码
func (s *AuthService) HashPassword(password string) (string, error) {
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
@@ -755,6 +876,198 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo
return ErrServiceUnavailable
}
// Also revoke all refresh tokens for this user
if err := s.RevokeAllUserSessions(ctx, user.ID); err != nil {
log.Printf("[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err)
// Don't return error - password was already changed successfully
}
log.Printf("[Auth] Password reset successful for user: %s", email)
return nil
}
// ==================== Refresh Token Methods ====================
// TokenPair 包含Access Token和Refresh Token
type TokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"` // Access Token有效期
}
// GenerateTokenPair 生成Access Token和Refresh Token对
// familyID: 可选的Token家族ID用于Token轮转时保持家族关系
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, errors.New("refresh token cache not configured")
}
// 生成Access Token
accessToken, err := s.GenerateToken(user)
if err != nil {
return nil, fmt.Errorf("generate access token: %w", err)
}
// 生成Refresh Token
refreshToken, err := s.generateRefreshToken(ctx, user, familyID)
if err != nil {
return nil, fmt.Errorf("generate refresh token: %w", err)
}
return &TokenPair{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: s.GetAccessTokenExpiresIn(),
}, nil
}
// generateRefreshToken 生成并存储Refresh Token
func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, familyID string) (string, error) {
// 生成随机Token
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
rawToken := refreshTokenPrefix + hex.EncodeToString(tokenBytes)
// 计算Token哈希存储哈希而非原始Token
tokenHash := hashToken(rawToken)
// 如果没有提供familyID生成新的
if familyID == "" {
familyBytes := make([]byte, 16)
if _, err := rand.Read(familyBytes); err != nil {
return "", fmt.Errorf("generate family id: %w", err)
}
familyID = hex.EncodeToString(familyBytes)
}
now := time.Now()
ttl := time.Duration(s.cfg.JWT.RefreshTokenExpireDays) * 24 * time.Hour
data := &RefreshTokenData{
UserID: user.ID,
TokenVersion: user.TokenVersion,
FamilyID: familyID,
CreatedAt: now,
ExpiresAt: now.Add(ttl),
}
// 存储Token数据
if err := s.refreshTokenCache.StoreRefreshToken(ctx, tokenHash, data, ttl); err != nil {
return "", fmt.Errorf("store refresh token: %w", err)
}
// 添加到用户Token集合
if err := s.refreshTokenCache.AddToUserTokenSet(ctx, user.ID, tokenHash, ttl); err != nil {
log.Printf("[Auth] Failed to add token to user set: %v", err)
// 不影响主流程
}
// 添加到家族Token集合
if err := s.refreshTokenCache.AddToFamilyTokenSet(ctx, familyID, tokenHash, ttl); err != nil {
log.Printf("[Auth] Failed to add token to family set: %v", err)
// 不影响主流程
}
return rawToken, nil
}
// RefreshTokenPair 使用Refresh Token刷新Token对
// 实现Token轮转每次刷新都会生成新的Refresh Token旧Token立即失效
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, ErrRefreshTokenInvalid
}
// 验证Token格式
if !strings.HasPrefix(refreshToken, refreshTokenPrefix) {
return nil, ErrRefreshTokenInvalid
}
tokenHash := hashToken(refreshToken)
// 获取Token数据
data, err := s.refreshTokenCache.GetRefreshToken(ctx, tokenHash)
if err != nil {
if errors.Is(err, ErrRefreshTokenNotFound) {
// Token不存在可能是已被使用Token轮转或已过期
log.Printf("[Auth] Refresh token not found, possible reuse attack")
return nil, ErrRefreshTokenInvalid
}
log.Printf("[Auth] Error getting refresh token: %v", err)
return nil, ErrServiceUnavailable
}
// 检查Token是否过期
if time.Now().After(data.ExpiresAt) {
// 删除过期Token
_ = s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash)
return nil, ErrRefreshTokenExpired
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, data.UserID)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// 用户已删除撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrRefreshTokenInvalid
}
log.Printf("[Auth] Database error getting user for token refresh: %v", err)
return nil, ErrServiceUnavailable
}
// 检查用户状态
if !user.IsActive() {
// 用户被禁用撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrUserNotActive
}
// 检查TokenVersion密码更改后所有Token失效
if data.TokenVersion != user.TokenVersion {
// TokenVersion不匹配撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrTokenRevoked
}
// Token轮转立即使旧Token失效
if err := s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash); err != nil {
log.Printf("[Auth] Failed to delete old refresh token: %v", err)
// 继续处理,不影响主流程
}
// 生成新的Token对保持同一个家族ID
return s.GenerateTokenPair(ctx, user, data.FamilyID)
}
// RevokeRefreshToken 撤销单个Refresh Token
func (s *AuthService) RevokeRefreshToken(ctx context.Context, refreshToken string) error {
if s.refreshTokenCache == nil {
return nil // No-op if cache not configured
}
if !strings.HasPrefix(refreshToken, refreshTokenPrefix) {
return ErrRefreshTokenInvalid
}
tokenHash := hashToken(refreshToken)
return s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash)
}
// RevokeAllUserSessions 撤销用户的所有会话所有Refresh Token
// 用于密码更改或用户主动登出所有设备
func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) error {
if s.refreshTokenCache == nil {
return nil // No-op if cache not configured
}
return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID)
}
// hashToken 计算Token的SHA256哈希
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}

View File

@@ -116,6 +116,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
return NewAuthService(
repo,
nil, // redeemRepo
nil, // refreshTokenCache
cfg,
settingService,
emailService,

View File

@@ -0,0 +1,300 @@
package service
import (
"context"
"log"
"sort"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/model"
)
// ErrorPassthroughRepository 定义错误透传规则的数据访问接口
type ErrorPassthroughRepository interface {
// List 获取所有规则
List(ctx context.Context) ([]*model.ErrorPassthroughRule, error)
// GetByID 根据 ID 获取规则
GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error)
// Create 创建规则
Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
// Update 更新规则
Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
// Delete 删除规则
Delete(ctx context.Context, id int64) error
}
// ErrorPassthroughCache 定义错误透传规则的缓存接口
type ErrorPassthroughCache interface {
// Get 从缓存获取规则列表
Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool)
// Set 设置缓存
Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error
// Invalidate 使缓存失效
Invalidate(ctx context.Context) error
// NotifyUpdate 通知其他实例刷新缓存
NotifyUpdate(ctx context.Context) error
// SubscribeUpdates 订阅缓存更新通知
SubscribeUpdates(ctx context.Context, handler func())
}
// ErrorPassthroughService 错误透传规则服务
type ErrorPassthroughService struct {
repo ErrorPassthroughRepository
cache ErrorPassthroughCache
// 本地内存缓存,用于快速匹配
localCache []*model.ErrorPassthroughRule
localCacheMu sync.RWMutex
}
// NewErrorPassthroughService 创建错误透传规则服务
func NewErrorPassthroughService(
repo ErrorPassthroughRepository,
cache ErrorPassthroughCache,
) *ErrorPassthroughService {
svc := &ErrorPassthroughService{
repo: repo,
cache: cache,
}
// 启动时加载规则到本地缓存
ctx := context.Background()
if err := svc.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to load rules on startup: %v", err)
}
// 订阅缓存更新通知
if cache != nil {
cache.SubscribeUpdates(ctx, func() {
if err := svc.refreshLocalCache(context.Background()); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err)
}
})
}
return svc
}
// List 获取所有规则
func (s *ErrorPassthroughService) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
return s.repo.List(ctx)
}
// GetByID 根据 ID 获取规则
func (s *ErrorPassthroughService) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
return s.repo.GetByID(ctx, id)
}
// Create 创建规则
func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
if err := rule.Validate(); err != nil {
return nil, err
}
created, err := s.repo.Create(ctx, rule)
if err != nil {
return nil, err
}
// 刷新缓存
s.invalidateAndNotify(ctx)
return created, nil
}
// Update 更新规则
func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
if err := rule.Validate(); err != nil {
return nil, err
}
updated, err := s.repo.Update(ctx, rule)
if err != nil {
return nil, err
}
// 刷新缓存
s.invalidateAndNotify(ctx)
return updated, nil
}
// Delete 删除规则
func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error {
if err := s.repo.Delete(ctx, id); err != nil {
return err
}
// 刷新缓存
s.invalidateAndNotify(ctx)
return nil
}
// MatchRule 匹配透传规则
// 返回第一个匹配的规则,如果没有匹配则返回 nil
func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, body []byte) *model.ErrorPassthroughRule {
rules := s.getCachedRules()
if len(rules) == 0 {
return nil
}
bodyStr := strings.ToLower(string(body))
for _, rule := range rules {
if !rule.Enabled {
continue
}
if !s.platformMatches(rule, platform) {
continue
}
if s.ruleMatches(rule, statusCode, bodyStr) {
return rule
}
}
return nil
}
// getCachedRules 获取缓存的规则列表(按优先级排序)
func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule {
s.localCacheMu.RLock()
rules := s.localCache
s.localCacheMu.RUnlock()
if rules != nil {
return rules
}
// 如果本地缓存为空,尝试刷新
ctx := context.Background()
if err := s.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err)
return nil
}
s.localCacheMu.RLock()
defer s.localCacheMu.RUnlock()
return s.localCache
}
// refreshLocalCache 刷新本地缓存
func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error {
// 先尝试从 Redis 缓存获取
if s.cache != nil {
if rules, ok := s.cache.Get(ctx); ok {
s.setLocalCache(rules)
return nil
}
}
// 从数据库加载repo.List 已按 priority 排序)
rules, err := s.repo.List(ctx)
if err != nil {
return err
}
// 更新 Redis 缓存
if s.cache != nil {
if err := s.cache.Set(ctx, rules); err != nil {
log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err)
}
}
// 更新本地缓存setLocalCache 内部会确保排序)
s.setLocalCache(rules)
return nil
}
// setLocalCache 设置本地缓存
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
// 按优先级排序
sorted := make([]*model.ErrorPassthroughRule, len(rules))
copy(sorted, rules)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Priority < sorted[j].Priority
})
s.localCacheMu.Lock()
s.localCache = sorted
s.localCacheMu.Unlock()
}
// invalidateAndNotify 使缓存失效并通知其他实例
func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
// 刷新本地缓存
if err := s.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err)
}
// 通知其他实例
if s.cache != nil {
if err := s.cache.NotifyUpdate(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err)
}
}
}
// platformMatches 检查平台是否匹配
func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool {
// 如果没有配置平台限制,则匹配所有平台
if len(rule.Platforms) == 0 {
return true
}
platform = strings.ToLower(platform)
for _, p := range rule.Platforms {
if strings.ToLower(p) == platform {
return true
}
}
return false
}
// ruleMatches 检查规则是否匹配
func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool {
hasErrorCodes := len(rule.ErrorCodes) > 0
hasKeywords := len(rule.Keywords) > 0
// 如果没有配置任何条件,不匹配
if !hasErrorCodes && !hasKeywords {
return false
}
codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode)
keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords)
if rule.MatchMode == model.MatchModeAll {
// "all" 模式:所有配置的条件都必须满足
return codeMatch && keywordMatch
}
// "any" 模式:任一条件满足即可
if hasErrorCodes && hasKeywords {
return codeMatch || keywordMatch
}
return codeMatch && keywordMatch
}
// containsInt 检查切片是否包含指定整数
func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool {
for _, v := range slice {
if v == val {
return true
}
}
return false
}
// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写)
func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool {
for _, kw := range keywords {
if strings.Contains(bodyLower, strings.ToLower(kw)) {
return true
}
}
return false
}

View File

@@ -0,0 +1,755 @@
//go:build unit
package service
import (
"context"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockErrorPassthroughRepo 用于测试的 mock repository
type mockErrorPassthroughRepo struct {
rules []*model.ErrorPassthroughRule
}
func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
return m.rules, nil
}
func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
for _, r := range m.rules {
if r.ID == id {
return r, nil
}
}
return nil, nil
}
func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
rule.ID = int64(len(m.rules) + 1)
m.rules = append(m.rules, rule)
return rule, nil
}
func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
for i, r := range m.rules {
if r.ID == rule.ID {
m.rules[i] = rule
return rule, nil
}
}
return rule, nil
}
func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error {
for i, r := range m.rules {
if r.ID == id {
m.rules = append(m.rules[:i], m.rules[i+1:]...)
return nil
}
}
return nil
}
// newTestService 创建测试用的服务实例
func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughService {
repo := &mockErrorPassthroughRepo{rules: rules}
svc := &ErrorPassthroughService{
repo: repo,
cache: nil, // 不使用缓存
}
// 直接设置本地缓存,避免调用 refreshLocalCache
svc.setLocalCache(rules)
return svc
}
// =============================================================================
// 测试 ruleMatches 核心匹配逻辑
// =============================================================================
func TestRuleMatches_NoConditions(t *testing.T) {
// 没有配置任何条件时,不应该匹配
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{},
Keywords: []string{},
MatchMode: model.MatchModeAny,
}
assert.False(t, svc.ruleMatches(rule, 422, "some error message"),
"没有配置条件时不应该匹配")
}
func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{},
MatchMode: model.MatchModeAny,
}
tests := []struct {
name string
statusCode int
body string
expected bool
}{
{"状态码匹配 422", 422, "any message", true},
{"状态码匹配 400", 400, "any message", true},
{"状态码不匹配 500", 500, "any message", false},
{"状态码不匹配 429", 429, "any message", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
assert.Equal(t, tt.expected, result)
})
}
}
func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{},
Keywords: []string{"context limit", "model not supported"},
MatchMode: model.MatchModeAny,
}
tests := []struct {
name string
statusCode int
body string
expected bool
}{
{"关键词匹配 context limit", 500, "error: context limit reached", true},
{"关键词匹配 model not supported", 400, "the model not supported here", true},
{"关键词不匹配", 422, "some other error", false},
// 注意ruleMatches 接收的 body 参数应该是已经转换为小写的
// 实际使用时MatchRule 会先将 body 转换为小写再传给 ruleMatches
{"关键词大小写 - 输入已小写", 500, "context limit exceeded", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 MatchRule 的行为:先转换为小写
bodyLower := strings.ToLower(tt.body)
result := svc.ruleMatches(rule, tt.statusCode, bodyLower)
assert.Equal(t, tt.expected, result)
})
}
}
func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
// any 模式:错误码 OR 关键词
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAny,
}
tests := []struct {
name string
statusCode int
body string
expected bool
reason string
}{
{
name: "状态码和关键词都匹配",
statusCode: 422,
body: "context limit reached",
expected: true,
reason: "both match",
},
{
name: "只有状态码匹配",
statusCode: 422,
body: "some other error",
expected: true,
reason: "code matches, keyword doesn't - OR mode should match",
},
{
name: "只有关键词匹配",
statusCode: 500,
body: "context limit exceeded",
expected: true,
reason: "keyword matches, code doesn't - OR mode should match",
},
{
name: "都不匹配",
statusCode: 500,
body: "some other error",
expected: false,
reason: "neither matches",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
assert.Equal(t, tt.expected, result, tt.reason)
})
}
}
func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
// all 模式:错误码 AND 关键词
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAll,
}
tests := []struct {
name string
statusCode int
body string
expected bool
reason string
}{
{
name: "状态码和关键词都匹配",
statusCode: 422,
body: "context limit reached",
expected: true,
reason: "both match - AND mode should match",
},
{
name: "只有状态码匹配",
statusCode: 422,
body: "some other error",
expected: false,
reason: "code matches but keyword doesn't - AND mode should NOT match",
},
{
name: "只有关键词匹配",
statusCode: 500,
body: "context limit exceeded",
expected: false,
reason: "keyword matches but code doesn't - AND mode should NOT match",
},
{
name: "都不匹配",
statusCode: 500,
body: "some other error",
expected: false,
reason: "neither matches",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
assert.Equal(t, tt.expected, result, tt.reason)
})
}
}
// =============================================================================
// 测试 platformMatches 平台匹配逻辑
// =============================================================================
func TestPlatformMatches(t *testing.T) {
svc := newTestService(nil)
tests := []struct {
name string
rulePlatforms []string
requestPlatform string
expected bool
}{
{
name: "空平台列表匹配所有",
rulePlatforms: []string{},
requestPlatform: "anthropic",
expected: true,
},
{
name: "nil平台列表匹配所有",
rulePlatforms: nil,
requestPlatform: "openai",
expected: true,
},
{
name: "精确匹配 anthropic",
rulePlatforms: []string{"anthropic", "openai"},
requestPlatform: "anthropic",
expected: true,
},
{
name: "精确匹配 openai",
rulePlatforms: []string{"anthropic", "openai"},
requestPlatform: "openai",
expected: true,
},
{
name: "不匹配 gemini",
rulePlatforms: []string{"anthropic", "openai"},
requestPlatform: "gemini",
expected: false,
},
{
name: "大小写不敏感",
rulePlatforms: []string{"Anthropic", "OpenAI"},
requestPlatform: "anthropic",
expected: true,
},
{
name: "匹配 antigravity",
rulePlatforms: []string{"antigravity"},
requestPlatform: "antigravity",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rule := &model.ErrorPassthroughRule{
Platforms: tt.rulePlatforms,
}
result := svc.platformMatches(rule, tt.requestPlatform)
assert.Equal(t, tt.expected, result)
})
}
}
// =============================================================================
// 测试 MatchRule 完整匹配流程
// =============================================================================
func TestMatchRule_Priority(t *testing.T) {
// 测试规则按优先级排序,优先级小的先匹配
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Low Priority",
Enabled: true,
Priority: 10,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
{
ID: 2,
Name: "High Priority",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(2), matched.ID, "应该匹配优先级更高(数值更小)的规则")
assert.Equal(t, "High Priority", matched.Name)
}
func TestMatchRule_DisabledRule(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Disabled Rule",
Enabled: false,
Priority: 1,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
{
ID: 2,
Name: "Enabled Rule",
Enabled: true,
Priority: 10,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(2), matched.ID, "应该跳过禁用的规则")
}
func TestMatchRule_PlatformFilter(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Anthropic Only",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
Platforms: []string{"anthropic"},
MatchMode: model.MatchModeAny,
},
{
ID: 2,
Name: "OpenAI Only",
Enabled: true,
Priority: 2,
ErrorCodes: []int{422},
Platforms: []string{"openai"},
MatchMode: model.MatchModeAny,
},
{
ID: 3,
Name: "All Platforms",
Enabled: true,
Priority: 3,
ErrorCodes: []int{422},
Platforms: []string{},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
t.Run("Anthropic 请求匹配 Anthropic 规则", func(t *testing.T) {
matched := svc.MatchRule("anthropic", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(1), matched.ID)
})
t.Run("OpenAI 请求匹配 OpenAI 规则", func(t *testing.T) {
matched := svc.MatchRule("openai", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(2), matched.ID)
})
t.Run("Gemini 请求匹配全平台规则", func(t *testing.T) {
matched := svc.MatchRule("gemini", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(3), matched.ID)
})
t.Run("Antigravity 请求匹配全平台规则", func(t *testing.T) {
matched := svc.MatchRule("antigravity", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(3), matched.ID)
})
}
func TestMatchRule_NoMatch(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Rule for 422",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 500, []byte("error"))
assert.Nil(t, matched, "不匹配任何规则时应返回 nil")
}
func TestMatchRule_EmptyRules(t *testing.T) {
svc := newTestService([]*model.ErrorPassthroughRule{})
matched := svc.MatchRule("anthropic", 422, []byte("error"))
assert.Nil(t, matched, "没有规则时应返回 nil")
}
func TestMatchRule_CaseInsensitiveKeyword(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Context Limit",
Enabled: true,
Priority: 1,
Keywords: []string{"Context Limit"},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
tests := []struct {
name string
body string
expected bool
}{
{"完全匹配", "Context Limit reached", true},
{"小写匹配", "context limit reached", true},
{"大写匹配", "CONTEXT LIMIT REACHED", true},
{"混合大小写", "ConTeXt LiMiT error", true},
{"不匹配", "some other error", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matched := svc.MatchRule("anthropic", 500, []byte(tt.body))
if tt.expected {
assert.NotNil(t, matched)
} else {
assert.Nil(t, matched)
}
})
}
}
// =============================================================================
// 测试真实场景
// =============================================================================
func TestMatchRule_RealWorldScenario_ContextLimitPassthrough(t *testing.T) {
// 场景:上游返回 422 + "context limit has been reached",需要透传给客户端
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Context Limit Passthrough",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAll, // 必须同时满足
Platforms: []string{"anthropic", "antigravity"},
PassthroughCode: true,
PassthroughBody: true,
},
}
svc := newTestService(rules)
// 测试 Anthropic 平台
t.Run("Anthropic 422 with context limit", func(t *testing.T) {
body := []byte(`{"type":"error","error":{"type":"invalid_request","message":"The context limit has been reached"}}`)
matched := svc.MatchRule("anthropic", 422, body)
require.NotNil(t, matched)
assert.True(t, matched.PassthroughCode)
assert.True(t, matched.PassthroughBody)
})
// 测试 Antigravity 平台
t.Run("Antigravity 422 with context limit", func(t *testing.T) {
body := []byte(`{"error":"context limit exceeded"}`)
matched := svc.MatchRule("antigravity", 422, body)
require.NotNil(t, matched)
})
// 测试 OpenAI 平台(不在规则的平台列表中)
t.Run("OpenAI should not match", func(t *testing.T) {
body := []byte(`{"error":"context limit exceeded"}`)
matched := svc.MatchRule("openai", 422, body)
assert.Nil(t, matched, "OpenAI 不在规则的平台列表中")
})
// 测试状态码不匹配
t.Run("Wrong status code", func(t *testing.T) {
body := []byte(`{"error":"context limit exceeded"}`)
matched := svc.MatchRule("anthropic", 400, body)
assert.Nil(t, matched, "状态码不匹配")
})
// 测试关键词不匹配
t.Run("Wrong keyword", func(t *testing.T) {
body := []byte(`{"error":"rate limit exceeded"}`)
matched := svc.MatchRule("anthropic", 422, body)
assert.Nil(t, matched, "关键词不匹配")
})
}
func TestMatchRule_RealWorldScenario_CustomErrorMessage(t *testing.T) {
// 场景:某些错误需要返回自定义消息,隐藏上游详细信息
customMsg := "Service temporarily unavailable, please try again later"
responseCode := 503
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Hide Internal Errors",
Enabled: true,
Priority: 1,
ErrorCodes: []int{500, 502, 503},
MatchMode: model.MatchModeAny,
PassthroughCode: false,
ResponseCode: &responseCode,
PassthroughBody: false,
CustomMessage: &customMsg,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 500, []byte("internal server error"))
require.NotNil(t, matched)
assert.False(t, matched.PassthroughCode)
assert.Equal(t, 503, *matched.ResponseCode)
assert.False(t, matched.PassthroughBody)
assert.Equal(t, customMsg, *matched.CustomMessage)
}
// =============================================================================
// 测试 model.Validate
// =============================================================================
func TestErrorPassthroughRule_Validate(t *testing.T) {
tests := []struct {
name string
rule *model.ErrorPassthroughRule
expectError bool
errorField string
}{
{
name: "有效规则 - 透传模式(含错误码)",
rule: &model.ErrorPassthroughRule{
Name: "Valid Rule",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: false,
},
{
name: "有效规则 - 透传模式(含关键词)",
rule: &model.ErrorPassthroughRule{
Name: "Valid Rule",
MatchMode: model.MatchModeAny,
Keywords: []string{"context limit"},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: false,
},
{
name: "有效规则 - 自定义响应",
rule: &model.ErrorPassthroughRule{
Name: "Valid Rule",
MatchMode: model.MatchModeAll,
ErrorCodes: []int{500},
Keywords: []string{"internal error"},
PassthroughCode: false,
ResponseCode: testIntPtr(503),
PassthroughBody: false,
CustomMessage: testStrPtr("Custom error"),
},
expectError: false,
},
{
name: "缺少名称",
rule: &model.ErrorPassthroughRule{
Name: "",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "name",
},
{
name: "无效的匹配模式",
rule: &model.ErrorPassthroughRule{
Name: "Invalid Mode",
MatchMode: "invalid",
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "match_mode",
},
{
name: "缺少匹配条件(错误码和关键词都为空)",
rule: &model.ErrorPassthroughRule{
Name: "No Conditions",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{},
Keywords: []string{},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "conditions",
},
{
name: "缺少匹配条件nil切片",
rule: &model.ErrorPassthroughRule{
Name: "Nil Conditions",
MatchMode: model.MatchModeAny,
ErrorCodes: nil,
Keywords: nil,
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "conditions",
},
{
name: "自定义状态码但未提供值",
rule: &model.ErrorPassthroughRule{
Name: "Missing Code",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: false,
ResponseCode: nil,
PassthroughBody: true,
},
expectError: true,
errorField: "response_code",
},
{
name: "自定义消息但未提供值",
rule: &model.ErrorPassthroughRule{
Name: "Missing Message",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: false,
CustomMessage: nil,
},
expectError: true,
errorField: "custom_message",
},
{
name: "自定义消息为空字符串",
rule: &model.ErrorPassthroughRule{
Name: "Empty Message",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: false,
CustomMessage: testStrPtr(""),
},
expectError: true,
errorField: "custom_message",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.rule.Validate()
if tt.expectError {
require.Error(t, err)
validationErr, ok := err.(*model.ValidationError)
require.True(t, ok, "应该返回 ValidationError")
assert.Equal(t, tt.errorField, validationErr.Field)
} else {
assert.NoError(t, err)
}
})
}
}
// Helper functions
func testIntPtr(i int) *int { return &i }
func testStrPtr(s string) *string { return &s }

View File

@@ -20,7 +20,6 @@ import (
"strings"
"sync/atomic"
"time"
"unicode"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
@@ -375,7 +374,8 @@ type ForwardResult struct {
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct {
StatusCode int
StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
}
func (e *UpstreamFailoverError) Error() string {
@@ -389,6 +389,7 @@ type GatewayService struct {
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
@@ -410,6 +411,7 @@ func NewGatewayService(
usageLogRepo UsageLogRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache GatewayCache,
cfg *config.Config,
schedulerSnapshot *SchedulerSnapshotService,
@@ -429,6 +431,7 @@ func NewGatewayService(
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
@@ -624,35 +627,6 @@ func stripToolPrefix(value string) string {
return toolPrefixRe.ReplaceAllString(value, "")
}
func toPascalCase(value string) string {
if value == "" {
return value
}
normalized := toolNameBoundaryRe.ReplaceAllString(value, " ")
tokens := make([]string, 0)
for _, token := range strings.Fields(normalized) {
expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2")
parts := strings.Fields(expanded)
if len(parts) > 0 {
tokens = append(tokens, parts...)
}
}
if len(tokens) == 0 {
return value
}
var builder strings.Builder
for _, token := range tokens {
lower := strings.ToLower(token)
if lower == "" {
continue
}
runes := []rune(lower)
runes[0] = unicode.ToUpper(runes[0])
_, _ = builder.WriteString(string(runes))
}
return builder.String()
}
func toSnakeCase(value string) string {
if value == "" {
return value
@@ -668,16 +642,15 @@ func normalizeToolNameForClaude(name string, cache map[string]string) string {
return name
}
stripped := stripToolPrefix(name)
// 只对已知的工具名进行映射,未知工具名保持原样
// 避免破坏 Anthropic 特殊工具(如 text_editor_20250728
mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)]
if !ok {
mapped = toPascalCase(stripped)
}
if mapped != "" && cache != nil && mapped != stripped {
cache[mapped] = stripped
}
if mapped == "" {
return stripped
}
if cache != nil && mapped != stripped {
cache[mapped] = stripped
}
return mapped
}
@@ -686,15 +659,18 @@ func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
return name
}
stripped := stripToolPrefix(name)
// 优先从请求时建立的映射中查找
if cache != nil {
if mapped, ok := cache[stripped]; ok {
return mapped
}
}
// 已知工具名的硬编码映射
if mapped, ok := openCodeToolOverrides[stripped]; ok {
return mapped
}
return toSnakeCase(stripped)
// 未知工具名保持原样,避免破坏 Anthropic 特殊工具
return stripped
}
func normalizeParamNameForOpenCode(name string, cache map[string]string) string {
@@ -3313,7 +3289,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
return s.handleRetryExhaustedError(ctx, resp, c, account)
}
@@ -3343,10 +3319,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
// 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
@@ -3390,7 +3364,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
log.Printf("Account %d: 400 error, attempting failover", account.ID)
}
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
}
return s.handleErrorResponse(ctx, resp, c, account)
@@ -3787,6 +3761,12 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
return false
}
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
func ExtractUpstreamErrorMessage(body []byte) string {
return extractUpstreamErrorMessage(body)
}
func extractUpstreamErrorMessage(body []byte) string {
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
@@ -3854,7 +3834,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
}
// 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端)
@@ -4641,10 +4621,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
account := input.Account
subscription := input.Subscription
// 获取费率倍数
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier
// 检查用户专属倍率
if s.userGroupRateRepo != nil {
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
multiplier = *userRate
}
}
}
var cost *CostBreakdown
@@ -4827,10 +4814,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
account := input.Account
subscription := input.Subscription
// 获取费率倍数
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier
// 检查用户专属倍率
if s.userGroupRateRepo != nil {
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
multiplier = *userRate
}
}
}
var cost *CostBreakdown

View File

@@ -864,7 +864,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
upstreamReqID := resp.Header.Get(requestIDHeader)
@@ -891,7 +891,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
@@ -1301,7 +1301,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
evBody := unwrapIfNeeded(isOAuth, respBody)
@@ -1325,7 +1325,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody}
}
respBody = unwrapIfNeeded(isOAuth, respBody)

View File

@@ -944,6 +944,32 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
}
// 关键逻辑:对齐 Gemini CLI 对“已注册用户”的处理方式。
// 当 LoadCodeAssist 返回了 currentTier / paidTier表示账号已注册但没有返回 cloudaicompanionProject 时:
// - 不要再调用 onboardUser通常不会再分配 project_id且可能触发 INVALID_ARGUMENT
// - 先尝试从 Cloud Resource Manager 获取可用项目;仍失败则提示用户手动填写 project_id
if loadResp != nil {
registeredTierID := strings.TrimSpace(loadResp.GetTier())
if registeredTierID != "" {
// 已注册但未返回 cloudaicompanionProject这在 Google One 用户中较常见:需要用户自行提供 project_id。
log.Printf("[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID)
// Try to get project from Cloud Resource Manager
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
log.Printf("[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback)
return strings.TrimSpace(fallback), tierID, nil
}
// No project found - user must provide project_id manually
log.Printf("[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually")
return "", tierID, fmt.Errorf("user is registered (tier: %s) but no project_id available. Please provide Project ID manually in the authorization form, or create a project at https://console.cloud.google.com", registeredTierID)
}
}
// 未检测到 currentTier/paidTier视为新用户继续调用 onboardUser
log.Printf("[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID)
req := &geminicli.OnboardUserRequest{
TierID: tierID,
Metadata: geminicli.LoadCodeAssistMetadata{

View File

@@ -846,10 +846,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
}
// Remove prompt_cache_retention (not supported by upstream OpenAI API)
if _, has := reqBody["prompt_cache_retention"]; has {
delete(reqBody, "prompt_cache_retention")
bodyModified = true
// Remove unsupported fields (not supported by upstream OpenAI API)
for _, unsupportedField := range []string{"prompt_cache_retention", "safety_identifier", "previous_response_id"} {
if _, has := reqBody[unsupportedField]; has {
delete(reqBody, unsupportedField)
bodyModified = true
}
}
}
@@ -938,7 +940,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
return s.handleErrorResponse(ctx, resp, c, account)
}
@@ -1129,7 +1131,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
Detail: upstreamDetail,
})
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
}
// Return appropriate error response

View File

@@ -0,0 +1,73 @@
package service
import (
"context"
"errors"
"time"
)
// ErrRefreshTokenNotFound is returned when a refresh token is not found in cache.
// This is used to abstract away the underlying cache implementation (e.g., redis.Nil).
var ErrRefreshTokenNotFound = errors.New("refresh token not found")
// RefreshTokenData 存储在Redis中的Refresh Token数据
type RefreshTokenData struct {
UserID int64 `json:"user_id"`
TokenVersion int64 `json:"token_version"` // 用于检测密码更改后的Token失效
FamilyID string `json:"family_id"` // Token家族ID用于防重放攻击
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
}
// RefreshTokenCache 管理Refresh Token的Redis缓存
// 用于JWT Token刷新机制支持Token轮转和防重放攻击
//
// Key 格式:
// - refresh_token:{token_hash} -> RefreshTokenData (JSON)
// - user_refresh_tokens:{user_id} -> Set<token_hash>
// - token_family:{family_id} -> Set<token_hash>
type RefreshTokenCache interface {
// StoreRefreshToken 存储Refresh Token
// tokenHash: Token的SHA256哈希值不存储原始Token
// data: Token关联的数据
// ttl: Token过期时间
StoreRefreshToken(ctx context.Context, tokenHash string, data *RefreshTokenData, ttl time.Duration) error
// GetRefreshToken 获取Refresh Token数据
// 返回 (data, nil) 如果Token存在
// 返回 (nil, ErrRefreshTokenNotFound) 如果Token不存在
// 返回 (nil, err) 如果发生其他错误
GetRefreshToken(ctx context.Context, tokenHash string) (*RefreshTokenData, error)
// DeleteRefreshToken 删除单个Refresh Token
// 用于Token轮转时使旧Token失效
DeleteRefreshToken(ctx context.Context, tokenHash string) error
// DeleteUserRefreshTokens 删除用户的所有Refresh Token
// 用于密码更改或用户主动登出所有设备
DeleteUserRefreshTokens(ctx context.Context, userID int64) error
// DeleteTokenFamily 删除整个Token家族
// 用于检测到Token重放攻击时撤销整个会话链
DeleteTokenFamily(ctx context.Context, familyID string) error
// AddToUserTokenSet 将Token添加到用户的Token集合
// 用于跟踪用户的所有活跃Refresh Token
AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error
// AddToFamilyTokenSet 将Token添加到家族Token集合
// 用于跟踪同一登录会话的所有Token
AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error
// GetUserTokenHashes 获取用户的所有Token哈希
// 用于批量删除用户Token
GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error)
// GetFamilyTokenHashes 获取家族的所有Token哈希
// 用于批量删除家族Token
GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error)
// IsTokenInFamily 检查Token是否属于指定家族
// 用于验证Token家族关系
IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error)
}

View File

@@ -288,6 +288,15 @@ func (s *UsageService) GetUserDashboardStats(ctx context.Context, userID int64)
return stats, nil
}
// GetAPIKeyDashboardStats returns dashboard summary stats filtered by API Key.
func (s *UsageService) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
stats, err := s.usageRepo.GetAPIKeyDashboardStats(ctx, apiKeyID)
if err != nil {
return nil, fmt.Errorf("get api key dashboard stats: %w", err)
}
return stats, nil
}
// GetUserUsageTrendByUserID returns per-user usage trend.
func (s *UsageService) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUserUsageTrendByUserID(ctx, userID, startTime, endTime, granularity)

View File

@@ -21,6 +21,10 @@ type User struct {
CreatedAt time.Time
UpdatedAt time.Time
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64
// TOTP 双因素认证字段
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
TotpEnabled bool // 是否启用 TOTP
@@ -40,18 +44,20 @@ func (u *User) IsActive() bool {
// CanBindGroup checks whether a user can bind to a given group.
// For standard groups:
// - If AllowedGroups is non-empty, only allow binding to IDs in that list.
// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group.
// - Public groups (non-exclusive): all users can bind
// - Exclusive groups: only users with the group in AllowedGroups can bind
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
if len(u.AllowedGroups) > 0 {
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
// 公开分组(非专属):所有用户都可以绑定
if !isExclusive {
return true
}
return !isExclusive
// 专属分组:需要在 AllowedGroups 中
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
}
func (u *User) SetPassword(password string) error {

View File

@@ -0,0 +1,25 @@
package service
import "context"
// UserGroupRateRepository 用户专属分组倍率仓储接口
// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
type UserGroupRateRepository interface {
// GetByUserID 获取用户的所有专属分组倍率
// 返回 map[groupID]rateMultiplier
GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error)
// GetByUserAndGroup 获取用户在特定分组的专属倍率
// 如果未设置专属倍率,返回 nil
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
// SyncUserGroupRates 同步用户的分组专属倍率
// rates: map[groupID]*rateMultipliernil 表示删除该分组的专属倍率
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
// DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用)
DeleteByGroupID(ctx context.Context, groupID int64) error
// DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用)
DeleteByUserID(ctx context.Context, userID int64) error
}

View File

@@ -294,4 +294,5 @@ var ProviderSet = wire.NewSet(
NewUserAttributeService,
NewUsageCache,
NewTotpService,
NewErrorPassthroughService,
)