Files
sub2api/backend/internal/service/openai_ws_v2_passthrough_adapter.go
DaydreamCoding 30f55a1f72 feat(openai): OpenAI Fast/Flex Policy 完整实现(HTTP + WebSocket + Admin)
对称参照 Claude BetaPolicy 的 fast-mode 过滤实现,新增针对 OpenAI 上游
service_tier 字段(priority / flex,含客户端 "fast" → "priority" 归一化)的
pass / filter / block 三态策略,覆盖全部 OpenAI 入口 + admin 配置入口。

后端核心
- 新增 SettingKeyOpenAIFastPolicySettings、OpenAIFastPolicyRule、
  OpenAIFastPolicySettings 配置模型,含规则的 service_tier × action × scope
  × 模型白名单 × fallback action 维度。
- SettingService.Get/SetOpenAIFastPolicySettings;缺失时返回内置默认策略
  (所有模型的 priority 走 filter,whitelist 为空,fallback=pass)。设计
  依据:service_tier=fast 是用户级开关,与 model 字段正交,默认锁定特定
  model slug 会留下"用 gpt-4 + fast 透传 priority 上游"的绕过路径。JSON
  解析失败不再静默 fallback,slog.Warn 记录脏数据,便于运维定位。
- service_tier 归一化(trim + ToLower + fast→priority + 白名单 priority/flex)
  与策略评估(evaluateOpenAIFastPolicy)作为唯一真实来源,HTTP / WS 共用。
  抽出纯函数 evaluateOpenAIFastPolicyWithSettings,配合 ctx-bound settings
  快照(withOpenAIFastPolicyContext / openAIFastPolicySettingsFromContext),
  WS 长会话入口预取一次后所有帧复用,避免每帧打到 settingService。

HTTP 入口(4 个)
- Chat Completions、Anthropic 兼容(Messages,含 BetaFastMode→priority 二次
  命中)、原生 Responses、Passthrough Responses 全部接入
  applyOpenAIFastPolicyToBody,filter 走 sjson 顶层删除 service_tier,block
  返回 403 forbidden_error JSON。
- 4 入口统一使用 upstream 视角的 model(GetMappedModel +
  normalizeOpenAIModelForUpstream + Codex OAuth normalize 后的 slug),
  避免 chat/messages/native /responses/passthrough 因为 model 维度不同
  造成 whitelist 命中差异。
- 在 pass 路径也把客户端 "fast" 别名归一化为 "priority" 写回 body,
  否则 native /responses 与 passthrough 入口会把 "fast" 原样透传给上游
  导致 400/拒绝(chat-completions 入口的 normalizeResponsesBodyServiceTier
  此前已具备同等行为)。

WebSocket 入口
- 新增 applyOpenAIFastPolicyToWSResponseCreate:严格匹配
  type="response.create",仅处理顶层 service_tier;filter 用 sjson 删字段,
  block 返回 typed *OpenAIFastBlockedError。
- ingress 路径在 parseClientPayload 内调用,block 命中先 Write Realtime
  风格 error event 再返回 OpenAIWSClientCloseError(StatusPolicyViolation
  =1008),依赖底层 WebSocket Conn.Write 的同步 flush 保证 error 先于
  close。
- passthrough 路径在 RunEntry 前对 firstClientMessage 应用策略,并通过
  openAIWSPolicyEnforcingFrameConn 包装 ReadFrame 对每个 client→upstream
  帧执行策略;后续帧无 model 字段时回退到 capturedSessionModel。
  filter 闭包内同时侦测 session.update / session.created 帧的 session.model
  字段刷新 capturedSessionModel,封堵"首帧 model=gpt-4o(pass)→
  session.update 改为 gpt-5.5 → 不带 model 的 response.create fallback
  到 gpt-4o"的 mid-session 绕过路径。
- passthrough billing:requestServiceTier 在策略 filter 之后再从
  firstClientMessage 提取,filter 命中时 OpenAIForwardResult.ServiceTier
  上报 nil(default tier),与 HTTP 入口(reqBody 来自 post-filter map)
  / WS ingress(payload 来自 post-filter bytes)的语义一致。
- 错误事件 schema:{event_id: "evt_<32hex>", type: "error",
  error: {type: "forbidden_error", code: "policy_violation", message}},
  与 OpenAI codex 客户端 error event 解析兼容。

Admin / Frontend
- dto.SystemSettings / UpdateSettingsRequest 新增
  openai_fast_policy_settings 字段(omitempty),bulk GET/PUT 接入。
- Settings 页 Gateway 页签新增 Fast/Flex Policy 表单卡片:
  service_tier × action × scope × 模型白名单 × fallback action 全字段配置。
- 前端守门:openaiFastPolicyLoaded 标志仅在 GET 真带回字段时才允许回写,
  避免 rollout/错误把默认规则覆盖成空;saveSettings 回写循环 skip 该字段,
  由专用刷新逻辑处理;仅 action=block 时发送 error_message,匹配后端
  omitempty 行为。

测试
- HTTP 路径:openai_fast_policy_test.go 覆盖默认配置(whitelist=[],所有
  模型 priority filter)/ block 自定义错误 / scope 区分 / filter 删字段 /
  block 不改 body / block 短路上游 / Anthropic BetaFastMode 触发 OpenAI
  fast policy 等场景。
- WebSocket 路径:openai_fast_policy_ws_test.go 覆盖
    helper 单元(filter / fast→priority 归一化 / flex 透传 / block typed
    error / 无 service_tier 字节不变 / 非 response.create 帧不动 / 空 type
    帧不动 / event_id+code 字段断言 / 非字符串 service_tier 容错)+
    pass 路径 fast 别名归一化回归 +
    ingress 端到端(filter 后上游不含 service_tier / block 后客户端先收
    error event 再收 close 1008 且上游 0 写)+
    passthrough capturedSessionModel fallback 用例(whitelist 策略下首帧
    建立、缺 model 命中 fallback、缺少 fallback 时的 leak 文档化)+
    passthrough session.update / session.created 旋转 capturedSessionModel
    的 mid-session 绕过回归 +
    passthrough billing post-filter ServiceTier 与 idempotent filter 回归。

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 11:15:09 +08:00

591 lines
22 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync/atomic"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type openAIWSClientFrameConn struct {
conn *coderws.Conn
}
// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs
// every client→upstream frame through the OpenAI Fast Policy. It is the
// passthrough-relay equivalent of the parseClientPayload integration in the
// ingress session path. filter returns:
// - newPayload, nil, nil: forward the (possibly mutated) payload
// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error
// event via onBlock and surfaces a transport-level error so the relay
// stops reading from the client.
// - _, _, err: a transport error other than block.
type openAIWSPolicyEnforcingFrameConn struct {
inner openaiwsv2.FrameConn
filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error)
onBlock func(blocked *OpenAIFastBlockedError)
}
var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil)
func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.inner == nil {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
msgType, payload, err := c.inner.ReadFrame(ctx)
if err != nil {
return msgType, payload, err
}
if c.filter == nil {
return msgType, payload, nil
}
updated, blocked, filterErr := c.filter(msgType, payload)
if filterErr != nil {
return msgType, payload, filterErr
}
if blocked != nil {
if c.onBlock != nil {
c.onBlock(blocked)
}
return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
}
return msgType, updated, nil
}
func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.inner == nil {
return errOpenAIWSConnClosed
}
return c.inner.WriteFrame(ctx, msgType, payload)
}
func (c *openAIWSPolicyEnforcingFrameConn) Close() error {
if c == nil || c.inner == nil {
return nil
}
return c.inner.Close()
}
// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective
// model name that should be passed to evaluateOpenAIFastPolicy for a single
// passthrough WS frame. Mirrors the HTTP-side normalization
// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path
// matches model whitelists identically.
func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string {
if account == nil || len(payload) == 0 {
return ""
}
original := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if original == "" {
return ""
}
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
}
// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model
// derived from a session.update frame's session.model field. Returns "" when
// the frame is not a session.update event or carries no session.model. Used
// by the per-frame policy filter (client→upstream direction) to keep
// capturedSessionModel in sync with the session-level model the client may
// rotate mid-session.
//
// Realtime / Responses WS lets the client change the session model after
// the WS handshake via:
//
// {"type":"session.update","session":{"model":"gpt-5.5", ...}}
//
// If we only capture the model from the very first frame, a client can ship
// gpt-4o on the first response.create (whitelisted as pass), then
// session.update to gpt-5.5, then send response.create without "model" so
// the per-frame resolver returns "" and the stale capturedSessionModel falls
// back to gpt-4o — defeating the gpt-5.5 fast-policy filter.
func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string {
if account == nil || len(payload) == 0 {
return ""
}
frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
if frameType != "session.update" {
return ""
}
original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
if original == "" {
return ""
}
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
}
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.conn == nil {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Read(ctx)
}
func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Write(ctx, msgType, payload)
}
func (c *openAIWSClientFrameConn) Close() error {
if c == nil || c.conn == nil {
return nil
}
_ = c.conn.Close(coderws.StatusNormalClosure, "")
_ = c.conn.CloseNow()
return nil
}
func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
ctx context.Context,
c *gin.Context,
clientConn *coderws.Conn,
account *Account,
token string,
firstClientMessage []byte,
hooks *OpenAIWSIngressHooks,
wsDecision OpenAIWSProtocolDecision,
) error {
if s == nil {
return errors.New("service is nil")
}
if clientConn == nil {
return errors.New("client websocket is nil")
}
if account == nil {
return errors.New("account is nil")
}
if strings.TrimSpace(token) == "" {
return errors.New("token is empty")
}
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
logOpenAIWSV2Passthrough(
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
account.ID,
truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen),
openaiwsv2RelayMessageTypeName(coderws.MessageText),
len(firstClientMessage),
)
// Apply OpenAI Fast Policy on the first response.create frame. Subsequent
// frames are filtered via a wrapping FrameConn below so every client→
// upstream frame goes through the same policy evaluator/normalize/scope as
// HTTP entrypoints.
//
// We capture the session-level model from the first frame here so the
// per-frame filter (below) can fall back to it when a follow-up frame
// omits "model" — Realtime clients are allowed to send response.create
// without re-stating the model, in which case the upstream uses the model
// negotiated at session.update time. Without this fallback, an empty
// model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be
// silently passed through, defeating the policy on every frame after
// the first.
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
if policyErr != nil {
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
}
if blocked != nil {
// coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires
// writeFrameMu, writes the entire frame, and Flushes the underlying
// bufio writer before returning (write.go:42 → write.go:307-311).
// The subsequent close handshake re-acquires the same writeFrameMu
// to send the close frame, so the error event is guaranteed to
// reach the kernel send buffer before any close frame is queued.
// No explicit flush hop is required here.
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
if eventBytes != nil {
writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
cancelWrite()
}
return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
}
firstClientMessage = updatedFirst
// 在 policy filter 之后再提取 service_tier 用于 billing 上报filter
// 命中时 service_tier 已经从 firstClientMessage 中删除billing 应当
// 反映上游实际处理的 tiernil = default而不是用户最初请求的
// "priority"。HTTP 入口line ~2728 extractOpenAIServiceTier(reqBody)
// 与 WS ingressopenai_ws_forwarder.go:2991 取自 payload的语义一致。
//
// 多轮 passthroughOpenAI Realtime / Responses WS 协议允许客户端在
// 同一连接的不同 response.create 帧上发送不同 service_tier参考
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
// 因此使用 atomic.Pointer[string] 在 filterrunClientToUpstream
// goroutine和 OnTurnComplete / final resultrunUpstreamToClient
// goroutine之间同步当前 turn 的 service_tier。
// extractOpenAIServiceTierFromBody 返回 *string本身是指针类型
// 可直接 Store/Load 而无需额外封装。
var requestServiceTierPtr atomic.Pointer[string]
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil {
return fmt.Errorf("build ws url: %w", err)
}
wsHost := "-"
wsPath := "-"
if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil {
wsHost = normalizeOpenAIWSLogValue(parsedURL.Host)
wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
}
logOpenAIWSV2Passthrough(
"relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v",
account.ID,
wsHost,
wsPath,
account.ProxyID != nil && account.Proxy != nil,
)
isCodexCLI := false
if c != nil {
isCodexCLI = openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator"))
}
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
isCodexCLI = true
}
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
dialer := s.getOpenAIWSPassthroughDialer()
if dialer == nil {
return errors.New("openai ws passthrough dialer is nil")
}
dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout())
defer cancelDial()
upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL)
if err != nil {
logOpenAIWSV2Passthrough(
"relay_dial_failed account_id=%d status_code=%d err=%s",
account.ID,
statusCode,
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
)
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
}
defer func() {
_ = upstreamConn.Close()
}()
logOpenAIWSV2Passthrough(
"relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s",
account.ID,
statusCode,
openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"),
)
upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn)
if !ok {
return errors.New("openai ws passthrough upstream connection does not support frame relay")
}
completedTurns := atomic.Int32{}
policyClientConn := &openAIWSPolicyEnforcingFrameConn{
inner: &openAIWSClientFrameConn{conn: clientConn},
// 注意线程安全filter 仅在 runClientToUpstream 这一条
// goroutine 中被调用passthrough_relay.go: ReadFrame loop
// capturedSessionModel 的读写都发生在该 goroutine 内,因此无需
// 加锁/原子化。
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if msgType != coderws.MessageText {
return payload, nil, nil
}
// 在评估策略前先刷新 capturedSessionModel客户端可能通过
// session.update 修改 session-level modelRealtime /
// Responses WS 协议允许),如果不刷新就会出现
// "首帧 model=gpt-4opass→ session.update 改成 gpt-5.5
// → 不带 model 的 response.create fallback 到 gpt-4o" 的
// 绕过路径。这里只看 session.update 事件中的 session.model
// 字段response.create 自己的 model 仍然由其本帧字段决定。
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
capturedSessionModel = updated
}
// Per-frame model first; if the client omits "model" on a
// follow-up frame (legal in Realtime), fall back to the
// session-level model captured from the first frame so the
// model whitelist still resolves. An empty model would miss
// any whitelist and silently fall back to pass.
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
if model == "" {
model = capturedSessionModel
}
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
// 多轮 passthrough billing仅在成功non-block / non-err
// 的 response.create 帧上更新 requestServiceTierPtr使用
// filter 处理后的 payload与首帧 policy-after-extract 语义
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
// - 非 response.create 帧response.cancel /
// conversation.item.create / session.update 等)不携带
// per-response service_tier不应覆盖前一轮值。
// - blocked != nil该帧不会发送上游billing tier 应保持
// 上一轮值。
// - policyErr != nil异常路径保持上一轮值。
// - 不带 service_tier 的 response.create 会让
// extractOpenAIServiceTierFromBody 返回 nil这里有意
// 覆盖Store(nil)),因为 OpenAI 上游对该帧实际不传
// service_tier 时按 default 处理billing 应如实反映。
if policyErr == nil && blocked == nil &&
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
}
return out, blocked, policyErr
},
onBlock: func(blocked *OpenAIFastBlockedError) {
// See note above on Conn.Write being synchronous w.r.t. flush;
// no explicit flush is required to ensure the error event lands
// before the close frame.
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
if eventBytes == nil {
return
}
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
cancel()
},
}
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
Ctx: ctx,
ClientConn: policyClientConn,
UpstreamConn: upstreamFrameConn,
FirstClientMessage: firstClientMessage,
Options: openaiwsv2.RelayOptions{
WriteTimeout: s.openAIWSWriteTimeout(),
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
FirstMessageType: coderws.MessageText,
OnUsageParseFailure: func(eventType string, usageRaw string) {
logOpenAIWSV2Passthrough(
"usage_parse_failed event_type=%s usage_raw=%s",
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen),
)
},
OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) {
turnNo := int(completedTurns.Add(1))
turnResult := &OpenAIForwardResult{
RequestID: turn.RequestID,
Usage: OpenAIUsage{
InputTokens: turn.Usage.InputTokens,
OutputTokens: turn.Usage.OutputTokens,
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
},
Model: turn.RequestModel,
ServiceTier: requestServiceTierPtr.Load(),
Stream: true,
OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders),
Duration: turn.Duration,
FirstTokenMs: turn.FirstTokenMs,
}
logOpenAIWSV2Passthrough(
"relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d",
account.ID,
turnNo,
truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(turn.TerminalEventType, openAIWSLogValueMaxLen),
turnResult.Duration.Milliseconds(),
openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs),
turnResult.Usage.InputTokens,
turnResult.Usage.OutputTokens,
turnResult.Usage.CacheReadInputTokens,
)
if hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(turnNo, turnResult, nil)
}
},
OnTrace: func(event openaiwsv2.RelayTraceEvent) {
logOpenAIWSV2Passthrough(
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",
account.ID,
truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen),
event.PayloadBytes,
event.Graceful,
event.WroteDownstream,
truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen),
)
},
},
})
result := &OpenAIForwardResult{
RequestID: relayResult.RequestID,
Usage: OpenAIUsage{
InputTokens: relayResult.Usage.InputTokens,
OutputTokens: relayResult.Usage.OutputTokens,
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
},
Model: relayResult.RequestModel,
ServiceTier: requestServiceTierPtr.Load(),
Stream: true,
OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders),
Duration: relayResult.Duration,
FirstTokenMs: relayResult.FirstTokenMs,
}
turnCount := int(completedTurns.Load())
if relayExit == nil {
logOpenAIWSV2Passthrough(
"relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
account.ID,
truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(relayResult.TerminalEventType, openAIWSLogValueMaxLen),
result.Duration.Milliseconds(),
relayResult.ClientToUpstreamFrames,
relayResult.UpstreamToClientFrames,
relayResult.DroppedDownstreamFrames,
turnCount,
)
// 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。
if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(1, result, nil)
}
return nil
}
logOpenAIWSV2Passthrough(
"relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
account.ID,
truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen),
relayExit.WroteDownstream,
truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen),
result.Duration.Milliseconds(),
relayResult.ClientToUpstreamFrames,
relayResult.UpstreamToClientFrames,
relayResult.DroppedDownstreamFrames,
turnCount,
)
relayErr := relayExit.Err
if relayExit.Stage == "idle_timeout" {
relayErr = NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"client websocket idle timeout",
relayErr,
)
}
turnErr := wrapOpenAIWSIngressTurnError(
relayExit.Stage,
relayErr,
relayExit.WroteDownstream,
)
if hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(turnCount+1, nil, turnErr)
}
return turnErr
}
func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError(
err error,
statusCode int,
handshakeHeaders http.Header,
) error {
if err == nil {
return nil
}
wrappedErr := err
var dialErr *openAIWSDialError
if !errors.As(err, &dialErr) {
wrappedErr = &openAIWSDialError{
StatusCode: statusCode,
ResponseHeaders: cloneHeader(handshakeHeaders),
Err: err,
}
}
if errors.Is(err, context.Canceled) {
return err
}
if errors.Is(err, context.DeadlineExceeded) {
return NewOpenAIWSClientCloseError(
coderws.StatusTryAgainLater,
"upstream websocket connect timeout",
wrappedErr,
)
}
if statusCode == http.StatusTooManyRequests {
return NewOpenAIWSClientCloseError(
coderws.StatusTryAgainLater,
"upstream websocket is busy, please retry later",
wrappedErr,
)
}
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"upstream websocket authentication failed",
wrappedErr,
)
}
if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError {
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"upstream websocket handshake rejected",
wrappedErr,
)
}
return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr)
}
func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string {
switch msgType {
case coderws.MessageText:
return "text"
case coderws.MessageBinary:
return "binary"
default:
return fmt.Sprintf("unknown(%d)", msgType)
}
}
func relayErrorText(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func openAIWSFirstTokenMsForLog(firstTokenMs *int) int {
if firstTokenMs == nil {
return -1
}
return *firstTokenMs
}
func logOpenAIWSV2Passthrough(format string, args ...any) {
logger.LegacyPrintf(
"service.openai_ws_v2",
"[OpenAI WS v2 passthrough] %s "+format,
append([]any{openaiWSV2PassthroughModeFields}, args...)...,
)
}