mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-03 06:52:13 +08:00
The updateRateLimitUsageScript Lua script previously performed unconditional HINCRBYFLOAT on all usage counters without checking whether the rate limit window had expired. This caused usage to accumulate across window boundaries in Redis while the DB correctly reset on expiration, leading to incorrect 429 rate limiting that could persist for up to 24 hours. The Lua script now checks each window timestamp before incrementing: - If the window has expired, usage is reset to the current cost and the window timestamp is updated (matching DB-side semantics) - If the window is still valid, usage is accumulated normally This also resolves the async race condition where stale HINCRBYFLOAT tasks from the worker queue could pollute a freshly rebuilt cache after invalidation, since the script now self-corrects expired windows. Closes #1049
331 lines
10 KiB
Go
331 lines
10 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"log"
|
||
"math/rand/v2"
|
||
"strconv"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||
"github.com/redis/go-redis/v9"
|
||
)
|
||
|
||
const (
|
||
billingBalanceKeyPrefix = "billing:balance:"
|
||
billingSubKeyPrefix = "billing:sub:"
|
||
billingRateLimitKeyPrefix = "apikey:rate:"
|
||
billingCacheTTL = 5 * time.Minute
|
||
billingCacheJitter = 30 * time.Second
|
||
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
|
||
|
||
// Rate limit window durations — must match service.RateLimitWindow* constants.
|
||
rateLimitWindow5h = 5 * time.Hour
|
||
rateLimitWindow1d = 24 * time.Hour
|
||
rateLimitWindow7d = 7 * 24 * time.Hour
|
||
)
|
||
|
||
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
|
||
func jitteredTTL() time.Duration {
|
||
// 只做“减法抖动”,确保实际 TTL 不会超过 billingCacheTTL(避免上界预期被打破)。
|
||
if billingCacheJitter <= 0 {
|
||
return billingCacheTTL
|
||
}
|
||
jitter := time.Duration(rand.IntN(int(billingCacheJitter)))
|
||
return billingCacheTTL - jitter
|
||
}
|
||
|
||
// billingBalanceKey generates the Redis key for user balance cache.
|
||
func billingBalanceKey(userID int64) string {
|
||
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||
}
|
||
|
||
// billingSubKey generates the Redis key for subscription cache.
|
||
func billingSubKey(userID, groupID int64) string {
|
||
return fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||
}
|
||
|
||
const (
|
||
subFieldStatus = "status"
|
||
subFieldExpiresAt = "expires_at"
|
||
subFieldDailyUsage = "daily_usage"
|
||
subFieldWeeklyUsage = "weekly_usage"
|
||
subFieldMonthlyUsage = "monthly_usage"
|
||
subFieldVersion = "version"
|
||
)
|
||
|
||
// billingRateLimitKey generates the Redis key for API key rate limit cache.
|
||
func billingRateLimitKey(keyID int64) string {
|
||
return fmt.Sprintf("%s%d", billingRateLimitKeyPrefix, keyID)
|
||
}
|
||
|
||
const (
|
||
rateLimitFieldUsage5h = "usage_5h"
|
||
rateLimitFieldUsage1d = "usage_1d"
|
||
rateLimitFieldUsage7d = "usage_7d"
|
||
rateLimitFieldWindow5h = "window_5h"
|
||
rateLimitFieldWindow1d = "window_1d"
|
||
rateLimitFieldWindow7d = "window_7d"
|
||
)
|
||
|
||
var (
|
||
deductBalanceScript = redis.NewScript(`
|
||
local current = redis.call('GET', KEYS[1])
|
||
if current == false then
|
||
return 0
|
||
end
|
||
local newVal = tonumber(current) - tonumber(ARGV[1])
|
||
redis.call('SET', KEYS[1], newVal)
|
||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||
return 1
|
||
`)
|
||
|
||
updateSubUsageScript = redis.NewScript(`
|
||
local exists = redis.call('EXISTS', KEYS[1])
|
||
if exists == 0 then
|
||
return 0
|
||
end
|
||
local cost = tonumber(ARGV[1])
|
||
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
|
||
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
|
||
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
|
||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||
return 1
|
||
`)
|
||
|
||
// updateRateLimitUsageScript atomically increments all three rate limit usage counters
|
||
// with window expiration checking. If a window has expired, its usage is reset to cost
|
||
// (instead of accumulated) and the window timestamp is updated, matching the DB-side
|
||
// IncrementRateLimitUsage semantics.
|
||
//
|
||
// ARGV: [1]=cost, [2]=ttl_seconds, [3]=now_unix, [4]=window_5h_seconds, [5]=window_1d_seconds, [6]=window_7d_seconds
|
||
updateRateLimitUsageScript = redis.NewScript(`
|
||
local exists = redis.call('EXISTS', KEYS[1])
|
||
if exists == 0 then
|
||
return 0
|
||
end
|
||
local cost = tonumber(ARGV[1])
|
||
local now = tonumber(ARGV[3])
|
||
local win5h = tonumber(ARGV[4])
|
||
local win1d = tonumber(ARGV[5])
|
||
local win7d = tonumber(ARGV[6])
|
||
|
||
-- Helper: check if window is expired and update usage + window accordingly
|
||
-- Returns nothing, modifies the hash in-place.
|
||
local function update_window(usage_field, window_field, window_duration)
|
||
local w = tonumber(redis.call('HGET', KEYS[1], window_field) or 0)
|
||
if w == 0 or (now - w) >= window_duration then
|
||
-- Window expired or never started: reset usage to cost, start new window
|
||
redis.call('HSET', KEYS[1], usage_field, tostring(cost))
|
||
redis.call('HSET', KEYS[1], window_field, tostring(now))
|
||
else
|
||
-- Window still valid: accumulate
|
||
redis.call('HINCRBYFLOAT', KEYS[1], usage_field, cost)
|
||
end
|
||
end
|
||
|
||
update_window('usage_5h', 'window_5h', win5h)
|
||
update_window('usage_1d', 'window_1d', win1d)
|
||
update_window('usage_7d', 'window_7d', win7d)
|
||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||
return 1
|
||
`)
|
||
)
|
||
|
||
type billingCache struct {
|
||
rdb *redis.Client
|
||
}
|
||
|
||
func NewBillingCache(rdb *redis.Client) service.BillingCache {
|
||
return &billingCache{rdb: rdb}
|
||
}
|
||
|
||
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||
key := billingBalanceKey(userID)
|
||
val, err := c.rdb.Get(ctx, key).Result()
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return strconv.ParseFloat(val, 64)
|
||
}
|
||
|
||
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||
key := billingBalanceKey(userID)
|
||
return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err()
|
||
}
|
||
|
||
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||
key := billingBalanceKey(userID)
|
||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result()
|
||
if err != nil && !errors.Is(err, redis.Nil) {
|
||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||
key := billingBalanceKey(userID)
|
||
return c.rdb.Del(ctx, key).Err()
|
||
}
|
||
|
||
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
|
||
key := billingSubKey(userID, groupID)
|
||
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if len(result) == 0 {
|
||
return nil, redis.Nil
|
||
}
|
||
return c.parseSubscriptionCache(result)
|
||
}
|
||
|
||
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*service.SubscriptionCacheData, error) {
|
||
result := &service.SubscriptionCacheData{}
|
||
|
||
result.Status = data[subFieldStatus]
|
||
if result.Status == "" {
|
||
return nil, errors.New("invalid cache: missing status")
|
||
}
|
||
|
||
if expiresStr, ok := data[subFieldExpiresAt]; ok {
|
||
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
|
||
if err == nil {
|
||
result.ExpiresAt = time.Unix(expiresAt, 0)
|
||
}
|
||
}
|
||
|
||
if dailyStr, ok := data[subFieldDailyUsage]; ok {
|
||
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
|
||
}
|
||
|
||
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
|
||
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
|
||
}
|
||
|
||
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
|
||
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
|
||
}
|
||
|
||
if versionStr, ok := data[subFieldVersion]; ok {
|
||
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *service.SubscriptionCacheData) error {
|
||
if data == nil {
|
||
return nil
|
||
}
|
||
|
||
key := billingSubKey(userID, groupID)
|
||
|
||
fields := map[string]any{
|
||
subFieldStatus: data.Status,
|
||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||
subFieldDailyUsage: data.DailyUsage,
|
||
subFieldWeeklyUsage: data.WeeklyUsage,
|
||
subFieldMonthlyUsage: data.MonthlyUsage,
|
||
subFieldVersion: data.Version,
|
||
}
|
||
|
||
pipe := c.rdb.Pipeline()
|
||
pipe.HSet(ctx, key, fields)
|
||
pipe.Expire(ctx, key, jitteredTTL())
|
||
_, err := pipe.Exec(ctx)
|
||
return err
|
||
}
|
||
|
||
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||
key := billingSubKey(userID, groupID)
|
||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result()
|
||
if err != nil && !errors.Is(err, redis.Nil) {
|
||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||
key := billingSubKey(userID, groupID)
|
||
return c.rdb.Del(ctx, key).Err()
|
||
}
|
||
|
||
func (c *billingCache) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*service.APIKeyRateLimitCacheData, error) {
|
||
key := billingRateLimitKey(keyID)
|
||
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if len(result) == 0 {
|
||
return nil, redis.Nil
|
||
}
|
||
data := &service.APIKeyRateLimitCacheData{}
|
||
if v, ok := result[rateLimitFieldUsage5h]; ok {
|
||
data.Usage5h, _ = strconv.ParseFloat(v, 64)
|
||
}
|
||
if v, ok := result[rateLimitFieldUsage1d]; ok {
|
||
data.Usage1d, _ = strconv.ParseFloat(v, 64)
|
||
}
|
||
if v, ok := result[rateLimitFieldUsage7d]; ok {
|
||
data.Usage7d, _ = strconv.ParseFloat(v, 64)
|
||
}
|
||
if v, ok := result[rateLimitFieldWindow5h]; ok {
|
||
data.Window5h, _ = strconv.ParseInt(v, 10, 64)
|
||
}
|
||
if v, ok := result[rateLimitFieldWindow1d]; ok {
|
||
data.Window1d, _ = strconv.ParseInt(v, 10, 64)
|
||
}
|
||
if v, ok := result[rateLimitFieldWindow7d]; ok {
|
||
data.Window7d, _ = strconv.ParseInt(v, 10, 64)
|
||
}
|
||
return data, nil
|
||
}
|
||
|
||
func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *service.APIKeyRateLimitCacheData) error {
|
||
if data == nil {
|
||
return nil
|
||
}
|
||
key := billingRateLimitKey(keyID)
|
||
fields := map[string]any{
|
||
rateLimitFieldUsage5h: data.Usage5h,
|
||
rateLimitFieldUsage1d: data.Usage1d,
|
||
rateLimitFieldUsage7d: data.Usage7d,
|
||
rateLimitFieldWindow5h: data.Window5h,
|
||
rateLimitFieldWindow1d: data.Window1d,
|
||
rateLimitFieldWindow7d: data.Window7d,
|
||
}
|
||
pipe := c.rdb.Pipeline()
|
||
pipe.HSet(ctx, key, fields)
|
||
pipe.Expire(ctx, key, rateLimitCacheTTL)
|
||
_, err := pipe.Exec(ctx)
|
||
return err
|
||
}
|
||
|
||
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||
key := billingRateLimitKey(keyID)
|
||
now := time.Now().Unix()
|
||
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key},
|
||
cost,
|
||
int(rateLimitCacheTTL.Seconds()),
|
||
now,
|
||
int(rateLimitWindow5h.Seconds()),
|
||
int(rateLimitWindow1d.Seconds()),
|
||
int(rateLimitWindow7d.Seconds()),
|
||
).Result()
|
||
if err != nil && !errors.Is(err, redis.Nil) {
|
||
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||
key := billingRateLimitKey(keyID)
|
||
return c.rdb.Del(ctx, key).Err()
|
||
}
|