2025-12-19 23:39:28 +08:00
|
|
|
|
package repository
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
2026-01-08 23:07:00 +08:00
|
|
|
|
"fmt"
|
2025-12-19 23:39:28 +08:00
|
|
|
|
"time"
|
|
|
|
|
|
|
2025-12-25 17:15:01 +08:00
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
2025-12-19 23:39:28 +08:00
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
const stickySessionPrefix = "sticky_session:"
|
|
|
|
|
|
|
2026-02-07 12:31:10 +08:00
|
|
|
|
// Gemini Trie Lua 脚本
|
|
|
|
|
|
const (
|
|
|
|
|
|
// geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本
|
|
|
|
|
|
// KEYS[1] = trie key
|
|
|
|
|
|
// ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d")
|
|
|
|
|
|
// ARGV[2] = TTL seconds (用于刷新)
|
|
|
|
|
|
// 返回: 最长匹配的 value (uuid:accountID) 或 nil
|
|
|
|
|
|
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
|
|
|
|
|
|
geminiTrieFindScript = `
|
|
|
|
|
|
local chain = ARGV[1]
|
|
|
|
|
|
local ttl = tonumber(ARGV[2])
|
|
|
|
|
|
local lastMatch = nil
|
|
|
|
|
|
local path = ""
|
|
|
|
|
|
|
|
|
|
|
|
for part in string.gmatch(chain, "[^-]+") do
|
|
|
|
|
|
path = path == "" and part or path .. "-" .. part
|
|
|
|
|
|
local val = redis.call('HGET', KEYS[1], path)
|
|
|
|
|
|
if val and val ~= "" then
|
|
|
|
|
|
lastMatch = val
|
|
|
|
|
|
end
|
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
if lastMatch then
|
|
|
|
|
|
redis.call('EXPIRE', KEYS[1], ttl)
|
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
return lastMatch
|
|
|
|
|
|
`
|
|
|
|
|
|
|
|
|
|
|
|
// geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本
|
|
|
|
|
|
// KEYS[1] = trie key
|
|
|
|
|
|
// ARGV[1] = digestChain
|
|
|
|
|
|
// ARGV[2] = value (uuid:accountID)
|
|
|
|
|
|
// ARGV[3] = TTL seconds
|
|
|
|
|
|
geminiTrieSaveScript = `
|
|
|
|
|
|
local chain = ARGV[1]
|
|
|
|
|
|
local value = ARGV[2]
|
|
|
|
|
|
local ttl = tonumber(ARGV[3])
|
|
|
|
|
|
local path = ""
|
|
|
|
|
|
|
|
|
|
|
|
for part in string.gmatch(chain, "[^-]+") do
|
|
|
|
|
|
path = path == "" and part or path .. "-" .. part
|
|
|
|
|
|
end
|
|
|
|
|
|
redis.call('HSET', KEYS[1], path, value)
|
|
|
|
|
|
redis.call('EXPIRE', KEYS[1], ttl)
|
|
|
|
|
|
return "OK"
|
|
|
|
|
|
`
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// 模型负载统计相关常量
|
|
|
|
|
|
const (
|
|
|
|
|
|
modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀
|
|
|
|
|
|
modelLastUsedKeyPrefix = "ag:model_last_used:" // 模型最后调度时间 key 前缀
|
|
|
|
|
|
modelLoadTTL = 24 * time.Hour // 调用次数 TTL(24 小时无调用后清零)
|
|
|
|
|
|
modelLastUsedTTL = 24 * time.Hour // 最后调度时间 TTL
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-19 23:39:28 +08:00
|
|
|
|
type gatewayCache struct {
|
|
|
|
|
|
rdb *redis.Client
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-25 17:15:01 +08:00
|
|
|
|
func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
|
2025-12-19 23:39:28 +08:00
|
|
|
|
return &gatewayCache{rdb: rdb}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-08 23:07:00 +08:00
|
|
|
|
// buildSessionKey 构建 session key,包含 groupID 实现分组隔离
|
|
|
|
|
|
// 格式: sticky_session:{groupID}:{sessionHash}
|
|
|
|
|
|
func buildSessionKey(groupID int64, sessionHash string) string {
|
|
|
|
|
|
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
|
|
|
|
|
key := buildSessionKey(groupID, sessionHash)
|
2025-12-19 23:39:28 +08:00
|
|
|
|
return c.rdb.Get(ctx, key).Int64()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-08 23:07:00 +08:00
|
|
|
|
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
|
|
|
|
|
key := buildSessionKey(groupID, sessionHash)
|
2025-12-19 23:39:28 +08:00
|
|
|
|
return c.rdb.Set(ctx, key, accountID, ttl).Err()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-08 23:07:00 +08:00
|
|
|
|
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
|
|
|
|
|
key := buildSessionKey(groupID, sessionHash)
|
2025-12-19 23:39:28 +08:00
|
|
|
|
return c.rdb.Expire(ctx, key, ttl).Err()
|
|
|
|
|
|
}
|
2026-01-20 11:19:32 +08:00
|
|
|
|
|
|
|
|
|
|
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
|
|
|
|
|
|
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
|
|
|
|
|
|
// 以便下次请求能够重新选择可用账号。
|
|
|
|
|
|
//
|
|
|
|
|
|
// DeleteSessionAccountID removes the sticky session binding for the given session.
|
|
|
|
|
|
// Called when the bound account becomes unavailable (e.g., error status, disabled,
|
|
|
|
|
|
// or unschedulable), allowing subsequent requests to select a new available account.
|
|
|
|
|
|
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
|
|
|
|
|
key := buildSessionKey(groupID, sessionHash)
|
|
|
|
|
|
return c.rdb.Del(ctx, key).Err()
|
|
|
|
|
|
}
|
2026-02-07 12:31:10 +08:00
|
|
|
|
|
|
|
|
|
|
// ============ Antigravity 模型负载统计方法 ============
|
|
|
|
|
|
|
|
|
|
|
|
// modelLoadKey 构建模型调用次数 key
|
|
|
|
|
|
// 格式: ag:model_load:{accountID}:{model}
|
|
|
|
|
|
func modelLoadKey(accountID int64, model string) string {
|
|
|
|
|
|
return fmt.Sprintf("%s%d:%s", modelLoadKeyPrefix, accountID, model)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// modelLastUsedKey 构建模型最后调度时间 key
|
|
|
|
|
|
// 格式: ag:model_last_used:{accountID}:{model}
|
|
|
|
|
|
func modelLastUsedKey(accountID int64, model string) string {
|
|
|
|
|
|
return fmt.Sprintf("%s%d:%s", modelLastUsedKeyPrefix, accountID, model)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// IncrModelCallCount 增加模型调用次数并更新最后调度时间
|
|
|
|
|
|
// 返回更新后的调用次数
|
|
|
|
|
|
func (c *gatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) {
|
|
|
|
|
|
loadKey := modelLoadKey(accountID, model)
|
|
|
|
|
|
lastUsedKey := modelLastUsedKey(accountID, model)
|
|
|
|
|
|
|
|
|
|
|
|
pipe := c.rdb.Pipeline()
|
|
|
|
|
|
incrCmd := pipe.Incr(ctx, loadKey)
|
|
|
|
|
|
pipe.Expire(ctx, loadKey, modelLoadTTL) // 每次调用刷新 TTL
|
|
|
|
|
|
pipe.Set(ctx, lastUsedKey, time.Now().Unix(), modelLastUsedTTL)
|
|
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil {
|
|
|
|
|
|
return 0, err
|
|
|
|
|
|
}
|
|
|
|
|
|
return incrCmd.Val(), nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// GetModelLoadBatch 批量获取账号的模型负载信息
|
|
|
|
|
|
func (c *gatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*service.ModelLoadInfo, error) {
|
|
|
|
|
|
if len(accountIDs) == 0 {
|
|
|
|
|
|
return make(map[int64]*service.ModelLoadInfo), nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
loadCmds, lastUsedCmds := c.pipelineModelLoadGet(ctx, accountIDs, model)
|
|
|
|
|
|
return c.parseModelLoadResults(accountIDs, loadCmds, lastUsedCmds), nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// pipelineModelLoadGet 批量获取模型负载的 Pipeline 操作
|
|
|
|
|
|
func (c *gatewayCache) pipelineModelLoadGet(
|
|
|
|
|
|
ctx context.Context,
|
|
|
|
|
|
accountIDs []int64,
|
|
|
|
|
|
model string,
|
|
|
|
|
|
) (map[int64]*redis.StringCmd, map[int64]*redis.StringCmd) {
|
|
|
|
|
|
pipe := c.rdb.Pipeline()
|
|
|
|
|
|
loadCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
|
|
|
|
|
|
lastUsedCmds := make(map[int64]*redis.StringCmd, len(accountIDs))
|
|
|
|
|
|
|
|
|
|
|
|
for _, id := range accountIDs {
|
|
|
|
|
|
loadCmds[id] = pipe.Get(ctx, modelLoadKey(id, model))
|
|
|
|
|
|
lastUsedCmds[id] = pipe.Get(ctx, modelLastUsedKey(id, model))
|
|
|
|
|
|
}
|
|
|
|
|
|
_, _ = pipe.Exec(ctx) // 忽略错误,key 不存在是正常的
|
|
|
|
|
|
return loadCmds, lastUsedCmds
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// parseModelLoadResults 解析 Pipeline 结果
|
|
|
|
|
|
func (c *gatewayCache) parseModelLoadResults(
|
|
|
|
|
|
accountIDs []int64,
|
|
|
|
|
|
loadCmds map[int64]*redis.StringCmd,
|
|
|
|
|
|
lastUsedCmds map[int64]*redis.StringCmd,
|
|
|
|
|
|
) map[int64]*service.ModelLoadInfo {
|
|
|
|
|
|
result := make(map[int64]*service.ModelLoadInfo, len(accountIDs))
|
|
|
|
|
|
for _, id := range accountIDs {
|
|
|
|
|
|
result[id] = &service.ModelLoadInfo{
|
|
|
|
|
|
CallCount: getInt64OrZero(loadCmds[id]),
|
|
|
|
|
|
LastUsedAt: getTimeOrZero(lastUsedCmds[id]),
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return result
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// getInt64OrZero 从 StringCmd 获取 int64 值,失败返回 0
|
|
|
|
|
|
func getInt64OrZero(cmd *redis.StringCmd) int64 {
|
|
|
|
|
|
val, _ := cmd.Int64()
|
|
|
|
|
|
return val
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// getTimeOrZero 从 StringCmd 获取 time.Time,失败返回零值
|
|
|
|
|
|
func getTimeOrZero(cmd *redis.StringCmd) time.Time {
|
|
|
|
|
|
val, err := cmd.Int64()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return time.Time{}
|
|
|
|
|
|
}
|
|
|
|
|
|
return time.Unix(val, 0)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============
|
|
|
|
|
|
|
|
|
|
|
|
// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询)
|
|
|
|
|
|
// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL
|
|
|
|
|
|
func (c *gatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) {
|
|
|
|
|
|
if digestChain == "" {
|
|
|
|
|
|
return "", 0, false
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
|
|
|
|
|
|
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
|
|
|
|
|
|
|
|
|
|
|
|
// 使用 Lua 脚本在 Redis 端执行 Trie 查找,O(L) 次 HGET,1 次网络往返
|
|
|
|
|
|
// 查找成功时自动刷新 TTL,防止活跃会话意外过期
|
|
|
|
|
|
result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result()
|
|
|
|
|
|
if err != nil || result == nil {
|
|
|
|
|
|
return "", 0, false
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
value, ok := result.(string)
|
|
|
|
|
|
if !ok || value == "" {
|
|
|
|
|
|
return "", 0, false
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
uuid, accountID, ok = service.ParseGeminiSessionValue(value)
|
|
|
|
|
|
return uuid, accountID, ok
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本)
|
|
|
|
|
|
func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error {
|
|
|
|
|
|
if digestChain == "" {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
trieKey := service.BuildGeminiTrieKey(groupID, prefixHash)
|
|
|
|
|
|
value := service.FormatGeminiSessionValue(uuid, accountID)
|
|
|
|
|
|
ttlSeconds := int(service.GeminiSessionTTL().Seconds())
|
|
|
|
|
|
|
|
|
|
|
|
return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err()
|
|
|
|
|
|
}
|