mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-17 21:34:45 +08:00
808 lines
22 KiB
Go
808 lines
22 KiB
Go
|
|
package openai_ws_v2
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"errors"
|
|||
|
|
"io"
|
|||
|
|
"net"
|
|||
|
|
"strconv"
|
|||
|
|
"strings"
|
|||
|
|
"sync/atomic"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
coderws "github.com/coder/websocket"
|
|||
|
|
"github.com/tidwall/gjson"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
type FrameConn interface {
|
|||
|
|
ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error)
|
|||
|
|
WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error
|
|||
|
|
Close() error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type Usage struct {
|
|||
|
|
InputTokens int
|
|||
|
|
OutputTokens int
|
|||
|
|
CacheCreationInputTokens int
|
|||
|
|
CacheReadInputTokens int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type RelayResult struct {
|
|||
|
|
RequestModel string
|
|||
|
|
Usage Usage
|
|||
|
|
RequestID string
|
|||
|
|
TerminalEventType string
|
|||
|
|
FirstTokenMs *int
|
|||
|
|
Duration time.Duration
|
|||
|
|
ClientToUpstreamFrames int64
|
|||
|
|
UpstreamToClientFrames int64
|
|||
|
|
DroppedDownstreamFrames int64
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type RelayTurnResult struct {
|
|||
|
|
RequestModel string
|
|||
|
|
Usage Usage
|
|||
|
|
RequestID string
|
|||
|
|
TerminalEventType string
|
|||
|
|
Duration time.Duration
|
|||
|
|
FirstTokenMs *int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type RelayExit struct {
|
|||
|
|
Stage string
|
|||
|
|
Err error
|
|||
|
|
WroteDownstream bool
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type RelayOptions struct {
|
|||
|
|
WriteTimeout time.Duration
|
|||
|
|
IdleTimeout time.Duration
|
|||
|
|
UpstreamDrainTimeout time.Duration
|
|||
|
|
FirstMessageType coderws.MessageType
|
|||
|
|
OnUsageParseFailure func(eventType string, usageRaw string)
|
|||
|
|
OnTurnComplete func(turn RelayTurnResult)
|
|||
|
|
OnTrace func(event RelayTraceEvent)
|
|||
|
|
Now func() time.Time
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type RelayTraceEvent struct {
|
|||
|
|
Stage string
|
|||
|
|
Direction string
|
|||
|
|
MessageType string
|
|||
|
|
PayloadBytes int
|
|||
|
|
Graceful bool
|
|||
|
|
WroteDownstream bool
|
|||
|
|
Error string
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type relayState struct {
|
|||
|
|
usage Usage
|
|||
|
|
requestModel string
|
|||
|
|
lastResponseID string
|
|||
|
|
terminalEventType string
|
|||
|
|
firstTokenMs *int
|
|||
|
|
turnTimingByID map[string]*relayTurnTiming
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type relayExitSignal struct {
|
|||
|
|
stage string
|
|||
|
|
err error
|
|||
|
|
graceful bool
|
|||
|
|
wroteDownstream bool
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type observedUpstreamEvent struct {
|
|||
|
|
terminal bool
|
|||
|
|
eventType string
|
|||
|
|
responseID string
|
|||
|
|
usage Usage
|
|||
|
|
duration time.Duration
|
|||
|
|
firstToken *int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type relayTurnTiming struct {
|
|||
|
|
startAt time.Time
|
|||
|
|
firstTokenMs *int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func Relay(
|
|||
|
|
ctx context.Context,
|
|||
|
|
clientConn FrameConn,
|
|||
|
|
upstreamConn FrameConn,
|
|||
|
|
firstClientMessage []byte,
|
|||
|
|
options RelayOptions,
|
|||
|
|
) (RelayResult, *RelayExit) {
|
|||
|
|
result := RelayResult{RequestModel: strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())}
|
|||
|
|
if clientConn == nil || upstreamConn == nil {
|
|||
|
|
return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")}
|
|||
|
|
}
|
|||
|
|
if ctx == nil {
|
|||
|
|
ctx = context.Background()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
nowFn := options.Now
|
|||
|
|
if nowFn == nil {
|
|||
|
|
nowFn = time.Now
|
|||
|
|
}
|
|||
|
|
writeTimeout := options.WriteTimeout
|
|||
|
|
if writeTimeout <= 0 {
|
|||
|
|
writeTimeout = 2 * time.Minute
|
|||
|
|
}
|
|||
|
|
drainTimeout := options.UpstreamDrainTimeout
|
|||
|
|
if drainTimeout <= 0 {
|
|||
|
|
drainTimeout = 1200 * time.Millisecond
|
|||
|
|
}
|
|||
|
|
firstMessageType := options.FirstMessageType
|
|||
|
|
if firstMessageType != coderws.MessageBinary {
|
|||
|
|
firstMessageType = coderws.MessageText
|
|||
|
|
}
|
|||
|
|
startAt := nowFn()
|
|||
|
|
state := &relayState{requestModel: result.RequestModel}
|
|||
|
|
onTrace := options.OnTrace
|
|||
|
|
|
|||
|
|
relayCtx, relayCancel := context.WithCancel(ctx)
|
|||
|
|
defer relayCancel()
|
|||
|
|
|
|||
|
|
lastActivity := atomic.Int64{}
|
|||
|
|
lastActivity.Store(nowFn().UnixNano())
|
|||
|
|
markActivity := func() {
|
|||
|
|
lastActivity.Store(nowFn().UnixNano())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
writeUpstream := func(msgType coderws.MessageType, payload []byte) error {
|
|||
|
|
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
|
|||
|
|
defer cancel()
|
|||
|
|
return upstreamConn.WriteFrame(writeCtx, msgType, payload)
|
|||
|
|
}
|
|||
|
|
writeClient := func(msgType coderws.MessageType, payload []byte) error {
|
|||
|
|
writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
|
|||
|
|
defer cancel()
|
|||
|
|
return clientConn.WriteFrame(writeCtx, msgType, payload)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
clientToUpstreamFrames := &atomic.Int64{}
|
|||
|
|
upstreamToClientFrames := &atomic.Int64{}
|
|||
|
|
droppedDownstreamFrames := &atomic.Int64{}
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "relay_start",
|
|||
|
|
PayloadBytes: len(firstClientMessage),
|
|||
|
|
MessageType: relayMessageTypeString(firstMessageType),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
|
|||
|
|
result.Duration = nowFn().Sub(startAt)
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "write_first_message_failed",
|
|||
|
|
Direction: "client_to_upstream",
|
|||
|
|
MessageType: relayMessageTypeString(firstMessageType),
|
|||
|
|
PayloadBytes: len(firstClientMessage),
|
|||
|
|
Error: err.Error(),
|
|||
|
|
})
|
|||
|
|
return result, &RelayExit{Stage: "write_upstream", Err: err}
|
|||
|
|
}
|
|||
|
|
clientToUpstreamFrames.Add(1)
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "write_first_message_ok",
|
|||
|
|
Direction: "client_to_upstream",
|
|||
|
|
MessageType: relayMessageTypeString(firstMessageType),
|
|||
|
|
PayloadBytes: len(firstClientMessage),
|
|||
|
|
})
|
|||
|
|
markActivity()
|
|||
|
|
|
|||
|
|
exitCh := make(chan relayExitSignal, 3)
|
|||
|
|
dropDownstreamWrites := atomic.Bool{}
|
|||
|
|
go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
|
|||
|
|
go runUpstreamToClient(
|
|||
|
|
relayCtx,
|
|||
|
|
upstreamConn,
|
|||
|
|
writeClient,
|
|||
|
|
startAt,
|
|||
|
|
nowFn,
|
|||
|
|
state,
|
|||
|
|
options.OnUsageParseFailure,
|
|||
|
|
options.OnTurnComplete,
|
|||
|
|
&dropDownstreamWrites,
|
|||
|
|
upstreamToClientFrames,
|
|||
|
|
droppedDownstreamFrames,
|
|||
|
|
markActivity,
|
|||
|
|
onTrace,
|
|||
|
|
exitCh,
|
|||
|
|
)
|
|||
|
|
go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh)
|
|||
|
|
|
|||
|
|
firstExit := <-exitCh
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "first_exit",
|
|||
|
|
Direction: relayDirectionFromStage(firstExit.stage),
|
|||
|
|
Graceful: firstExit.graceful,
|
|||
|
|
WroteDownstream: firstExit.wroteDownstream,
|
|||
|
|
Error: relayErrorString(firstExit.err),
|
|||
|
|
})
|
|||
|
|
combinedWroteDownstream := firstExit.wroteDownstream
|
|||
|
|
secondExit := relayExitSignal{graceful: true}
|
|||
|
|
hasSecondExit := false
|
|||
|
|
|
|||
|
|
// 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。
|
|||
|
|
if firstExit.stage == "read_client" && firstExit.graceful {
|
|||
|
|
dropDownstreamWrites.Store(true)
|
|||
|
|
secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout)
|
|||
|
|
} else {
|
|||
|
|
relayCancel()
|
|||
|
|
_ = upstreamConn.Close()
|
|||
|
|
secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
|
|||
|
|
}
|
|||
|
|
if hasSecondExit {
|
|||
|
|
combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "second_exit",
|
|||
|
|
Direction: relayDirectionFromStage(secondExit.stage),
|
|||
|
|
Graceful: secondExit.graceful,
|
|||
|
|
WroteDownstream: secondExit.wroteDownstream,
|
|||
|
|
Error: relayErrorString(secondExit.err),
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
relayCancel()
|
|||
|
|
_ = upstreamConn.Close()
|
|||
|
|
|
|||
|
|
enrichResult(&result, state, nowFn().Sub(startAt))
|
|||
|
|
result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
|
|||
|
|
result.UpstreamToClientFrames = upstreamToClientFrames.Load()
|
|||
|
|
result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
|
|||
|
|
if firstExit.stage == "read_client" && firstExit.graceful {
|
|||
|
|
stage := "client_disconnected"
|
|||
|
|
exitErr := firstExit.err
|
|||
|
|
if hasSecondExit && !secondExit.graceful {
|
|||
|
|
stage = secondExit.stage
|
|||
|
|
exitErr = secondExit.err
|
|||
|
|
}
|
|||
|
|
if exitErr == nil {
|
|||
|
|
exitErr = io.EOF
|
|||
|
|
}
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "relay_exit",
|
|||
|
|
Direction: relayDirectionFromStage(stage),
|
|||
|
|
Graceful: false,
|
|||
|
|
WroteDownstream: combinedWroteDownstream,
|
|||
|
|
Error: relayErrorString(exitErr),
|
|||
|
|
})
|
|||
|
|
return result, &RelayExit{
|
|||
|
|
Stage: stage,
|
|||
|
|
Err: exitErr,
|
|||
|
|
WroteDownstream: combinedWroteDownstream,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if firstExit.graceful && (!hasSecondExit || secondExit.graceful) {
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "relay_complete",
|
|||
|
|
Graceful: true,
|
|||
|
|
WroteDownstream: combinedWroteDownstream,
|
|||
|
|
})
|
|||
|
|
_ = clientConn.Close()
|
|||
|
|
return result, nil
|
|||
|
|
}
|
|||
|
|
if !firstExit.graceful {
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "relay_exit",
|
|||
|
|
Direction: relayDirectionFromStage(firstExit.stage),
|
|||
|
|
Graceful: false,
|
|||
|
|
WroteDownstream: combinedWroteDownstream,
|
|||
|
|
Error: relayErrorString(firstExit.err),
|
|||
|
|
})
|
|||
|
|
return result, &RelayExit{
|
|||
|
|
Stage: firstExit.stage,
|
|||
|
|
Err: firstExit.err,
|
|||
|
|
WroteDownstream: combinedWroteDownstream,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if hasSecondExit && !secondExit.graceful {
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "relay_exit",
|
|||
|
|
Direction: relayDirectionFromStage(secondExit.stage),
|
|||
|
|
Graceful: false,
|
|||
|
|
WroteDownstream: combinedWroteDownstream,
|
|||
|
|
Error: relayErrorString(secondExit.err),
|
|||
|
|
})
|
|||
|
|
return result, &RelayExit{
|
|||
|
|
Stage: secondExit.stage,
|
|||
|
|
Err: secondExit.err,
|
|||
|
|
WroteDownstream: combinedWroteDownstream,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "relay_complete",
|
|||
|
|
Graceful: true,
|
|||
|
|
WroteDownstream: combinedWroteDownstream,
|
|||
|
|
})
|
|||
|
|
_ = clientConn.Close()
|
|||
|
|
return result, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func runClientToUpstream(
|
|||
|
|
ctx context.Context,
|
|||
|
|
clientConn FrameConn,
|
|||
|
|
writeUpstream func(msgType coderws.MessageType, payload []byte) error,
|
|||
|
|
markActivity func(),
|
|||
|
|
forwardedFrames *atomic.Int64,
|
|||
|
|
onTrace func(event RelayTraceEvent),
|
|||
|
|
exitCh chan<- relayExitSignal,
|
|||
|
|
) {
|
|||
|
|
for {
|
|||
|
|
msgType, payload, err := clientConn.ReadFrame(ctx)
|
|||
|
|
if err != nil {
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "read_client_failed",
|
|||
|
|
Direction: "client_to_upstream",
|
|||
|
|
Error: err.Error(),
|
|||
|
|
Graceful: isDisconnectError(err),
|
|||
|
|
})
|
|||
|
|
exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: isDisconnectError(err)}
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
markActivity()
|
|||
|
|
if err := writeUpstream(msgType, payload); err != nil {
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "write_upstream_failed",
|
|||
|
|
Direction: "client_to_upstream",
|
|||
|
|
MessageType: relayMessageTypeString(msgType),
|
|||
|
|
PayloadBytes: len(payload),
|
|||
|
|
Error: err.Error(),
|
|||
|
|
})
|
|||
|
|
exitCh <- relayExitSignal{stage: "write_upstream", err: err}
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if forwardedFrames != nil {
|
|||
|
|
forwardedFrames.Add(1)
|
|||
|
|
}
|
|||
|
|
markActivity()
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func runUpstreamToClient(
|
|||
|
|
ctx context.Context,
|
|||
|
|
upstreamConn FrameConn,
|
|||
|
|
writeClient func(msgType coderws.MessageType, payload []byte) error,
|
|||
|
|
startAt time.Time,
|
|||
|
|
nowFn func() time.Time,
|
|||
|
|
state *relayState,
|
|||
|
|
onUsageParseFailure func(eventType string, usageRaw string),
|
|||
|
|
onTurnComplete func(turn RelayTurnResult),
|
|||
|
|
dropDownstreamWrites *atomic.Bool,
|
|||
|
|
forwardedFrames *atomic.Int64,
|
|||
|
|
droppedFrames *atomic.Int64,
|
|||
|
|
markActivity func(),
|
|||
|
|
onTrace func(event RelayTraceEvent),
|
|||
|
|
exitCh chan<- relayExitSignal,
|
|||
|
|
) {
|
|||
|
|
wroteDownstream := false
|
|||
|
|
for {
|
|||
|
|
msgType, payload, err := upstreamConn.ReadFrame(ctx)
|
|||
|
|
if err != nil {
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "read_upstream_failed",
|
|||
|
|
Direction: "upstream_to_client",
|
|||
|
|
Error: err.Error(),
|
|||
|
|
Graceful: isDisconnectError(err),
|
|||
|
|
WroteDownstream: wroteDownstream,
|
|||
|
|
})
|
|||
|
|
exitCh <- relayExitSignal{
|
|||
|
|
stage: "read_upstream",
|
|||
|
|
err: err,
|
|||
|
|
graceful: isDisconnectError(err),
|
|||
|
|
wroteDownstream: wroteDownstream,
|
|||
|
|
}
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
markActivity()
|
|||
|
|
observedEvent := observedUpstreamEvent{}
|
|||
|
|
switch msgType {
|
|||
|
|
case coderws.MessageText:
|
|||
|
|
observedEvent = observeUpstreamMessage(state, payload, startAt, nowFn, onUsageParseFailure)
|
|||
|
|
case coderws.MessageBinary:
|
|||
|
|
// binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。
|
|||
|
|
}
|
|||
|
|
emitTurnComplete(onTurnComplete, state, observedEvent)
|
|||
|
|
if dropDownstreamWrites != nil && dropDownstreamWrites.Load() {
|
|||
|
|
if droppedFrames != nil {
|
|||
|
|
droppedFrames.Add(1)
|
|||
|
|
}
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "drop_downstream_frame",
|
|||
|
|
Direction: "upstream_to_client",
|
|||
|
|
MessageType: relayMessageTypeString(msgType),
|
|||
|
|
PayloadBytes: len(payload),
|
|||
|
|
WroteDownstream: wroteDownstream,
|
|||
|
|
})
|
|||
|
|
if observedEvent.terminal {
|
|||
|
|
exitCh <- relayExitSignal{
|
|||
|
|
stage: "drain_terminal",
|
|||
|
|
graceful: true,
|
|||
|
|
wroteDownstream: wroteDownstream,
|
|||
|
|
}
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
markActivity()
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if err := writeClient(msgType, payload); err != nil {
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "write_client_failed",
|
|||
|
|
Direction: "upstream_to_client",
|
|||
|
|
MessageType: relayMessageTypeString(msgType),
|
|||
|
|
PayloadBytes: len(payload),
|
|||
|
|
WroteDownstream: wroteDownstream,
|
|||
|
|
Error: err.Error(),
|
|||
|
|
})
|
|||
|
|
exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream}
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
wroteDownstream = true
|
|||
|
|
if forwardedFrames != nil {
|
|||
|
|
forwardedFrames.Add(1)
|
|||
|
|
}
|
|||
|
|
markActivity()
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func runIdleWatchdog(
|
|||
|
|
ctx context.Context,
|
|||
|
|
nowFn func() time.Time,
|
|||
|
|
idleTimeout time.Duration,
|
|||
|
|
lastActivity *atomic.Int64,
|
|||
|
|
onTrace func(event RelayTraceEvent),
|
|||
|
|
exitCh chan<- relayExitSignal,
|
|||
|
|
) {
|
|||
|
|
if idleTimeout <= 0 {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
checkInterval := minDuration(idleTimeout/4, 5*time.Second)
|
|||
|
|
if checkInterval < time.Second {
|
|||
|
|
checkInterval = time.Second
|
|||
|
|
}
|
|||
|
|
ticker := time.NewTicker(checkInterval)
|
|||
|
|
defer ticker.Stop()
|
|||
|
|
|
|||
|
|
for {
|
|||
|
|
select {
|
|||
|
|
case <-ctx.Done():
|
|||
|
|
return
|
|||
|
|
case <-ticker.C:
|
|||
|
|
last := time.Unix(0, lastActivity.Load())
|
|||
|
|
if nowFn().Sub(last) < idleTimeout {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
emitRelayTrace(onTrace, RelayTraceEvent{
|
|||
|
|
Stage: "idle_timeout_triggered",
|
|||
|
|
Direction: "watchdog",
|
|||
|
|
Error: context.DeadlineExceeded.Error(),
|
|||
|
|
})
|
|||
|
|
exitCh <- relayExitSignal{stage: "idle_timeout", err: context.DeadlineExceeded}
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) {
|
|||
|
|
if onTrace == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
onTrace(event)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func relayMessageTypeString(msgType coderws.MessageType) string {
|
|||
|
|
switch msgType {
|
|||
|
|
case coderws.MessageText:
|
|||
|
|
return "text"
|
|||
|
|
case coderws.MessageBinary:
|
|||
|
|
return "binary"
|
|||
|
|
default:
|
|||
|
|
return "unknown(" + strconv.Itoa(int(msgType)) + ")"
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func relayDirectionFromStage(stage string) string {
|
|||
|
|
switch stage {
|
|||
|
|
case "read_client", "write_upstream":
|
|||
|
|
return "client_to_upstream"
|
|||
|
|
case "read_upstream", "write_client", "drain_terminal":
|
|||
|
|
return "upstream_to_client"
|
|||
|
|
case "idle_timeout":
|
|||
|
|
return "watchdog"
|
|||
|
|
default:
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func relayErrorString(err error) string {
|
|||
|
|
if err == nil {
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
return err.Error()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func observeUpstreamMessage(
|
|||
|
|
state *relayState,
|
|||
|
|
message []byte,
|
|||
|
|
startAt time.Time,
|
|||
|
|
nowFn func() time.Time,
|
|||
|
|
onUsageParseFailure func(eventType string, usageRaw string),
|
|||
|
|
) observedUpstreamEvent {
|
|||
|
|
if state == nil || len(message) == 0 {
|
|||
|
|
return observedUpstreamEvent{}
|
|||
|
|
}
|
|||
|
|
values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id")
|
|||
|
|
eventType := strings.TrimSpace(values[0].String())
|
|||
|
|
if eventType == "" {
|
|||
|
|
return observedUpstreamEvent{}
|
|||
|
|
}
|
|||
|
|
responseID := strings.TrimSpace(values[1].String())
|
|||
|
|
if responseID == "" {
|
|||
|
|
responseID = strings.TrimSpace(values[2].String())
|
|||
|
|
}
|
|||
|
|
// 仅 terminal 事件兜底读取顶层 id,避免把 event_id 当成 response_id 关联到 turn。
|
|||
|
|
if responseID == "" && isTerminalEvent(eventType) {
|
|||
|
|
responseID = strings.TrimSpace(values[3].String())
|
|||
|
|
}
|
|||
|
|
now := nowFn()
|
|||
|
|
|
|||
|
|
if state.firstTokenMs == nil && isTokenEvent(eventType) {
|
|||
|
|
ms := int(now.Sub(startAt).Milliseconds())
|
|||
|
|
if ms >= 0 {
|
|||
|
|
state.firstTokenMs = &ms
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure)
|
|||
|
|
observed := observedUpstreamEvent{
|
|||
|
|
eventType: eventType,
|
|||
|
|
responseID: responseID,
|
|||
|
|
usage: parsedUsage,
|
|||
|
|
}
|
|||
|
|
if responseID != "" {
|
|||
|
|
turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now)
|
|||
|
|
if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) {
|
|||
|
|
ms := int(now.Sub(turnTiming.startAt).Milliseconds())
|
|||
|
|
if ms >= 0 {
|
|||
|
|
turnTiming.firstTokenMs = &ms
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if !isTerminalEvent(eventType) {
|
|||
|
|
return observed
|
|||
|
|
}
|
|||
|
|
observed.terminal = true
|
|||
|
|
state.terminalEventType = eventType
|
|||
|
|
if responseID != "" {
|
|||
|
|
state.lastResponseID = responseID
|
|||
|
|
if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok {
|
|||
|
|
duration := now.Sub(turnTiming.startAt)
|
|||
|
|
if duration < 0 {
|
|||
|
|
duration = 0
|
|||
|
|
}
|
|||
|
|
observed.duration = duration
|
|||
|
|
observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return observed
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func emitTurnComplete(
|
|||
|
|
onTurnComplete func(turn RelayTurnResult),
|
|||
|
|
state *relayState,
|
|||
|
|
observed observedUpstreamEvent,
|
|||
|
|
) {
|
|||
|
|
if onTurnComplete == nil || !observed.terminal {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
responseID := strings.TrimSpace(observed.responseID)
|
|||
|
|
if responseID == "" {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
requestModel := ""
|
|||
|
|
if state != nil {
|
|||
|
|
requestModel = state.requestModel
|
|||
|
|
}
|
|||
|
|
onTurnComplete(RelayTurnResult{
|
|||
|
|
RequestModel: requestModel,
|
|||
|
|
Usage: observed.usage,
|
|||
|
|
RequestID: responseID,
|
|||
|
|
TerminalEventType: observed.eventType,
|
|||
|
|
Duration: observed.duration,
|
|||
|
|
FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken),
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming {
|
|||
|
|
if state == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
if state.turnTimingByID == nil {
|
|||
|
|
state.turnTimingByID = make(map[string]*relayTurnTiming, 8)
|
|||
|
|
}
|
|||
|
|
timing, ok := state.turnTimingByID[responseID]
|
|||
|
|
if !ok || timing == nil || timing.startAt.IsZero() {
|
|||
|
|
timing = &relayTurnTiming{startAt: now}
|
|||
|
|
state.turnTimingByID[responseID] = timing
|
|||
|
|
return timing
|
|||
|
|
}
|
|||
|
|
return timing
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) {
|
|||
|
|
if state == nil || state.turnTimingByID == nil {
|
|||
|
|
return relayTurnTiming{}, false
|
|||
|
|
}
|
|||
|
|
timing, ok := state.turnTimingByID[responseID]
|
|||
|
|
if !ok || timing == nil {
|
|||
|
|
return relayTurnTiming{}, false
|
|||
|
|
}
|
|||
|
|
delete(state.turnTimingByID, responseID)
|
|||
|
|
return *timing, true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func openAIWSRelayCloneIntPtr(v *int) *int {
|
|||
|
|
if v == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
cloned := *v
|
|||
|
|
return &cloned
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func parseUsageAndAccumulate(
|
|||
|
|
state *relayState,
|
|||
|
|
message []byte,
|
|||
|
|
eventType string,
|
|||
|
|
onParseFailure func(eventType string, usageRaw string),
|
|||
|
|
) Usage {
|
|||
|
|
if state == nil || len(message) == 0 || !shouldParseUsage(eventType) {
|
|||
|
|
return Usage{}
|
|||
|
|
}
|
|||
|
|
usageResult := gjson.GetBytes(message, "response.usage")
|
|||
|
|
if !usageResult.Exists() {
|
|||
|
|
return Usage{}
|
|||
|
|
}
|
|||
|
|
usageRaw := strings.TrimSpace(usageResult.Raw)
|
|||
|
|
if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") {
|
|||
|
|
recordUsageParseFailure()
|
|||
|
|
if onParseFailure != nil {
|
|||
|
|
onParseFailure(eventType, usageRaw)
|
|||
|
|
}
|
|||
|
|
return Usage{}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
inputResult := gjson.GetBytes(message, "response.usage.input_tokens")
|
|||
|
|
outputResult := gjson.GetBytes(message, "response.usage.output_tokens")
|
|||
|
|
cachedResult := gjson.GetBytes(message, "response.usage.input_tokens_details.cached_tokens")
|
|||
|
|
|
|||
|
|
inputTokens, inputOK := parseUsageIntField(inputResult, true)
|
|||
|
|
outputTokens, outputOK := parseUsageIntField(outputResult, true)
|
|||
|
|
cachedTokens, cachedOK := parseUsageIntField(cachedResult, false)
|
|||
|
|
if !inputOK || !outputOK || !cachedOK {
|
|||
|
|
recordUsageParseFailure()
|
|||
|
|
if onParseFailure != nil {
|
|||
|
|
onParseFailure(eventType, usageRaw)
|
|||
|
|
}
|
|||
|
|
// 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。
|
|||
|
|
return Usage{}
|
|||
|
|
}
|
|||
|
|
parsedUsage := Usage{
|
|||
|
|
InputTokens: inputTokens,
|
|||
|
|
OutputTokens: outputTokens,
|
|||
|
|
CacheReadInputTokens: cachedTokens,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
state.usage.InputTokens += parsedUsage.InputTokens
|
|||
|
|
state.usage.OutputTokens += parsedUsage.OutputTokens
|
|||
|
|
state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
|
|||
|
|
return parsedUsage
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func parseUsageIntField(value gjson.Result, required bool) (int, bool) {
|
|||
|
|
if !value.Exists() {
|
|||
|
|
return 0, !required
|
|||
|
|
}
|
|||
|
|
if value.Type != gjson.Number {
|
|||
|
|
return 0, false
|
|||
|
|
}
|
|||
|
|
return int(value.Int()), true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func enrichResult(result *RelayResult, state *relayState, duration time.Duration) {
|
|||
|
|
if result == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
result.Duration = duration
|
|||
|
|
if state == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
result.RequestModel = state.requestModel
|
|||
|
|
result.Usage = state.usage
|
|||
|
|
result.RequestID = state.lastResponseID
|
|||
|
|
result.TerminalEventType = state.terminalEventType
|
|||
|
|
result.FirstTokenMs = state.firstTokenMs
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func isDisconnectError(err error) bool {
|
|||
|
|
if err == nil {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
switch coderws.CloseStatus(err) {
|
|||
|
|
case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
message := strings.ToLower(strings.TrimSpace(err.Error()))
|
|||
|
|
if message == "" {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
return strings.Contains(message, "failed to read frame header: eof") ||
|
|||
|
|
strings.Contains(message, "unexpected eof") ||
|
|||
|
|
strings.Contains(message, "use of closed network connection") ||
|
|||
|
|
strings.Contains(message, "connection reset by peer") ||
|
|||
|
|
strings.Contains(message, "broken pipe")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func isTerminalEvent(eventType string) bool {
|
|||
|
|
switch eventType {
|
|||
|
|
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
|
|||
|
|
return true
|
|||
|
|
default:
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func shouldParseUsage(eventType string) bool {
|
|||
|
|
switch eventType {
|
|||
|
|
case "response.completed", "response.done", "response.failed":
|
|||
|
|
return true
|
|||
|
|
default:
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func isTokenEvent(eventType string) bool {
|
|||
|
|
if eventType == "" {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
switch eventType {
|
|||
|
|
case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
if strings.Contains(eventType, ".delta") {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
if strings.HasPrefix(eventType, "response.output_text") {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
if strings.HasPrefix(eventType, "response.output") {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
return eventType == "response.completed" || eventType == "response.done"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func minDuration(a, b time.Duration) time.Duration {
|
|||
|
|
if a <= 0 {
|
|||
|
|
return b
|
|||
|
|
}
|
|||
|
|
if b <= 0 {
|
|||
|
|
return a
|
|||
|
|
}
|
|||
|
|
if a < b {
|
|||
|
|
return a
|
|||
|
|
}
|
|||
|
|
return b
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) {
|
|||
|
|
if timeout <= 0 {
|
|||
|
|
timeout = 200 * time.Millisecond
|
|||
|
|
}
|
|||
|
|
select {
|
|||
|
|
case sig := <-exitCh:
|
|||
|
|
return sig, true
|
|||
|
|
case <-time.After(timeout):
|
|||
|
|
return relayExitSignal{}, false
|
|||
|
|
}
|
|||
|
|
}
|