Files

1714 lines
40 KiB
Go
Raw Permalink Normal View History

package service
import (
"context"
"errors"
"fmt"
"math"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"golang.org/x/sync/errgroup"
)
const (
openAIWSConnMaxAge = 60 * time.Minute
openAIWSConnHealthCheckIdle = 90 * time.Second
openAIWSConnHealthCheckTO = 2 * time.Second
openAIWSConnPrewarmExtraDelay = 2 * time.Second
openAIWSAcquireCleanupInterval = 3 * time.Second
openAIWSBackgroundPingInterval = 30 * time.Second
openAIWSBackgroundSweepTicker = 30 * time.Second
openAIWSPrewarmFailureWindow = 30 * time.Second
openAIWSPrewarmFailureSuppress = 2
)
var (
errOpenAIWSConnClosed = errors.New("openai ws connection closed")
errOpenAIWSConnQueueFull = errors.New("openai ws connection queue full")
errOpenAIWSPreferredConnUnavailable = errors.New("openai ws preferred connection unavailable")
)
type openAIWSDialError struct {
StatusCode int
ResponseHeaders http.Header
Err error
}
func (e *openAIWSDialError) Error() string {
if e == nil {
return ""
}
if e.StatusCode > 0 {
return fmt.Sprintf("openai ws dial failed: status=%d err=%v", e.StatusCode, e.Err)
}
return fmt.Sprintf("openai ws dial failed: %v", e.Err)
}
func (e *openAIWSDialError) Unwrap() error {
if e == nil {
return nil
}
return e.Err
}
type openAIWSAcquireRequest struct {
Account *Account
WSURL string
Headers http.Header
ProxyURL string
PreferredConnID string
// ForceNewConn: 强制本次获取新连接(避免复用导致连接内续链状态互相污染)。
ForceNewConn bool
// ForcePreferredConn: 强制本次只使用 PreferredConnID禁止漂移到其它连接。
ForcePreferredConn bool
}
type openAIWSConnLease struct {
pool *openAIWSConnPool
accountID int64
conn *openAIWSConn
queueWait time.Duration
connPick time.Duration
reused bool
released atomic.Bool
}
func (l *openAIWSConnLease) activeConn() (*openAIWSConn, error) {
if l == nil || l.conn == nil {
return nil, errOpenAIWSConnClosed
}
if l.released.Load() {
return nil, errOpenAIWSConnClosed
}
return l.conn, nil
}
func (l *openAIWSConnLease) ConnID() string {
if l == nil || l.conn == nil {
return ""
}
return l.conn.id
}
func (l *openAIWSConnLease) QueueWaitDuration() time.Duration {
if l == nil {
return 0
}
return l.queueWait
}
func (l *openAIWSConnLease) ConnPickDuration() time.Duration {
if l == nil {
return 0
}
return l.connPick
}
func (l *openAIWSConnLease) Reused() bool {
if l == nil {
return false
}
return l.reused
}
func (l *openAIWSConnLease) HandshakeHeader(name string) string {
if l == nil || l.conn == nil {
return ""
}
return l.conn.handshakeHeader(name)
}
func (l *openAIWSConnLease) HandshakeHeaders() http.Header {
if l == nil || l.conn == nil {
return nil
}
return cloneHeader(l.conn.handshakeHeaders)
}
func (l *openAIWSConnLease) IsPrewarmed() bool {
if l == nil || l.conn == nil {
return false
}
return l.conn.isPrewarmed()
}
func (l *openAIWSConnLease) MarkPrewarmed() {
if l == nil || l.conn == nil {
return
}
l.conn.markPrewarmed()
}
func (l *openAIWSConnLease) WriteJSON(value any, timeout time.Duration) error {
conn, err := l.activeConn()
if err != nil {
return err
}
return conn.writeJSONWithTimeout(context.Background(), value, timeout)
}
func (l *openAIWSConnLease) WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error {
conn, err := l.activeConn()
if err != nil {
return err
}
return conn.writeJSONWithTimeout(ctx, value, timeout)
}
func (l *openAIWSConnLease) WriteJSONContext(ctx context.Context, value any) error {
conn, err := l.activeConn()
if err != nil {
return err
}
return conn.writeJSON(value, ctx)
}
func (l *openAIWSConnLease) ReadMessage(timeout time.Duration) ([]byte, error) {
conn, err := l.activeConn()
if err != nil {
return nil, err
}
return conn.readMessageWithTimeout(timeout)
}
func (l *openAIWSConnLease) ReadMessageContext(ctx context.Context) ([]byte, error) {
conn, err := l.activeConn()
if err != nil {
return nil, err
}
return conn.readMessage(ctx)
}
func (l *openAIWSConnLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) {
conn, err := l.activeConn()
if err != nil {
return nil, err
}
return conn.readMessageWithContextTimeout(ctx, timeout)
}
func (l *openAIWSConnLease) PingWithTimeout(timeout time.Duration) error {
conn, err := l.activeConn()
if err != nil {
return err
}
return conn.pingWithTimeout(timeout)
}
func (l *openAIWSConnLease) MarkBroken() {
if l == nil || l.pool == nil || l.conn == nil || l.released.Load() {
return
}
l.pool.evictConn(l.accountID, l.conn.id)
}
func (l *openAIWSConnLease) Release() {
if l == nil || l.conn == nil {
return
}
if !l.released.CompareAndSwap(false, true) {
return
}
l.conn.release()
}
type openAIWSConn struct {
id string
ws openAIWSClientConn
handshakeHeaders http.Header
leaseCh chan struct{}
closedCh chan struct{}
closeOnce sync.Once
readMu sync.Mutex
writeMu sync.Mutex
waiters atomic.Int32
createdAtNano atomic.Int64
lastUsedNano atomic.Int64
prewarmed atomic.Bool
}
func newOpenAIWSConn(id string, _ int64, ws openAIWSClientConn, handshakeHeaders http.Header) *openAIWSConn {
now := time.Now()
conn := &openAIWSConn{
id: id,
ws: ws,
handshakeHeaders: cloneHeader(handshakeHeaders),
leaseCh: make(chan struct{}, 1),
closedCh: make(chan struct{}),
}
conn.leaseCh <- struct{}{}
conn.createdAtNano.Store(now.UnixNano())
conn.lastUsedNano.Store(now.UnixNano())
return conn
}
func (c *openAIWSConn) tryAcquire() bool {
if c == nil {
return false
}
select {
case <-c.closedCh:
return false
default:
}
select {
case <-c.leaseCh:
select {
case <-c.closedCh:
c.release()
return false
default:
}
return true
default:
return false
}
}
func (c *openAIWSConn) acquire(ctx context.Context) error {
if c == nil {
return errOpenAIWSConnClosed
}
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-c.closedCh:
return errOpenAIWSConnClosed
case <-c.leaseCh:
select {
case <-c.closedCh:
c.release()
return errOpenAIWSConnClosed
default:
}
return nil
}
}
}
func (c *openAIWSConn) release() {
if c == nil {
return
}
select {
case c.leaseCh <- struct{}{}:
default:
}
c.touch()
}
func (c *openAIWSConn) close() {
if c == nil {
return
}
c.closeOnce.Do(func() {
close(c.closedCh)
if c.ws != nil {
_ = c.ws.Close()
}
select {
case c.leaseCh <- struct{}{}:
default:
}
})
}
func (c *openAIWSConn) writeJSONWithTimeout(parent context.Context, value any, timeout time.Duration) error {
if c == nil {
return errOpenAIWSConnClosed
}
select {
case <-c.closedCh:
return errOpenAIWSConnClosed
default:
}
writeCtx := parent
if writeCtx == nil {
writeCtx = context.Background()
}
if timeout <= 0 {
return c.writeJSON(value, writeCtx)
}
var cancel context.CancelFunc
writeCtx, cancel = context.WithTimeout(writeCtx, timeout)
defer cancel()
return c.writeJSON(value, writeCtx)
}
func (c *openAIWSConn) writeJSON(value any, writeCtx context.Context) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
if c.ws == nil {
return errOpenAIWSConnClosed
}
if writeCtx == nil {
writeCtx = context.Background()
}
if err := c.ws.WriteJSON(writeCtx, value); err != nil {
return err
}
c.touch()
return nil
}
func (c *openAIWSConn) readMessageWithTimeout(timeout time.Duration) ([]byte, error) {
return c.readMessageWithContextTimeout(context.Background(), timeout)
}
func (c *openAIWSConn) readMessageWithContextTimeout(parent context.Context, timeout time.Duration) ([]byte, error) {
if c == nil {
return nil, errOpenAIWSConnClosed
}
select {
case <-c.closedCh:
return nil, errOpenAIWSConnClosed
default:
}
if parent == nil {
parent = context.Background()
}
if timeout <= 0 {
return c.readMessage(parent)
}
readCtx, cancel := context.WithTimeout(parent, timeout)
defer cancel()
return c.readMessage(readCtx)
}
func (c *openAIWSConn) readMessage(readCtx context.Context) ([]byte, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
if c.ws == nil {
return nil, errOpenAIWSConnClosed
}
if readCtx == nil {
readCtx = context.Background()
}
payload, err := c.ws.ReadMessage(readCtx)
if err != nil {
return nil, err
}
c.touch()
return payload, nil
}
func (c *openAIWSConn) pingWithTimeout(timeout time.Duration) error {
if c == nil {
return errOpenAIWSConnClosed
}
select {
case <-c.closedCh:
return errOpenAIWSConnClosed
default:
}
c.writeMu.Lock()
defer c.writeMu.Unlock()
if c.ws == nil {
return errOpenAIWSConnClosed
}
if timeout <= 0 {
timeout = openAIWSConnHealthCheckTO
}
pingCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := c.ws.Ping(pingCtx); err != nil {
return err
}
return nil
}
func (c *openAIWSConn) touch() {
if c == nil {
return
}
c.lastUsedNano.Store(time.Now().UnixNano())
}
func (c *openAIWSConn) createdAt() time.Time {
if c == nil {
return time.Time{}
}
nano := c.createdAtNano.Load()
if nano <= 0 {
return time.Time{}
}
return time.Unix(0, nano)
}
func (c *openAIWSConn) lastUsedAt() time.Time {
if c == nil {
return time.Time{}
}
nano := c.lastUsedNano.Load()
if nano <= 0 {
return time.Time{}
}
return time.Unix(0, nano)
}
func (c *openAIWSConn) idleDuration(now time.Time) time.Duration {
if c == nil {
return 0
}
last := c.lastUsedAt()
if last.IsZero() {
return 0
}
return now.Sub(last)
}
func (c *openAIWSConn) age(now time.Time) time.Duration {
if c == nil {
return 0
}
created := c.createdAt()
if created.IsZero() {
return 0
}
return now.Sub(created)
}
func (c *openAIWSConn) isLeased() bool {
if c == nil {
return false
}
return len(c.leaseCh) == 0
}
func (c *openAIWSConn) handshakeHeader(name string) string {
if c == nil || c.handshakeHeaders == nil {
return ""
}
return strings.TrimSpace(c.handshakeHeaders.Get(strings.TrimSpace(name)))
}
func (c *openAIWSConn) isPrewarmed() bool {
if c == nil {
return false
}
return c.prewarmed.Load()
}
func (c *openAIWSConn) markPrewarmed() {
if c == nil {
return
}
c.prewarmed.Store(true)
}
type openAIWSAccountPool struct {
mu sync.Mutex
conns map[string]*openAIWSConn
pinnedConns map[string]int
creating int
lastCleanupAt time.Time
lastAcquire *openAIWSAcquireRequest
prewarmActive bool
prewarmUntil time.Time
prewarmFails int
prewarmFailAt time.Time
}
type OpenAIWSPoolMetricsSnapshot struct {
AcquireTotal int64
AcquireReuseTotal int64
AcquireCreateTotal int64
AcquireQueueWaitTotal int64
AcquireQueueWaitMsTotal int64
ConnPickTotal int64
ConnPickMsTotal int64
ScaleUpTotal int64
ScaleDownTotal int64
}
type openAIWSPoolMetrics struct {
acquireTotal atomic.Int64
acquireReuseTotal atomic.Int64
acquireCreateTotal atomic.Int64
acquireQueueWaitTotal atomic.Int64
acquireQueueWaitMs atomic.Int64
connPickTotal atomic.Int64
connPickMs atomic.Int64
scaleUpTotal atomic.Int64
scaleDownTotal atomic.Int64
}
type openAIWSConnPool struct {
cfg *config.Config
// 通过接口解耦底层 WS 客户端实现,默认使用 coder/websocket。
clientDialer openAIWSClientDialer
accounts sync.Map // key: int64(accountID), value: *openAIWSAccountPool
seq atomic.Uint64
metrics openAIWSPoolMetrics
workerStopCh chan struct{}
workerWg sync.WaitGroup
closeOnce sync.Once
}
func newOpenAIWSConnPool(cfg *config.Config) *openAIWSConnPool {
pool := &openAIWSConnPool{
cfg: cfg,
clientDialer: newDefaultOpenAIWSClientDialer(),
workerStopCh: make(chan struct{}),
}
pool.startBackgroundWorkers()
return pool
}
func (p *openAIWSConnPool) SnapshotMetrics() OpenAIWSPoolMetricsSnapshot {
if p == nil {
return OpenAIWSPoolMetricsSnapshot{}
}
return OpenAIWSPoolMetricsSnapshot{
AcquireTotal: p.metrics.acquireTotal.Load(),
AcquireReuseTotal: p.metrics.acquireReuseTotal.Load(),
AcquireCreateTotal: p.metrics.acquireCreateTotal.Load(),
AcquireQueueWaitTotal: p.metrics.acquireQueueWaitTotal.Load(),
AcquireQueueWaitMsTotal: p.metrics.acquireQueueWaitMs.Load(),
ConnPickTotal: p.metrics.connPickTotal.Load(),
ConnPickMsTotal: p.metrics.connPickMs.Load(),
ScaleUpTotal: p.metrics.scaleUpTotal.Load(),
ScaleDownTotal: p.metrics.scaleDownTotal.Load(),
}
}
func (p *openAIWSConnPool) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
if p == nil {
return OpenAIWSTransportMetricsSnapshot{}
}
if dialer, ok := p.clientDialer.(openAIWSTransportMetricsDialer); ok {
return dialer.SnapshotTransportMetrics()
}
return OpenAIWSTransportMetricsSnapshot{}
}
func (p *openAIWSConnPool) setClientDialerForTest(dialer openAIWSClientDialer) {
if p == nil || dialer == nil {
return
}
p.clientDialer = dialer
}
// Close 停止后台 worker 并关闭所有空闲连接,应在优雅关闭时调用。
func (p *openAIWSConnPool) Close() {
if p == nil {
return
}
p.closeOnce.Do(func() {
if p.workerStopCh != nil {
close(p.workerStopCh)
}
p.workerWg.Wait()
// 遍历所有账户池,关闭全部空闲连接。
p.accounts.Range(func(key, value any) bool {
ap, ok := value.(*openAIWSAccountPool)
if !ok || ap == nil {
return true
}
ap.mu.Lock()
for _, conn := range ap.conns {
if conn != nil && !conn.isLeased() {
conn.close()
}
}
ap.mu.Unlock()
return true
})
})
}
func (p *openAIWSConnPool) startBackgroundWorkers() {
if p == nil || p.workerStopCh == nil {
return
}
p.workerWg.Add(2)
go func() {
defer p.workerWg.Done()
p.runBackgroundPingWorker()
}()
go func() {
defer p.workerWg.Done()
p.runBackgroundCleanupWorker()
}()
}
type openAIWSIdlePingCandidate struct {
accountID int64
conn *openAIWSConn
}
func (p *openAIWSConnPool) runBackgroundPingWorker() {
if p == nil {
return
}
ticker := time.NewTicker(openAIWSBackgroundPingInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
p.runBackgroundPingSweep()
case <-p.workerStopCh:
return
}
}
}
func (p *openAIWSConnPool) runBackgroundPingSweep() {
if p == nil {
return
}
candidates := p.snapshotIdleConnsForPing()
var g errgroup.Group
g.SetLimit(10)
for _, item := range candidates {
item := item
if item.conn == nil || item.conn.isLeased() || item.conn.waiters.Load() > 0 {
continue
}
g.Go(func() error {
if err := item.conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
p.evictConn(item.accountID, item.conn.id)
}
return nil
})
}
_ = g.Wait()
}
func (p *openAIWSConnPool) snapshotIdleConnsForPing() []openAIWSIdlePingCandidate {
if p == nil {
return nil
}
candidates := make([]openAIWSIdlePingCandidate, 0)
p.accounts.Range(func(key, value any) bool {
accountID, ok := key.(int64)
if !ok || accountID <= 0 {
return true
}
ap, ok := value.(*openAIWSAccountPool)
if !ok || ap == nil {
return true
}
ap.mu.Lock()
for _, conn := range ap.conns {
if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 {
continue
}
candidates = append(candidates, openAIWSIdlePingCandidate{
accountID: accountID,
conn: conn,
})
}
ap.mu.Unlock()
return true
})
return candidates
}
func (p *openAIWSConnPool) runBackgroundCleanupWorker() {
if p == nil {
return
}
ticker := time.NewTicker(openAIWSBackgroundSweepTicker)
defer ticker.Stop()
for {
select {
case <-ticker.C:
p.runBackgroundCleanupSweep(time.Now())
case <-p.workerStopCh:
return
}
}
}
func (p *openAIWSConnPool) runBackgroundCleanupSweep(now time.Time) {
if p == nil {
return
}
type cleanupResult struct {
evicted []*openAIWSConn
}
results := make([]cleanupResult, 0)
p.accounts.Range(func(_ any, value any) bool {
ap, ok := value.(*openAIWSAccountPool)
if !ok || ap == nil {
return true
}
maxConns := p.maxConnsHardCap()
ap.mu.Lock()
if ap.lastAcquire != nil && ap.lastAcquire.Account != nil {
maxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account)
}
evicted := p.cleanupAccountLocked(ap, now, maxConns)
ap.lastCleanupAt = now
ap.mu.Unlock()
if len(evicted) > 0 {
results = append(results, cleanupResult{evicted: evicted})
}
return true
})
for _, result := range results {
closeOpenAIWSConns(result.evicted)
}
}
func (p *openAIWSConnPool) Acquire(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConnLease, error) {
if p != nil {
p.metrics.acquireTotal.Add(1)
}
return p.acquire(ctx, cloneOpenAIWSAcquireRequest(req), 0)
}
func (p *openAIWSConnPool) acquire(ctx context.Context, req openAIWSAcquireRequest, retry int) (*openAIWSConnLease, error) {
if p == nil || req.Account == nil || req.Account.ID <= 0 {
return nil, errors.New("invalid ws acquire request")
}
if stringsTrim(req.WSURL) == "" {
return nil, errors.New("ws url is empty")
}
accountID := req.Account.ID
effectiveMaxConns := p.effectiveMaxConnsByAccount(req.Account)
if effectiveMaxConns <= 0 {
return nil, errOpenAIWSConnQueueFull
}
var evicted []*openAIWSConn
ap := p.getOrCreateAccountPool(accountID)
ap.mu.Lock()
ap.lastAcquire = cloneOpenAIWSAcquireRequestPtr(&req)
now := time.Now()
if ap.lastCleanupAt.IsZero() || now.Sub(ap.lastCleanupAt) >= openAIWSAcquireCleanupInterval {
evicted = p.cleanupAccountLocked(ap, now, effectiveMaxConns)
ap.lastCleanupAt = now
}
pickStartedAt := time.Now()
allowReuse := !req.ForceNewConn
preferredConnID := stringsTrim(req.PreferredConnID)
forcePreferredConn := allowReuse && req.ForcePreferredConn
if allowReuse {
if forcePreferredConn {
if preferredConnID == "" {
p.recordConnPickDuration(time.Since(pickStartedAt))
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSPreferredConnUnavailable
}
preferredConn, ok := ap.conns[preferredConnID]
if !ok || preferredConn == nil {
p.recordConnPickDuration(time.Since(pickStartedAt))
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSPreferredConnUnavailable
}
if preferredConn.tryAcquire() {
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
if p.shouldHealthCheckConn(preferredConn) {
if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
preferredConn.close()
p.evictConn(accountID, preferredConn.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
lease := &openAIWSConnLease{
pool: p,
accountID: accountID,
conn: preferredConn,
connPick: connPick,
reused: true,
}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
if int(preferredConn.waiters.Load()) >= p.queueLimitPerConn() {
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSConnQueueFull
}
preferredConn.waiters.Add(1)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
defer preferredConn.waiters.Add(-1)
waitStart := time.Now()
p.metrics.acquireQueueWaitTotal.Add(1)
if err := preferredConn.acquire(ctx); err != nil {
if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
if p.shouldHealthCheckConn(preferredConn) {
if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
preferredConn.release()
preferredConn.close()
p.evictConn(accountID, preferredConn.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
queueWait := time.Since(waitStart)
p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds())
lease := &openAIWSConnLease{
pool: p,
accountID: accountID,
conn: preferredConn,
queueWait: queueWait,
connPick: connPick,
reused: true,
}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
if preferredConnID != "" {
if conn, ok := ap.conns[preferredConnID]; ok && conn.tryAcquire() {
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
if p.shouldHealthCheckConn(conn) {
if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
conn.close()
p.evictConn(accountID, conn.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
}
best := p.pickLeastBusyConnLocked(ap, "")
if best != nil && best.tryAcquire() {
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
if p.shouldHealthCheckConn(best) {
if err := best.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
best.close()
p.evictConn(accountID, best.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: best, connPick: connPick, reused: true}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
for _, conn := range ap.conns {
if conn == nil || conn == best {
continue
}
if conn.tryAcquire() {
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
if p.shouldHealthCheckConn(conn) {
if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
conn.close()
p.evictConn(accountID, conn.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
}
}
if req.ForceNewConn && len(ap.conns)+ap.creating >= effectiveMaxConns {
if idle := p.pickOldestIdleConnLocked(ap); idle != nil {
delete(ap.conns, idle.id)
evicted = append(evicted, idle)
p.metrics.scaleDownTotal.Add(1)
}
}
if len(ap.conns)+ap.creating < effectiveMaxConns {
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
ap.creating++
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
conn, dialErr := p.dialConn(ctx, req)
ap = p.getOrCreateAccountPool(accountID)
ap.mu.Lock()
ap.creating--
if dialErr != nil {
ap.prewarmFails++
ap.prewarmFailAt = time.Now()
ap.mu.Unlock()
return nil, dialErr
}
ap.conns[conn.id] = conn
ap.prewarmFails = 0
ap.prewarmFailAt = time.Time{}
ap.mu.Unlock()
p.metrics.acquireCreateTotal.Add(1)
if !conn.tryAcquire() {
if err := conn.acquire(ctx); err != nil {
conn.close()
p.evictConn(accountID, conn.id)
return nil, err
}
}
lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick}
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
if req.ForceNewConn {
p.recordConnPickDuration(time.Since(pickStartedAt))
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSConnQueueFull
}
target := p.pickLeastBusyConnLocked(ap, req.PreferredConnID)
connPick := time.Since(pickStartedAt)
p.recordConnPickDuration(connPick)
if target == nil {
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSConnClosed
}
if int(target.waiters.Load()) >= p.queueLimitPerConn() {
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
return nil, errOpenAIWSConnQueueFull
}
target.waiters.Add(1)
ap.mu.Unlock()
closeOpenAIWSConns(evicted)
defer target.waiters.Add(-1)
waitStart := time.Now()
p.metrics.acquireQueueWaitTotal.Add(1)
if err := target.acquire(ctx); err != nil {
if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
if p.shouldHealthCheckConn(target) {
if err := target.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
target.release()
target.close()
p.evictConn(accountID, target.id)
if retry < 1 {
return p.acquire(ctx, req, retry+1)
}
return nil, err
}
}
queueWait := time.Since(waitStart)
p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds())
lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: target, queueWait: queueWait, connPick: connPick, reused: true}
p.metrics.acquireReuseTotal.Add(1)
p.ensureTargetIdleAsync(accountID)
return lease, nil
}
func (p *openAIWSConnPool) recordConnPickDuration(duration time.Duration) {
if p == nil {
return
}
if duration < 0 {
duration = 0
}
p.metrics.connPickTotal.Add(1)
p.metrics.connPickMs.Add(duration.Milliseconds())
}
func (p *openAIWSConnPool) pickOldestIdleConnLocked(ap *openAIWSAccountPool) *openAIWSConn {
if ap == nil || len(ap.conns) == 0 {
return nil
}
var oldest *openAIWSConn
for _, conn := range ap.conns {
if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) {
continue
}
if oldest == nil || conn.lastUsedAt().Before(oldest.lastUsedAt()) {
oldest = conn
}
}
return oldest
}
func (p *openAIWSConnPool) getOrCreateAccountPool(accountID int64) *openAIWSAccountPool {
if p == nil || accountID <= 0 {
return nil
}
if existing, ok := p.accounts.Load(accountID); ok {
if ap, typed := existing.(*openAIWSAccountPool); typed && ap != nil {
return ap
}
}
ap := &openAIWSAccountPool{
conns: make(map[string]*openAIWSConn),
pinnedConns: make(map[string]int),
}
actual, _ := p.accounts.LoadOrStore(accountID, ap)
if typed, ok := actual.(*openAIWSAccountPool); ok && typed != nil {
return typed
}
return ap
}
// ensureAccountPoolLocked 兼容旧调用。
func (p *openAIWSConnPool) ensureAccountPoolLocked(accountID int64) *openAIWSAccountPool {
return p.getOrCreateAccountPool(accountID)
}
func (p *openAIWSConnPool) getAccountPool(accountID int64) (*openAIWSAccountPool, bool) {
if p == nil || accountID <= 0 {
return nil, false
}
value, ok := p.accounts.Load(accountID)
if !ok || value == nil {
return nil, false
}
ap, typed := value.(*openAIWSAccountPool)
return ap, typed && ap != nil
}
func (p *openAIWSConnPool) isConnPinnedLocked(ap *openAIWSAccountPool, connID string) bool {
if ap == nil || connID == "" || len(ap.pinnedConns) == 0 {
return false
}
return ap.pinnedConns[connID] > 0
}
func (p *openAIWSConnPool) cleanupAccountLocked(ap *openAIWSAccountPool, now time.Time, maxConns int) []*openAIWSConn {
if ap == nil {
return nil
}
maxAge := p.maxConnAge()
evicted := make([]*openAIWSConn, 0)
for id, conn := range ap.conns {
if conn == nil {
delete(ap.conns, id)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, id)
}
continue
}
select {
case <-conn.closedCh:
delete(ap.conns, id)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, id)
}
evicted = append(evicted, conn)
continue
default:
}
if p.isConnPinnedLocked(ap, id) {
continue
}
if maxAge > 0 && !conn.isLeased() && conn.age(now) > maxAge {
delete(ap.conns, id)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, id)
}
evicted = append(evicted, conn)
}
}
if maxConns <= 0 {
maxConns = p.maxConnsHardCap()
}
maxIdle := p.maxIdlePerAccount()
if maxIdle < 0 || maxIdle > maxConns {
maxIdle = maxConns
}
if maxIdle >= 0 && len(ap.conns) > maxIdle {
idleConns := make([]*openAIWSConn, 0, len(ap.conns))
for id, conn := range ap.conns {
if conn == nil {
delete(ap.conns, id)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, id)
}
continue
}
// 有等待者的连接不能在清理阶段被淘汰,否则等待中的 acquire 会收到 closed 错误。
if conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) {
continue
}
idleConns = append(idleConns, conn)
}
sort.SliceStable(idleConns, func(i, j int) bool {
return idleConns[i].lastUsedAt().Before(idleConns[j].lastUsedAt())
})
redundant := len(ap.conns) - maxIdle
if redundant > len(idleConns) {
redundant = len(idleConns)
}
for i := 0; i < redundant; i++ {
conn := idleConns[i]
delete(ap.conns, conn.id)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, conn.id)
}
evicted = append(evicted, conn)
}
if redundant > 0 {
p.metrics.scaleDownTotal.Add(int64(redundant))
}
}
return evicted
}
func (p *openAIWSConnPool) pickLeastBusyConnLocked(ap *openAIWSAccountPool, preferredConnID string) *openAIWSConn {
if ap == nil || len(ap.conns) == 0 {
return nil
}
preferredConnID = stringsTrim(preferredConnID)
if preferredConnID != "" {
if conn, ok := ap.conns[preferredConnID]; ok {
return conn
}
}
var best *openAIWSConn
var bestWaiters int32
var bestLastUsed time.Time
for _, conn := range ap.conns {
if conn == nil {
continue
}
waiters := conn.waiters.Load()
lastUsed := conn.lastUsedAt()
if best == nil ||
waiters < bestWaiters ||
(waiters == bestWaiters && lastUsed.Before(bestLastUsed)) {
best = conn
bestWaiters = waiters
bestLastUsed = lastUsed
}
}
return best
}
func accountPoolLoadLocked(ap *openAIWSAccountPool) (inflight int, waiters int) {
if ap == nil {
return 0, 0
}
for _, conn := range ap.conns {
if conn == nil {
continue
}
if conn.isLeased() {
inflight++
}
waiters += int(conn.waiters.Load())
}
return inflight, waiters
}
// AccountPoolLoad 返回指定账号连接池的并发与排队快照。
func (p *openAIWSConnPool) AccountPoolLoad(accountID int64) (inflight int, waiters int, conns int) {
if p == nil || accountID <= 0 {
return 0, 0, 0
}
ap, ok := p.getAccountPool(accountID)
if !ok || ap == nil {
return 0, 0, 0
}
ap.mu.Lock()
defer ap.mu.Unlock()
inflight, waiters = accountPoolLoadLocked(ap)
return inflight, waiters, len(ap.conns)
}
func (p *openAIWSConnPool) ensureTargetIdleAsync(accountID int64) {
if p == nil || accountID <= 0 {
return
}
var req openAIWSAcquireRequest
need := 0
ap, ok := p.getAccountPool(accountID)
if !ok || ap == nil {
return
}
ap.mu.Lock()
defer ap.mu.Unlock()
if ap.lastAcquire == nil {
return
}
if ap.prewarmActive {
return
}
now := time.Now()
if !ap.prewarmUntil.IsZero() && now.Before(ap.prewarmUntil) {
return
}
if p.shouldSuppressPrewarmLocked(ap, now) {
return
}
effectiveMaxConns := p.maxConnsHardCap()
if ap.lastAcquire != nil && ap.lastAcquire.Account != nil {
effectiveMaxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account)
}
target := p.targetConnCountLocked(ap, effectiveMaxConns)
current := len(ap.conns) + ap.creating
if current >= target {
return
}
need = target - current
if need <= 0 {
return
}
req = cloneOpenAIWSAcquireRequest(*ap.lastAcquire)
ap.prewarmActive = true
if cooldown := p.prewarmCooldown(); cooldown > 0 {
ap.prewarmUntil = now.Add(cooldown)
}
ap.creating += need
p.metrics.scaleUpTotal.Add(int64(need))
go p.prewarmConns(accountID, req, need)
}
func (p *openAIWSConnPool) targetConnCountLocked(ap *openAIWSAccountPool, maxConns int) int {
if ap == nil {
return 0
}
if maxConns <= 0 {
return 0
}
minIdle := p.minIdlePerAccount()
if minIdle < 0 {
minIdle = 0
}
if minIdle > maxConns {
minIdle = maxConns
}
inflight, waiters := accountPoolLoadLocked(ap)
utilization := p.targetUtilization()
demand := inflight + waiters
if demand <= 0 {
return minIdle
}
target := 1
if demand > 1 {
target = int(math.Ceil(float64(demand) / utilization))
}
if waiters > 0 && target < len(ap.conns)+1 {
target = len(ap.conns) + 1
}
if target < minIdle {
target = minIdle
}
if target > maxConns {
target = maxConns
}
return target
}
func (p *openAIWSConnPool) prewarmConns(accountID int64, req openAIWSAcquireRequest, total int) {
defer func() {
if ap, ok := p.getAccountPool(accountID); ok && ap != nil {
ap.mu.Lock()
ap.prewarmActive = false
ap.mu.Unlock()
}
}()
for i := 0; i < total; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.dialTimeout()+openAIWSConnPrewarmExtraDelay)
conn, err := p.dialConn(ctx, req)
cancel()
ap, ok := p.getAccountPool(accountID)
if !ok || ap == nil {
if conn != nil {
conn.close()
}
return
}
ap.mu.Lock()
if ap.creating > 0 {
ap.creating--
}
if err != nil {
ap.prewarmFails++
ap.prewarmFailAt = time.Now()
ap.mu.Unlock()
continue
}
if len(ap.conns) >= p.effectiveMaxConnsByAccount(req.Account) {
ap.mu.Unlock()
conn.close()
continue
}
ap.conns[conn.id] = conn
ap.prewarmFails = 0
ap.prewarmFailAt = time.Time{}
ap.mu.Unlock()
}
}
func (p *openAIWSConnPool) evictConn(accountID int64, connID string) {
if p == nil || accountID <= 0 || stringsTrim(connID) == "" {
return
}
var conn *openAIWSConn
ap, ok := p.getAccountPool(accountID)
if ok && ap != nil {
ap.mu.Lock()
if c, exists := ap.conns[connID]; exists {
conn = c
delete(ap.conns, connID)
if len(ap.pinnedConns) > 0 {
delete(ap.pinnedConns, connID)
}
}
ap.mu.Unlock()
}
if conn != nil {
conn.close()
}
}
func (p *openAIWSConnPool) PinConn(accountID int64, connID string) bool {
if p == nil || accountID <= 0 {
return false
}
connID = stringsTrim(connID)
if connID == "" {
return false
}
ap, ok := p.getAccountPool(accountID)
if !ok || ap == nil {
return false
}
ap.mu.Lock()
defer ap.mu.Unlock()
if _, exists := ap.conns[connID]; !exists {
return false
}
if ap.pinnedConns == nil {
ap.pinnedConns = make(map[string]int)
}
ap.pinnedConns[connID]++
return true
}
func (p *openAIWSConnPool) UnpinConn(accountID int64, connID string) {
if p == nil || accountID <= 0 {
return
}
connID = stringsTrim(connID)
if connID == "" {
return
}
ap, ok := p.getAccountPool(accountID)
if !ok || ap == nil {
return
}
ap.mu.Lock()
defer ap.mu.Unlock()
if len(ap.pinnedConns) == 0 {
return
}
count := ap.pinnedConns[connID]
if count <= 1 {
delete(ap.pinnedConns, connID)
return
}
ap.pinnedConns[connID] = count - 1
}
func (p *openAIWSConnPool) dialConn(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConn, error) {
if p == nil || p.clientDialer == nil {
return nil, errors.New("openai ws client dialer is nil")
}
conn, status, handshakeHeaders, err := p.clientDialer.Dial(ctx, req.WSURL, req.Headers, req.ProxyURL)
if err != nil {
return nil, &openAIWSDialError{
StatusCode: status,
ResponseHeaders: cloneHeader(handshakeHeaders),
Err: err,
}
}
if conn == nil {
return nil, &openAIWSDialError{
StatusCode: status,
ResponseHeaders: cloneHeader(handshakeHeaders),
Err: errors.New("openai ws dialer returned nil connection"),
}
}
id := p.nextConnID(req.Account.ID)
return newOpenAIWSConn(id, req.Account.ID, conn, handshakeHeaders), nil
}
func (p *openAIWSConnPool) nextConnID(accountID int64) string {
seq := p.seq.Add(1)
buf := make([]byte, 0, 32)
buf = append(buf, "oa_ws_"...)
buf = strconv.AppendInt(buf, accountID, 10)
buf = append(buf, '_')
buf = strconv.AppendUint(buf, seq, 10)
return string(buf)
}
func (p *openAIWSConnPool) shouldHealthCheckConn(conn *openAIWSConn) bool {
if conn == nil {
return false
}
return conn.idleDuration(time.Now()) >= openAIWSConnHealthCheckIdle
}
func (p *openAIWSConnPool) maxConnsHardCap() int {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount > 0 {
return p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount
}
return 8
}
func (p *openAIWSConnPool) dynamicMaxConnsEnabled() bool {
if p != nil && p.cfg != nil {
return p.cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled
}
return false
}
func (p *openAIWSConnPool) modeRouterV2Enabled() bool {
if p != nil && p.cfg != nil {
return p.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
}
return false
}
func (p *openAIWSConnPool) maxConnsFactorByAccount(account *Account) float64 {
if p == nil || p.cfg == nil || account == nil {
return 1.0
}
switch account.Type {
case AccountTypeOAuth:
if p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor > 0 {
return p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor
}
case AccountTypeAPIKey:
if p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor > 0 {
return p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor
}
}
return 1.0
}
func (p *openAIWSConnPool) effectiveMaxConnsByAccount(account *Account) int {
hardCap := p.maxConnsHardCap()
if hardCap <= 0 {
return 0
}
if p.modeRouterV2Enabled() {
if account == nil {
return hardCap
}
if account.Concurrency <= 0 {
return 0
}
return account.Concurrency
}
if account == nil || !p.dynamicMaxConnsEnabled() {
return hardCap
}
if account.Concurrency <= 0 {
// 0/-1 等“无限制”并发场景下,仍由全局硬上限兜底。
return hardCap
}
factor := p.maxConnsFactorByAccount(account)
if factor <= 0 {
factor = 1.0
}
effective := int(math.Ceil(float64(account.Concurrency) * factor))
if effective < 1 {
effective = 1
}
if effective > hardCap {
effective = hardCap
}
return effective
}
func (p *openAIWSConnPool) minIdlePerAccount() int {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MinIdlePerAccount >= 0 {
return p.cfg.Gateway.OpenAIWS.MinIdlePerAccount
}
return 0
}
func (p *openAIWSConnPool) maxIdlePerAccount() int {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount >= 0 {
return p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount
}
return 4
}
func (p *openAIWSConnPool) maxConnAge() time.Duration {
return openAIWSConnMaxAge
}
func (p *openAIWSConnPool) queueLimitPerConn() int {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.QueueLimitPerConn > 0 {
return p.cfg.Gateway.OpenAIWS.QueueLimitPerConn
}
return 256
}
func (p *openAIWSConnPool) targetUtilization() float64 {
if p != nil && p.cfg != nil {
ratio := p.cfg.Gateway.OpenAIWS.PoolTargetUtilization
if ratio > 0 && ratio <= 1 {
return ratio
}
}
return 0.7
}
func (p *openAIWSConnPool) prewarmCooldown() time.Duration {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS > 0 {
return time.Duration(p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS) * time.Millisecond
}
return 0
}
func (p *openAIWSConnPool) shouldSuppressPrewarmLocked(ap *openAIWSAccountPool, now time.Time) bool {
if ap == nil {
return true
}
if ap.prewarmFails <= 0 {
return false
}
if ap.prewarmFailAt.IsZero() {
ap.prewarmFails = 0
return false
}
if now.Sub(ap.prewarmFailAt) > openAIWSPrewarmFailureWindow {
ap.prewarmFails = 0
ap.prewarmFailAt = time.Time{}
return false
}
return ap.prewarmFails >= openAIWSPrewarmFailureSuppress
}
func (p *openAIWSConnPool) dialTimeout() time.Duration {
if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 {
return time.Duration(p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second
}
return 10 * time.Second
}
func cloneOpenAIWSAcquireRequest(req openAIWSAcquireRequest) openAIWSAcquireRequest {
copied := req
copied.Headers = cloneHeader(req.Headers)
copied.WSURL = stringsTrim(req.WSURL)
copied.ProxyURL = stringsTrim(req.ProxyURL)
copied.PreferredConnID = stringsTrim(req.PreferredConnID)
return copied
}
func cloneOpenAIWSAcquireRequestPtr(req *openAIWSAcquireRequest) *openAIWSAcquireRequest {
if req == nil {
return nil
}
copied := cloneOpenAIWSAcquireRequest(*req)
return &copied
}
func cloneHeader(src http.Header) http.Header {
if src == nil {
return nil
}
dst := make(http.Header, len(src))
for k, vals := range src {
if len(vals) == 0 {
dst[k] = nil
continue
}
copied := make([]string, len(vals))
copy(copied, vals)
dst[k] = copied
}
return dst
}
func closeOpenAIWSConns(conns []*openAIWSConn) {
if len(conns) == 0 {
return
}
for _, conn := range conns {
if conn == nil {
continue
}
conn.close()
}
}
func stringsTrim(value string) string {
return strings.TrimSpace(value)
}