Files
sub2api/backend/internal/service/subscription_service.go

1045 lines
32 KiB
Go
Raw Permalink Normal View History

2025-12-18 13:50:39 +08:00
package service
import (
"context"
"fmt"
"log"
"math/rand/v2"
"strconv"
"strings"
2025-12-18 13:50:39 +08:00
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
2025-12-24 21:07:21 +08:00
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/dgraph-io/ristretto"
"golang.org/x/sync/singleflight"
2025-12-18 13:50:39 +08:00
)
// MaxExpiresAt is the maximum allowed expiration date (year 2099)
// This prevents time.Time JSON serialization errors (RFC 3339 requires year <= 9999)
var MaxExpiresAt = time.Date(2099, 12, 31, 23, 59, 59, 0, time.UTC)
// MaxValidityDays is the maximum allowed validity days for subscriptions (100 years)
const MaxValidityDays = 36500
2025-12-18 13:50:39 +08:00
var (
ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found")
ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired")
ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended")
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics")
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
2026-03-10 23:59:39 +08:00
ErrInvalidInput = infraerrors.BadRequest("INVALID_INPUT", "at least one of resetDaily or resetWeekly must be true")
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil")
ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)")
2025-12-18 13:50:39 +08:00
)
// SubscriptionService 订阅服务
type SubscriptionService struct {
2025-12-25 17:15:01 +08:00
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
2025-12-18 13:50:39 +08:00
billingCacheService *BillingCacheService
entClient *dbent.Client
// L1 缓存:加速中间件热路径的订阅查询
subCacheL1 *ristretto.Cache
subCacheGroup singleflight.Group
subCacheTTL time.Duration
subCacheJitter int // 抖动百分比
maintenanceQueue *SubscriptionMaintenanceQueue
2025-12-18 13:50:39 +08:00
}
// NewSubscriptionService 创建订阅服务
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, entClient *dbent.Client, cfg *config.Config) *SubscriptionService {
svc := &SubscriptionService{
groupRepo: groupRepo,
userSubRepo: userSubRepo,
billingCacheService: billingCacheService,
entClient: entClient,
}
svc.initSubCache(cfg)
svc.initMaintenanceQueue(cfg)
return svc
}
func (s *SubscriptionService) initMaintenanceQueue(cfg *config.Config) {
if cfg == nil {
return
}
mc := cfg.SubscriptionMaintenance
if mc.WorkerCount <= 0 || mc.QueueSize <= 0 {
return
}
s.maintenanceQueue = NewSubscriptionMaintenanceQueue(mc.WorkerCount, mc.QueueSize)
}
// Stop stops the maintenance worker pool.
func (s *SubscriptionService) Stop() {
if s == nil {
return
}
if s.maintenanceQueue != nil {
s.maintenanceQueue.Stop()
}
}
// initSubCache 初始化订阅 L1 缓存
func (s *SubscriptionService) initSubCache(cfg *config.Config) {
if cfg == nil {
return
}
sc := cfg.SubscriptionCache
if sc.L1Size <= 0 || sc.L1TTLSeconds <= 0 {
return
}
cache, err := ristretto.NewCache(&ristretto.Config{
NumCounters: int64(sc.L1Size) * 10,
MaxCost: int64(sc.L1Size),
BufferItems: 64,
})
if err != nil {
log.Printf("Warning: failed to init subscription L1 cache: %v", err)
return
}
s.subCacheL1 = cache
s.subCacheTTL = time.Duration(sc.L1TTLSeconds) * time.Second
s.subCacheJitter = sc.JitterPercent
}
// subCacheKey 生成订阅缓存 key热路径避免 fmt.Sprintf 开销)
func subCacheKey(userID, groupID int64) string {
return "sub:" + strconv.FormatInt(userID, 10) + ":" + strconv.FormatInt(groupID, 10)
}
// jitteredTTL 为 TTL 添加抖动,避免集中过期
func (s *SubscriptionService) jitteredTTL(ttl time.Duration) time.Duration {
if ttl <= 0 || s.subCacheJitter <= 0 {
return ttl
}
pct := s.subCacheJitter
if pct > 100 {
pct = 100
}
delta := float64(pct) / 100
factor := 1 - delta + rand.Float64()*(2*delta)
if factor <= 0 {
return ttl
}
return time.Duration(float64(ttl) * factor)
}
// InvalidateSubCache 失效指定用户+分组的订阅 L1 缓存
func (s *SubscriptionService) InvalidateSubCache(userID, groupID int64) {
if s.subCacheL1 == nil {
return
}
s.subCacheL1.Del(subCacheKey(userID, groupID))
2025-12-18 13:50:39 +08:00
}
// AssignSubscriptionInput 分配订阅输入
type AssignSubscriptionInput struct {
UserID int64
GroupID int64
ValidityDays int
AssignedBy int64
Notes string
}
// AssignSubscription 分配订阅给用户(不允许重复分配)
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) {
sub, _, err := s.assignSubscriptionWithReuse(ctx, input)
2025-12-18 13:50:39 +08:00
if err != nil {
return nil, err
}
return sub, nil
}
// AssignOrExtendSubscription 分配或续期订阅(用于兑换码等场景)
// 如果用户已有同分组的订阅:
// - 未过期:从当前过期时间累加天数
// - 已过期:从当前时间开始计算新的过期时间,并激活订阅
//
2025-12-18 13:50:39 +08:00
// 如果没有订阅:创建新订阅
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
2025-12-18 13:50:39 +08:00
// 检查分组是否存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
2025-12-18 13:50:39 +08:00
if err != nil {
return nil, false, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, false, ErrGroupNotSubscriptionType
}
// 查询是否已有订阅
existingSub, err := s.userSubRepo.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
2025-12-18 13:50:39 +08:00
if err != nil {
// 不存在记录是正常情况,其他错误需要返回
existingSub = nil
}
validityDays := input.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
if validityDays > MaxValidityDays {
validityDays = MaxValidityDays
}
2025-12-18 13:50:39 +08:00
// 已有订阅,执行续期(在事务中完成所有更新)
2025-12-18 13:50:39 +08:00
if existingSub != nil {
now := time.Now()
var newExpiresAt time.Time
if existingSub.ExpiresAt.After(now) {
// 未过期:从当前过期时间累加
newExpiresAt = existingSub.ExpiresAt.AddDate(0, 0, validityDays)
} else {
// 已过期:从当前时间开始计算
newExpiresAt = now.AddDate(0, 0, validityDays)
}
// 确保不超过最大过期时间
if newExpiresAt.After(MaxExpiresAt) {
newExpiresAt = MaxExpiresAt
}
// 开启事务ExtendExpiry + UpdateStatus + UpdateNotes 在同一事务中完成
tx, err := s.entClient.Tx(ctx)
if err != nil {
return nil, false, fmt.Errorf("begin transaction: %w", err)
}
txCtx := dbent.NewTxContext(ctx, tx)
2025-12-18 13:50:39 +08:00
// 更新过期时间
if err := s.userSubRepo.ExtendExpiry(txCtx, existingSub.ID, newExpiresAt); err != nil {
_ = tx.Rollback()
2025-12-18 13:50:39 +08:00
return nil, false, fmt.Errorf("extend subscription: %w", err)
}
// 如果订阅已过期或被暂停恢复为active状态
if existingSub.Status != SubscriptionStatusActive {
if err := s.userSubRepo.UpdateStatus(txCtx, existingSub.ID, SubscriptionStatusActive); err != nil {
_ = tx.Rollback()
2025-12-18 13:50:39 +08:00
return nil, false, fmt.Errorf("update subscription status: %w", err)
}
}
// 追加备注
if input.Notes != "" {
newNotes := existingSub.Notes
if newNotes != "" {
newNotes += "\n"
}
newNotes += input.Notes
if err := s.userSubRepo.UpdateNotes(txCtx, existingSub.ID, newNotes); err != nil {
_ = tx.Rollback()
return nil, false, fmt.Errorf("update subscription notes: %w", err)
2025-12-18 13:50:39 +08:00
}
}
// 提交事务
if err := tx.Commit(); err != nil {
return nil, false, fmt.Errorf("commit transaction: %w", err)
}
2025-12-18 13:50:39 +08:00
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
2025-12-18 13:50:39 +08:00
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
2025-12-18 13:50:39 +08:00
}()
}
// 返回更新后的订阅
sub, err := s.userSubRepo.GetByID(ctx, existingSub.ID)
2025-12-18 13:50:39 +08:00
return sub, true, err // true 表示是续期
}
// 没有订阅,创建新订阅
sub, err := s.createSubscription(ctx, input)
if err != nil {
return nil, false, err
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
2025-12-18 13:50:39 +08:00
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
2025-12-18 13:50:39 +08:00
}()
}
return sub, false, nil // false 表示是新建
}
// createSubscription 创建新订阅(内部方法)
func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) {
2025-12-18 13:50:39 +08:00
validityDays := input.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
if validityDays > MaxValidityDays {
validityDays = MaxValidityDays
}
2025-12-18 13:50:39 +08:00
now := time.Now()
expiresAt := now.AddDate(0, 0, validityDays)
if expiresAt.After(MaxExpiresAt) {
expiresAt = MaxExpiresAt
}
sub := &UserSubscription{
UserID: input.UserID,
GroupID: input.GroupID,
StartsAt: now,
ExpiresAt: expiresAt,
Status: SubscriptionStatusActive,
2025-12-18 13:50:39 +08:00
AssignedAt: now,
Notes: input.Notes,
CreatedAt: now,
UpdatedAt: now,
2025-12-18 13:50:39 +08:00
}
// 只有当 AssignedBy > 0 时才设置0 表示系统分配,如兑换码)
if input.AssignedBy > 0 {
sub.AssignedBy = &input.AssignedBy
}
if err := s.userSubRepo.Create(ctx, sub); err != nil {
2025-12-18 13:50:39 +08:00
return nil, err
}
// 重新获取完整订阅信息(包含关联)
return s.userSubRepo.GetByID(ctx, sub.ID)
2025-12-18 13:50:39 +08:00
}
// BulkAssignSubscriptionInput 批量分配订阅输入
type BulkAssignSubscriptionInput struct {
UserIDs []int64
GroupID int64
ValidityDays int
AssignedBy int64
Notes string
}
// BulkAssignResult 批量分配结果
type BulkAssignResult struct {
SuccessCount int
CreatedCount int
ReusedCount int
FailedCount int
Subscriptions []UserSubscription
Errors []string
Statuses map[int64]string
2025-12-18 13:50:39 +08:00
}
// BulkAssignSubscription 批量分配订阅
func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) {
result := &BulkAssignResult{
Subscriptions: make([]UserSubscription, 0),
Errors: make([]string, 0),
Statuses: make(map[int64]string),
2025-12-18 13:50:39 +08:00
}
for _, userID := range input.UserIDs {
sub, reused, err := s.assignSubscriptionWithReuse(ctx, &AssignSubscriptionInput{
2025-12-18 13:50:39 +08:00
UserID: userID,
GroupID: input.GroupID,
ValidityDays: input.ValidityDays,
AssignedBy: input.AssignedBy,
Notes: input.Notes,
})
if err != nil {
result.FailedCount++
result.Errors = append(result.Errors, fmt.Sprintf("user %d: %v", userID, err))
result.Statuses[userID] = "failed"
2025-12-18 13:50:39 +08:00
} else {
result.SuccessCount++
result.Subscriptions = append(result.Subscriptions, *sub)
if reused {
result.ReusedCount++
result.Statuses[userID] = "reused"
} else {
result.CreatedCount++
result.Statuses[userID] = "created"
}
2025-12-18 13:50:39 +08:00
}
}
return result, nil
}
func (s *SubscriptionService) assignSubscriptionWithReuse(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
// 检查分组是否存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil {
return nil, false, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, false, ErrGroupNotSubscriptionType
}
// 检查是否已存在订阅;若已存在,则按幂等成功返回现有订阅
exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil {
return nil, false, err
}
if exists {
sub, getErr := s.userSubRepo.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if getErr != nil {
return nil, false, getErr
}
if conflictReason, conflict := detectAssignSemanticConflict(sub, input); conflict {
return nil, false, ErrSubscriptionAssignConflict.WithMetadata(map[string]string{
"conflict_reason": conflictReason,
})
}
return sub, true, nil
}
sub, err := s.createSubscription(ctx, input)
if err != nil {
return nil, false, err
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return sub, false, nil
}
func detectAssignSemanticConflict(existing *UserSubscription, input *AssignSubscriptionInput) (string, bool) {
if existing == nil || input == nil {
return "", false
}
normalizedDays := normalizeAssignValidityDays(input.ValidityDays)
if !existing.StartsAt.IsZero() {
expectedExpiresAt := existing.StartsAt.AddDate(0, 0, normalizedDays)
if expectedExpiresAt.After(MaxExpiresAt) {
expectedExpiresAt = MaxExpiresAt
}
if !existing.ExpiresAt.Equal(expectedExpiresAt) {
return "validity_days_mismatch", true
}
}
existingNotes := strings.TrimSpace(existing.Notes)
inputNotes := strings.TrimSpace(input.Notes)
if existingNotes != inputNotes {
return "notes_mismatch", true
}
return "", false
}
func normalizeAssignValidityDays(days int) int {
if days <= 0 {
days = 30
}
if days > MaxValidityDays {
days = MaxValidityDays
}
return days
}
2025-12-18 13:50:39 +08:00
// RevokeSubscription 撤销订阅
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
// 先获取订阅信息用于失效缓存
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
2025-12-18 13:50:39 +08:00
if err != nil {
return err
}
if err := s.userSubRepo.Delete(ctx, subscriptionID); err != nil {
2025-12-18 13:50:39 +08:00
return err
}
// 失效订阅缓存
s.InvalidateSubCache(sub.UserID, sub.GroupID)
2025-12-18 13:50:39 +08:00
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
2025-12-18 13:50:39 +08:00
}()
}
return nil
}
// ExtendSubscription 调整订阅时长(正数延长,负数缩短)
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*UserSubscription, error) {
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
2025-12-18 13:50:39 +08:00
if err != nil {
return nil, ErrSubscriptionNotFound
}
// 限制调整天数范围
if days > MaxValidityDays {
days = MaxValidityDays
}
if days < -MaxValidityDays {
days = -MaxValidityDays
}
now := time.Now()
isExpired := !sub.ExpiresAt.After(now)
// 如果订阅已过期,不允许负向调整
if isExpired && days < 0 {
return nil, infraerrors.BadRequest("CANNOT_SHORTEN_EXPIRED", "cannot shorten an expired subscription")
}
2025-12-18 13:50:39 +08:00
// 计算新的过期时间
var newExpiresAt time.Time
if isExpired {
// 已过期:从当前时间开始增加天数
newExpiresAt = now.AddDate(0, 0, days)
} else {
// 未过期:从原过期时间增加/减少天数
newExpiresAt = sub.ExpiresAt.AddDate(0, 0, days)
}
if newExpiresAt.After(MaxExpiresAt) {
newExpiresAt = MaxExpiresAt
}
// 检查新的过期时间必须大于当前时间
if !newExpiresAt.After(now) {
return nil, ErrAdjustWouldExpire
}
if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
2025-12-18 13:50:39 +08:00
return nil, err
}
// 如果订阅已过期恢复为active状态
if sub.Status == SubscriptionStatusExpired {
if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, SubscriptionStatusActive); err != nil {
2025-12-18 13:50:39 +08:00
return nil, err
}
}
// 失效订阅缓存
s.InvalidateSubCache(sub.UserID, sub.GroupID)
2025-12-18 13:50:39 +08:00
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
2025-12-18 13:50:39 +08:00
}()
}
return s.userSubRepo.GetByID(ctx, subscriptionID)
2025-12-18 13:50:39 +08:00
}
// GetByID 根据ID获取订阅
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubscription, error) {
return s.userSubRepo.GetByID(ctx, id)
2025-12-18 13:50:39 +08:00
}
// GetActiveSubscription 获取用户对特定分组的有效订阅
// 使用 L1 缓存 + singleflight 加速中间件热路径。
// 返回缓存对象的浅拷贝,调用方可安全修改字段而不会污染缓存或触发 data race。
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*UserSubscription, error) {
key := subCacheKey(userID, groupID)
// L1 缓存命中:返回浅拷贝
if s.subCacheL1 != nil {
if v, ok := s.subCacheL1.Get(key); ok {
if sub, ok := v.(*UserSubscription); ok {
cp := *sub
return &cp, nil
}
}
}
// singleflight 防止并发击穿
value, err, _ := s.subCacheGroup.Do(key, func() (any, error) {
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, err // 直接透传 repo 已翻译的错误NotFound → ErrSubscriptionNotFound其他错误原样返回
}
// 写入 L1 缓存
if s.subCacheL1 != nil {
_ = s.subCacheL1.SetWithTTL(key, sub, 1, s.jitteredTTL(s.subCacheTTL))
}
return sub, nil
})
2025-12-18 13:50:39 +08:00
if err != nil {
return nil, err
2025-12-18 13:50:39 +08:00
}
// singleflight 返回的也是缓存指针,需要浅拷贝
sub, ok := value.(*UserSubscription)
if !ok || sub == nil {
return nil, ErrSubscriptionNotFound
}
cp := *sub
return &cp, nil
2025-12-18 13:50:39 +08:00
}
// ListUserSubscriptions 获取用户的所有订阅
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]UserSubscription, error) {
subs, err := s.userSubRepo.ListByUserID(ctx, userID)
if err != nil {
return nil, err
}
normalizeExpiredWindows(subs)
normalizeSubscriptionStatus(subs)
return subs, nil
2025-12-18 13:50:39 +08:00
}
// ListActiveUserSubscriptions 获取用户的所有有效订阅
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]UserSubscription, error) {
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, err
}
normalizeExpiredWindows(subs)
return subs, nil
2025-12-18 13:50:39 +08:00
}
// ListGroupSubscriptions 获取分组的所有订阅
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
subs, pag, err := s.userSubRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
return nil, nil, err
}
normalizeExpiredWindows(subs)
normalizeSubscriptionStatus(subs)
return subs, pag, nil
2025-12-18 13:50:39 +08:00
}
// List 获取所有订阅(分页,支持筛选和排序)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, sortBy, sortOrder)
if err != nil {
return nil, nil, err
}
normalizeExpiredWindows(subs)
normalizeSubscriptionStatus(subs)
return subs, pag, nil
}
// normalizeExpiredWindows 将已过期窗口的数据清零(仅影响返回数据,不影响数据库)
// 这确保前端显示正确的当前窗口状态,而不是过期窗口的历史数据
func normalizeExpiredWindows(subs []UserSubscription) {
for i := range subs {
sub := &subs[i]
// 日窗口过期:清零展示数据
if sub.NeedsDailyReset() {
sub.DailyWindowStart = nil
sub.DailyUsageUSD = 0
}
// 周窗口过期:清零展示数据
if sub.NeedsWeeklyReset() {
sub.WeeklyWindowStart = nil
sub.WeeklyUsageUSD = 0
}
// 月窗口过期:清零展示数据
if sub.NeedsMonthlyReset() {
sub.MonthlyWindowStart = nil
sub.MonthlyUsageUSD = 0
}
}
2025-12-18 13:50:39 +08:00
}
// normalizeSubscriptionStatus 根据实际过期时间修正状态(仅影响返回数据,不影响数据库)
// 这确保前端显示正确的状态,即使定时任务尚未更新数据库
func normalizeSubscriptionStatus(subs []UserSubscription) {
now := time.Now()
for i := range subs {
sub := &subs[i]
if sub.Status == SubscriptionStatusActive && !sub.ExpiresAt.After(now) {
sub.Status = SubscriptionStatusExpired
}
}
}
// startOfDay 返回给定时间所在日期的零点(保持原时区)
func startOfDay(t time.Time) time.Time {
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
}
2025-12-18 13:50:39 +08:00
// CheckAndActivateWindow 检查并激活窗口(首次使用时)
func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *UserSubscription) error {
2025-12-18 13:50:39 +08:00
if sub.IsWindowActivated() {
return nil
}
// 使用当天零点作为窗口起始时间
windowStart := startOfDay(time.Now())
return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart)
2025-12-18 13:50:39 +08:00
}
// AdminResetQuota manually resets the daily and/or weekly usage windows.
// Uses startOfDay(now) as the new window start, matching automatic resets.
func (s *SubscriptionService) AdminResetQuota(ctx context.Context, subscriptionID int64, resetDaily, resetWeekly bool) (*UserSubscription, error) {
2026-03-10 23:59:39 +08:00
if !resetDaily && !resetWeekly {
return nil, ErrInvalidInput
}
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return nil, err
}
windowStart := startOfDay(time.Now())
if resetDaily {
if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, windowStart); err != nil {
return nil, err
}
}
if resetWeekly {
if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, windowStart); err != nil {
return nil, err
}
}
// Invalidate caches, same as CheckAndResetWindows
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil {
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
}
// Return the refreshed subscription from DB
return s.userSubRepo.GetByID(ctx, subscriptionID)
}
2025-12-18 13:50:39 +08:00
// CheckAndResetWindows 检查并重置过期的窗口
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error {
// 使用当天零点作为新窗口起始时间
windowStart := startOfDay(time.Now())
needsInvalidateCache := false
2025-12-18 13:50:39 +08:00
// 日窗口重置24小时
if sub.NeedsDailyReset() {
if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, windowStart); err != nil {
2025-12-18 13:50:39 +08:00
return err
}
sub.DailyWindowStart = &windowStart
2025-12-18 13:50:39 +08:00
sub.DailyUsageUSD = 0
needsInvalidateCache = true
2025-12-18 13:50:39 +08:00
}
// 周窗口重置7天
if sub.NeedsWeeklyReset() {
if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, windowStart); err != nil {
2025-12-18 13:50:39 +08:00
return err
}
sub.WeeklyWindowStart = &windowStart
2025-12-18 13:50:39 +08:00
sub.WeeklyUsageUSD = 0
needsInvalidateCache = true
2025-12-18 13:50:39 +08:00
}
// 月窗口重置30天
if sub.NeedsMonthlyReset() {
if err := s.userSubRepo.ResetMonthlyUsage(ctx, sub.ID, windowStart); err != nil {
2025-12-18 13:50:39 +08:00
return err
}
sub.MonthlyWindowStart = &windowStart
2025-12-18 13:50:39 +08:00
sub.MonthlyUsageUSD = 0
needsInvalidateCache = true
}
// 如果有窗口被重置,失效缓存以保持一致性
if needsInvalidateCache {
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil {
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
}
2025-12-18 13:50:39 +08:00
}
return nil
}
// CheckUsageLimits 检查使用限额(返回错误如果超限)
// 用于中间件的快速预检查additionalCost 通常为 0
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSubscription, group *Group, additionalCost float64) error {
2025-12-18 13:50:39 +08:00
if !sub.CheckDailyLimit(group, additionalCost) {
return ErrDailyLimitExceeded
}
if !sub.CheckWeeklyLimit(group, additionalCost) {
return ErrWeeklyLimitExceeded
}
if !sub.CheckMonthlyLimit(group, additionalCost) {
return ErrMonthlyLimitExceeded
}
return nil
}
// ValidateAndCheckLimits 合并验证+限额检查(中间件热路径专用)
// 仅做内存检查,不触发 DB 写入。窗口重置的 DB 写入由 DoWindowMaintenance 异步完成。
// 返回 needsMaintenance 表示是否需要异步执行窗口维护。
func (s *SubscriptionService) ValidateAndCheckLimits(sub *UserSubscription, group *Group) (needsMaintenance bool, err error) {
// 1. 验证订阅状态
if sub.Status == SubscriptionStatusExpired {
return false, ErrSubscriptionExpired
}
if sub.Status == SubscriptionStatusSuspended {
return false, ErrSubscriptionSuspended
}
if sub.IsExpired() {
return false, ErrSubscriptionExpired
}
// 2. 内存中修正过期窗口的用量,确保 CheckUsageLimits 不会误拒绝用户
// 实际的 DB 窗口重置由 DoWindowMaintenance 异步完成
if sub.NeedsDailyReset() {
sub.DailyUsageUSD = 0
needsMaintenance = true
}
if sub.NeedsWeeklyReset() {
sub.WeeklyUsageUSD = 0
needsMaintenance = true
}
if sub.NeedsMonthlyReset() {
sub.MonthlyUsageUSD = 0
needsMaintenance = true
}
if !sub.IsWindowActivated() {
needsMaintenance = true
}
// 3. 检查用量限额
if !sub.CheckDailyLimit(group, 0) {
return needsMaintenance, ErrDailyLimitExceeded
}
if !sub.CheckWeeklyLimit(group, 0) {
return needsMaintenance, ErrWeeklyLimitExceeded
}
if !sub.CheckMonthlyLimit(group, 0) {
return needsMaintenance, ErrMonthlyLimitExceeded
}
return needsMaintenance, nil
}
// DoWindowMaintenance 异步执行窗口维护(激活+重置)
// 使用独立 context不受请求取消影响。
// 注意:此方法仅在 ValidateAndCheckLimits 返回 needsMaintenance=true 时调用,
// 而 IsExpired()=true 的订阅在 ValidateAndCheckLimits 中已被拦截返回错误,
// 因此进入此方法的订阅一定未过期,无需处理过期状态同步。
func (s *SubscriptionService) DoWindowMaintenance(sub *UserSubscription) {
if s == nil {
return
}
if s.maintenanceQueue != nil {
err := s.maintenanceQueue.TryEnqueue(func() {
s.doWindowMaintenance(sub)
})
if err != nil {
log.Printf("Subscription maintenance enqueue failed: %v", err)
}
return
}
s.doWindowMaintenance(sub)
}
func (s *SubscriptionService) doWindowMaintenance(sub *UserSubscription) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 激活窗口(首次使用时)
if !sub.IsWindowActivated() {
if err := s.CheckAndActivateWindow(ctx, sub); err != nil {
log.Printf("Failed to activate subscription windows: %v", err)
}
}
// 重置过期窗口
if err := s.CheckAndResetWindows(ctx, sub); err != nil {
log.Printf("Failed to reset subscription windows: %v", err)
}
// 失效 L1 缓存,确保后续请求拿到更新后的数据
s.InvalidateSubCache(sub.UserID, sub.GroupID)
}
2025-12-18 13:50:39 +08:00
// RecordUsage 记录使用量到订阅
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD)
2025-12-18 13:50:39 +08:00
}
// SubscriptionProgress 订阅进度
type SubscriptionProgress struct {
ID int64 `json:"id"`
GroupName string `json:"group_name"`
ExpiresAt time.Time `json:"expires_at"`
ExpiresInDays int `json:"expires_in_days"`
2025-12-18 13:50:39 +08:00
Daily *UsageWindowProgress `json:"daily,omitempty"`
Weekly *UsageWindowProgress `json:"weekly,omitempty"`
Monthly *UsageWindowProgress `json:"monthly,omitempty"`
}
// UsageWindowProgress 使用窗口进度
type UsageWindowProgress struct {
LimitUSD float64 `json:"limit_usd"`
UsedUSD float64 `json:"used_usd"`
RemainingUSD float64 `json:"remaining_usd"`
Percentage float64 `json:"percentage"`
WindowStart time.Time `json:"window_start"`
ResetsAt time.Time `json:"resets_at"`
ResetsInSeconds int64 `json:"resets_in_seconds"`
2025-12-18 13:50:39 +08:00
}
// GetSubscriptionProgress 获取订阅使用进度
func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subscriptionID int64) (*SubscriptionProgress, error) {
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
2025-12-18 13:50:39 +08:00
if err != nil {
return nil, ErrSubscriptionNotFound
}
group := sub.Group
if group == nil {
group, err = s.groupRepo.GetByID(ctx, sub.GroupID)
2025-12-18 13:50:39 +08:00
if err != nil {
return nil, err
}
}
return s.calculateProgress(sub, group), nil
}
// calculateProgress 根据已加载的订阅和分组数据计算使用进度(纯内存计算,无 DB 查询)
func (s *SubscriptionService) calculateProgress(sub *UserSubscription, group *Group) *SubscriptionProgress {
2025-12-18 13:50:39 +08:00
progress := &SubscriptionProgress{
ID: sub.ID,
GroupName: group.Name,
ExpiresAt: sub.ExpiresAt,
ExpiresInDays: sub.DaysRemaining(),
}
// 日进度
if group.HasDailyLimit() && sub.DailyWindowStart != nil {
limit := *group.DailyLimitUSD
resetsAt := sub.DailyWindowStart.Add(24 * time.Hour)
progress.Daily = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.DailyUsageUSD,
RemainingUSD: limit - sub.DailyUsageUSD,
Percentage: (sub.DailyUsageUSD / limit) * 100,
WindowStart: *sub.DailyWindowStart,
ResetsAt: resetsAt,
2025-12-18 13:50:39 +08:00
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Daily.RemainingUSD < 0 {
progress.Daily.RemainingUSD = 0
}
if progress.Daily.Percentage > 100 {
progress.Daily.Percentage = 100
}
if progress.Daily.ResetsInSeconds < 0 {
progress.Daily.ResetsInSeconds = 0
}
}
// 周进度
if group.HasWeeklyLimit() && sub.WeeklyWindowStart != nil {
limit := *group.WeeklyLimitUSD
resetsAt := sub.WeeklyWindowStart.Add(7 * 24 * time.Hour)
progress.Weekly = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.WeeklyUsageUSD,
RemainingUSD: limit - sub.WeeklyUsageUSD,
Percentage: (sub.WeeklyUsageUSD / limit) * 100,
WindowStart: *sub.WeeklyWindowStart,
ResetsAt: resetsAt,
2025-12-18 13:50:39 +08:00
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Weekly.RemainingUSD < 0 {
progress.Weekly.RemainingUSD = 0
}
if progress.Weekly.Percentage > 100 {
progress.Weekly.Percentage = 100
}
if progress.Weekly.ResetsInSeconds < 0 {
progress.Weekly.ResetsInSeconds = 0
}
}
// 月进度
if group.HasMonthlyLimit() && sub.MonthlyWindowStart != nil {
limit := *group.MonthlyLimitUSD
resetsAt := sub.MonthlyWindowStart.Add(30 * 24 * time.Hour)
progress.Monthly = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.MonthlyUsageUSD,
RemainingUSD: limit - sub.MonthlyUsageUSD,
Percentage: (sub.MonthlyUsageUSD / limit) * 100,
WindowStart: *sub.MonthlyWindowStart,
ResetsAt: resetsAt,
2025-12-18 13:50:39 +08:00
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Monthly.RemainingUSD < 0 {
progress.Monthly.RemainingUSD = 0
}
if progress.Monthly.Percentage > 100 {
progress.Monthly.Percentage = 100
}
if progress.Monthly.ResetsInSeconds < 0 {
progress.Monthly.ResetsInSeconds = 0
}
}
return progress
2025-12-18 13:50:39 +08:00
}
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
// ListActiveByUserID 已使用 .WithGroup() eager-load Group 关联1 次查询获取所有数据
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
2025-12-18 13:50:39 +08:00
if err != nil {
return nil, err
}
progresses := make([]SubscriptionProgress, 0, len(subs))
for i := range subs {
sub := &subs[i]
group := sub.Group
if group == nil {
2025-12-18 13:50:39 +08:00
continue
}
progresses = append(progresses, *s.calculateProgress(sub, group))
2025-12-18 13:50:39 +08:00
}
return progresses, nil
}
// ValidateSubscription 验证订阅是否有效
func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *UserSubscription) error {
if sub.Status == SubscriptionStatusExpired {
2025-12-18 13:50:39 +08:00
return ErrSubscriptionExpired
}
if sub.Status == SubscriptionStatusSuspended {
2025-12-18 13:50:39 +08:00
return ErrSubscriptionSuspended
}
if sub.IsExpired() {
// 更新状态
_ = s.userSubRepo.UpdateStatus(ctx, sub.ID, SubscriptionStatusExpired)
2025-12-18 13:50:39 +08:00
return ErrSubscriptionExpired
}
return nil
}