mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 05:30:44 +08:00
对称参照 Claude BetaPolicy 的 fast-mode 过滤实现,新增针对 OpenAI 上游
service_tier 字段(priority / flex,含客户端 "fast" → "priority" 归一化)的
pass / filter / block 三态策略,覆盖全部 OpenAI 入口 + admin 配置入口。
后端核心
- 新增 SettingKeyOpenAIFastPolicySettings、OpenAIFastPolicyRule、
OpenAIFastPolicySettings 配置模型,含规则的 service_tier × action × scope
× 模型白名单 × fallback action 维度。
- SettingService.Get/SetOpenAIFastPolicySettings;缺失时返回内置默认策略
(所有模型的 priority 走 filter,whitelist 为空,fallback=pass)。设计
依据:service_tier=fast 是用户级开关,与 model 字段正交,默认锁定特定
model slug 会留下"用 gpt-4 + fast 透传 priority 上游"的绕过路径。JSON
解析失败不再静默 fallback,slog.Warn 记录脏数据,便于运维定位。
- service_tier 归一化(trim + ToLower + fast→priority + 白名单 priority/flex)
与策略评估(evaluateOpenAIFastPolicy)作为唯一真实来源,HTTP / WS 共用。
抽出纯函数 evaluateOpenAIFastPolicyWithSettings,配合 ctx-bound settings
快照(withOpenAIFastPolicyContext / openAIFastPolicySettingsFromContext),
WS 长会话入口预取一次后所有帧复用,避免每帧打到 settingService。
HTTP 入口(4 个)
- Chat Completions、Anthropic 兼容(Messages,含 BetaFastMode→priority 二次
命中)、原生 Responses、Passthrough Responses 全部接入
applyOpenAIFastPolicyToBody,filter 走 sjson 顶层删除 service_tier,block
返回 403 forbidden_error JSON。
- 4 入口统一使用 upstream 视角的 model(GetMappedModel +
normalizeOpenAIModelForUpstream + Codex OAuth normalize 后的 slug),
避免 chat/messages/native /responses/passthrough 因为 model 维度不同
造成 whitelist 命中差异。
- 在 pass 路径也把客户端 "fast" 别名归一化为 "priority" 写回 body,
否则 native /responses 与 passthrough 入口会把 "fast" 原样透传给上游
导致 400/拒绝(chat-completions 入口的 normalizeResponsesBodyServiceTier
此前已具备同等行为)。
WebSocket 入口
- 新增 applyOpenAIFastPolicyToWSResponseCreate:严格匹配
type="response.create",仅处理顶层 service_tier;filter 用 sjson 删字段,
block 返回 typed *OpenAIFastBlockedError。
- ingress 路径在 parseClientPayload 内调用,block 命中先 Write Realtime
风格 error event 再返回 OpenAIWSClientCloseError(StatusPolicyViolation
=1008),依赖底层 WebSocket Conn.Write 的同步 flush 保证 error 先于
close。
- passthrough 路径在 RunEntry 前对 firstClientMessage 应用策略,并通过
openAIWSPolicyEnforcingFrameConn 包装 ReadFrame 对每个 client→upstream
帧执行策略;后续帧无 model 字段时回退到 capturedSessionModel。
filter 闭包内同时侦测 session.update / session.created 帧的 session.model
字段刷新 capturedSessionModel,封堵"首帧 model=gpt-4o(pass)→
session.update 改为 gpt-5.5 → 不带 model 的 response.create fallback
到 gpt-4o"的 mid-session 绕过路径。
- passthrough billing:requestServiceTier 在策略 filter 之后再从
firstClientMessage 提取,filter 命中时 OpenAIForwardResult.ServiceTier
上报 nil(default tier),与 HTTP 入口(reqBody 来自 post-filter map)
/ WS ingress(payload 来自 post-filter bytes)的语义一致。
- 错误事件 schema:{event_id: "evt_<32hex>", type: "error",
error: {type: "forbidden_error", code: "policy_violation", message}},
与 OpenAI codex 客户端 error event 解析兼容。
Admin / Frontend
- dto.SystemSettings / UpdateSettingsRequest 新增
openai_fast_policy_settings 字段(omitempty),bulk GET/PUT 接入。
- Settings 页 Gateway 页签新增 Fast/Flex Policy 表单卡片:
service_tier × action × scope × 模型白名单 × fallback action 全字段配置。
- 前端守门:openaiFastPolicyLoaded 标志仅在 GET 真带回字段时才允许回写,
避免 rollout/错误把默认规则覆盖成空;saveSettings 回写循环 skip 该字段,
由专用刷新逻辑处理;仅 action=block 时发送 error_message,匹配后端
omitempty 行为。
测试
- HTTP 路径:openai_fast_policy_test.go 覆盖默认配置(whitelist=[],所有
模型 priority filter)/ block 自定义错误 / scope 区分 / filter 删字段 /
block 不改 body / block 短路上游 / Anthropic BetaFastMode 触发 OpenAI
fast policy 等场景。
- WebSocket 路径:openai_fast_policy_ws_test.go 覆盖
helper 单元(filter / fast→priority 归一化 / flex 透传 / block typed
error / 无 service_tier 字节不变 / 非 response.create 帧不动 / 空 type
帧不动 / event_id+code 字段断言 / 非字符串 service_tier 容错)+
pass 路径 fast 别名归一化回归 +
ingress 端到端(filter 后上游不含 service_tier / block 后客户端先收
error event 再收 close 1008 且上游 0 写)+
passthrough capturedSessionModel fallback 用例(whitelist 策略下首帧
建立、缺 model 命中 fallback、缺少 fallback 时的 leak 文档化)+
passthrough session.update / session.created 旋转 capturedSessionModel
的 mid-session 绕过回归 +
passthrough billing post-filter ServiceTier 与 idempotent filter 回归。
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
591 lines
22 KiB
Go
591 lines
22 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
"sync/atomic"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
|
||
coderws "github.com/coder/websocket"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/tidwall/gjson"
|
||
)
|
||
|
||
type openAIWSClientFrameConn struct {
|
||
conn *coderws.Conn
|
||
}
|
||
|
||
// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs
|
||
// every client→upstream frame through the OpenAI Fast Policy. It is the
|
||
// passthrough-relay equivalent of the parseClientPayload integration in the
|
||
// ingress session path. filter returns:
|
||
// - newPayload, nil, nil: forward the (possibly mutated) payload
|
||
// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error
|
||
// event via onBlock and surfaces a transport-level error so the relay
|
||
// stops reading from the client.
|
||
// - _, _, err: a transport error other than block.
|
||
type openAIWSPolicyEnforcingFrameConn struct {
|
||
inner openaiwsv2.FrameConn
|
||
filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error)
|
||
onBlock func(blocked *OpenAIFastBlockedError)
|
||
}
|
||
|
||
var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil)
|
||
|
||
func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||
if c == nil || c.inner == nil {
|
||
return coderws.MessageText, nil, errOpenAIWSConnClosed
|
||
}
|
||
msgType, payload, err := c.inner.ReadFrame(ctx)
|
||
if err != nil {
|
||
return msgType, payload, err
|
||
}
|
||
if c.filter == nil {
|
||
return msgType, payload, nil
|
||
}
|
||
updated, blocked, filterErr := c.filter(msgType, payload)
|
||
if filterErr != nil {
|
||
return msgType, payload, filterErr
|
||
}
|
||
if blocked != nil {
|
||
if c.onBlock != nil {
|
||
c.onBlock(blocked)
|
||
}
|
||
return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
|
||
}
|
||
return msgType, updated, nil
|
||
}
|
||
|
||
func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||
if c == nil || c.inner == nil {
|
||
return errOpenAIWSConnClosed
|
||
}
|
||
return c.inner.WriteFrame(ctx, msgType, payload)
|
||
}
|
||
|
||
func (c *openAIWSPolicyEnforcingFrameConn) Close() error {
|
||
if c == nil || c.inner == nil {
|
||
return nil
|
||
}
|
||
return c.inner.Close()
|
||
}
|
||
|
||
// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective
|
||
// model name that should be passed to evaluateOpenAIFastPolicy for a single
|
||
// passthrough WS frame. Mirrors the HTTP-side normalization
|
||
// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path
|
||
// matches model whitelists identically.
|
||
func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string {
|
||
if account == nil || len(payload) == 0 {
|
||
return ""
|
||
}
|
||
original := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||
if original == "" {
|
||
return ""
|
||
}
|
||
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
||
}
|
||
|
||
// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model
|
||
// derived from a session.update frame's session.model field. Returns "" when
|
||
// the frame is not a session.update event or carries no session.model. Used
|
||
// by the per-frame policy filter (client→upstream direction) to keep
|
||
// capturedSessionModel in sync with the session-level model the client may
|
||
// rotate mid-session.
|
||
//
|
||
// Realtime / Responses WS lets the client change the session model after
|
||
// the WS handshake via:
|
||
//
|
||
// {"type":"session.update","session":{"model":"gpt-5.5", ...}}
|
||
//
|
||
// If we only capture the model from the very first frame, a client can ship
|
||
// gpt-4o on the first response.create (whitelisted as pass), then
|
||
// session.update to gpt-5.5, then send response.create without "model" so
|
||
// the per-frame resolver returns "" and the stale capturedSessionModel falls
|
||
// back to gpt-4o — defeating the gpt-5.5 fast-policy filter.
|
||
func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string {
|
||
if account == nil || len(payload) == 0 {
|
||
return ""
|
||
}
|
||
frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
|
||
if frameType != "session.update" {
|
||
return ""
|
||
}
|
||
original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
|
||
if original == "" {
|
||
return ""
|
||
}
|
||
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
||
}
|
||
|
||
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
|
||
|
||
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
|
||
|
||
func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||
if c == nil || c.conn == nil {
|
||
return coderws.MessageText, nil, errOpenAIWSConnClosed
|
||
}
|
||
if ctx == nil {
|
||
ctx = context.Background()
|
||
}
|
||
return c.conn.Read(ctx)
|
||
}
|
||
|
||
func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||
if c == nil || c.conn == nil {
|
||
return errOpenAIWSConnClosed
|
||
}
|
||
if ctx == nil {
|
||
ctx = context.Background()
|
||
}
|
||
return c.conn.Write(ctx, msgType, payload)
|
||
}
|
||
|
||
func (c *openAIWSClientFrameConn) Close() error {
|
||
if c == nil || c.conn == nil {
|
||
return nil
|
||
}
|
||
_ = c.conn.Close(coderws.StatusNormalClosure, "")
|
||
_ = c.conn.CloseNow()
|
||
return nil
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||
ctx context.Context,
|
||
c *gin.Context,
|
||
clientConn *coderws.Conn,
|
||
account *Account,
|
||
token string,
|
||
firstClientMessage []byte,
|
||
hooks *OpenAIWSIngressHooks,
|
||
wsDecision OpenAIWSProtocolDecision,
|
||
) error {
|
||
if s == nil {
|
||
return errors.New("service is nil")
|
||
}
|
||
if clientConn == nil {
|
||
return errors.New("client websocket is nil")
|
||
}
|
||
if account == nil {
|
||
return errors.New("account is nil")
|
||
}
|
||
if strings.TrimSpace(token) == "" {
|
||
return errors.New("token is empty")
|
||
}
|
||
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
|
||
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
|
||
account.ID,
|
||
truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen),
|
||
openaiwsv2RelayMessageTypeName(coderws.MessageText),
|
||
len(firstClientMessage),
|
||
)
|
||
|
||
// Apply OpenAI Fast Policy on the first response.create frame. Subsequent
|
||
// frames are filtered via a wrapping FrameConn below so every client→
|
||
// upstream frame goes through the same policy evaluator/normalize/scope as
|
||
// HTTP entrypoints.
|
||
//
|
||
// We capture the session-level model from the first frame here so the
|
||
// per-frame filter (below) can fall back to it when a follow-up frame
|
||
// omits "model" — Realtime clients are allowed to send response.create
|
||
// without re-stating the model, in which case the upstream uses the model
|
||
// negotiated at session.update time. Without this fallback, an empty
|
||
// model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be
|
||
// silently passed through, defeating the policy on every frame after
|
||
// the first.
|
||
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
|
||
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
|
||
if policyErr != nil {
|
||
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
|
||
}
|
||
if blocked != nil {
|
||
// coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires
|
||
// writeFrameMu, writes the entire frame, and Flushes the underlying
|
||
// bufio writer before returning (write.go:42 → write.go:307-311).
|
||
// The subsequent close handshake re-acquires the same writeFrameMu
|
||
// to send the close frame, so the error event is guaranteed to
|
||
// reach the kernel send buffer before any close frame is queued.
|
||
// No explicit flush hop is required here.
|
||
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
||
if eventBytes != nil {
|
||
writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
||
cancelWrite()
|
||
}
|
||
return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
|
||
}
|
||
firstClientMessage = updatedFirst
|
||
|
||
// 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter
|
||
// 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当
|
||
// 反映上游实际处理的 tier(nil = default),而不是用户最初请求的
|
||
// "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody))
|
||
// 与 WS ingress(openai_ws_forwarder.go:2991 取自 payload)的语义一致。
|
||
//
|
||
// 多轮 passthrough:OpenAI Realtime / Responses WS 协议允许客户端在
|
||
// 同一连接的不同 response.create 帧上发送不同 service_tier(参考
|
||
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
|
||
// 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream
|
||
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
||
// goroutine)之间同步当前 turn 的 service_tier。
|
||
// extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型,
|
||
// 可直接 Store/Load 而无需额外封装。
|
||
var requestServiceTierPtr atomic.Pointer[string]
|
||
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
|
||
|
||
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||
if err != nil {
|
||
return fmt.Errorf("build ws url: %w", err)
|
||
}
|
||
wsHost := "-"
|
||
wsPath := "-"
|
||
if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil {
|
||
wsHost = normalizeOpenAIWSLogValue(parsedURL.Host)
|
||
wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
|
||
}
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v",
|
||
account.ID,
|
||
wsHost,
|
||
wsPath,
|
||
account.ProxyID != nil && account.Proxy != nil,
|
||
)
|
||
|
||
isCodexCLI := false
|
||
if c != nil {
|
||
isCodexCLI = openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator"))
|
||
}
|
||
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
||
isCodexCLI = true
|
||
}
|
||
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
|
||
proxyURL := ""
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
dialer := s.getOpenAIWSPassthroughDialer()
|
||
if dialer == nil {
|
||
return errors.New("openai ws passthrough dialer is nil")
|
||
}
|
||
|
||
dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout())
|
||
defer cancelDial()
|
||
upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL)
|
||
if err != nil {
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_dial_failed account_id=%d status_code=%d err=%s",
|
||
account.ID,
|
||
statusCode,
|
||
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
|
||
)
|
||
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
|
||
}
|
||
defer func() {
|
||
_ = upstreamConn.Close()
|
||
}()
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s",
|
||
account.ID,
|
||
statusCode,
|
||
openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"),
|
||
)
|
||
|
||
upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn)
|
||
if !ok {
|
||
return errors.New("openai ws passthrough upstream connection does not support frame relay")
|
||
}
|
||
|
||
completedTurns := atomic.Int32{}
|
||
policyClientConn := &openAIWSPolicyEnforcingFrameConn{
|
||
inner: &openAIWSClientFrameConn{conn: clientConn},
|
||
// 注意线程安全:filter 仅在 runClientToUpstream 这一条
|
||
// goroutine 中被调用(passthrough_relay.go: ReadFrame loop),
|
||
// capturedSessionModel 的读写都发生在该 goroutine 内,因此无需
|
||
// 加锁/原子化。
|
||
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
|
||
if msgType != coderws.MessageText {
|
||
return payload, nil, nil
|
||
}
|
||
// 在评估策略前先刷新 capturedSessionModel:客户端可能通过
|
||
// session.update 修改 session-level model(Realtime /
|
||
// Responses WS 协议允许),如果不刷新就会出现
|
||
// "首帧 model=gpt-4o(pass)→ session.update 改成 gpt-5.5
|
||
// → 不带 model 的 response.create fallback 到 gpt-4o" 的
|
||
// 绕过路径。这里只看 session.update 事件中的 session.model
|
||
// 字段,response.create 自己的 model 仍然由其本帧字段决定。
|
||
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
|
||
capturedSessionModel = updated
|
||
}
|
||
// Per-frame model first; if the client omits "model" on a
|
||
// follow-up frame (legal in Realtime), fall back to the
|
||
// session-level model captured from the first frame so the
|
||
// model whitelist still resolves. An empty model would miss
|
||
// any whitelist and silently fall back to pass.
|
||
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
|
||
if model == "" {
|
||
model = capturedSessionModel
|
||
}
|
||
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
|
||
// 多轮 passthrough billing:仅在成功(non-block / non-err)
|
||
// 的 response.create 帧上更新 requestServiceTierPtr,使用
|
||
// filter 处理后的 payload,与首帧 policy-after-extract 语义
|
||
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
|
||
// - 非 response.create 帧(response.cancel /
|
||
// conversation.item.create / session.update 等)不携带
|
||
// per-response service_tier,不应覆盖前一轮值。
|
||
// - blocked != nil:该帧不会发送上游,billing tier 应保持
|
||
// 上一轮值。
|
||
// - policyErr != nil:异常路径,保持上一轮值。
|
||
// - 不带 service_tier 的 response.create 会让
|
||
// extractOpenAIServiceTierFromBody 返回 nil;这里有意
|
||
// 覆盖(Store(nil)),因为 OpenAI 上游对该帧实际不传
|
||
// service_tier 时按 default 处理,billing 应如实反映。
|
||
if policyErr == nil && blocked == nil &&
|
||
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
||
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
|
||
}
|
||
return out, blocked, policyErr
|
||
},
|
||
onBlock: func(blocked *OpenAIFastBlockedError) {
|
||
// See note above on Conn.Write being synchronous w.r.t. flush;
|
||
// no explicit flush is required to ensure the error event lands
|
||
// before the close frame.
|
||
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
||
if eventBytes == nil {
|
||
return
|
||
}
|
||
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
||
cancel()
|
||
},
|
||
}
|
||
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
|
||
Ctx: ctx,
|
||
ClientConn: policyClientConn,
|
||
UpstreamConn: upstreamFrameConn,
|
||
FirstClientMessage: firstClientMessage,
|
||
Options: openaiwsv2.RelayOptions{
|
||
WriteTimeout: s.openAIWSWriteTimeout(),
|
||
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
|
||
FirstMessageType: coderws.MessageText,
|
||
OnUsageParseFailure: func(eventType string, usageRaw string) {
|
||
logOpenAIWSV2Passthrough(
|
||
"usage_parse_failed event_type=%s usage_raw=%s",
|
||
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen),
|
||
)
|
||
},
|
||
OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) {
|
||
turnNo := int(completedTurns.Add(1))
|
||
turnResult := &OpenAIForwardResult{
|
||
RequestID: turn.RequestID,
|
||
Usage: OpenAIUsage{
|
||
InputTokens: turn.Usage.InputTokens,
|
||
OutputTokens: turn.Usage.OutputTokens,
|
||
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
|
||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||
},
|
||
Model: turn.RequestModel,
|
||
ServiceTier: requestServiceTierPtr.Load(),
|
||
Stream: true,
|
||
OpenAIWSMode: true,
|
||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||
Duration: turn.Duration,
|
||
FirstTokenMs: turn.FirstTokenMs,
|
||
}
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d",
|
||
account.ID,
|
||
turnNo,
|
||
truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(turn.TerminalEventType, openAIWSLogValueMaxLen),
|
||
turnResult.Duration.Milliseconds(),
|
||
openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs),
|
||
turnResult.Usage.InputTokens,
|
||
turnResult.Usage.OutputTokens,
|
||
turnResult.Usage.CacheReadInputTokens,
|
||
)
|
||
if hooks != nil && hooks.AfterTurn != nil {
|
||
hooks.AfterTurn(turnNo, turnResult, nil)
|
||
}
|
||
},
|
||
OnTrace: func(event openaiwsv2.RelayTraceEvent) {
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",
|
||
account.ID,
|
||
truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen),
|
||
event.PayloadBytes,
|
||
event.Graceful,
|
||
event.WroteDownstream,
|
||
truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen),
|
||
)
|
||
},
|
||
},
|
||
})
|
||
|
||
result := &OpenAIForwardResult{
|
||
RequestID: relayResult.RequestID,
|
||
Usage: OpenAIUsage{
|
||
InputTokens: relayResult.Usage.InputTokens,
|
||
OutputTokens: relayResult.Usage.OutputTokens,
|
||
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
|
||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||
},
|
||
Model: relayResult.RequestModel,
|
||
ServiceTier: requestServiceTierPtr.Load(),
|
||
Stream: true,
|
||
OpenAIWSMode: true,
|
||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||
Duration: relayResult.Duration,
|
||
FirstTokenMs: relayResult.FirstTokenMs,
|
||
}
|
||
|
||
turnCount := int(completedTurns.Load())
|
||
if relayExit == nil {
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
|
||
account.ID,
|
||
truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(relayResult.TerminalEventType, openAIWSLogValueMaxLen),
|
||
result.Duration.Milliseconds(),
|
||
relayResult.ClientToUpstreamFrames,
|
||
relayResult.UpstreamToClientFrames,
|
||
relayResult.DroppedDownstreamFrames,
|
||
turnCount,
|
||
)
|
||
// 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。
|
||
if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil {
|
||
hooks.AfterTurn(1, result, nil)
|
||
}
|
||
return nil
|
||
}
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
|
||
account.ID,
|
||
truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen),
|
||
relayExit.WroteDownstream,
|
||
truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen),
|
||
result.Duration.Milliseconds(),
|
||
relayResult.ClientToUpstreamFrames,
|
||
relayResult.UpstreamToClientFrames,
|
||
relayResult.DroppedDownstreamFrames,
|
||
turnCount,
|
||
)
|
||
|
||
relayErr := relayExit.Err
|
||
if relayExit.Stage == "idle_timeout" {
|
||
relayErr = NewOpenAIWSClientCloseError(
|
||
coderws.StatusPolicyViolation,
|
||
"client websocket idle timeout",
|
||
relayErr,
|
||
)
|
||
}
|
||
turnErr := wrapOpenAIWSIngressTurnError(
|
||
relayExit.Stage,
|
||
relayErr,
|
||
relayExit.WroteDownstream,
|
||
)
|
||
if hooks != nil && hooks.AfterTurn != nil {
|
||
hooks.AfterTurn(turnCount+1, nil, turnErr)
|
||
}
|
||
return turnErr
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError(
|
||
err error,
|
||
statusCode int,
|
||
handshakeHeaders http.Header,
|
||
) error {
|
||
if err == nil {
|
||
return nil
|
||
}
|
||
wrappedErr := err
|
||
var dialErr *openAIWSDialError
|
||
if !errors.As(err, &dialErr) {
|
||
wrappedErr = &openAIWSDialError{
|
||
StatusCode: statusCode,
|
||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||
Err: err,
|
||
}
|
||
}
|
||
|
||
if errors.Is(err, context.Canceled) {
|
||
return err
|
||
}
|
||
if errors.Is(err, context.DeadlineExceeded) {
|
||
return NewOpenAIWSClientCloseError(
|
||
coderws.StatusTryAgainLater,
|
||
"upstream websocket connect timeout",
|
||
wrappedErr,
|
||
)
|
||
}
|
||
if statusCode == http.StatusTooManyRequests {
|
||
return NewOpenAIWSClientCloseError(
|
||
coderws.StatusTryAgainLater,
|
||
"upstream websocket is busy, please retry later",
|
||
wrappedErr,
|
||
)
|
||
}
|
||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||
return NewOpenAIWSClientCloseError(
|
||
coderws.StatusPolicyViolation,
|
||
"upstream websocket authentication failed",
|
||
wrappedErr,
|
||
)
|
||
}
|
||
if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError {
|
||
return NewOpenAIWSClientCloseError(
|
||
coderws.StatusPolicyViolation,
|
||
"upstream websocket handshake rejected",
|
||
wrappedErr,
|
||
)
|
||
}
|
||
return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr)
|
||
}
|
||
|
||
func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string {
|
||
switch msgType {
|
||
case coderws.MessageText:
|
||
return "text"
|
||
case coderws.MessageBinary:
|
||
return "binary"
|
||
default:
|
||
return fmt.Sprintf("unknown(%d)", msgType)
|
||
}
|
||
}
|
||
|
||
func relayErrorText(err error) string {
|
||
if err == nil {
|
||
return ""
|
||
}
|
||
return err.Error()
|
||
}
|
||
|
||
func openAIWSFirstTokenMsForLog(firstTokenMs *int) int {
|
||
if firstTokenMs == nil {
|
||
return -1
|
||
}
|
||
return *firstTokenMs
|
||
}
|
||
|
||
func logOpenAIWSV2Passthrough(format string, args ...any) {
|
||
logger.LegacyPrintf(
|
||
"service.openai_ws_v2",
|
||
"[OpenAI WS v2 passthrough] %s "+format,
|
||
append([]any{openaiWSV2PassthroughModeFields}, args...)...,
|
||
)
|
||
}
|