mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-11 02:24:45 +08:00
441 lines
12 KiB
Go
441 lines
12 KiB
Go
|
|
package service
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"crypto/sha256"
|
|||
|
|
"encoding/hex"
|
|||
|
|
"fmt"
|
|||
|
|
"strings"
|
|||
|
|
"sync"
|
|||
|
|
"sync/atomic"
|
|||
|
|
"time"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
const (
|
|||
|
|
openAIWSResponseAccountCachePrefix = "openai:response:"
|
|||
|
|
openAIWSStateStoreCleanupInterval = time.Minute
|
|||
|
|
openAIWSStateStoreCleanupMaxPerMap = 512
|
|||
|
|
openAIWSStateStoreMaxEntriesPerMap = 65536
|
|||
|
|
openAIWSStateStoreRedisTimeout = 3 * time.Second
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
type openAIWSAccountBinding struct {
|
|||
|
|
accountID int64
|
|||
|
|
expiresAt time.Time
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type openAIWSConnBinding struct {
|
|||
|
|
connID string
|
|||
|
|
expiresAt time.Time
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type openAIWSTurnStateBinding struct {
|
|||
|
|
turnState string
|
|||
|
|
expiresAt time.Time
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type openAIWSSessionConnBinding struct {
|
|||
|
|
connID string
|
|||
|
|
expiresAt time.Time
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// OpenAIWSStateStore 管理 WSv2 的粘连状态。
|
|||
|
|
// - response_id -> account_id 用于续链路由
|
|||
|
|
// - response_id -> conn_id 用于连接内上下文复用
|
|||
|
|
//
|
|||
|
|
// response_id -> account_id 优先走 GatewayCache(Redis),同时维护本地热缓存。
|
|||
|
|
// response_id -> conn_id 仅在本进程内有效。
|
|||
|
|
type OpenAIWSStateStore interface {
|
|||
|
|
BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error
|
|||
|
|
GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error)
|
|||
|
|
DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error
|
|||
|
|
|
|||
|
|
BindResponseConn(responseID, connID string, ttl time.Duration)
|
|||
|
|
GetResponseConn(responseID string) (string, bool)
|
|||
|
|
DeleteResponseConn(responseID string)
|
|||
|
|
|
|||
|
|
BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration)
|
|||
|
|
GetSessionTurnState(groupID int64, sessionHash string) (string, bool)
|
|||
|
|
DeleteSessionTurnState(groupID int64, sessionHash string)
|
|||
|
|
|
|||
|
|
BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration)
|
|||
|
|
GetSessionConn(groupID int64, sessionHash string) (string, bool)
|
|||
|
|
DeleteSessionConn(groupID int64, sessionHash string)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type defaultOpenAIWSStateStore struct {
|
|||
|
|
cache GatewayCache
|
|||
|
|
|
|||
|
|
responseToAccountMu sync.RWMutex
|
|||
|
|
responseToAccount map[string]openAIWSAccountBinding
|
|||
|
|
responseToConnMu sync.RWMutex
|
|||
|
|
responseToConn map[string]openAIWSConnBinding
|
|||
|
|
sessionToTurnStateMu sync.RWMutex
|
|||
|
|
sessionToTurnState map[string]openAIWSTurnStateBinding
|
|||
|
|
sessionToConnMu sync.RWMutex
|
|||
|
|
sessionToConn map[string]openAIWSSessionConnBinding
|
|||
|
|
|
|||
|
|
lastCleanupUnixNano atomic.Int64
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewOpenAIWSStateStore 创建默认 WS 状态存储。
|
|||
|
|
func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore {
|
|||
|
|
store := &defaultOpenAIWSStateStore{
|
|||
|
|
cache: cache,
|
|||
|
|
responseToAccount: make(map[string]openAIWSAccountBinding, 256),
|
|||
|
|
responseToConn: make(map[string]openAIWSConnBinding, 256),
|
|||
|
|
sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256),
|
|||
|
|
sessionToConn: make(map[string]openAIWSSessionConnBinding, 256),
|
|||
|
|
}
|
|||
|
|
store.lastCleanupUnixNano.Store(time.Now().UnixNano())
|
|||
|
|
return store
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error {
|
|||
|
|
id := normalizeOpenAIWSResponseID(responseID)
|
|||
|
|
if id == "" || accountID <= 0 {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
ttl = normalizeOpenAIWSTTL(ttl)
|
|||
|
|
s.maybeCleanup()
|
|||
|
|
|
|||
|
|
expiresAt := time.Now().Add(ttl)
|
|||
|
|
s.responseToAccountMu.Lock()
|
|||
|
|
ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap)
|
|||
|
|
s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt}
|
|||
|
|
s.responseToAccountMu.Unlock()
|
|||
|
|
|
|||
|
|
if s.cache == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
cacheKey := openAIWSResponseAccountCacheKey(id)
|
|||
|
|
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
|
|||
|
|
defer cancel()
|
|||
|
|
return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) {
|
|||
|
|
id := normalizeOpenAIWSResponseID(responseID)
|
|||
|
|
if id == "" {
|
|||
|
|
return 0, nil
|
|||
|
|
}
|
|||
|
|
s.maybeCleanup()
|
|||
|
|
|
|||
|
|
now := time.Now()
|
|||
|
|
s.responseToAccountMu.RLock()
|
|||
|
|
if binding, ok := s.responseToAccount[id]; ok {
|
|||
|
|
if now.Before(binding.expiresAt) {
|
|||
|
|
accountID := binding.accountID
|
|||
|
|
s.responseToAccountMu.RUnlock()
|
|||
|
|
return accountID, nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
s.responseToAccountMu.RUnlock()
|
|||
|
|
|
|||
|
|
if s.cache == nil {
|
|||
|
|
return 0, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
cacheKey := openAIWSResponseAccountCacheKey(id)
|
|||
|
|
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
|
|||
|
|
defer cancel()
|
|||
|
|
accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey)
|
|||
|
|
if err != nil || accountID <= 0 {
|
|||
|
|
// 缓存读取失败不阻断主流程,按未命中降级。
|
|||
|
|
return 0, nil
|
|||
|
|
}
|
|||
|
|
return accountID, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error {
|
|||
|
|
id := normalizeOpenAIWSResponseID(responseID)
|
|||
|
|
if id == "" {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
s.responseToAccountMu.Lock()
|
|||
|
|
delete(s.responseToAccount, id)
|
|||
|
|
s.responseToAccountMu.Unlock()
|
|||
|
|
|
|||
|
|
if s.cache == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
|
|||
|
|
defer cancel()
|
|||
|
|
return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) {
|
|||
|
|
id := normalizeOpenAIWSResponseID(responseID)
|
|||
|
|
conn := strings.TrimSpace(connID)
|
|||
|
|
if id == "" || conn == "" {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
ttl = normalizeOpenAIWSTTL(ttl)
|
|||
|
|
s.maybeCleanup()
|
|||
|
|
|
|||
|
|
s.responseToConnMu.Lock()
|
|||
|
|
ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap)
|
|||
|
|
s.responseToConn[id] = openAIWSConnBinding{
|
|||
|
|
connID: conn,
|
|||
|
|
expiresAt: time.Now().Add(ttl),
|
|||
|
|
}
|
|||
|
|
s.responseToConnMu.Unlock()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) {
|
|||
|
|
id := normalizeOpenAIWSResponseID(responseID)
|
|||
|
|
if id == "" {
|
|||
|
|
return "", false
|
|||
|
|
}
|
|||
|
|
s.maybeCleanup()
|
|||
|
|
|
|||
|
|
now := time.Now()
|
|||
|
|
s.responseToConnMu.RLock()
|
|||
|
|
binding, ok := s.responseToConn[id]
|
|||
|
|
s.responseToConnMu.RUnlock()
|
|||
|
|
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
|
|||
|
|
return "", false
|
|||
|
|
}
|
|||
|
|
return binding.connID, true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) {
|
|||
|
|
id := normalizeOpenAIWSResponseID(responseID)
|
|||
|
|
if id == "" {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
s.responseToConnMu.Lock()
|
|||
|
|
delete(s.responseToConn, id)
|
|||
|
|
s.responseToConnMu.Unlock()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) {
|
|||
|
|
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
|||
|
|
state := strings.TrimSpace(turnState)
|
|||
|
|
if key == "" || state == "" {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
ttl = normalizeOpenAIWSTTL(ttl)
|
|||
|
|
s.maybeCleanup()
|
|||
|
|
|
|||
|
|
s.sessionToTurnStateMu.Lock()
|
|||
|
|
ensureBindingCapacity(s.sessionToTurnState, key, openAIWSStateStoreMaxEntriesPerMap)
|
|||
|
|
s.sessionToTurnState[key] = openAIWSTurnStateBinding{
|
|||
|
|
turnState: state,
|
|||
|
|
expiresAt: time.Now().Add(ttl),
|
|||
|
|
}
|
|||
|
|
s.sessionToTurnStateMu.Unlock()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) {
|
|||
|
|
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
|||
|
|
if key == "" {
|
|||
|
|
return "", false
|
|||
|
|
}
|
|||
|
|
s.maybeCleanup()
|
|||
|
|
|
|||
|
|
now := time.Now()
|
|||
|
|
s.sessionToTurnStateMu.RLock()
|
|||
|
|
binding, ok := s.sessionToTurnState[key]
|
|||
|
|
s.sessionToTurnStateMu.RUnlock()
|
|||
|
|
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.turnState) == "" {
|
|||
|
|
return "", false
|
|||
|
|
}
|
|||
|
|
return binding.turnState, true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessionHash string) {
|
|||
|
|
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
|||
|
|
if key == "" {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
s.sessionToTurnStateMu.Lock()
|
|||
|
|
delete(s.sessionToTurnState, key)
|
|||
|
|
s.sessionToTurnStateMu.Unlock()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) {
|
|||
|
|
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
|||
|
|
conn := strings.TrimSpace(connID)
|
|||
|
|
if key == "" || conn == "" {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
ttl = normalizeOpenAIWSTTL(ttl)
|
|||
|
|
s.maybeCleanup()
|
|||
|
|
|
|||
|
|
s.sessionToConnMu.Lock()
|
|||
|
|
ensureBindingCapacity(s.sessionToConn, key, openAIWSStateStoreMaxEntriesPerMap)
|
|||
|
|
s.sessionToConn[key] = openAIWSSessionConnBinding{
|
|||
|
|
connID: conn,
|
|||
|
|
expiresAt: time.Now().Add(ttl),
|
|||
|
|
}
|
|||
|
|
s.sessionToConnMu.Unlock()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash string) (string, bool) {
|
|||
|
|
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
|||
|
|
if key == "" {
|
|||
|
|
return "", false
|
|||
|
|
}
|
|||
|
|
s.maybeCleanup()
|
|||
|
|
|
|||
|
|
now := time.Now()
|
|||
|
|
s.sessionToConnMu.RLock()
|
|||
|
|
binding, ok := s.sessionToConn[key]
|
|||
|
|
s.sessionToConnMu.RUnlock()
|
|||
|
|
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
|
|||
|
|
return "", false
|
|||
|
|
}
|
|||
|
|
return binding.connID, true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *defaultOpenAIWSStateStore) DeleteSessionConn(groupID int64, sessionHash string) {
|
|||
|
|
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
|||
|
|
if key == "" {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
s.sessionToConnMu.Lock()
|
|||
|
|
delete(s.sessionToConn, key)
|
|||
|
|
s.sessionToConnMu.Unlock()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *defaultOpenAIWSStateStore) maybeCleanup() {
|
|||
|
|
if s == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
now := time.Now()
|
|||
|
|
last := time.Unix(0, s.lastCleanupUnixNano.Load())
|
|||
|
|
if now.Sub(last) < openAIWSStateStoreCleanupInterval {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if !s.lastCleanupUnixNano.CompareAndSwap(last.UnixNano(), now.UnixNano()) {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 增量限额清理,避免高规模下一次性全量扫描导致长时间阻塞。
|
|||
|
|
s.responseToAccountMu.Lock()
|
|||
|
|
cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap)
|
|||
|
|
s.responseToAccountMu.Unlock()
|
|||
|
|
|
|||
|
|
s.responseToConnMu.Lock()
|
|||
|
|
cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap)
|
|||
|
|
s.responseToConnMu.Unlock()
|
|||
|
|
|
|||
|
|
s.sessionToTurnStateMu.Lock()
|
|||
|
|
cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap)
|
|||
|
|
s.sessionToTurnStateMu.Unlock()
|
|||
|
|
|
|||
|
|
s.sessionToConnMu.Lock()
|
|||
|
|
cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap)
|
|||
|
|
s.sessionToConnMu.Unlock()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func cleanupExpiredAccountBindings(bindings map[string]openAIWSAccountBinding, now time.Time, maxScan int) {
|
|||
|
|
if len(bindings) == 0 || maxScan <= 0 {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
scanned := 0
|
|||
|
|
for key, binding := range bindings {
|
|||
|
|
if now.After(binding.expiresAt) {
|
|||
|
|
delete(bindings, key)
|
|||
|
|
}
|
|||
|
|
scanned++
|
|||
|
|
if scanned >= maxScan {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now time.Time, maxScan int) {
|
|||
|
|
if len(bindings) == 0 || maxScan <= 0 {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
scanned := 0
|
|||
|
|
for key, binding := range bindings {
|
|||
|
|
if now.After(binding.expiresAt) {
|
|||
|
|
delete(bindings, key)
|
|||
|
|
}
|
|||
|
|
scanned++
|
|||
|
|
if scanned >= maxScan {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) {
|
|||
|
|
if len(bindings) == 0 || maxScan <= 0 {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
scanned := 0
|
|||
|
|
for key, binding := range bindings {
|
|||
|
|
if now.After(binding.expiresAt) {
|
|||
|
|
delete(bindings, key)
|
|||
|
|
}
|
|||
|
|
scanned++
|
|||
|
|
if scanned >= maxScan {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) {
|
|||
|
|
if len(bindings) == 0 || maxScan <= 0 {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
scanned := 0
|
|||
|
|
for key, binding := range bindings {
|
|||
|
|
if now.After(binding.expiresAt) {
|
|||
|
|
delete(bindings, key)
|
|||
|
|
}
|
|||
|
|
scanned++
|
|||
|
|
if scanned >= maxScan {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) {
|
|||
|
|
if len(bindings) < maxEntries || maxEntries <= 0 {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if _, exists := bindings[incomingKey]; exists {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
// 固定上限保护:淘汰任意一项,优先保证内存有界。
|
|||
|
|
for key := range bindings {
|
|||
|
|
delete(bindings, key)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func normalizeOpenAIWSResponseID(responseID string) string {
|
|||
|
|
return strings.TrimSpace(responseID)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func openAIWSResponseAccountCacheKey(responseID string) string {
|
|||
|
|
sum := sha256.Sum256([]byte(responseID))
|
|||
|
|
return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:])
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration {
|
|||
|
|
if ttl <= 0 {
|
|||
|
|
return time.Hour
|
|||
|
|
}
|
|||
|
|
return ttl
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string {
|
|||
|
|
hash := strings.TrimSpace(sessionHash)
|
|||
|
|
if hash == "" {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
return fmt.Sprintf("%d:%s", groupID, hash)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
|
|||
|
|
if ctx == nil {
|
|||
|
|
ctx = context.Background()
|
|||
|
|
}
|
|||
|
|
return context.WithTimeout(ctx, openAIWSStateStoreRedisTimeout)
|
|||
|
|
}
|