mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 23:12:14 +08:00
187 lines
5.8 KiB
Go
187 lines
5.8 KiB
Go
|
|
package repository
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"errors"
|
|||
|
|
"fmt"
|
|||
|
|
"strconv"
|
|||
|
|
"strings"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|||
|
|
"github.com/redis/go-redis/v9"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// Redis Key 模式(使用 hash tag 确保 Redis Cluster 下同一 accountID 的 key 落入同一 slot)
|
|||
|
|
// 格式: umq:{accountID}:lock / umq:{accountID}:last
|
|||
|
|
const (
|
|||
|
|
umqKeyPrefix = "umq:"
|
|||
|
|
umqLockSuffix = ":lock" // STRING (requestID), PX lockTtlMs
|
|||
|
|
umqLastSuffix = ":last" // STRING (毫秒时间戳), EX 60s
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// Lua 脚本:原子获取串行锁(SET NX PX + 重入安全)
|
|||
|
|
var acquireLockScript = redis.NewScript(`
|
|||
|
|
local cur = redis.call('GET', KEYS[1])
|
|||
|
|
if cur == ARGV[1] then
|
|||
|
|
redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[2]))
|
|||
|
|
return 1
|
|||
|
|
end
|
|||
|
|
if cur ~= false then return 0 end
|
|||
|
|
redis.call('SET', KEYS[1], ARGV[1], 'PX', tonumber(ARGV[2]))
|
|||
|
|
return 1
|
|||
|
|
`)
|
|||
|
|
|
|||
|
|
// Lua 脚本:原子释放锁 + 记录完成时间(使用 Redis TIME 避免时钟偏差)
|
|||
|
|
var releaseLockScript = redis.NewScript(`
|
|||
|
|
local cur = redis.call('GET', KEYS[1])
|
|||
|
|
if cur == ARGV[1] then
|
|||
|
|
redis.call('DEL', KEYS[1])
|
|||
|
|
local t = redis.call('TIME')
|
|||
|
|
local ms = tonumber(t[1])*1000 + math.floor(tonumber(t[2])/1000)
|
|||
|
|
redis.call('SET', KEYS[2], ms, 'EX', 60)
|
|||
|
|
return 1
|
|||
|
|
end
|
|||
|
|
return 0
|
|||
|
|
`)
|
|||
|
|
|
|||
|
|
// Lua 脚本:原子清理孤儿锁(仅在 PTTL == -1 时删除,避免 TOCTOU 竞态误删合法锁)
|
|||
|
|
var forceReleaseLockScript = redis.NewScript(`
|
|||
|
|
local pttl = redis.call('PTTL', KEYS[1])
|
|||
|
|
if pttl == -1 then
|
|||
|
|
redis.call('DEL', KEYS[1])
|
|||
|
|
return 1
|
|||
|
|
end
|
|||
|
|
return 0
|
|||
|
|
`)
|
|||
|
|
|
|||
|
|
type userMsgQueueCache struct {
|
|||
|
|
rdb *redis.Client
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewUserMsgQueueCache 创建用户消息队列缓存
|
|||
|
|
func NewUserMsgQueueCache(rdb *redis.Client) service.UserMsgQueueCache {
|
|||
|
|
return &userMsgQueueCache{rdb: rdb}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func umqLockKey(accountID int64) string {
|
|||
|
|
// 格式: umq:{123}:lock — 花括号确保 Redis Cluster hash tag 生效
|
|||
|
|
return umqKeyPrefix + "{" + strconv.FormatInt(accountID, 10) + "}" + umqLockSuffix
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func umqLastKey(accountID int64) string {
|
|||
|
|
// 格式: umq:{123}:last — 与 lockKey 同一 hash slot
|
|||
|
|
return umqKeyPrefix + "{" + strconv.FormatInt(accountID, 10) + "}" + umqLastSuffix
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// umqScanPattern 用于 SCAN 扫描锁 key
|
|||
|
|
func umqScanPattern() string {
|
|||
|
|
return umqKeyPrefix + "{*}" + umqLockSuffix
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// AcquireLock 尝试获取账号级串行锁
|
|||
|
|
func (c *userMsgQueueCache) AcquireLock(ctx context.Context, accountID int64, requestID string, lockTtlMs int) (bool, error) {
|
|||
|
|
key := umqLockKey(accountID)
|
|||
|
|
result, err := acquireLockScript.Run(ctx, c.rdb, []string{key}, requestID, lockTtlMs).Int()
|
|||
|
|
if err != nil {
|
|||
|
|
return false, fmt.Errorf("umq acquire lock: %w", err)
|
|||
|
|
}
|
|||
|
|
return result == 1, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ReleaseLock 释放锁并记录完成时间
|
|||
|
|
func (c *userMsgQueueCache) ReleaseLock(ctx context.Context, accountID int64, requestID string) (bool, error) {
|
|||
|
|
lockKey := umqLockKey(accountID)
|
|||
|
|
lastKey := umqLastKey(accountID)
|
|||
|
|
result, err := releaseLockScript.Run(ctx, c.rdb, []string{lockKey, lastKey}, requestID).Int()
|
|||
|
|
if err != nil {
|
|||
|
|
return false, fmt.Errorf("umq release lock: %w", err)
|
|||
|
|
}
|
|||
|
|
return result == 1, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetLastCompletedMs 获取上次完成时间(毫秒时间戳)
|
|||
|
|
func (c *userMsgQueueCache) GetLastCompletedMs(ctx context.Context, accountID int64) (int64, error) {
|
|||
|
|
key := umqLastKey(accountID)
|
|||
|
|
val, err := c.rdb.Get(ctx, key).Result()
|
|||
|
|
if errors.Is(err, redis.Nil) {
|
|||
|
|
return 0, nil
|
|||
|
|
}
|
|||
|
|
if err != nil {
|
|||
|
|
return 0, fmt.Errorf("umq get last completed: %w", err)
|
|||
|
|
}
|
|||
|
|
ms, err := strconv.ParseInt(val, 10, 64)
|
|||
|
|
if err != nil {
|
|||
|
|
return 0, fmt.Errorf("umq parse last completed: %w", err)
|
|||
|
|
}
|
|||
|
|
return ms, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ForceReleaseLock 原子清理孤儿锁(仅在 PTTL == -1 时删除,防止 TOCTOU 竞态误删合法锁)
|
|||
|
|
func (c *userMsgQueueCache) ForceReleaseLock(ctx context.Context, accountID int64) error {
|
|||
|
|
key := umqLockKey(accountID)
|
|||
|
|
_, err := forceReleaseLockScript.Run(ctx, c.rdb, []string{key}).Result()
|
|||
|
|
if err != nil && !errors.Is(err, redis.Nil) {
|
|||
|
|
return fmt.Errorf("umq force release lock: %w", err)
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ScanLockKeys 扫描所有锁 key,仅返回 PTTL == -1(无过期时间)的孤儿锁 accountID 列表
|
|||
|
|
// 正常的锁都有 PX 过期时间,PTTL == -1 表示异常状态(如 Redis 故障恢复后丢失 TTL)
|
|||
|
|
func (c *userMsgQueueCache) ScanLockKeys(ctx context.Context, maxCount int) ([]int64, error) {
|
|||
|
|
var accountIDs []int64
|
|||
|
|
var cursor uint64
|
|||
|
|
pattern := umqScanPattern()
|
|||
|
|
|
|||
|
|
for {
|
|||
|
|
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, 100).Result()
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("umq scan lock keys: %w", err)
|
|||
|
|
}
|
|||
|
|
for _, key := range keys {
|
|||
|
|
// 检查 PTTL:只清理 PTTL == -1(无过期时间)的异常锁
|
|||
|
|
pttl, err := c.rdb.PTTL(ctx, key).Result()
|
|||
|
|
if err != nil {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
// PTTL 返回值:-2 = key 不存在,-1 = 无过期时间,>0 = 剩余毫秒
|
|||
|
|
// go-redis 对哨兵值 -1/-2 不乘精度系数,直接返回 time.Duration(-1)/-2
|
|||
|
|
// 只删除 -1(无过期时间的异常锁),跳过正常持有的锁
|
|||
|
|
if pttl != time.Duration(-1) {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 从 key 中提取 accountID: umq:{123}:lock → 提取 {} 内的数字
|
|||
|
|
openBrace := strings.IndexByte(key, '{')
|
|||
|
|
closeBrace := strings.IndexByte(key, '}')
|
|||
|
|
if openBrace < 0 || closeBrace <= openBrace+1 {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
idStr := key[openBrace+1 : closeBrace]
|
|||
|
|
id, err := strconv.ParseInt(idStr, 10, 64)
|
|||
|
|
if err != nil {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
accountIDs = append(accountIDs, id)
|
|||
|
|
if len(accountIDs) >= maxCount {
|
|||
|
|
return accountIDs, nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
cursor = nextCursor
|
|||
|
|
if cursor == 0 {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return accountIDs, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetCurrentTimeMs 通过 Redis TIME 命令获取当前服务器时间(毫秒),确保与锁记录的时间源一致
|
|||
|
|
func (c *userMsgQueueCache) GetCurrentTimeMs(ctx context.Context) (int64, error) {
|
|||
|
|
t, err := c.rdb.Time(ctx).Result()
|
|||
|
|
if err != nil {
|
|||
|
|
return 0, fmt.Errorf("umq get redis time: %w", err)
|
|||
|
|
}
|
|||
|
|
return t.UnixMilli(), nil
|
|||
|
|
}
|