diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index 4fbdae14..6922b4c8 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -20,6 +20,11 @@ const ( billingCacheTTL = 5 * time.Minute billingCacheJitter = 30 * time.Second rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window + + // Rate limit window durations — must match service.RateLimitWindow* constants. + rateLimitWindow5h = 5 * time.Hour + rateLimitWindow1d = 24 * time.Hour + rateLimitWindow7d = 7 * 24 * time.Hour ) // jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩 @@ -90,17 +95,40 @@ var ( return 1 `) - // updateRateLimitUsageScript atomically increments all three rate limit usage counters. - // Returns 0 if the key doesn't exist (cache miss), 1 on success. + // updateRateLimitUsageScript atomically increments all three rate limit usage counters + // with window expiration checking. If a window has expired, its usage is reset to cost + // (instead of accumulated) and the window timestamp is updated, matching the DB-side + // IncrementRateLimitUsage semantics. + // + // ARGV: [1]=cost, [2]=ttl_seconds, [3]=now_unix, [4]=window_5h_seconds, [5]=window_1d_seconds, [6]=window_7d_seconds updateRateLimitUsageScript = redis.NewScript(` local exists = redis.call('EXISTS', KEYS[1]) if exists == 0 then return 0 end local cost = tonumber(ARGV[1]) - redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost) - redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost) - redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost) + local now = tonumber(ARGV[3]) + local win5h = tonumber(ARGV[4]) + local win1d = tonumber(ARGV[5]) + local win7d = tonumber(ARGV[6]) + + -- Helper: check if window is expired and update usage + window accordingly + -- Returns nothing, modifies the hash in-place. + local function update_window(usage_field, window_field, window_duration) + local w = tonumber(redis.call('HGET', KEYS[1], window_field) or 0) + if w == 0 or (now - w) >= window_duration then + -- Window expired or never started: reset usage to cost, start new window + redis.call('HSET', KEYS[1], usage_field, tostring(cost)) + redis.call('HSET', KEYS[1], window_field, tostring(now)) + else + -- Window still valid: accumulate + redis.call('HINCRBYFLOAT', KEYS[1], usage_field, cost) + end + end + + update_window('usage_5h', 'window_5h', win5h) + update_window('usage_1d', 'window_1d', win1d) + update_window('usage_7d', 'window_7d', win7d) redis.call('EXPIRE', KEYS[1], ARGV[2]) return 1 `) @@ -280,7 +308,15 @@ func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error { key := billingRateLimitKey(keyID) - _, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result() + now := time.Now().Unix() + _, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, + cost, + int(rateLimitCacheTTL.Seconds()), + now, + int(rateLimitWindow5h.Seconds()), + int(rateLimitWindow1d.Seconds()), + int(rateLimitWindow7d.Seconds()), + ).Result() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err) return err