mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
Backend: Fix three race conditions in SetSnapshot that caused account scheduling anomalies and broken sticky sessions: - Use Lua CAS script for atomic version activation, preventing version rollback when concurrent goroutines write snapshots simultaneously - Add UnlockBucket to release rebuild lock immediately after completion instead of waiting 30s TTL expiry - Replace immediate DEL of old snapshots with 60s EXPIRE grace period, preventing readers from hitting empty ZRANGE during version switches Frontend: Remove serial queue throttle (1-2s delay per request) from usage loading since backend now uses passive sampling. All usage requests execute immediately in parallel.
503 lines
14 KiB
Go
503 lines
14 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"strconv"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||
"github.com/redis/go-redis/v9"
|
||
)
|
||
|
||
const (
|
||
schedulerBucketSetKey = "sched:buckets"
|
||
schedulerOutboxWatermarkKey = "sched:outbox:watermark"
|
||
schedulerAccountPrefix = "sched:acc:"
|
||
schedulerAccountMetaPrefix = "sched:meta:"
|
||
schedulerActivePrefix = "sched:active:"
|
||
schedulerReadyPrefix = "sched:ready:"
|
||
schedulerVersionPrefix = "sched:ver:"
|
||
schedulerSnapshotPrefix = "sched:"
|
||
schedulerLockPrefix = "sched:lock:"
|
||
|
||
defaultSchedulerSnapshotMGetChunkSize = 128
|
||
defaultSchedulerSnapshotWriteChunkSize = 256
|
||
|
||
// snapshotGraceTTLSeconds 旧快照过期的宽限期(秒)。
|
||
// 替代立即 DEL,让正在读取旧版本的 reader 有足够时间完成 ZRANGE。
|
||
snapshotGraceTTLSeconds = 60
|
||
)
|
||
|
||
var (
|
||
// activateSnapshotScript 原子 CAS 切换快照版本。
|
||
// 仅当新版本号 >= 当前激活版本时才切换,防止并发写入导致版本回滚。
|
||
// 旧快照使用 EXPIRE 设置宽限期而非立即 DEL,避免与 reader 竞态。
|
||
//
|
||
// KEYS[1] = activeKey (sched:active:{bucket})
|
||
// KEYS[2] = readyKey (sched:ready:{bucket})
|
||
// KEYS[3] = bucketSetKey (sched:buckets)
|
||
// KEYS[4] = snapshotKey (新写入的快照 key)
|
||
// ARGV[1] = 新版本号字符串
|
||
// ARGV[2] = bucket 字符串 (用于 SADD)
|
||
// ARGV[3] = 快照 key 前缀 (用于构造旧快照 key)
|
||
// ARGV[4] = 宽限期 TTL 秒数
|
||
//
|
||
// 返回 1 = 已激活, 0 = 版本过旧未激活
|
||
activateSnapshotScript = redis.NewScript(`
|
||
local currentActive = redis.call('GET', KEYS[1])
|
||
local newVersion = tonumber(ARGV[1])
|
||
|
||
if currentActive ~= false then
|
||
local curVersion = tonumber(currentActive)
|
||
if curVersion and newVersion < curVersion then
|
||
redis.call('DEL', KEYS[4])
|
||
return 0
|
||
end
|
||
end
|
||
|
||
redis.call('SET', KEYS[1], ARGV[1])
|
||
redis.call('SET', KEYS[2], '1')
|
||
redis.call('SADD', KEYS[3], ARGV[2])
|
||
|
||
if currentActive ~= false and currentActive ~= ARGV[1] then
|
||
redis.call('EXPIRE', ARGV[3] .. currentActive, tonumber(ARGV[4]))
|
||
end
|
||
|
||
return 1
|
||
`)
|
||
)
|
||
|
||
type schedulerCache struct {
|
||
rdb *redis.Client
|
||
mgetChunkSize int
|
||
writeChunkSize int
|
||
}
|
||
|
||
func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache {
|
||
return newSchedulerCacheWithChunkSizes(rdb, defaultSchedulerSnapshotMGetChunkSize, defaultSchedulerSnapshotWriteChunkSize)
|
||
}
|
||
|
||
func newSchedulerCacheWithChunkSizes(rdb *redis.Client, mgetChunkSize, writeChunkSize int) service.SchedulerCache {
|
||
if mgetChunkSize <= 0 {
|
||
mgetChunkSize = defaultSchedulerSnapshotMGetChunkSize
|
||
}
|
||
if writeChunkSize <= 0 {
|
||
writeChunkSize = defaultSchedulerSnapshotWriteChunkSize
|
||
}
|
||
return &schedulerCache{
|
||
rdb: rdb,
|
||
mgetChunkSize: mgetChunkSize,
|
||
writeChunkSize: writeChunkSize,
|
||
}
|
||
}
|
||
|
||
func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
|
||
readyVal, err := c.rdb.Get(ctx, readyKey).Result()
|
||
if err == redis.Nil {
|
||
return nil, false, nil
|
||
}
|
||
if err != nil {
|
||
return nil, false, err
|
||
}
|
||
if readyVal != "1" {
|
||
return nil, false, nil
|
||
}
|
||
|
||
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||
activeVal, err := c.rdb.Get(ctx, activeKey).Result()
|
||
if err == redis.Nil {
|
||
return nil, false, nil
|
||
}
|
||
if err != nil {
|
||
return nil, false, err
|
||
}
|
||
|
||
snapshotKey := schedulerSnapshotKey(bucket, activeVal)
|
||
ids, err := c.rdb.ZRange(ctx, snapshotKey, 0, -1).Result()
|
||
if err != nil {
|
||
return nil, false, err
|
||
}
|
||
if len(ids) == 0 {
|
||
// 空快照视为缓存未命中,触发数据库回退查询
|
||
// 这解决了新分组创建后立即绑定账号时的竞态条件问题
|
||
return nil, false, nil
|
||
}
|
||
|
||
keys := make([]string, 0, len(ids))
|
||
for _, id := range ids {
|
||
keys = append(keys, schedulerAccountMetaKey(id))
|
||
}
|
||
values, err := c.mgetChunked(ctx, keys)
|
||
if err != nil {
|
||
return nil, false, err
|
||
}
|
||
|
||
accounts := make([]*service.Account, 0, len(values))
|
||
for _, val := range values {
|
||
if val == nil {
|
||
return nil, false, nil
|
||
}
|
||
account, err := decodeCachedAccount(val)
|
||
if err != nil {
|
||
return nil, false, err
|
||
}
|
||
accounts = append(accounts, account)
|
||
}
|
||
|
||
return accounts, true, nil
|
||
}
|
||
|
||
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
||
// Phase 1: 分配新版本号并写入快照数据。
|
||
// INCR 保证每个调用方获得唯一递增版本号。
|
||
// 写入的 snapshotKey 是新的版本化 key,reader 尚不知晓,因此无竞态。
|
||
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
|
||
version, err := c.rdb.Incr(ctx, versionKey).Result()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
versionStr := strconv.FormatInt(version, 10)
|
||
snapshotKey := schedulerSnapshotKey(bucket, versionStr)
|
||
|
||
if err := c.writeAccounts(ctx, accounts); err != nil {
|
||
return err
|
||
}
|
||
|
||
if len(accounts) > 0 {
|
||
// 使用序号作为 score,保持数据库返回的排序语义。
|
||
members := make([]redis.Z, 0, len(accounts))
|
||
for idx, account := range accounts {
|
||
members = append(members, redis.Z{
|
||
Score: float64(idx),
|
||
Member: strconv.FormatInt(account.ID, 10),
|
||
})
|
||
}
|
||
pipe := c.rdb.Pipeline()
|
||
for start := 0; start < len(members); start += c.writeChunkSize {
|
||
end := start + c.writeChunkSize
|
||
if end > len(members) {
|
||
end = len(members)
|
||
}
|
||
pipe.ZAdd(ctx, snapshotKey, members[start:end]...)
|
||
}
|
||
if _, err := pipe.Exec(ctx); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// Phase 2: 原子 CAS 激活版本。
|
||
// Lua 脚本保证:仅当新版本 >= 当前激活版本时才切换 active 指针,
|
||
// 防止并发写入导致版本回滚。
|
||
// 旧快照使用 EXPIRE 宽限期而非立即 DEL,避免 reader 竞态。
|
||
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
|
||
snapshotKeyPrefix := fmt.Sprintf("%s%d:%s:%s:v", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode)
|
||
|
||
keys := []string{activeKey, readyKey, schedulerBucketSetKey, snapshotKey}
|
||
args := []any{versionStr, bucket.String(), snapshotKeyPrefix, snapshotGraceTTLSeconds}
|
||
|
||
_, err = activateSnapshotScript.Run(ctx, c.rdb, keys, args...).Result()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (c *schedulerCache) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
|
||
val, err := c.rdb.Get(ctx, key).Result()
|
||
if err == redis.Nil {
|
||
return nil, nil
|
||
}
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return decodeCachedAccount(val)
|
||
}
|
||
|
||
func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Account) error {
|
||
if account == nil || account.ID <= 0 {
|
||
return nil
|
||
}
|
||
return c.writeAccounts(ctx, []service.Account{*account})
|
||
}
|
||
|
||
func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error {
|
||
if accountID <= 0 {
|
||
return nil
|
||
}
|
||
id := strconv.FormatInt(accountID, 10)
|
||
return c.rdb.Del(ctx, schedulerAccountKey(id), schedulerAccountMetaKey(id)).Err()
|
||
}
|
||
|
||
func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||
if len(updates) == 0 {
|
||
return nil
|
||
}
|
||
|
||
keys := make([]string, 0, len(updates))
|
||
ids := make([]int64, 0, len(updates))
|
||
for id := range updates {
|
||
keys = append(keys, schedulerAccountKey(strconv.FormatInt(id, 10)))
|
||
ids = append(ids, id)
|
||
}
|
||
|
||
values, err := c.mgetChunked(ctx, keys)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
pipe := c.rdb.Pipeline()
|
||
for i, val := range values {
|
||
if val == nil {
|
||
continue
|
||
}
|
||
account, err := decodeCachedAccount(val)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
account.LastUsedAt = ptrTime(updates[ids[i]])
|
||
updated, err := json.Marshal(account)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(*account))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
pipe.Set(ctx, keys[i], updated, 0)
|
||
pipe.Set(ctx, schedulerAccountMetaKey(strconv.FormatInt(ids[i], 10)), metaPayload, 0)
|
||
}
|
||
_, err = pipe.Exec(ctx)
|
||
return err
|
||
}
|
||
|
||
func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
|
||
key := schedulerBucketKey(schedulerLockPrefix, bucket)
|
||
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
|
||
}
|
||
|
||
func (c *schedulerCache) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
|
||
key := schedulerBucketKey(schedulerLockPrefix, bucket)
|
||
return c.rdb.Del(ctx, key).Err()
|
||
}
|
||
|
||
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
out := make([]service.SchedulerBucket, 0, len(raw))
|
||
for _, entry := range raw {
|
||
bucket, ok := service.ParseSchedulerBucket(entry)
|
||
if !ok {
|
||
continue
|
||
}
|
||
out = append(out, bucket)
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func (c *schedulerCache) GetOutboxWatermark(ctx context.Context) (int64, error) {
|
||
val, err := c.rdb.Get(ctx, schedulerOutboxWatermarkKey).Result()
|
||
if err == redis.Nil {
|
||
return 0, nil
|
||
}
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
id, err := strconv.ParseInt(val, 10, 64)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return id, nil
|
||
}
|
||
|
||
func (c *schedulerCache) SetOutboxWatermark(ctx context.Context, id int64) error {
|
||
return c.rdb.Set(ctx, schedulerOutboxWatermarkKey, strconv.FormatInt(id, 10), 0).Err()
|
||
}
|
||
|
||
func schedulerBucketKey(prefix string, bucket service.SchedulerBucket) string {
|
||
return fmt.Sprintf("%s%d:%s:%s", prefix, bucket.GroupID, bucket.Platform, bucket.Mode)
|
||
}
|
||
|
||
func schedulerSnapshotKey(bucket service.SchedulerBucket, version string) string {
|
||
return fmt.Sprintf("%s%d:%s:%s:v%s", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode, version)
|
||
}
|
||
|
||
func schedulerAccountKey(id string) string {
|
||
return schedulerAccountPrefix + id
|
||
}
|
||
|
||
func schedulerAccountMetaKey(id string) string {
|
||
return schedulerAccountMetaPrefix + id
|
||
}
|
||
|
||
func ptrTime(t time.Time) *time.Time {
|
||
return &t
|
||
}
|
||
|
||
func decodeCachedAccount(val any) (*service.Account, error) {
|
||
var payload []byte
|
||
switch raw := val.(type) {
|
||
case string:
|
||
payload = []byte(raw)
|
||
case []byte:
|
||
payload = raw
|
||
default:
|
||
return nil, fmt.Errorf("unexpected account cache type: %T", val)
|
||
}
|
||
var account service.Account
|
||
if err := json.Unmarshal(payload, &account); err != nil {
|
||
return nil, err
|
||
}
|
||
return &account, nil
|
||
}
|
||
|
||
func (c *schedulerCache) writeAccounts(ctx context.Context, accounts []service.Account) error {
|
||
if len(accounts) == 0 {
|
||
return nil
|
||
}
|
||
|
||
pipe := c.rdb.Pipeline()
|
||
pending := 0
|
||
flush := func() error {
|
||
if pending == 0 {
|
||
return nil
|
||
}
|
||
if _, err := pipe.Exec(ctx); err != nil {
|
||
return err
|
||
}
|
||
pipe = c.rdb.Pipeline()
|
||
pending = 0
|
||
return nil
|
||
}
|
||
|
||
for _, account := range accounts {
|
||
fullPayload, err := json.Marshal(account)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(account))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
id := strconv.FormatInt(account.ID, 10)
|
||
pipe.Set(ctx, schedulerAccountKey(id), fullPayload, 0)
|
||
pipe.Set(ctx, schedulerAccountMetaKey(id), metaPayload, 0)
|
||
pending++
|
||
if pending >= c.writeChunkSize {
|
||
if err := flush(); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
}
|
||
|
||
return flush()
|
||
}
|
||
|
||
func (c *schedulerCache) mgetChunked(ctx context.Context, keys []string) ([]any, error) {
|
||
if len(keys) == 0 {
|
||
return []any{}, nil
|
||
}
|
||
|
||
out := make([]any, 0, len(keys))
|
||
chunkSize := c.mgetChunkSize
|
||
if chunkSize <= 0 {
|
||
chunkSize = defaultSchedulerSnapshotMGetChunkSize
|
||
}
|
||
for start := 0; start < len(keys); start += chunkSize {
|
||
end := start + chunkSize
|
||
if end > len(keys) {
|
||
end = len(keys)
|
||
}
|
||
part, err := c.rdb.MGet(ctx, keys[start:end]...).Result()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
out = append(out, part...)
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func buildSchedulerMetadataAccount(account service.Account) service.Account {
|
||
return service.Account{
|
||
ID: account.ID,
|
||
Name: account.Name,
|
||
Platform: account.Platform,
|
||
Type: account.Type,
|
||
Concurrency: account.Concurrency,
|
||
LoadFactor: account.LoadFactor,
|
||
Priority: account.Priority,
|
||
RateMultiplier: account.RateMultiplier,
|
||
Status: account.Status,
|
||
LastUsedAt: account.LastUsedAt,
|
||
ExpiresAt: account.ExpiresAt,
|
||
AutoPauseOnExpired: account.AutoPauseOnExpired,
|
||
Schedulable: account.Schedulable,
|
||
RateLimitedAt: account.RateLimitedAt,
|
||
RateLimitResetAt: account.RateLimitResetAt,
|
||
OverloadUntil: account.OverloadUntil,
|
||
TempUnschedulableUntil: account.TempUnschedulableUntil,
|
||
TempUnschedulableReason: account.TempUnschedulableReason,
|
||
SessionWindowStart: account.SessionWindowStart,
|
||
SessionWindowEnd: account.SessionWindowEnd,
|
||
SessionWindowStatus: account.SessionWindowStatus,
|
||
Credentials: filterSchedulerCredentials(account.Credentials),
|
||
Extra: filterSchedulerExtra(account.Extra),
|
||
}
|
||
}
|
||
|
||
func filterSchedulerCredentials(credentials map[string]any) map[string]any {
|
||
if len(credentials) == 0 {
|
||
return nil
|
||
}
|
||
keys := []string{"model_mapping", "api_key", "project_id", "oauth_type"}
|
||
filtered := make(map[string]any)
|
||
for _, key := range keys {
|
||
if value, ok := credentials[key]; ok && value != nil {
|
||
filtered[key] = value
|
||
}
|
||
}
|
||
if len(filtered) == 0 {
|
||
return nil
|
||
}
|
||
return filtered
|
||
}
|
||
|
||
func filterSchedulerExtra(extra map[string]any) map[string]any {
|
||
if len(extra) == 0 {
|
||
return nil
|
||
}
|
||
keys := []string{
|
||
"mixed_scheduling",
|
||
"window_cost_limit",
|
||
"window_cost_sticky_reserve",
|
||
"max_sessions",
|
||
"session_idle_timeout_minutes",
|
||
"openai_oauth_responses_websockets_v2_enabled",
|
||
"openai_oauth_responses_websockets_v2_mode",
|
||
"openai_apikey_responses_websockets_v2_enabled",
|
||
"openai_apikey_responses_websockets_v2_mode",
|
||
"responses_websockets_v2_enabled",
|
||
"openai_ws_enabled",
|
||
"openai_ws_force_http",
|
||
}
|
||
filtered := make(map[string]any)
|
||
for _, key := range keys {
|
||
if value, ok := extra[key]; ok && value != nil {
|
||
filtered[key] = value
|
||
}
|
||
}
|
||
if len(filtered) == 0 {
|
||
return nil
|
||
}
|
||
return filtered
|
||
}
|