Files
sub2api/backend/internal/service/openai_ws_v2_passthrough_adapter.go
deqiying 23555be380 fix(openai): 修复 WS passthrough 使用记录缺失推理强度和 User-Agent
- 为 OpenAI Responses WebSocket v2 passthrough 补齐每轮 reasoning_effort 元数据
- 传递首帧渠道映射前模型,保留模型后缀推理强度推导能力
- 增加 usage log 端到端回归,覆盖入站 User-Agent、显式 effort 和渠道映射场景
2026-05-03 19:33:09 +08:00

665 lines
24 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync/atomic"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type openAIWSClientFrameConn struct {
conn *coderws.Conn
}
// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs
// every client→upstream frame through the OpenAI Fast Policy. It is the
// passthrough-relay equivalent of the parseClientPayload integration in the
// ingress session path. filter returns:
// - newPayload, nil, nil: forward the (possibly mutated) payload
// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error
// event via onBlock and surfaces a transport-level error so the relay
// stops reading from the client.
// - _, _, err: a transport error other than block.
type openAIWSPolicyEnforcingFrameConn struct {
inner openaiwsv2.FrameConn
filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error)
onBlock func(blocked *OpenAIFastBlockedError)
}
var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil)
func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.inner == nil {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
msgType, payload, err := c.inner.ReadFrame(ctx)
if err != nil {
return msgType, payload, err
}
if c.filter == nil {
return msgType, payload, nil
}
updated, blocked, filterErr := c.filter(msgType, payload)
if filterErr != nil {
return msgType, payload, filterErr
}
if blocked != nil {
if c.onBlock != nil {
c.onBlock(blocked)
}
return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
}
return msgType, updated, nil
}
func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.inner == nil {
return errOpenAIWSConnClosed
}
return c.inner.WriteFrame(ctx, msgType, payload)
}
func (c *openAIWSPolicyEnforcingFrameConn) Close() error {
if c == nil || c.inner == nil {
return nil
}
return c.inner.Close()
}
// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective
// model name that should be passed to evaluateOpenAIFastPolicy for a single
// passthrough WS frame. Mirrors the HTTP-side normalization
// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path
// matches model whitelists identically.
func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string {
if account == nil || len(payload) == 0 {
return ""
}
original := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if original == "" {
return ""
}
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
}
// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model
// derived from a session.update frame's session.model field. Returns "" when
// the frame is not a session.update event or carries no session.model. Used
// by the per-frame policy filter (client→upstream direction) to keep
// capturedSessionModel in sync with the session-level model the client may
// rotate mid-session.
//
// Realtime / Responses WS lets the client change the session model after
// the WS handshake via:
//
// {"type":"session.update","session":{"model":"gpt-5.5", ...}}
//
// If we only capture the model from the very first frame, a client can ship
// gpt-4o on the first response.create (whitelisted as pass), then
// session.update to gpt-5.5, then send response.create without "model" so
// the per-frame resolver returns "" and the stale capturedSessionModel falls
// back to gpt-4o — defeating the gpt-5.5 fast-policy filter.
func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string {
if account == nil || len(payload) == 0 {
return ""
}
frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
if frameType != "session.update" {
return ""
}
original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
if original == "" {
return ""
}
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
}
type openAIWSPassthroughUsageMeta struct {
serviceTier atomic.Pointer[string]
reasoningEffort atomic.Pointer[string]
// 仅在 client->upstream filter goroutine 中读写Load 侧通过上方原子指针同步。
sessionRequestModel string
}
func newOpenAIWSPassthroughUsageMeta(initialRequestModel string, firstFrame []byte) *openAIWSPassthroughUsageMeta {
meta := &openAIWSPassthroughUsageMeta{
sessionRequestModel: strings.TrimSpace(initialRequestModel),
}
if meta.sessionRequestModel == "" {
meta.sessionRequestModel = openAIWSPassthroughRequestModelForFrame(firstFrame)
}
return meta
}
func (m *openAIWSPassthroughUsageMeta) initFromFirstFrame(policyOutput []byte) {
if m == nil {
return
}
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, m.sessionRequestModel))
}
func (m *openAIWSPassthroughUsageMeta) updateSessionRequestModel(payload []byte) {
if m == nil {
return
}
if model := openAIWSPassthroughRequestModelFromSessionFrame(payload); model != "" {
m.sessionRequestModel = model
}
}
func (m *openAIWSPassthroughUsageMeta) requestModelForFrame(payload []byte) string {
if m == nil {
return openAIWSPassthroughRequestModelForFrame(payload)
}
if model := openAIWSPassthroughRequestModelForFrame(payload); model != "" {
return model
}
return m.sessionRequestModel
}
func (m *openAIWSPassthroughUsageMeta) updateFromResponseCreate(policyOutput []byte, requestModelForFrame string) {
if m == nil {
return
}
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, requestModelForFrame))
}
func openAIWSPassthroughRequestModelForFrame(payload []byte) string {
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
return ""
}
return strings.TrimSpace(gjson.GetBytes(payload, "model").String())
}
func openAIWSPassthroughRequestModelFromSessionFrame(payload []byte) string {
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "session.update" {
return ""
}
return strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
}
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.conn == nil {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Read(ctx)
}
func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Write(ctx, msgType, payload)
}
func (c *openAIWSClientFrameConn) Close() error {
if c == nil || c.conn == nil {
return nil
}
_ = c.conn.Close(coderws.StatusNormalClosure, "")
_ = c.conn.CloseNow()
return nil
}
func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
ctx context.Context,
c *gin.Context,
clientConn *coderws.Conn,
account *Account,
token string,
firstClientMessage []byte,
hooks *OpenAIWSIngressHooks,
wsDecision OpenAIWSProtocolDecision,
) error {
if s == nil {
return errors.New("service is nil")
}
if clientConn == nil {
return errors.New("client websocket is nil")
}
if account == nil {
return errors.New("account is nil")
}
if strings.TrimSpace(token) == "" {
return errors.New("token is empty")
}
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
logOpenAIWSV2Passthrough(
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
account.ID,
truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen),
openaiwsv2RelayMessageTypeName(coderws.MessageText),
len(firstClientMessage),
)
// Apply OpenAI Fast Policy on the first response.create frame. Subsequent
// frames are filtered via a wrapping FrameConn below so every client→
// upstream frame goes through the same policy evaluator/normalize/scope as
// HTTP entrypoints.
//
// We capture the session-level model from the first frame here so the
// per-frame filter (below) can fall back to it when a follow-up frame
// omits "model" — Realtime clients are allowed to send response.create
// without re-stating the model, in which case the upstream uses the model
// negotiated at session.update time. Without this fallback, an empty
// model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be
// silently passed through, defeating the policy on every frame after
// the first.
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
initialRequestModel := ""
if hooks != nil {
initialRequestModel = hooks.InitialRequestModel
}
usageMeta := newOpenAIWSPassthroughUsageMeta(initialRequestModel, firstClientMessage)
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
if policyErr != nil {
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
}
if blocked != nil {
// coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires
// writeFrameMu, writes the entire frame, and Flushes the underlying
// bufio writer before returning (write.go:42 → write.go:307-311).
// The subsequent close handshake re-acquires the same writeFrameMu
// to send the close frame, so the error event is guaranteed to
// reach the kernel send buffer before any close frame is queued.
// No explicit flush hop is required here.
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
if eventBytes != nil {
writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
cancelWrite()
}
return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
}
firstClientMessage = updatedFirst
// 在 policy filter 之后再提取 service_tier / reasoning_effort 用于
// usage 上报filter
// 命中时 service_tier 已经从 firstClientMessage 中删除billing 应当
// 反映上游实际处理的 tiernil = default而不是用户最初请求的
// "priority"。HTTP 入口line ~2728 extractOpenAIServiceTier(reqBody)
// 与 WS ingressopenai_ws_forwarder.go:2991 取自 payload的语义一致。
//
// 多轮 passthroughOpenAI Realtime / Responses WS 协议允许客户端在
// 同一连接的不同 response.create 帧上发送不同 service_tier参考
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
// 因此使用 atomic.Pointer[string] 在 filterrunClientToUpstream
// goroutine和 OnTurnComplete / final resultrunUpstreamToClient
// goroutine之间同步当前 turn 的 usage metadata。
usageMeta.initFromFirstFrame(firstClientMessage)
wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil {
return fmt.Errorf("build ws url: %w", err)
}
wsHost := "-"
wsPath := "-"
if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil {
wsHost = normalizeOpenAIWSLogValue(parsedURL.Host)
wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
}
logOpenAIWSV2Passthrough(
"relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v",
account.ID,
wsHost,
wsPath,
account.ProxyID != nil && account.Proxy != nil,
)
isCodexCLI := false
if c != nil {
isCodexCLI = openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator"))
}
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
isCodexCLI = true
}
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
dialer := s.getOpenAIWSPassthroughDialer()
if dialer == nil {
return errors.New("openai ws passthrough dialer is nil")
}
dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout())
defer cancelDial()
upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL)
if err != nil {
logOpenAIWSV2Passthrough(
"relay_dial_failed account_id=%d status_code=%d err=%s",
account.ID,
statusCode,
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
)
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
}
defer func() {
_ = upstreamConn.Close()
}()
logOpenAIWSV2Passthrough(
"relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s",
account.ID,
statusCode,
openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"),
)
upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn)
if !ok {
return errors.New("openai ws passthrough upstream connection does not support frame relay")
}
completedTurns := atomic.Int32{}
policyClientConn := &openAIWSPolicyEnforcingFrameConn{
inner: &openAIWSClientFrameConn{conn: clientConn},
// 注意线程安全filter 仅在 runClientToUpstream 这一条
// goroutine 中被调用passthrough_relay.go: ReadFrame loop
// capturedSessionModel 的读写都发生在该 goroutine 内,因此无需
// 加锁/原子化。
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if msgType != coderws.MessageText {
return payload, nil, nil
}
// 在评估策略前先刷新 capturedSessionModel客户端可能通过
// session.update 修改 session-level modelRealtime /
// Responses WS 协议允许),如果不刷新就会出现
// "首帧 model=gpt-4opass→ session.update 改成 gpt-5.5
// → 不带 model 的 response.create fallback 到 gpt-4o" 的
// 绕过路径。这里只看 session.update 事件中的 session.model
// 字段response.create 自己的 model 仍然由其本帧字段决定。
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
capturedSessionModel = updated
}
usageMeta.updateSessionRequestModel(payload)
requestModelForThisFrame := usageMeta.requestModelForFrame(payload)
// Per-frame model first; if the client omits "model" on a
// follow-up frame (legal in Realtime), fall back to the
// session-level model captured from the first frame so the
// model whitelist still resolves. An empty model would miss
// any whitelist and silently fall back to pass.
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
if model == "" {
model = capturedSessionModel
}
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
// 多轮 passthrough usage仅在成功non-block / non-err
// 的 response.create 帧上更新 usageMeta使用
// filter 处理后的 payload与首帧 policy-after-extract 语义
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
// - 非 response.create 帧response.cancel /
// conversation.item.create / session.update 等)不携带
// per-response metadata不应覆盖前一轮值。
// - blocked != nil该帧不会发送上游usage metadata 应保持
// 上一轮值。
// - policyErr != nil异常路径保持上一轮值。
// - 不带 service_tier 的 response.create 会让
// extractOpenAIServiceTierFromBody 返回 nil这里有意
// 覆盖Store(nil)),因为 OpenAI 上游对该帧实际不传
// service_tier 时按 default 处理billing 应如实反映。
if policyErr == nil && blocked == nil &&
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
usageMeta.updateFromResponseCreate(out, requestModelForThisFrame)
}
return out, blocked, policyErr
},
onBlock: func(blocked *OpenAIFastBlockedError) {
// See note above on Conn.Write being synchronous w.r.t. flush;
// no explicit flush is required to ensure the error event lands
// before the close frame.
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
if eventBytes == nil {
return
}
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
cancel()
},
}
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
Ctx: ctx,
ClientConn: policyClientConn,
UpstreamConn: upstreamFrameConn,
FirstClientMessage: firstClientMessage,
Options: openaiwsv2.RelayOptions{
WriteTimeout: s.openAIWSWriteTimeout(),
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
FirstMessageType: coderws.MessageText,
OnUsageParseFailure: func(eventType string, usageRaw string) {
logOpenAIWSV2Passthrough(
"usage_parse_failed event_type=%s usage_raw=%s",
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen),
)
},
OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) {
turnNo := int(completedTurns.Add(1))
turnResult := &OpenAIForwardResult{
RequestID: turn.RequestID,
Usage: OpenAIUsage{
InputTokens: turn.Usage.InputTokens,
OutputTokens: turn.Usage.OutputTokens,
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
},
Model: turn.RequestModel,
ServiceTier: usageMeta.serviceTier.Load(),
ReasoningEffort: usageMeta.reasoningEffort.Load(),
Stream: true,
OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders),
Duration: turn.Duration,
FirstTokenMs: turn.FirstTokenMs,
}
logOpenAIWSV2Passthrough(
"relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d",
account.ID,
turnNo,
truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(turn.TerminalEventType, openAIWSLogValueMaxLen),
turnResult.Duration.Milliseconds(),
openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs),
turnResult.Usage.InputTokens,
turnResult.Usage.OutputTokens,
turnResult.Usage.CacheReadInputTokens,
)
if hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(turnNo, turnResult, nil)
}
},
OnTrace: func(event openaiwsv2.RelayTraceEvent) {
logOpenAIWSV2Passthrough(
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",
account.ID,
truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen),
event.PayloadBytes,
event.Graceful,
event.WroteDownstream,
truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen),
)
},
},
})
result := &OpenAIForwardResult{
RequestID: relayResult.RequestID,
Usage: OpenAIUsage{
InputTokens: relayResult.Usage.InputTokens,
OutputTokens: relayResult.Usage.OutputTokens,
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
},
Model: relayResult.RequestModel,
ServiceTier: usageMeta.serviceTier.Load(),
ReasoningEffort: usageMeta.reasoningEffort.Load(),
Stream: true,
OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders),
Duration: relayResult.Duration,
FirstTokenMs: relayResult.FirstTokenMs,
}
turnCount := int(completedTurns.Load())
if relayExit == nil {
logOpenAIWSV2Passthrough(
"relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
account.ID,
truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(relayResult.TerminalEventType, openAIWSLogValueMaxLen),
result.Duration.Milliseconds(),
relayResult.ClientToUpstreamFrames,
relayResult.UpstreamToClientFrames,
relayResult.DroppedDownstreamFrames,
turnCount,
)
// 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。
if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(1, result, nil)
}
return nil
}
logOpenAIWSV2Passthrough(
"relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
account.ID,
truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen),
relayExit.WroteDownstream,
truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen),
result.Duration.Milliseconds(),
relayResult.ClientToUpstreamFrames,
relayResult.UpstreamToClientFrames,
relayResult.DroppedDownstreamFrames,
turnCount,
)
relayErr := relayExit.Err
if relayExit.Stage == "idle_timeout" {
relayErr = NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"client websocket idle timeout",
relayErr,
)
}
turnErr := wrapOpenAIWSIngressTurnError(
relayExit.Stage,
relayErr,
relayExit.WroteDownstream,
)
if hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(turnCount+1, nil, turnErr)
}
return turnErr
}
func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError(
err error,
statusCode int,
handshakeHeaders http.Header,
) error {
if err == nil {
return nil
}
wrappedErr := err
var dialErr *openAIWSDialError
if !errors.As(err, &dialErr) {
wrappedErr = &openAIWSDialError{
StatusCode: statusCode,
ResponseHeaders: cloneHeader(handshakeHeaders),
Err: err,
}
}
if errors.Is(err, context.Canceled) {
return err
}
if errors.Is(err, context.DeadlineExceeded) {
return NewOpenAIWSClientCloseError(
coderws.StatusTryAgainLater,
"upstream websocket connect timeout",
wrappedErr,
)
}
if statusCode == http.StatusTooManyRequests {
return NewOpenAIWSClientCloseError(
coderws.StatusTryAgainLater,
"upstream websocket is busy, please retry later",
wrappedErr,
)
}
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"upstream websocket authentication failed",
wrappedErr,
)
}
if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError {
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"upstream websocket handshake rejected",
wrappedErr,
)
}
return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr)
}
func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string {
switch msgType {
case coderws.MessageText:
return "text"
case coderws.MessageBinary:
return "binary"
default:
return fmt.Sprintf("unknown(%d)", msgType)
}
}
func relayErrorText(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func openAIWSFirstTokenMsForLog(firstTokenMs *int) int {
if firstTokenMs == nil {
return -1
}
return *firstTokenMs
}
func logOpenAIWSV2Passthrough(format string, args ...any) {
logger.LegacyPrintf(
"service.openai_ws_v2",
"[OpenAI WS v2 passthrough] %s "+format,
append([]any{openaiWSV2PassthroughModeFields}, args...)...,
)
}