2025-12-19 23:39:28 +08:00
|
|
|
|
package repository
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
2026-01-01 04:01:51 +08:00
|
|
|
|
"errors"
|
2025-12-19 23:39:28 +08:00
|
|
|
|
"fmt"
|
2026-01-01 04:01:51 +08:00
|
|
|
|
"strconv"
|
2025-12-19 23:39:28 +08:00
|
|
|
|
|
2025-12-25 17:15:01 +08:00
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
2025-12-19 23:39:28 +08:00
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-31 08:50:12 +08:00
|
|
|
|
// 并发控制缓存常量定义
|
|
|
|
|
|
//
|
|
|
|
|
|
// 性能优化说明:
|
|
|
|
|
|
// 原实现使用 SCAN 命令遍历独立的槽位键(concurrency:account:{id}:{requestID}),
|
|
|
|
|
|
// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。
|
|
|
|
|
|
//
|
|
|
|
|
|
// 新实现改用 Redis 有序集合(Sorted Set):
|
|
|
|
|
|
// 1. 每个账号/用户只有一个键,成员为 requestID,分数为时间戳
|
|
|
|
|
|
// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1)
|
|
|
|
|
|
// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL
|
|
|
|
|
|
// 4. 单次 Redis 调用完成计数,减少网络往返
|
2025-12-19 23:39:28 +08:00
|
|
|
|
const (
|
2025-12-31 08:50:12 +08:00
|
|
|
|
// 并发槽位键前缀(有序集合)
|
|
|
|
|
|
// 格式: concurrency:account:{accountID}
|
2025-12-24 21:00:29 +08:00
|
|
|
|
accountSlotKeyPrefix = "concurrency:account:"
|
2025-12-31 08:50:12 +08:00
|
|
|
|
// 格式: concurrency:user:{userID}
|
2025-12-24 21:00:29 +08:00
|
|
|
|
userSlotKeyPrefix = "concurrency:user:"
|
2026-01-03 06:32:51 -08:00
|
|
|
|
// 等待队列计数器格式: concurrency:wait:{userID}
|
|
|
|
|
|
waitQueueKeyPrefix = "concurrency:wait:"
|
2026-01-01 04:01:51 +08:00
|
|
|
|
// 账号级等待队列计数器格式: wait:account:{accountID}
|
|
|
|
|
|
accountWaitKeyPrefix = "wait:account:"
|
2025-12-24 21:00:29 +08:00
|
|
|
|
|
2025-12-31 08:50:12 +08:00
|
|
|
|
// 默认槽位过期时间(分钟),可通过配置覆盖
|
|
|
|
|
|
defaultSlotTTLMinutes = 15
|
2025-12-19 23:39:28 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
var (
|
2025-12-31 08:50:12 +08:00
|
|
|
|
// acquireScript 使用有序集合计数并在未达上限时添加槽位
|
|
|
|
|
|
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题
|
|
|
|
|
|
// KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id})
|
2025-12-24 21:00:29 +08:00
|
|
|
|
// ARGV[1] = maxConcurrency
|
2025-12-31 08:50:12 +08:00
|
|
|
|
// ARGV[2] = TTL(秒)
|
|
|
|
|
|
// ARGV[3] = requestID
|
2025-12-19 23:39:28 +08:00
|
|
|
|
acquireScript = redis.NewScript(`
|
2025-12-31 08:50:12 +08:00
|
|
|
|
local key = KEYS[1]
|
2025-12-24 21:00:29 +08:00
|
|
|
|
local maxConcurrency = tonumber(ARGV[1])
|
|
|
|
|
|
local ttl = tonumber(ARGV[2])
|
2025-12-31 08:50:12 +08:00
|
|
|
|
local requestID = ARGV[3]
|
|
|
|
|
|
|
|
|
|
|
|
-- 使用 Redis 服务器时间,确保多实例时钟一致
|
|
|
|
|
|
local timeResult = redis.call('TIME')
|
|
|
|
|
|
local now = tonumber(timeResult[1])
|
|
|
|
|
|
local expireBefore = now - ttl
|
2025-12-24 21:00:29 +08:00
|
|
|
|
|
2025-12-31 08:50:12 +08:00
|
|
|
|
-- 清理过期槽位
|
|
|
|
|
|
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
|
|
|
|
|
|
|
|
|
|
|
-- 检查是否已存在(支持重试场景刷新时间戳)
|
|
|
|
|
|
local exists = redis.call('ZSCORE', key, requestID)
|
|
|
|
|
|
if exists ~= false then
|
|
|
|
|
|
redis.call('ZADD', key, now, requestID)
|
|
|
|
|
|
redis.call('EXPIRE', key, ttl)
|
|
|
|
|
|
return 1
|
|
|
|
|
|
end
|
2025-12-24 21:00:29 +08:00
|
|
|
|
|
2025-12-31 08:50:12 +08:00
|
|
|
|
-- 检查是否达到并发上限
|
|
|
|
|
|
local count = redis.call('ZCARD', key)
|
2025-12-24 21:00:29 +08:00
|
|
|
|
if count < maxConcurrency then
|
2025-12-31 08:50:12 +08:00
|
|
|
|
redis.call('ZADD', key, now, requestID)
|
|
|
|
|
|
redis.call('EXPIRE', key, ttl)
|
2025-12-19 23:39:28 +08:00
|
|
|
|
return 1
|
|
|
|
|
|
end
|
2025-12-24 21:00:29 +08:00
|
|
|
|
|
2025-12-19 23:39:28 +08:00
|
|
|
|
return 0
|
|
|
|
|
|
`)
|
|
|
|
|
|
|
2025-12-31 08:50:12 +08:00
|
|
|
|
// getCountScript 统计有序集合中的槽位数量并清理过期条目
|
|
|
|
|
|
// 使用 Redis TIME 命令获取服务器时间
|
|
|
|
|
|
// KEYS[1] = 有序集合键
|
|
|
|
|
|
// ARGV[1] = TTL(秒)
|
2025-12-24 21:00:29 +08:00
|
|
|
|
getCountScript = redis.NewScript(`
|
2025-12-31 08:50:12 +08:00
|
|
|
|
local key = KEYS[1]
|
|
|
|
|
|
local ttl = tonumber(ARGV[1])
|
|
|
|
|
|
|
|
|
|
|
|
-- 使用 Redis 服务器时间
|
|
|
|
|
|
local timeResult = redis.call('TIME')
|
|
|
|
|
|
local now = tonumber(timeResult[1])
|
|
|
|
|
|
local expireBefore = now - ttl
|
|
|
|
|
|
|
|
|
|
|
|
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
|
|
|
|
|
return redis.call('ZCARD', key)
|
2025-12-19 23:39:28 +08:00
|
|
|
|
`)
|
|
|
|
|
|
|
2026-01-09 20:52:57 +08:00
|
|
|
|
// incrementWaitScript - refreshes TTL on each increment to keep queue depth accurate
|
2026-01-03 06:32:51 -08:00
|
|
|
|
// KEYS[1] = wait queue key
|
|
|
|
|
|
// ARGV[1] = maxWait
|
|
|
|
|
|
// ARGV[2] = TTL in seconds
|
2025-12-19 23:39:28 +08:00
|
|
|
|
incrementWaitScript = redis.NewScript(`
|
2026-01-03 06:32:51 -08:00
|
|
|
|
local current = redis.call('GET', KEYS[1])
|
|
|
|
|
|
if current == false then
|
|
|
|
|
|
current = 0
|
|
|
|
|
|
else
|
|
|
|
|
|
current = tonumber(current)
|
2025-12-19 23:39:28 +08:00
|
|
|
|
end
|
2025-12-24 21:00:29 +08:00
|
|
|
|
|
2026-01-03 06:32:51 -08:00
|
|
|
|
if current >= tonumber(ARGV[1]) then
|
2025-12-19 23:39:28 +08:00
|
|
|
|
return 0
|
|
|
|
|
|
end
|
2025-12-24 21:00:29 +08:00
|
|
|
|
|
2026-01-03 06:32:51 -08:00
|
|
|
|
local newVal = redis.call('INCR', KEYS[1])
|
2025-12-24 21:00:29 +08:00
|
|
|
|
|
2026-01-09 20:52:57 +08:00
|
|
|
|
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
|
|
|
|
|
|
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
2025-12-24 21:00:29 +08:00
|
|
|
|
|
2026-01-01 04:01:51 +08:00
|
|
|
|
return 1
|
|
|
|
|
|
`)
|
|
|
|
|
|
|
2026-01-09 20:52:57 +08:00
|
|
|
|
// incrementAccountWaitScript - account-level wait queue count (refresh TTL on each increment)
|
2026-01-01 04:01:51 +08:00
|
|
|
|
incrementAccountWaitScript = redis.NewScript(`
|
|
|
|
|
|
local current = redis.call('GET', KEYS[1])
|
|
|
|
|
|
if current == false then
|
|
|
|
|
|
current = 0
|
|
|
|
|
|
else
|
|
|
|
|
|
current = tonumber(current)
|
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
if current >= tonumber(ARGV[1]) then
|
|
|
|
|
|
return 0
|
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
local newVal = redis.call('INCR', KEYS[1])
|
|
|
|
|
|
|
2026-01-09 20:52:57 +08:00
|
|
|
|
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
|
|
|
|
|
|
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
2026-01-01 04:01:51 +08:00
|
|
|
|
|
|
|
|
|
|
return 1
|
|
|
|
|
|
`)
|
2025-12-19 23:39:28 +08:00
|
|
|
|
|
2025-12-24 21:00:29 +08:00
|
|
|
|
// decrementWaitScript - same as before
|
2025-12-19 23:39:28 +08:00
|
|
|
|
decrementWaitScript = redis.NewScript(`
|
2026-01-01 04:01:51 +08:00
|
|
|
|
local current = redis.call('GET', KEYS[1])
|
|
|
|
|
|
if current ~= false and tonumber(current) > 0 then
|
|
|
|
|
|
redis.call('DECR', KEYS[1])
|
|
|
|
|
|
end
|
|
|
|
|
|
return 1
|
|
|
|
|
|
`)
|
|
|
|
|
|
|
2026-03-09 19:55:18 +08:00
|
|
|
|
// cleanupExpiredSlotsScript 清理单个账号/用户有序集合中过期槽位
|
|
|
|
|
|
// KEYS[1] = 有序集合键
|
|
|
|
|
|
// ARGV[1] = TTL(秒)
|
2026-01-01 04:01:51 +08:00
|
|
|
|
cleanupExpiredSlotsScript = redis.NewScript(`
|
2026-03-09 19:55:18 +08:00
|
|
|
|
local key = KEYS[1]
|
|
|
|
|
|
local ttl = tonumber(ARGV[1])
|
|
|
|
|
|
local timeResult = redis.call('TIME')
|
|
|
|
|
|
local now = tonumber(timeResult[1])
|
|
|
|
|
|
local expireBefore = now - ttl
|
|
|
|
|
|
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
|
|
|
|
|
if redis.call('ZCARD', key) == 0 then
|
|
|
|
|
|
redis.call('DEL', key)
|
|
|
|
|
|
else
|
|
|
|
|
|
redis.call('EXPIRE', key, ttl)
|
|
|
|
|
|
end
|
|
|
|
|
|
return 1
|
|
|
|
|
|
`)
|
|
|
|
|
|
|
|
|
|
|
|
// startupCleanupScript 清理非当前进程前缀的槽位成员。
|
|
|
|
|
|
// KEYS 是有序集合键列表,ARGV[1] 是当前进程前缀,ARGV[2] 是槽位 TTL。
|
|
|
|
|
|
// 遍历每个 KEYS[i],移除前缀不匹配的成员,清空后删 key,否则刷新 EXPIRE。
|
|
|
|
|
|
startupCleanupScript = redis.NewScript(`
|
|
|
|
|
|
local activePrefix = ARGV[1]
|
|
|
|
|
|
local slotTTL = tonumber(ARGV[2])
|
|
|
|
|
|
local removed = 0
|
|
|
|
|
|
for i = 1, #KEYS do
|
|
|
|
|
|
local key = KEYS[i]
|
|
|
|
|
|
local members = redis.call('ZRANGE', key, 0, -1)
|
|
|
|
|
|
for _, member in ipairs(members) do
|
|
|
|
|
|
if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then
|
|
|
|
|
|
removed = removed + redis.call('ZREM', key, member)
|
|
|
|
|
|
end
|
|
|
|
|
|
end
|
|
|
|
|
|
if redis.call('ZCARD', key) == 0 then
|
|
|
|
|
|
redis.call('DEL', key)
|
|
|
|
|
|
else
|
|
|
|
|
|
redis.call('EXPIRE', key, slotTTL)
|
|
|
|
|
|
end
|
|
|
|
|
|
end
|
|
|
|
|
|
return removed
|
|
|
|
|
|
`)
|
2025-12-19 23:39:28 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
type concurrencyCache struct {
|
2026-01-01 04:01:51 +08:00
|
|
|
|
rdb *redis.Client
|
|
|
|
|
|
slotTTLSeconds int // 槽位过期时间(秒)
|
|
|
|
|
|
waitQueueTTLSeconds int // 等待队列过期时间(秒)
|
2025-12-19 23:39:28 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-31 08:50:12 +08:00
|
|
|
|
// NewConcurrencyCache 创建并发控制缓存
|
|
|
|
|
|
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
|
2026-01-01 04:01:51 +08:00
|
|
|
|
// waitQueueTTLSeconds: 等待队列过期时间(秒),0 或负数使用 slot TTL
|
|
|
|
|
|
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache {
|
2025-12-31 08:50:12 +08:00
|
|
|
|
if slotTTLMinutes <= 0 {
|
|
|
|
|
|
slotTTLMinutes = defaultSlotTTLMinutes
|
|
|
|
|
|
}
|
2026-01-01 04:01:51 +08:00
|
|
|
|
if waitQueueTTLSeconds <= 0 {
|
|
|
|
|
|
waitQueueTTLSeconds = slotTTLMinutes * 60
|
|
|
|
|
|
}
|
2025-12-31 08:50:12 +08:00
|
|
|
|
return &concurrencyCache{
|
2026-01-01 04:01:51 +08:00
|
|
|
|
rdb: rdb,
|
|
|
|
|
|
slotTTLSeconds: slotTTLMinutes * 60,
|
|
|
|
|
|
waitQueueTTLSeconds: waitQueueTTLSeconds,
|
2025-12-31 08:50:12 +08:00
|
|
|
|
}
|
2025-12-19 23:39:28 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-24 21:00:29 +08:00
|
|
|
|
// Helper functions for key generation
|
2025-12-31 08:50:12 +08:00
|
|
|
|
func accountSlotKey(accountID int64) string {
|
|
|
|
|
|
return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
2025-12-24 21:00:29 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-31 08:50:12 +08:00
|
|
|
|
func userSlotKey(userID int64) string {
|
|
|
|
|
|
return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
2025-12-24 21:00:29 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func waitQueueKey(userID int64) string {
|
2026-01-03 06:32:51 -08:00
|
|
|
|
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
2025-12-24 21:00:29 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-01 04:01:51 +08:00
|
|
|
|
func accountWaitKey(accountID int64) string {
|
|
|
|
|
|
return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-24 21:00:29 +08:00
|
|
|
|
// Account slot operations
|
|
|
|
|
|
|
|
|
|
|
|
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
2025-12-31 08:50:12 +08:00
|
|
|
|
key := accountSlotKey(accountID)
|
|
|
|
|
|
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
|
|
|
|
|
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
2025-12-19 23:39:28 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return false, err
|
|
|
|
|
|
}
|
|
|
|
|
|
return result == 1, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-24 21:00:29 +08:00
|
|
|
|
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
2025-12-31 08:50:12 +08:00
|
|
|
|
key := accountSlotKey(accountID)
|
|
|
|
|
|
return c.rdb.ZRem(ctx, key, requestID).Err()
|
2025-12-19 23:39:28 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
2025-12-31 08:50:12 +08:00
|
|
|
|
key := accountSlotKey(accountID)
|
|
|
|
|
|
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
|
|
|
|
|
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
2025-12-24 21:00:29 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return 0, err
|
|
|
|
|
|
}
|
|
|
|
|
|
return result, nil
|
2025-12-19 23:39:28 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-02-28 15:01:20 +08:00
|
|
|
|
func (c *concurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
|
|
|
|
|
if len(accountIDs) == 0 {
|
|
|
|
|
|
return map[int64]int{}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
now, err := c.rdb.Time(ctx).Result()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("redis TIME: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
|
|
|
|
|
|
|
|
|
|
|
|
pipe := c.rdb.Pipeline()
|
|
|
|
|
|
type accountCmd struct {
|
|
|
|
|
|
accountID int64
|
|
|
|
|
|
zcardCmd *redis.IntCmd
|
|
|
|
|
|
}
|
|
|
|
|
|
cmds := make([]accountCmd, 0, len(accountIDs))
|
|
|
|
|
|
for _, accountID := range accountIDs {
|
|
|
|
|
|
slotKey := accountSlotKeyPrefix + strconv.FormatInt(accountID, 10)
|
|
|
|
|
|
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
|
|
|
|
|
|
cmds = append(cmds, accountCmd{
|
|
|
|
|
|
accountID: accountID,
|
|
|
|
|
|
zcardCmd: pipe.ZCard(ctx, slotKey),
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
|
|
|
|
|
|
return nil, fmt.Errorf("pipeline exec: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
result := make(map[int64]int, len(accountIDs))
|
|
|
|
|
|
for _, cmd := range cmds {
|
|
|
|
|
|
result[cmd.accountID] = int(cmd.zcardCmd.Val())
|
|
|
|
|
|
}
|
|
|
|
|
|
return result, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-24 21:00:29 +08:00
|
|
|
|
// User slot operations
|
|
|
|
|
|
|
|
|
|
|
|
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
2025-12-31 08:50:12 +08:00
|
|
|
|
key := userSlotKey(userID)
|
|
|
|
|
|
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
|
|
|
|
|
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
2025-12-19 23:39:28 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return false, err
|
|
|
|
|
|
}
|
|
|
|
|
|
return result == 1, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-24 21:00:29 +08:00
|
|
|
|
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
2025-12-31 08:50:12 +08:00
|
|
|
|
key := userSlotKey(userID)
|
|
|
|
|
|
return c.rdb.ZRem(ctx, key, requestID).Err()
|
2025-12-19 23:39:28 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
2025-12-31 08:50:12 +08:00
|
|
|
|
key := userSlotKey(userID)
|
|
|
|
|
|
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
|
|
|
|
|
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
2025-12-24 21:00:29 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return 0, err
|
|
|
|
|
|
}
|
|
|
|
|
|
return result, nil
|
2025-12-19 23:39:28 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-24 21:00:29 +08:00
|
|
|
|
// Wait queue operations
|
|
|
|
|
|
|
2025-12-19 23:39:28 +08:00
|
|
|
|
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
2026-01-03 06:32:51 -08:00
|
|
|
|
key := waitQueueKey(userID)
|
|
|
|
|
|
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
2025-12-19 23:39:28 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return false, err
|
|
|
|
|
|
}
|
|
|
|
|
|
return result == 1, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
2026-01-03 06:32:51 -08:00
|
|
|
|
key := waitQueueKey(userID)
|
|
|
|
|
|
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
2025-12-19 23:39:28 +08:00
|
|
|
|
return err
|
|
|
|
|
|
}
|
2026-01-01 04:01:51 +08:00
|
|
|
|
|
|
|
|
|
|
// Account wait queue operations
|
|
|
|
|
|
|
|
|
|
|
|
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
|
|
|
|
|
key := accountWaitKey(accountID)
|
|
|
|
|
|
result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return false, err
|
|
|
|
|
|
}
|
|
|
|
|
|
return result == 1, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
|
|
|
|
|
key := accountWaitKey(accountID)
|
2026-01-03 06:32:51 -08:00
|
|
|
|
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
2026-01-01 04:01:51 +08:00
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
|
|
|
|
|
key := accountWaitKey(accountID)
|
|
|
|
|
|
val, err := c.rdb.Get(ctx, key).Int()
|
|
|
|
|
|
if err != nil && !errors.Is(err, redis.Nil) {
|
|
|
|
|
|
return 0, err
|
|
|
|
|
|
}
|
|
|
|
|
|
if errors.Is(err, redis.Nil) {
|
|
|
|
|
|
return 0, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
return val, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
|
|
|
|
|
if len(accounts) == 0 {
|
|
|
|
|
|
return map[int64]*service.AccountLoadInfo{}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-02-10 17:51:49 +08:00
|
|
|
|
// 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster(Lua 内动态拼 key 会 CROSSSLOT)。
|
|
|
|
|
|
// 每个账号执行 3 个命令:ZREMRANGEBYSCORE(清理过期)、ZCARD(并发数)、GET(等待数)。
|
|
|
|
|
|
now, err := c.rdb.Time(ctx).Result()
|
2026-01-01 04:01:51 +08:00
|
|
|
|
if err != nil {
|
2026-02-10 17:51:49 +08:00
|
|
|
|
return nil, fmt.Errorf("redis TIME: %w", err)
|
2026-01-01 04:01:51 +08:00
|
|
|
|
}
|
2026-02-10 17:51:49 +08:00
|
|
|
|
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
|
2026-01-01 04:01:51 +08:00
|
|
|
|
|
2026-02-10 17:51:49 +08:00
|
|
|
|
pipe := c.rdb.Pipeline()
|
|
|
|
|
|
|
|
|
|
|
|
type accountCmds struct {
|
|
|
|
|
|
id int64
|
|
|
|
|
|
maxConcurrency int
|
|
|
|
|
|
zcardCmd *redis.IntCmd
|
|
|
|
|
|
getCmd *redis.StringCmd
|
|
|
|
|
|
}
|
|
|
|
|
|
cmds := make([]accountCmds, 0, len(accounts))
|
|
|
|
|
|
for _, acc := range accounts {
|
|
|
|
|
|
slotKey := accountSlotKeyPrefix + strconv.FormatInt(acc.ID, 10)
|
|
|
|
|
|
waitKey := accountWaitKeyPrefix + strconv.FormatInt(acc.ID, 10)
|
|
|
|
|
|
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
|
|
|
|
|
|
ac := accountCmds{
|
|
|
|
|
|
id: acc.ID,
|
|
|
|
|
|
maxConcurrency: acc.MaxConcurrency,
|
|
|
|
|
|
zcardCmd: pipe.ZCard(ctx, slotKey),
|
|
|
|
|
|
getCmd: pipe.Get(ctx, waitKey),
|
2026-01-01 04:01:51 +08:00
|
|
|
|
}
|
2026-02-10 17:51:49 +08:00
|
|
|
|
cmds = append(cmds, ac)
|
|
|
|
|
|
}
|
2026-01-01 04:01:51 +08:00
|
|
|
|
|
2026-02-10 17:51:49 +08:00
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
|
|
|
|
|
|
return nil, fmt.Errorf("pipeline exec: %w", err)
|
|
|
|
|
|
}
|
2026-01-01 04:01:51 +08:00
|
|
|
|
|
2026-02-10 17:51:49 +08:00
|
|
|
|
loadMap := make(map[int64]*service.AccountLoadInfo, len(accounts))
|
|
|
|
|
|
for _, ac := range cmds {
|
|
|
|
|
|
currentConcurrency := int(ac.zcardCmd.Val())
|
|
|
|
|
|
waitingCount := 0
|
|
|
|
|
|
if v, err := ac.getCmd.Int(); err == nil {
|
|
|
|
|
|
waitingCount = v
|
|
|
|
|
|
}
|
|
|
|
|
|
loadRate := 0
|
|
|
|
|
|
if ac.maxConcurrency > 0 {
|
|
|
|
|
|
loadRate = (currentConcurrency + waitingCount) * 100 / ac.maxConcurrency
|
|
|
|
|
|
}
|
|
|
|
|
|
loadMap[ac.id] = &service.AccountLoadInfo{
|
|
|
|
|
|
AccountID: ac.id,
|
2026-01-01 04:01:51 +08:00
|
|
|
|
CurrentConcurrency: currentConcurrency,
|
|
|
|
|
|
WaitingCount: waitingCount,
|
|
|
|
|
|
LoadRate: loadRate,
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return loadMap, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-02-07 12:31:10 +08:00
|
|
|
|
func (c *concurrencyCache) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
|
|
|
|
|
if len(users) == 0 {
|
|
|
|
|
|
return map[int64]*service.UserLoadInfo{}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-02-10 17:51:49 +08:00
|
|
|
|
// 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster。
|
|
|
|
|
|
now, err := c.rdb.Time(ctx).Result()
|
2026-02-07 12:31:10 +08:00
|
|
|
|
if err != nil {
|
2026-02-10 17:51:49 +08:00
|
|
|
|
return nil, fmt.Errorf("redis TIME: %w", err)
|
2026-02-07 12:31:10 +08:00
|
|
|
|
}
|
2026-02-10 17:51:49 +08:00
|
|
|
|
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
|
2026-02-07 12:31:10 +08:00
|
|
|
|
|
2026-02-10 17:51:49 +08:00
|
|
|
|
pipe := c.rdb.Pipeline()
|
|
|
|
|
|
|
|
|
|
|
|
type userCmds struct {
|
|
|
|
|
|
id int64
|
|
|
|
|
|
maxConcurrency int
|
|
|
|
|
|
zcardCmd *redis.IntCmd
|
|
|
|
|
|
getCmd *redis.StringCmd
|
|
|
|
|
|
}
|
|
|
|
|
|
cmds := make([]userCmds, 0, len(users))
|
|
|
|
|
|
for _, u := range users {
|
|
|
|
|
|
slotKey := userSlotKeyPrefix + strconv.FormatInt(u.ID, 10)
|
|
|
|
|
|
waitKey := waitQueueKeyPrefix + strconv.FormatInt(u.ID, 10)
|
|
|
|
|
|
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
|
|
|
|
|
|
uc := userCmds{
|
|
|
|
|
|
id: u.ID,
|
|
|
|
|
|
maxConcurrency: u.MaxConcurrency,
|
|
|
|
|
|
zcardCmd: pipe.ZCard(ctx, slotKey),
|
|
|
|
|
|
getCmd: pipe.Get(ctx, waitKey),
|
2026-02-07 12:31:10 +08:00
|
|
|
|
}
|
2026-02-10 17:51:49 +08:00
|
|
|
|
cmds = append(cmds, uc)
|
|
|
|
|
|
}
|
2026-02-07 12:31:10 +08:00
|
|
|
|
|
2026-02-10 17:51:49 +08:00
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
|
|
|
|
|
|
return nil, fmt.Errorf("pipeline exec: %w", err)
|
|
|
|
|
|
}
|
2026-02-07 12:31:10 +08:00
|
|
|
|
|
2026-02-10 17:51:49 +08:00
|
|
|
|
loadMap := make(map[int64]*service.UserLoadInfo, len(users))
|
|
|
|
|
|
for _, uc := range cmds {
|
|
|
|
|
|
currentConcurrency := int(uc.zcardCmd.Val())
|
|
|
|
|
|
waitingCount := 0
|
|
|
|
|
|
if v, err := uc.getCmd.Int(); err == nil {
|
|
|
|
|
|
waitingCount = v
|
|
|
|
|
|
}
|
|
|
|
|
|
loadRate := 0
|
|
|
|
|
|
if uc.maxConcurrency > 0 {
|
|
|
|
|
|
loadRate = (currentConcurrency + waitingCount) * 100 / uc.maxConcurrency
|
|
|
|
|
|
}
|
|
|
|
|
|
loadMap[uc.id] = &service.UserLoadInfo{
|
|
|
|
|
|
UserID: uc.id,
|
2026-02-07 12:31:10 +08:00
|
|
|
|
CurrentConcurrency: currentConcurrency,
|
|
|
|
|
|
WaitingCount: waitingCount,
|
|
|
|
|
|
LoadRate: loadRate,
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return loadMap, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-01 04:01:51 +08:00
|
|
|
|
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
|
|
|
|
|
key := accountSlotKey(accountID)
|
|
|
|
|
|
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
2026-03-09 19:55:18 +08:00
|
|
|
|
|
|
|
|
|
|
func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
|
|
|
|
|
if activeRequestPrefix == "" {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 1. 清理有序集合中非当前进程前缀的成员
|
|
|
|
|
|
slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"}
|
|
|
|
|
|
for _, pattern := range slotPatterns {
|
|
|
|
|
|
if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2. 删除所有等待队列计数器(重启后计数器失效)
|
|
|
|
|
|
waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"}
|
|
|
|
|
|
for _, pattern := range waitPatterns {
|
|
|
|
|
|
if err := c.deleteKeysByPattern(ctx, pattern); err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// cleanupSlotsByPattern 扫描匹配 pattern 的有序集合键,批量调用 Lua 脚本清理非当前进程成员。
|
|
|
|
|
|
func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error {
|
|
|
|
|
|
const scanCount = 200
|
|
|
|
|
|
var cursor uint64
|
|
|
|
|
|
for {
|
|
|
|
|
|
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return fmt.Errorf("scan %s: %w", pattern, err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if len(keys) > 0 {
|
|
|
|
|
|
_, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return fmt.Errorf("cleanup slots %s: %w", pattern, err)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
cursor = nextCursor
|
|
|
|
|
|
if cursor == 0 {
|
|
|
|
|
|
break
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// deleteKeysByPattern 扫描匹配 pattern 的键并删除。
|
|
|
|
|
|
func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error {
|
|
|
|
|
|
const scanCount = 200
|
|
|
|
|
|
var cursor uint64
|
|
|
|
|
|
for {
|
|
|
|
|
|
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return fmt.Errorf("scan %s: %w", pattern, err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if len(keys) > 0 {
|
|
|
|
|
|
if err := c.rdb.Del(ctx, keys...).Err(); err != nil {
|
|
|
|
|
|
return fmt.Errorf("del %s: %w", pattern, err)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
cursor = nextCursor
|
|
|
|
|
|
if cursor == 0 {
|
|
|
|
|
|
break
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|