mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-14 20:04:46 +08:00
The upstream v0.1.90 changed GetAccountConcurrencyBatch from individual Lua script calls (which swallowed per-account errors) to a Redis pipeline approach that propagates errors from rdb.Time() or pipe.Exec(). When the HTTP request context is cancelled (e.g., browser abort), the entire batch fails and the handler silently shows all concurrency as 0. Fix: use context.WithTimeout(context.Background(), 3s) for the Redis call so HTTP request cancellation doesn't affect the read-only concurrency query.
356 lines
12 KiB
Go
356 lines
12 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"crypto/rand"
|
||
"encoding/binary"
|
||
"os"
|
||
"strconv"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||
)
|
||
|
||
// ConcurrencyCache 定义并发控制的缓存接口
|
||
// 使用有序集合存储槽位,按时间戳清理过期条目
|
||
type ConcurrencyCache interface {
|
||
// 账号槽位管理
|
||
// 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID)
|
||
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
|
||
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
|
||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||
GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
|
||
|
||
// 账号等待队列(账号级)
|
||
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
|
||
DecrementAccountWaitCount(ctx context.Context, accountID int64) error
|
||
GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error)
|
||
|
||
// 用户槽位管理
|
||
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
|
||
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
|
||
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
|
||
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
|
||
|
||
// 等待队列计数(只在首次创建时设置 TTL)
|
||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||
|
||
// 批量负载查询(只读)
|
||
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
|
||
GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error)
|
||
|
||
// 清理过期槽位(后台任务)
|
||
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
|
||
}
|
||
|
||
var (
|
||
requestIDPrefix = initRequestIDPrefix()
|
||
requestIDCounter atomic.Uint64
|
||
)
|
||
|
||
func initRequestIDPrefix() string {
|
||
b := make([]byte, 8)
|
||
if _, err := rand.Read(b); err == nil {
|
||
return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36)
|
||
}
|
||
fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16)
|
||
return "r" + strconv.FormatUint(fallback, 36)
|
||
}
|
||
|
||
// generateRequestID generates a unique request ID for concurrency slot tracking.
|
||
// Format: {process_random_prefix}-{base36_counter}
|
||
func generateRequestID() string {
|
||
seq := requestIDCounter.Add(1)
|
||
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
|
||
}
|
||
|
||
const (
|
||
// Default extra wait slots beyond concurrency limit
|
||
defaultExtraWaitSlots = 20
|
||
)
|
||
|
||
// ConcurrencyService manages concurrent request limiting for accounts and users
|
||
type ConcurrencyService struct {
|
||
cache ConcurrencyCache
|
||
}
|
||
|
||
// NewConcurrencyService creates a new ConcurrencyService
|
||
func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService {
|
||
return &ConcurrencyService{cache: cache}
|
||
}
|
||
|
||
// AcquireResult represents the result of acquiring a concurrency slot
|
||
type AcquireResult struct {
|
||
Acquired bool
|
||
ReleaseFunc func() // Must be called when done (typically via defer)
|
||
}
|
||
|
||
type AccountWithConcurrency struct {
|
||
ID int64
|
||
MaxConcurrency int
|
||
}
|
||
|
||
type UserWithConcurrency struct {
|
||
ID int64
|
||
MaxConcurrency int
|
||
}
|
||
|
||
type AccountLoadInfo struct {
|
||
AccountID int64
|
||
CurrentConcurrency int
|
||
WaitingCount int
|
||
LoadRate int // 0-100+ (percent)
|
||
}
|
||
|
||
type UserLoadInfo struct {
|
||
UserID int64
|
||
CurrentConcurrency int
|
||
WaitingCount int
|
||
LoadRate int // 0-100+ (percent)
|
||
}
|
||
|
||
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
|
||
// If the account is at max concurrency, it waits until a slot is available or timeout.
|
||
// Returns a release function that MUST be called when the request completes.
|
||
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||
// If maxConcurrency is 0 or negative, no limit
|
||
if maxConcurrency <= 0 {
|
||
return &AcquireResult{
|
||
Acquired: true,
|
||
ReleaseFunc: func() {}, // no-op
|
||
}, nil
|
||
}
|
||
|
||
// Generate unique request ID for this slot
|
||
requestID := generateRequestID()
|
||
|
||
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency, requestID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if acquired {
|
||
return &AcquireResult{
|
||
Acquired: true,
|
||
ReleaseFunc: func() {
|
||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil {
|
||
logger.LegacyPrintf("service.concurrency", "Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err)
|
||
}
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
return &AcquireResult{
|
||
Acquired: false,
|
||
ReleaseFunc: nil,
|
||
}, nil
|
||
}
|
||
|
||
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
|
||
// If the user is at max concurrency, it waits until a slot is available or timeout.
|
||
// Returns a release function that MUST be called when the request completes.
|
||
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
|
||
// If maxConcurrency is 0 or negative, no limit
|
||
if maxConcurrency <= 0 {
|
||
return &AcquireResult{
|
||
Acquired: true,
|
||
ReleaseFunc: func() {}, // no-op
|
||
}, nil
|
||
}
|
||
|
||
// Generate unique request ID for this slot
|
||
requestID := generateRequestID()
|
||
|
||
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency, requestID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if acquired {
|
||
return &AcquireResult{
|
||
Acquired: true,
|
||
ReleaseFunc: func() {
|
||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil {
|
||
logger.LegacyPrintf("service.concurrency", "Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err)
|
||
}
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
return &AcquireResult{
|
||
Acquired: false,
|
||
ReleaseFunc: nil,
|
||
}, nil
|
||
}
|
||
|
||
// ============================================
|
||
// Wait Queue Count Methods
|
||
// ============================================
|
||
|
||
// IncrementWaitCount attempts to increment the wait queue counter for a user.
|
||
// Returns true if successful, false if the wait queue is full.
|
||
// maxWait should be user.Concurrency + defaultExtraWaitSlots
|
||
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||
if s.cache == nil {
|
||
// Redis not available, allow request
|
||
return true, nil
|
||
}
|
||
|
||
result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait)
|
||
if err != nil {
|
||
// On error, allow the request to proceed (fail open)
|
||
logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for user %d: %v", userID, err)
|
||
return true, nil
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// DecrementWaitCount decrements the wait queue counter for a user.
|
||
// Should be called when a request completes or exits the wait queue.
|
||
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
|
||
if s.cache == nil {
|
||
return
|
||
}
|
||
|
||
// Use background context to ensure decrement even if original context is cancelled
|
||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil {
|
||
logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for user %d: %v", userID, err)
|
||
}
|
||
}
|
||
|
||
// IncrementAccountWaitCount increments the wait queue counter for an account.
|
||
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||
if s.cache == nil {
|
||
return true, nil
|
||
}
|
||
|
||
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
|
||
if err != nil {
|
||
logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for account %d: %v", accountID, err)
|
||
return true, nil
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// DecrementAccountWaitCount decrements the wait queue counter for an account.
|
||
func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
|
||
if s.cache == nil {
|
||
return
|
||
}
|
||
|
||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
|
||
logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for account %d: %v", accountID, err)
|
||
}
|
||
}
|
||
|
||
// GetAccountWaitingCount gets current wait queue count for an account.
|
||
func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||
if s.cache == nil {
|
||
return 0, nil
|
||
}
|
||
return s.cache.GetAccountWaitingCount(ctx, accountID)
|
||
}
|
||
|
||
// CalculateMaxWait calculates the maximum wait queue size for a user
|
||
// maxWait = userConcurrency + defaultExtraWaitSlots
|
||
func CalculateMaxWait(userConcurrency int) int {
|
||
if userConcurrency <= 0 {
|
||
userConcurrency = 1
|
||
}
|
||
return userConcurrency + defaultExtraWaitSlots
|
||
}
|
||
|
||
// GetAccountsLoadBatch returns load info for multiple accounts.
|
||
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||
if s.cache == nil {
|
||
return map[int64]*AccountLoadInfo{}, nil
|
||
}
|
||
return s.cache.GetAccountsLoadBatch(ctx, accounts)
|
||
}
|
||
|
||
// GetUsersLoadBatch returns load info for multiple users.
|
||
func (s *ConcurrencyService) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
|
||
if s.cache == nil {
|
||
return map[int64]*UserLoadInfo{}, nil
|
||
}
|
||
return s.cache.GetUsersLoadBatch(ctx, users)
|
||
}
|
||
|
||
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
|
||
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||
if s.cache == nil {
|
||
return nil
|
||
}
|
||
return s.cache.CleanupExpiredAccountSlots(ctx, accountID)
|
||
}
|
||
|
||
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
|
||
func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) {
|
||
if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 {
|
||
return
|
||
}
|
||
|
||
runCleanup := func() {
|
||
listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
accounts, err := accountRepo.ListSchedulable(listCtx)
|
||
cancel()
|
||
if err != nil {
|
||
logger.LegacyPrintf("service.concurrency", "Warning: list schedulable accounts failed: %v", err)
|
||
return
|
||
}
|
||
for _, account := range accounts {
|
||
accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
|
||
accountCancel()
|
||
if err != nil {
|
||
logger.LegacyPrintf("service.concurrency", "Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
|
||
}
|
||
}
|
||
}
|
||
|
||
go func() {
|
||
ticker := time.NewTicker(interval)
|
||
defer ticker.Stop()
|
||
|
||
runCleanup()
|
||
for range ticker.C {
|
||
runCleanup()
|
||
}
|
||
}()
|
||
}
|
||
|
||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts.
|
||
// Uses a detached context with timeout to prevent HTTP request cancellation from
|
||
// causing the entire batch to fail (which would show all concurrency as 0).
|
||
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||
if len(accountIDs) == 0 {
|
||
return map[int64]int{}, nil
|
||
}
|
||
if s.cache == nil {
|
||
result := make(map[int64]int, len(accountIDs))
|
||
for _, accountID := range accountIDs {
|
||
result[accountID] = 0
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// Use a detached context so that a cancelled HTTP request doesn't cause
|
||
// the Redis pipeline to fail and return all-zero concurrency counts.
|
||
redisCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||
defer cancel()
|
||
|
||
return s.cache.GetAccountConcurrencyBatch(redisCtx, accountIDs)
|
||
}
|