mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-05 16:00:21 +08:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fb883f0092 | ||
|
|
64b8219245 | ||
|
|
2004230b66 | ||
|
|
0026e871f0 | ||
|
|
19d0ee130d | ||
|
|
942c3e1529 | ||
|
|
caa8c47b68 |
@@ -216,7 +216,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxAccountSwitches = 3
|
const maxAccountSwitches = 10
|
||||||
switchCount := 0
|
switchCount := 0
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
lastFailoverStatus := 0
|
lastFailoverStatus := 0
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ func InitDB(cfg *config.Config) (*gorm.DB, error) {
|
|||||||
|
|
||||||
// 自动迁移(始终执行,确保数据库结构与代码同步)
|
// 自动迁移(始终执行,确保数据库结构与代码同步)
|
||||||
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
|
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
|
||||||
if err := repository.AutoMigrate(db); err != nil {
|
if err := repository.AutoMigrate(db, cfg.RunMode); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ var maxExpiresAt = time.Date(2099, 12, 31, 23, 59, 59, 0, time.UTC)
|
|||||||
|
|
||||||
// AutoMigrate runs schema migrations for all repository persistence models.
|
// AutoMigrate runs schema migrations for all repository persistence models.
|
||||||
// Persistence models are defined within individual `*_repo.go` files.
|
// Persistence models are defined within individual `*_repo.go` files.
|
||||||
func AutoMigrate(db *gorm.DB) error {
|
// runMode: "standard" or "simple" - determines whether to create default groups
|
||||||
|
func AutoMigrate(db *gorm.DB, runMode string) error {
|
||||||
err := db.AutoMigrate(
|
err := db.AutoMigrate(
|
||||||
&userModel{},
|
&userModel{},
|
||||||
&apiKeyModel{},
|
&apiKeyModel{},
|
||||||
@@ -31,7 +32,7 @@ func AutoMigrate(db *gorm.DB) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 创建默认分组(简易模式支持)
|
// 创建默认分组(简易模式支持)
|
||||||
if err := ensureDefaultGroups(db); err != nil {
|
if err := ensureDefaultGroups(db, runMode); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,7 +56,13 @@ func fixInvalidExpiresAt(db *gorm.DB) error {
|
|||||||
|
|
||||||
// ensureDefaultGroups 确保默认分组存在(简易模式支持)
|
// ensureDefaultGroups 确保默认分组存在(简易模式支持)
|
||||||
// 为每个平台创建一个默认分组,配置最大权限以确保简易模式下不受限制
|
// 为每个平台创建一个默认分组,配置最大权限以确保简易模式下不受限制
|
||||||
func ensureDefaultGroups(db *gorm.DB) error {
|
// runMode: "standard" 时跳过创建, "simple" 时创建/恢复默认分组
|
||||||
|
func ensureDefaultGroups(db *gorm.DB, runMode string) error {
|
||||||
|
// 标准版不创建默认分组
|
||||||
|
if runMode == "standard" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
defaultGroups := []struct {
|
defaultGroups := []struct {
|
||||||
name string
|
name string
|
||||||
platform string
|
platform string
|
||||||
@@ -79,12 +86,34 @@ func ensureDefaultGroups(db *gorm.DB) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, dg := range defaultGroups {
|
for _, dg := range defaultGroups {
|
||||||
var count int64
|
// 步骤1: 检查是否有软删除的记录
|
||||||
if err := db.Model(&groupModel{}).Where("name = ?", dg.name).Count(&count).Error; err != nil {
|
var softDeletedCount int64
|
||||||
|
if err := db.Unscoped().Model(&groupModel{}).
|
||||||
|
Where("name = ? AND deleted_at IS NOT NULL", dg.name).
|
||||||
|
Count(&softDeletedCount).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if count == 0 {
|
if softDeletedCount > 0 {
|
||||||
|
// 恢复软删除的记录
|
||||||
|
if err := db.Unscoped().Model(&groupModel{}).
|
||||||
|
Where("name = ?", dg.name).
|
||||||
|
Update("deleted_at", nil).Error; err != nil {
|
||||||
|
log.Printf("[AutoMigrate] Failed to restore default group %s: %v", dg.name, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Printf("[AutoMigrate] Restored default group: %s (platform: %s)", dg.name, dg.platform)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 步骤2: 检查是否有活跃记录
|
||||||
|
var activeCount int64
|
||||||
|
if err := db.Model(&groupModel{}).Where("name = ?", dg.name).Count(&activeCount).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if activeCount == 0 {
|
||||||
|
// 创建新分组
|
||||||
group := &groupModel{
|
group := &groupModel{
|
||||||
Name: dg.name,
|
Name: dg.name,
|
||||||
Description: dg.description,
|
Description: dg.description,
|
||||||
|
|||||||
@@ -93,7 +93,8 @@ func TestMain(m *testing.M) {
|
|||||||
log.Printf("failed to open gorm db: %v", err)
|
log.Printf("failed to open gorm db: %v", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
if err := AutoMigrate(integrationDB); err != nil {
|
// 使用 simple 模式以便测试默认分组功能
|
||||||
|
if err := AutoMigrate(integrationDB, "simple"); err != nil {
|
||||||
log.Printf("failed to automigrate db: %v", err)
|
log.Printf("failed to automigrate db: %v", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -198,6 +198,7 @@ type userModel struct {
|
|||||||
Concurrency int `gorm:"default:5;not null"`
|
Concurrency int `gorm:"default:5;not null"`
|
||||||
Status string `gorm:"size:20;default:active;not null"`
|
Status string `gorm:"size:20;default:active;not null"`
|
||||||
AllowedGroups pq.Int64Array `gorm:"type:bigint[]"`
|
AllowedGroups pq.Int64Array `gorm:"type:bigint[]"`
|
||||||
|
TokenVersion int64 `gorm:"default:0;not null"` // Incremented on password change
|
||||||
CreatedAt time.Time `gorm:"not null"`
|
CreatedAt time.Time `gorm:"not null"`
|
||||||
UpdatedAt time.Time `gorm:"not null"`
|
UpdatedAt time.Time `gorm:"not null"`
|
||||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||||
@@ -221,6 +222,7 @@ func userModelToService(m *userModel) *service.User {
|
|||||||
Concurrency: m.Concurrency,
|
Concurrency: m.Concurrency,
|
||||||
Status: m.Status,
|
Status: m.Status,
|
||||||
AllowedGroups: []int64(m.AllowedGroups),
|
AllowedGroups: []int64(m.AllowedGroups),
|
||||||
|
TokenVersion: m.TokenVersion,
|
||||||
CreatedAt: m.CreatedAt,
|
CreatedAt: m.CreatedAt,
|
||||||
UpdatedAt: m.UpdatedAt,
|
UpdatedAt: m.UpdatedAt,
|
||||||
}
|
}
|
||||||
@@ -242,6 +244,7 @@ func userModelFromService(u *service.User) *userModel {
|
|||||||
Concurrency: u.Concurrency,
|
Concurrency: u.Concurrency,
|
||||||
Status: u.Status,
|
Status: u.Status,
|
||||||
AllowedGroups: pq.Int64Array(u.AllowedGroups),
|
AllowedGroups: pq.Int64Array(u.AllowedGroups),
|
||||||
|
TokenVersion: u.TokenVersion,
|
||||||
CreatedAt: u.CreatedAt,
|
CreatedAt: u.CreatedAt,
|
||||||
UpdatedAt: u.UpdatedAt,
|
UpdatedAt: u.UpdatedAt,
|
||||||
}
|
}
|
||||||
@@ -252,6 +255,7 @@ func applyUserModelToService(dst *service.User, src *userModel) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
dst.ID = src.ID
|
dst.ID = src.ID
|
||||||
|
dst.TokenVersion = src.TokenVersion
|
||||||
dst.CreatedAt = src.CreatedAt
|
dst.CreatedAt = src.CreatedAt
|
||||||
dst.UpdatedAt = src.UpdatedAt
|
dst.UpdatedAt = src.UpdatedAt
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,6 +61,13 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Security: Validate TokenVersion to ensure token hasn't been invalidated
|
||||||
|
// This check ensures tokens issued before a password change are rejected
|
||||||
|
if claims.TokenVersion != user.TokenVersion {
|
||||||
|
AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.Set(string(ContextKeyUser), AuthSubject{
|
c.Set(string(ContextKeyUser), AuthSubject{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
Concurrency: user.Concurrency,
|
Concurrency: user.Concurrency,
|
||||||
|
|||||||
@@ -51,23 +51,27 @@ var antigravityModelMapping = map[string]string{
|
|||||||
"claude-haiku-4-5": "gemini-3-flash",
|
"claude-haiku-4-5": "gemini-3-flash",
|
||||||
"claude-3-haiku-20240307": "gemini-3-flash",
|
"claude-3-haiku-20240307": "gemini-3-flash",
|
||||||
"claude-haiku-4-5-20251001": "gemini-3-flash",
|
"claude-haiku-4-5-20251001": "gemini-3-flash",
|
||||||
|
// 生图模型:官方名 → Antigravity 内部名
|
||||||
|
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
||||||
}
|
}
|
||||||
|
|
||||||
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
|
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
|
||||||
type AntigravityGatewayService struct {
|
type AntigravityGatewayService struct {
|
||||||
|
accountRepo AccountRepository
|
||||||
tokenProvider *AntigravityTokenProvider
|
tokenProvider *AntigravityTokenProvider
|
||||||
rateLimitService *RateLimitService
|
rateLimitService *RateLimitService
|
||||||
httpUpstream HTTPUpstream
|
httpUpstream HTTPUpstream
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAntigravityGatewayService(
|
func NewAntigravityGatewayService(
|
||||||
_ AccountRepository,
|
accountRepo AccountRepository,
|
||||||
_ GatewayCache,
|
_ GatewayCache,
|
||||||
tokenProvider *AntigravityTokenProvider,
|
tokenProvider *AntigravityTokenProvider,
|
||||||
rateLimitService *RateLimitService,
|
rateLimitService *RateLimitService,
|
||||||
httpUpstream HTTPUpstream,
|
httpUpstream HTTPUpstream,
|
||||||
) *AntigravityGatewayService {
|
) *AntigravityGatewayService {
|
||||||
return &AntigravityGatewayService{
|
return &AntigravityGatewayService{
|
||||||
|
accountRepo: accountRepo,
|
||||||
tokenProvider: tokenProvider,
|
tokenProvider: tokenProvider,
|
||||||
rateLimitService: rateLimitService,
|
rateLimitService: rateLimitService,
|
||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
@@ -402,14 +406,15 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
|||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode == 429 {
|
|
||||||
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
|
||||||
}
|
|
||||||
if attempt < antigravityMaxRetries {
|
if attempt < antigravityMaxRetries {
|
||||||
log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries)
|
log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||||
sleepAntigravityBackoff(attempt)
|
sleepAntigravityBackoff(attempt)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// 所有重试都失败,标记限流状态
|
||||||
|
if resp.StatusCode == 429 {
|
||||||
|
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
|
}
|
||||||
if action == "countTokens" {
|
if action == "countTokens" {
|
||||||
estimated := estimateGeminiCountTokens(body)
|
estimated := estimateGeminiCountTokens(body)
|
||||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||||
@@ -526,6 +531,23 @@ func sleepAntigravityBackoff(attempt int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
|
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
|
||||||
|
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
|
||||||
|
if statusCode == 429 {
|
||||||
|
resetAt := ParseGeminiRateLimitResetTime(body)
|
||||||
|
if resetAt == nil {
|
||||||
|
// 解析失败:Gemini 有重试时间用 5 分钟,Claude 没有用 1 分钟
|
||||||
|
defaultDur := 1 * time.Minute
|
||||||
|
if bytes.Contains(body, []byte("Please retry in")) || bytes.Contains(body, []byte("retryDelay")) {
|
||||||
|
defaultDur = 5 * time.Minute
|
||||||
|
}
|
||||||
|
ra := time.Now().Add(defaultDur)
|
||||||
|
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 其他错误码继续使用 rateLimitService
|
||||||
if s.rateLimitService == nil {
|
if s.rateLimitService == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ var (
|
|||||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||||
|
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||||
@@ -27,9 +28,10 @@ var (
|
|||||||
|
|
||||||
// JWTClaims JWT载荷数据
|
// JWTClaims JWT载荷数据
|
||||||
type JWTClaims struct {
|
type JWTClaims struct {
|
||||||
UserID int64 `json:"user_id"`
|
UserID int64 `json:"user_id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
|
TokenVersion int64 `json:"token_version"` // Used to invalidate tokens on password change
|
||||||
jwt.RegisteredClaims
|
jwt.RegisteredClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -311,9 +313,10 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
|
|||||||
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
|
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
|
||||||
|
|
||||||
claims := &JWTClaims{
|
claims := &JWTClaims{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
Email: user.Email,
|
Email: user.Email,
|
||||||
Role: user.Role,
|
Role: user.Role,
|
||||||
|
TokenVersion: user.TokenVersion,
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||||
IssuedAt: jwt.NewNumericDate(now),
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
@@ -368,6 +371,12 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
|
|||||||
return "", ErrUserNotActive
|
return "", ErrUserNotActive
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Security: Check TokenVersion to prevent refreshing revoked tokens
|
||||||
|
// This ensures tokens issued before a password change cannot be refreshed
|
||||||
|
if claims.TokenVersion != user.TokenVersion {
|
||||||
|
return "", ErrTokenRevoked
|
||||||
|
}
|
||||||
|
|
||||||
// 生成新token
|
// 生成新token
|
||||||
return s.GenerateToken(user)
|
return s.GenerateToken(user)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -695,6 +695,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
if req.Stream {
|
if req.Stream {
|
||||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, req.Model)
|
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, req.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err.Error() == "have error in stream" {
|
||||||
|
return nil, &UpstreamFailoverError{
|
||||||
|
StatusCode: 403,
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
usage = streamResult.usage
|
usage = streamResult.usage
|
||||||
@@ -969,6 +974,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
|
if line == "event: error" {
|
||||||
|
return nil, errors.New("have error in stream")
|
||||||
|
}
|
||||||
|
|
||||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||||
if sseDataRe.MatchString(line) {
|
if sseDataRe.MatchString(line) {
|
||||||
|
|||||||
@@ -1883,7 +1883,7 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
|
|||||||
if statusCode != 429 {
|
if statusCode != 429 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resetAt := parseGeminiRateLimitResetTime(body)
|
resetAt := ParseGeminiRateLimitResetTime(body)
|
||||||
if resetAt == nil {
|
if resetAt == nil {
|
||||||
ra := time.Now().Add(5 * time.Minute)
|
ra := time.Now().Add(5 * time.Minute)
|
||||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
|
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
|
||||||
@@ -1892,7 +1892,8 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
|
|||||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0))
|
_ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseGeminiRateLimitResetTime(body []byte) *int64 {
|
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||||||
|
func ParseGeminiRateLimitResetTime(body []byte) *int64 {
|
||||||
// Try to parse metadata.quotaResetDelay like "12.345s"
|
// Try to parse metadata.quotaResetDelay like "12.345s"
|
||||||
var parsed map[string]any
|
var parsed map[string]any
|
||||||
if err := json.Unmarshal(body, &parsed); err == nil {
|
if err := json.Unmarshal(body, &parsed); err == nil {
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type User struct {
|
|||||||
Concurrency int
|
Concurrency int
|
||||||
Status string
|
Status string
|
||||||
AllowedGroups []int64
|
AllowedGroups []int64
|
||||||
|
TokenVersion int64 // Incremented on password change to invalidate existing tokens
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
|
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ChangePassword 修改密码
|
// ChangePassword 修改密码
|
||||||
|
// Security: Increments TokenVersion to invalidate all existing JWT tokens
|
||||||
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
|
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
|
||||||
user, err := s.userRepo.GetByID(ctx, userID)
|
user, err := s.userRepo.GetByID(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -131,6 +132,10 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan
|
|||||||
return fmt.Errorf("set password: %w", err)
|
return fmt.Errorf("set password: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Increment TokenVersion to invalidate all existing tokens
|
||||||
|
// This ensures that any tokens issued before the password change become invalid
|
||||||
|
user.TokenVersion++
|
||||||
|
|
||||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
return fmt.Errorf("update user: %w", err)
|
return fmt.Errorf("update user: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -271,7 +271,9 @@ func initializeDatabase(cfg *SetupConfig) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return repository.AutoMigrate(db)
|
// setup 阶段使用 standard 模式(不创建默认分组)
|
||||||
|
// 默认分组将在正常启动时根据实际 runMode 决定是否创建
|
||||||
|
return repository.AutoMigrate(db, "standard")
|
||||||
}
|
}
|
||||||
|
|
||||||
func createAdminUser(cfg *SetupConfig) error {
|
func createAdminUser(cfg *SetupConfig) error {
|
||||||
|
|||||||
@@ -335,12 +335,59 @@
|
|||||||
>
|
>
|
||||||
<div>
|
<div>
|
||||||
<label class="input-label">{{ t('admin.subscriptions.form.user') }}</label>
|
<label class="input-label">{{ t('admin.subscriptions.form.user') }}</label>
|
||||||
<Select
|
<div class="relative">
|
||||||
v-model="assignForm.user_id"
|
<input
|
||||||
:options="userOptions"
|
v-model="userSearchKeyword"
|
||||||
:placeholder="t('admin.subscriptions.selectUser')"
|
type="text"
|
||||||
searchable
|
class="input pr-8"
|
||||||
/>
|
:placeholder="t('admin.usage.searchUserPlaceholder')"
|
||||||
|
@input="debounceSearchUsers"
|
||||||
|
@focus="showUserDropdown = true"
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
v-if="selectedUser"
|
||||||
|
@click="clearUserSelection"
|
||||||
|
type="button"
|
||||||
|
class="absolute right-2 top-1/2 -translate-y-1/2 text-gray-400 hover:text-gray-600 dark:hover:text-gray-300"
|
||||||
|
>
|
||||||
|
<svg class="h-4 w-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
stroke-width="2"
|
||||||
|
d="M6 18L18 6M6 6l12 12"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
<!-- User Dropdown -->
|
||||||
|
<div
|
||||||
|
v-if="showUserDropdown && (userSearchResults.length > 0 || userSearchKeyword)"
|
||||||
|
class="absolute z-50 mt-1 max-h-60 w-full overflow-auto rounded-lg border border-gray-200 bg-white shadow-lg dark:border-gray-700 dark:bg-gray-800"
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
v-if="userSearchLoading"
|
||||||
|
class="px-4 py-3 text-sm text-gray-500 dark:text-gray-400"
|
||||||
|
>
|
||||||
|
{{ t('common.loading') }}
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
v-else-if="userSearchResults.length === 0 && userSearchKeyword"
|
||||||
|
class="px-4 py-3 text-sm text-gray-500 dark:text-gray-400"
|
||||||
|
>
|
||||||
|
{{ t('common.noOptionsFound') }}
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
v-for="user in userSearchResults"
|
||||||
|
:key="user.id"
|
||||||
|
type="button"
|
||||||
|
@click="selectUser(user)"
|
||||||
|
class="w-full px-4 py-2 text-left text-sm hover:bg-gray-100 dark:hover:bg-gray-700"
|
||||||
|
>
|
||||||
|
<span class="font-medium text-gray-900 dark:text-white">{{ user.email }}</span>
|
||||||
|
<span class="ml-2 text-gray-500 dark:text-gray-400">#{{ user.id }}</span>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<label class="input-label">{{ t('admin.subscriptions.form.group') }}</label>
|
<label class="input-label">{{ t('admin.subscriptions.form.group') }}</label>
|
||||||
@@ -462,11 +509,12 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, reactive, computed, onMounted } from 'vue'
|
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
import type { UserSubscription, Group, User } from '@/types'
|
import type { UserSubscription, Group } from '@/types'
|
||||||
|
import type { SimpleUser } from '@/api/admin/usage'
|
||||||
import type { Column } from '@/components/common/types'
|
import type { Column } from '@/components/common/types'
|
||||||
import { formatDateOnly } from '@/utils/format'
|
import { formatDateOnly } from '@/utils/format'
|
||||||
import AppLayout from '@/components/layout/AppLayout.vue'
|
import AppLayout from '@/components/layout/AppLayout.vue'
|
||||||
@@ -501,9 +549,17 @@ const statusOptions = computed(() => [
|
|||||||
|
|
||||||
const subscriptions = ref<UserSubscription[]>([])
|
const subscriptions = ref<UserSubscription[]>([])
|
||||||
const groups = ref<Group[]>([])
|
const groups = ref<Group[]>([])
|
||||||
const users = ref<User[]>([])
|
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
let abortController: AbortController | null = null
|
let abortController: AbortController | null = null
|
||||||
|
|
||||||
|
// User search state
|
||||||
|
const userSearchKeyword = ref('')
|
||||||
|
const userSearchResults = ref<SimpleUser[]>([])
|
||||||
|
const userSearchLoading = ref(false)
|
||||||
|
const showUserDropdown = ref(false)
|
||||||
|
const selectedUser = ref<SimpleUser | null>(null)
|
||||||
|
let userSearchTimeout: ReturnType<typeof setTimeout> | null = null
|
||||||
|
|
||||||
const filters = reactive({
|
const filters = reactive({
|
||||||
status: '',
|
status: '',
|
||||||
group_id: ''
|
group_id: ''
|
||||||
@@ -545,9 +601,6 @@ const subscriptionGroupOptions = computed(() =>
|
|||||||
.map((g) => ({ value: g.id, label: g.name }))
|
.map((g) => ({ value: g.id, label: g.name }))
|
||||||
)
|
)
|
||||||
|
|
||||||
// User options for assign
|
|
||||||
const userOptions = computed(() => users.value.map((u) => ({ value: u.id, label: u.email })))
|
|
||||||
|
|
||||||
const loadSubscriptions = async () => {
|
const loadSubscriptions = async () => {
|
||||||
if (abortController) {
|
if (abortController) {
|
||||||
abortController.abort()
|
abortController.abort()
|
||||||
@@ -590,13 +643,51 @@ const loadGroups = async () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const loadUsers = async () => {
|
// User search with debounce
|
||||||
try {
|
const debounceSearchUsers = () => {
|
||||||
const response = await adminAPI.users.list(1, 1000)
|
if (userSearchTimeout) {
|
||||||
users.value = response.items
|
clearTimeout(userSearchTimeout)
|
||||||
} catch (error) {
|
|
||||||
console.error('Error loading users:', error)
|
|
||||||
}
|
}
|
||||||
|
userSearchTimeout = setTimeout(searchUsers, 300)
|
||||||
|
}
|
||||||
|
|
||||||
|
const searchUsers = async () => {
|
||||||
|
const keyword = userSearchKeyword.value.trim()
|
||||||
|
|
||||||
|
// Clear selection if user modified the search keyword
|
||||||
|
if (selectedUser.value && keyword !== selectedUser.value.email) {
|
||||||
|
selectedUser.value = null
|
||||||
|
assignForm.user_id = null
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!keyword) {
|
||||||
|
userSearchResults.value = []
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userSearchLoading.value = true
|
||||||
|
try {
|
||||||
|
userSearchResults.value = await adminAPI.usage.searchUsers(keyword)
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to search users:', error)
|
||||||
|
userSearchResults.value = []
|
||||||
|
} finally {
|
||||||
|
userSearchLoading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const selectUser = (user: SimpleUser) => {
|
||||||
|
selectedUser.value = user
|
||||||
|
userSearchKeyword.value = user.email
|
||||||
|
showUserDropdown.value = false
|
||||||
|
assignForm.user_id = user.id
|
||||||
|
}
|
||||||
|
|
||||||
|
const clearUserSelection = () => {
|
||||||
|
selectedUser.value = null
|
||||||
|
userSearchKeyword.value = ''
|
||||||
|
userSearchResults.value = []
|
||||||
|
assignForm.user_id = null
|
||||||
}
|
}
|
||||||
|
|
||||||
const handlePageChange = (page: number) => {
|
const handlePageChange = (page: number) => {
|
||||||
@@ -615,6 +706,11 @@ const closeAssignModal = () => {
|
|||||||
assignForm.user_id = null
|
assignForm.user_id = null
|
||||||
assignForm.group_id = null
|
assignForm.group_id = null
|
||||||
assignForm.validity_days = 30
|
assignForm.validity_days = 30
|
||||||
|
// Clear user search state
|
||||||
|
selectedUser.value = null
|
||||||
|
userSearchKeyword.value = ''
|
||||||
|
userSearchResults.value = []
|
||||||
|
showUserDropdown.value = false
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleAssignSubscription = async () => {
|
const handleAssignSubscription = async () => {
|
||||||
@@ -754,10 +850,25 @@ const formatResetTime = (windowStart: string, period: 'daily' | 'weekly' | 'mont
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle click outside to close user dropdown
|
||||||
|
const handleClickOutside = (event: MouseEvent) => {
|
||||||
|
const target = event.target as HTMLElement
|
||||||
|
if (!target.closest('.relative')) {
|
||||||
|
showUserDropdown.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
loadSubscriptions()
|
loadSubscriptions()
|
||||||
loadGroups()
|
loadGroups()
|
||||||
loadUsers()
|
document.addEventListener('click', handleClickOutside)
|
||||||
|
})
|
||||||
|
|
||||||
|
onUnmounted(() => {
|
||||||
|
document.removeEventListener('click', handleClickOutside)
|
||||||
|
if (userSearchTimeout) {
|
||||||
|
clearTimeout(userSearchTimeout)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user