Files
sub2api/backend/internal/repository/scheduler_cache.go
shaw 8bf2a7b88a fix(scheduler): resolve SetSnapshot race conditions and remove usage throttle
Backend: Fix three race conditions in SetSnapshot that caused account
scheduling anomalies and broken sticky sessions:
- Use Lua CAS script for atomic version activation, preventing version
  rollback when concurrent goroutines write snapshots simultaneously
- Add UnlockBucket to release rebuild lock immediately after completion
  instead of waiting 30s TTL expiry
- Replace immediate DEL of old snapshots with 60s EXPIRE grace period,
  preventing readers from hitting empty ZRANGE during version switches

Frontend: Remove serial queue throttle (1-2s delay per request) from
usage loading since backend now uses passive sampling. All usage
requests execute immediately in parallel.
2026-04-29 22:48:39 +08:00

503 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"context"
"encoding/json"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
schedulerBucketSetKey = "sched:buckets"
schedulerOutboxWatermarkKey = "sched:outbox:watermark"
schedulerAccountPrefix = "sched:acc:"
schedulerAccountMetaPrefix = "sched:meta:"
schedulerActivePrefix = "sched:active:"
schedulerReadyPrefix = "sched:ready:"
schedulerVersionPrefix = "sched:ver:"
schedulerSnapshotPrefix = "sched:"
schedulerLockPrefix = "sched:lock:"
defaultSchedulerSnapshotMGetChunkSize = 128
defaultSchedulerSnapshotWriteChunkSize = 256
// snapshotGraceTTLSeconds 旧快照过期的宽限期(秒)。
// 替代立即 DEL让正在读取旧版本的 reader 有足够时间完成 ZRANGE。
snapshotGraceTTLSeconds = 60
)
var (
// activateSnapshotScript 原子 CAS 切换快照版本。
// 仅当新版本号 >= 当前激活版本时才切换,防止并发写入导致版本回滚。
// 旧快照使用 EXPIRE 设置宽限期而非立即 DEL避免与 reader 竞态。
//
// KEYS[1] = activeKey (sched:active:{bucket})
// KEYS[2] = readyKey (sched:ready:{bucket})
// KEYS[3] = bucketSetKey (sched:buckets)
// KEYS[4] = snapshotKey (新写入的快照 key)
// ARGV[1] = 新版本号字符串
// ARGV[2] = bucket 字符串 (用于 SADD)
// ARGV[3] = 快照 key 前缀 (用于构造旧快照 key)
// ARGV[4] = 宽限期 TTL 秒数
//
// 返回 1 = 已激活, 0 = 版本过旧未激活
activateSnapshotScript = redis.NewScript(`
local currentActive = redis.call('GET', KEYS[1])
local newVersion = tonumber(ARGV[1])
if currentActive ~= false then
local curVersion = tonumber(currentActive)
if curVersion and newVersion < curVersion then
redis.call('DEL', KEYS[4])
return 0
end
end
redis.call('SET', KEYS[1], ARGV[1])
redis.call('SET', KEYS[2], '1')
redis.call('SADD', KEYS[3], ARGV[2])
if currentActive ~= false and currentActive ~= ARGV[1] then
redis.call('EXPIRE', ARGV[3] .. currentActive, tonumber(ARGV[4]))
end
return 1
`)
)
type schedulerCache struct {
rdb *redis.Client
mgetChunkSize int
writeChunkSize int
}
func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache {
return newSchedulerCacheWithChunkSizes(rdb, defaultSchedulerSnapshotMGetChunkSize, defaultSchedulerSnapshotWriteChunkSize)
}
func newSchedulerCacheWithChunkSizes(rdb *redis.Client, mgetChunkSize, writeChunkSize int) service.SchedulerCache {
if mgetChunkSize <= 0 {
mgetChunkSize = defaultSchedulerSnapshotMGetChunkSize
}
if writeChunkSize <= 0 {
writeChunkSize = defaultSchedulerSnapshotWriteChunkSize
}
return &schedulerCache{
rdb: rdb,
mgetChunkSize: mgetChunkSize,
writeChunkSize: writeChunkSize,
}
}
func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
readyVal, err := c.rdb.Get(ctx, readyKey).Result()
if err == redis.Nil {
return nil, false, nil
}
if err != nil {
return nil, false, err
}
if readyVal != "1" {
return nil, false, nil
}
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
activeVal, err := c.rdb.Get(ctx, activeKey).Result()
if err == redis.Nil {
return nil, false, nil
}
if err != nil {
return nil, false, err
}
snapshotKey := schedulerSnapshotKey(bucket, activeVal)
ids, err := c.rdb.ZRange(ctx, snapshotKey, 0, -1).Result()
if err != nil {
return nil, false, err
}
if len(ids) == 0 {
// 空快照视为缓存未命中,触发数据库回退查询
// 这解决了新分组创建后立即绑定账号时的竞态条件问题
return nil, false, nil
}
keys := make([]string, 0, len(ids))
for _, id := range ids {
keys = append(keys, schedulerAccountMetaKey(id))
}
values, err := c.mgetChunked(ctx, keys)
if err != nil {
return nil, false, err
}
accounts := make([]*service.Account, 0, len(values))
for _, val := range values {
if val == nil {
return nil, false, nil
}
account, err := decodeCachedAccount(val)
if err != nil {
return nil, false, err
}
accounts = append(accounts, account)
}
return accounts, true, nil
}
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
// Phase 1: 分配新版本号并写入快照数据。
// INCR 保证每个调用方获得唯一递增版本号。
// 写入的 snapshotKey 是新的版本化 keyreader 尚不知晓,因此无竞态。
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
version, err := c.rdb.Incr(ctx, versionKey).Result()
if err != nil {
return err
}
versionStr := strconv.FormatInt(version, 10)
snapshotKey := schedulerSnapshotKey(bucket, versionStr)
if err := c.writeAccounts(ctx, accounts); err != nil {
return err
}
if len(accounts) > 0 {
// 使用序号作为 score保持数据库返回的排序语义。
members := make([]redis.Z, 0, len(accounts))
for idx, account := range accounts {
members = append(members, redis.Z{
Score: float64(idx),
Member: strconv.FormatInt(account.ID, 10),
})
}
pipe := c.rdb.Pipeline()
for start := 0; start < len(members); start += c.writeChunkSize {
end := start + c.writeChunkSize
if end > len(members) {
end = len(members)
}
pipe.ZAdd(ctx, snapshotKey, members[start:end]...)
}
if _, err := pipe.Exec(ctx); err != nil {
return err
}
}
// Phase 2: 原子 CAS 激活版本。
// Lua 脚本保证:仅当新版本 >= 当前激活版本时才切换 active 指针,
// 防止并发写入导致版本回滚。
// 旧快照使用 EXPIRE 宽限期而非立即 DEL避免 reader 竞态。
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
snapshotKeyPrefix := fmt.Sprintf("%s%d:%s:%s:v", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode)
keys := []string{activeKey, readyKey, schedulerBucketSetKey, snapshotKey}
args := []any{versionStr, bucket.String(), snapshotKeyPrefix, snapshotGraceTTLSeconds}
_, err = activateSnapshotScript.Run(ctx, c.rdb, keys, args...).Result()
if err != nil {
return err
}
return nil
}
func (c *schedulerCache) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
val, err := c.rdb.Get(ctx, key).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
return decodeCachedAccount(val)
}
func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Account) error {
if account == nil || account.ID <= 0 {
return nil
}
return c.writeAccounts(ctx, []service.Account{*account})
}
func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error {
if accountID <= 0 {
return nil
}
id := strconv.FormatInt(accountID, 10)
return c.rdb.Del(ctx, schedulerAccountKey(id), schedulerAccountMetaKey(id)).Err()
}
func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
if len(updates) == 0 {
return nil
}
keys := make([]string, 0, len(updates))
ids := make([]int64, 0, len(updates))
for id := range updates {
keys = append(keys, schedulerAccountKey(strconv.FormatInt(id, 10)))
ids = append(ids, id)
}
values, err := c.mgetChunked(ctx, keys)
if err != nil {
return err
}
pipe := c.rdb.Pipeline()
for i, val := range values {
if val == nil {
continue
}
account, err := decodeCachedAccount(val)
if err != nil {
return err
}
account.LastUsedAt = ptrTime(updates[ids[i]])
updated, err := json.Marshal(account)
if err != nil {
return err
}
metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(*account))
if err != nil {
return err
}
pipe.Set(ctx, keys[i], updated, 0)
pipe.Set(ctx, schedulerAccountMetaKey(strconv.FormatInt(ids[i], 10)), metaPayload, 0)
}
_, err = pipe.Exec(ctx)
return err
}
func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
key := schedulerBucketKey(schedulerLockPrefix, bucket)
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
}
func (c *schedulerCache) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
key := schedulerBucketKey(schedulerLockPrefix, bucket)
return c.rdb.Del(ctx, key).Err()
}
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
if err != nil {
return nil, err
}
out := make([]service.SchedulerBucket, 0, len(raw))
for _, entry := range raw {
bucket, ok := service.ParseSchedulerBucket(entry)
if !ok {
continue
}
out = append(out, bucket)
}
return out, nil
}
func (c *schedulerCache) GetOutboxWatermark(ctx context.Context) (int64, error) {
val, err := c.rdb.Get(ctx, schedulerOutboxWatermarkKey).Result()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, err
}
id, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return 0, err
}
return id, nil
}
func (c *schedulerCache) SetOutboxWatermark(ctx context.Context, id int64) error {
return c.rdb.Set(ctx, schedulerOutboxWatermarkKey, strconv.FormatInt(id, 10), 0).Err()
}
func schedulerBucketKey(prefix string, bucket service.SchedulerBucket) string {
return fmt.Sprintf("%s%d:%s:%s", prefix, bucket.GroupID, bucket.Platform, bucket.Mode)
}
func schedulerSnapshotKey(bucket service.SchedulerBucket, version string) string {
return fmt.Sprintf("%s%d:%s:%s:v%s", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode, version)
}
func schedulerAccountKey(id string) string {
return schedulerAccountPrefix + id
}
func schedulerAccountMetaKey(id string) string {
return schedulerAccountMetaPrefix + id
}
func ptrTime(t time.Time) *time.Time {
return &t
}
func decodeCachedAccount(val any) (*service.Account, error) {
var payload []byte
switch raw := val.(type) {
case string:
payload = []byte(raw)
case []byte:
payload = raw
default:
return nil, fmt.Errorf("unexpected account cache type: %T", val)
}
var account service.Account
if err := json.Unmarshal(payload, &account); err != nil {
return nil, err
}
return &account, nil
}
func (c *schedulerCache) writeAccounts(ctx context.Context, accounts []service.Account) error {
if len(accounts) == 0 {
return nil
}
pipe := c.rdb.Pipeline()
pending := 0
flush := func() error {
if pending == 0 {
return nil
}
if _, err := pipe.Exec(ctx); err != nil {
return err
}
pipe = c.rdb.Pipeline()
pending = 0
return nil
}
for _, account := range accounts {
fullPayload, err := json.Marshal(account)
if err != nil {
return err
}
metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(account))
if err != nil {
return err
}
id := strconv.FormatInt(account.ID, 10)
pipe.Set(ctx, schedulerAccountKey(id), fullPayload, 0)
pipe.Set(ctx, schedulerAccountMetaKey(id), metaPayload, 0)
pending++
if pending >= c.writeChunkSize {
if err := flush(); err != nil {
return err
}
}
}
return flush()
}
func (c *schedulerCache) mgetChunked(ctx context.Context, keys []string) ([]any, error) {
if len(keys) == 0 {
return []any{}, nil
}
out := make([]any, 0, len(keys))
chunkSize := c.mgetChunkSize
if chunkSize <= 0 {
chunkSize = defaultSchedulerSnapshotMGetChunkSize
}
for start := 0; start < len(keys); start += chunkSize {
end := start + chunkSize
if end > len(keys) {
end = len(keys)
}
part, err := c.rdb.MGet(ctx, keys[start:end]...).Result()
if err != nil {
return nil, err
}
out = append(out, part...)
}
return out, nil
}
func buildSchedulerMetadataAccount(account service.Account) service.Account {
return service.Account{
ID: account.ID,
Name: account.Name,
Platform: account.Platform,
Type: account.Type,
Concurrency: account.Concurrency,
LoadFactor: account.LoadFactor,
Priority: account.Priority,
RateMultiplier: account.RateMultiplier,
Status: account.Status,
LastUsedAt: account.LastUsedAt,
ExpiresAt: account.ExpiresAt,
AutoPauseOnExpired: account.AutoPauseOnExpired,
Schedulable: account.Schedulable,
RateLimitedAt: account.RateLimitedAt,
RateLimitResetAt: account.RateLimitResetAt,
OverloadUntil: account.OverloadUntil,
TempUnschedulableUntil: account.TempUnschedulableUntil,
TempUnschedulableReason: account.TempUnschedulableReason,
SessionWindowStart: account.SessionWindowStart,
SessionWindowEnd: account.SessionWindowEnd,
SessionWindowStatus: account.SessionWindowStatus,
Credentials: filterSchedulerCredentials(account.Credentials),
Extra: filterSchedulerExtra(account.Extra),
}
}
func filterSchedulerCredentials(credentials map[string]any) map[string]any {
if len(credentials) == 0 {
return nil
}
keys := []string{"model_mapping", "api_key", "project_id", "oauth_type"}
filtered := make(map[string]any)
for _, key := range keys {
if value, ok := credentials[key]; ok && value != nil {
filtered[key] = value
}
}
if len(filtered) == 0 {
return nil
}
return filtered
}
func filterSchedulerExtra(extra map[string]any) map[string]any {
if len(extra) == 0 {
return nil
}
keys := []string{
"mixed_scheduling",
"window_cost_limit",
"window_cost_sticky_reserve",
"max_sessions",
"session_idle_timeout_minutes",
"openai_oauth_responses_websockets_v2_enabled",
"openai_oauth_responses_websockets_v2_mode",
"openai_apikey_responses_websockets_v2_enabled",
"openai_apikey_responses_websockets_v2_mode",
"responses_websockets_v2_enabled",
"openai_ws_enabled",
"openai_ws_force_http",
}
filtered := make(map[string]any)
for _, key := range keys {
if value, ok := extra[key]; ok && value != nil {
filtered[key] = value
}
}
if len(filtered) == 0 {
return nil
}
return filtered
}