mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
- 为 OpenAI Responses WebSocket v2 passthrough 补齐每轮 reasoning_effort 元数据 - 传递首帧渠道映射前模型,保留模型后缀推理强度推导能力 - 增加 usage log 端到端回归,覆盖入站 User-Agent、显式 effort 和渠道映射场景
665 lines
24 KiB
Go
665 lines
24 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
"sync/atomic"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
|
||
coderws "github.com/coder/websocket"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/tidwall/gjson"
|
||
)
|
||
|
||
type openAIWSClientFrameConn struct {
|
||
conn *coderws.Conn
|
||
}
|
||
|
||
// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs
|
||
// every client→upstream frame through the OpenAI Fast Policy. It is the
|
||
// passthrough-relay equivalent of the parseClientPayload integration in the
|
||
// ingress session path. filter returns:
|
||
// - newPayload, nil, nil: forward the (possibly mutated) payload
|
||
// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error
|
||
// event via onBlock and surfaces a transport-level error so the relay
|
||
// stops reading from the client.
|
||
// - _, _, err: a transport error other than block.
|
||
type openAIWSPolicyEnforcingFrameConn struct {
|
||
inner openaiwsv2.FrameConn
|
||
filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error)
|
||
onBlock func(blocked *OpenAIFastBlockedError)
|
||
}
|
||
|
||
var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil)
|
||
|
||
func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||
if c == nil || c.inner == nil {
|
||
return coderws.MessageText, nil, errOpenAIWSConnClosed
|
||
}
|
||
msgType, payload, err := c.inner.ReadFrame(ctx)
|
||
if err != nil {
|
||
return msgType, payload, err
|
||
}
|
||
if c.filter == nil {
|
||
return msgType, payload, nil
|
||
}
|
||
updated, blocked, filterErr := c.filter(msgType, payload)
|
||
if filterErr != nil {
|
||
return msgType, payload, filterErr
|
||
}
|
||
if blocked != nil {
|
||
if c.onBlock != nil {
|
||
c.onBlock(blocked)
|
||
}
|
||
return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
|
||
}
|
||
return msgType, updated, nil
|
||
}
|
||
|
||
func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||
if c == nil || c.inner == nil {
|
||
return errOpenAIWSConnClosed
|
||
}
|
||
return c.inner.WriteFrame(ctx, msgType, payload)
|
||
}
|
||
|
||
func (c *openAIWSPolicyEnforcingFrameConn) Close() error {
|
||
if c == nil || c.inner == nil {
|
||
return nil
|
||
}
|
||
return c.inner.Close()
|
||
}
|
||
|
||
// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective
|
||
// model name that should be passed to evaluateOpenAIFastPolicy for a single
|
||
// passthrough WS frame. Mirrors the HTTP-side normalization
|
||
// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path
|
||
// matches model whitelists identically.
|
||
func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string {
|
||
if account == nil || len(payload) == 0 {
|
||
return ""
|
||
}
|
||
original := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||
if original == "" {
|
||
return ""
|
||
}
|
||
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
||
}
|
||
|
||
// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model
|
||
// derived from a session.update frame's session.model field. Returns "" when
|
||
// the frame is not a session.update event or carries no session.model. Used
|
||
// by the per-frame policy filter (client→upstream direction) to keep
|
||
// capturedSessionModel in sync with the session-level model the client may
|
||
// rotate mid-session.
|
||
//
|
||
// Realtime / Responses WS lets the client change the session model after
|
||
// the WS handshake via:
|
||
//
|
||
// {"type":"session.update","session":{"model":"gpt-5.5", ...}}
|
||
//
|
||
// If we only capture the model from the very first frame, a client can ship
|
||
// gpt-4o on the first response.create (whitelisted as pass), then
|
||
// session.update to gpt-5.5, then send response.create without "model" so
|
||
// the per-frame resolver returns "" and the stale capturedSessionModel falls
|
||
// back to gpt-4o — defeating the gpt-5.5 fast-policy filter.
|
||
func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string {
|
||
if account == nil || len(payload) == 0 {
|
||
return ""
|
||
}
|
||
frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
|
||
if frameType != "session.update" {
|
||
return ""
|
||
}
|
||
original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
|
||
if original == "" {
|
||
return ""
|
||
}
|
||
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
||
}
|
||
|
||
type openAIWSPassthroughUsageMeta struct {
|
||
serviceTier atomic.Pointer[string]
|
||
reasoningEffort atomic.Pointer[string]
|
||
|
||
// 仅在 client->upstream filter goroutine 中读写;Load 侧通过上方原子指针同步。
|
||
sessionRequestModel string
|
||
}
|
||
|
||
func newOpenAIWSPassthroughUsageMeta(initialRequestModel string, firstFrame []byte) *openAIWSPassthroughUsageMeta {
|
||
meta := &openAIWSPassthroughUsageMeta{
|
||
sessionRequestModel: strings.TrimSpace(initialRequestModel),
|
||
}
|
||
if meta.sessionRequestModel == "" {
|
||
meta.sessionRequestModel = openAIWSPassthroughRequestModelForFrame(firstFrame)
|
||
}
|
||
return meta
|
||
}
|
||
|
||
func (m *openAIWSPassthroughUsageMeta) initFromFirstFrame(policyOutput []byte) {
|
||
if m == nil {
|
||
return
|
||
}
|
||
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
|
||
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, m.sessionRequestModel))
|
||
}
|
||
|
||
func (m *openAIWSPassthroughUsageMeta) updateSessionRequestModel(payload []byte) {
|
||
if m == nil {
|
||
return
|
||
}
|
||
if model := openAIWSPassthroughRequestModelFromSessionFrame(payload); model != "" {
|
||
m.sessionRequestModel = model
|
||
}
|
||
}
|
||
|
||
func (m *openAIWSPassthroughUsageMeta) requestModelForFrame(payload []byte) string {
|
||
if m == nil {
|
||
return openAIWSPassthroughRequestModelForFrame(payload)
|
||
}
|
||
if model := openAIWSPassthroughRequestModelForFrame(payload); model != "" {
|
||
return model
|
||
}
|
||
return m.sessionRequestModel
|
||
}
|
||
|
||
func (m *openAIWSPassthroughUsageMeta) updateFromResponseCreate(policyOutput []byte, requestModelForFrame string) {
|
||
if m == nil {
|
||
return
|
||
}
|
||
m.serviceTier.Store(extractOpenAIServiceTierFromBody(policyOutput))
|
||
m.reasoningEffort.Store(extractOpenAIReasoningEffortFromBody(policyOutput, requestModelForFrame))
|
||
}
|
||
|
||
func openAIWSPassthroughRequestModelForFrame(payload []byte) string {
|
||
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
|
||
return ""
|
||
}
|
||
return strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||
}
|
||
|
||
func openAIWSPassthroughRequestModelFromSessionFrame(payload []byte) string {
|
||
if len(payload) == 0 || strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "session.update" {
|
||
return ""
|
||
}
|
||
return strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
|
||
}
|
||
|
||
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
|
||
|
||
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
|
||
|
||
func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||
if c == nil || c.conn == nil {
|
||
return coderws.MessageText, nil, errOpenAIWSConnClosed
|
||
}
|
||
if ctx == nil {
|
||
ctx = context.Background()
|
||
}
|
||
return c.conn.Read(ctx)
|
||
}
|
||
|
||
func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||
if c == nil || c.conn == nil {
|
||
return errOpenAIWSConnClosed
|
||
}
|
||
if ctx == nil {
|
||
ctx = context.Background()
|
||
}
|
||
return c.conn.Write(ctx, msgType, payload)
|
||
}
|
||
|
||
func (c *openAIWSClientFrameConn) Close() error {
|
||
if c == nil || c.conn == nil {
|
||
return nil
|
||
}
|
||
_ = c.conn.Close(coderws.StatusNormalClosure, "")
|
||
_ = c.conn.CloseNow()
|
||
return nil
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||
ctx context.Context,
|
||
c *gin.Context,
|
||
clientConn *coderws.Conn,
|
||
account *Account,
|
||
token string,
|
||
firstClientMessage []byte,
|
||
hooks *OpenAIWSIngressHooks,
|
||
wsDecision OpenAIWSProtocolDecision,
|
||
) error {
|
||
if s == nil {
|
||
return errors.New("service is nil")
|
||
}
|
||
if clientConn == nil {
|
||
return errors.New("client websocket is nil")
|
||
}
|
||
if account == nil {
|
||
return errors.New("account is nil")
|
||
}
|
||
if strings.TrimSpace(token) == "" {
|
||
return errors.New("token is empty")
|
||
}
|
||
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
|
||
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
|
||
account.ID,
|
||
truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen),
|
||
openaiwsv2RelayMessageTypeName(coderws.MessageText),
|
||
len(firstClientMessage),
|
||
)
|
||
|
||
// Apply OpenAI Fast Policy on the first response.create frame. Subsequent
|
||
// frames are filtered via a wrapping FrameConn below so every client→
|
||
// upstream frame goes through the same policy evaluator/normalize/scope as
|
||
// HTTP entrypoints.
|
||
//
|
||
// We capture the session-level model from the first frame here so the
|
||
// per-frame filter (below) can fall back to it when a follow-up frame
|
||
// omits "model" — Realtime clients are allowed to send response.create
|
||
// without re-stating the model, in which case the upstream uses the model
|
||
// negotiated at session.update time. Without this fallback, an empty
|
||
// model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be
|
||
// silently passed through, defeating the policy on every frame after
|
||
// the first.
|
||
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
|
||
initialRequestModel := ""
|
||
if hooks != nil {
|
||
initialRequestModel = hooks.InitialRequestModel
|
||
}
|
||
usageMeta := newOpenAIWSPassthroughUsageMeta(initialRequestModel, firstClientMessage)
|
||
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
|
||
if policyErr != nil {
|
||
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
|
||
}
|
||
if blocked != nil {
|
||
// coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires
|
||
// writeFrameMu, writes the entire frame, and Flushes the underlying
|
||
// bufio writer before returning (write.go:42 → write.go:307-311).
|
||
// The subsequent close handshake re-acquires the same writeFrameMu
|
||
// to send the close frame, so the error event is guaranteed to
|
||
// reach the kernel send buffer before any close frame is queued.
|
||
// No explicit flush hop is required here.
|
||
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
||
if eventBytes != nil {
|
||
writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
||
cancelWrite()
|
||
}
|
||
return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
|
||
}
|
||
firstClientMessage = updatedFirst
|
||
|
||
// 在 policy filter 之后再提取 service_tier / reasoning_effort 用于
|
||
// usage 上报:filter
|
||
// 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当
|
||
// 反映上游实际处理的 tier(nil = default),而不是用户最初请求的
|
||
// "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody))
|
||
// 与 WS ingress(openai_ws_forwarder.go:2991 取自 payload)的语义一致。
|
||
//
|
||
// 多轮 passthrough:OpenAI Realtime / Responses WS 协议允许客户端在
|
||
// 同一连接的不同 response.create 帧上发送不同 service_tier(参考
|
||
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
|
||
// 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream
|
||
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
||
// goroutine)之间同步当前 turn 的 usage metadata。
|
||
usageMeta.initFromFirstFrame(firstClientMessage)
|
||
|
||
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||
if err != nil {
|
||
return fmt.Errorf("build ws url: %w", err)
|
||
}
|
||
wsHost := "-"
|
||
wsPath := "-"
|
||
if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil {
|
||
wsHost = normalizeOpenAIWSLogValue(parsedURL.Host)
|
||
wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
|
||
}
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v",
|
||
account.ID,
|
||
wsHost,
|
||
wsPath,
|
||
account.ProxyID != nil && account.Proxy != nil,
|
||
)
|
||
|
||
isCodexCLI := false
|
||
if c != nil {
|
||
isCodexCLI = openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator"))
|
||
}
|
||
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
||
isCodexCLI = true
|
||
}
|
||
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
|
||
proxyURL := ""
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
dialer := s.getOpenAIWSPassthroughDialer()
|
||
if dialer == nil {
|
||
return errors.New("openai ws passthrough dialer is nil")
|
||
}
|
||
|
||
dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout())
|
||
defer cancelDial()
|
||
upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL)
|
||
if err != nil {
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_dial_failed account_id=%d status_code=%d err=%s",
|
||
account.ID,
|
||
statusCode,
|
||
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
|
||
)
|
||
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
|
||
}
|
||
defer func() {
|
||
_ = upstreamConn.Close()
|
||
}()
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s",
|
||
account.ID,
|
||
statusCode,
|
||
openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"),
|
||
)
|
||
|
||
upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn)
|
||
if !ok {
|
||
return errors.New("openai ws passthrough upstream connection does not support frame relay")
|
||
}
|
||
|
||
completedTurns := atomic.Int32{}
|
||
policyClientConn := &openAIWSPolicyEnforcingFrameConn{
|
||
inner: &openAIWSClientFrameConn{conn: clientConn},
|
||
// 注意线程安全:filter 仅在 runClientToUpstream 这一条
|
||
// goroutine 中被调用(passthrough_relay.go: ReadFrame loop),
|
||
// capturedSessionModel 的读写都发生在该 goroutine 内,因此无需
|
||
// 加锁/原子化。
|
||
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
|
||
if msgType != coderws.MessageText {
|
||
return payload, nil, nil
|
||
}
|
||
// 在评估策略前先刷新 capturedSessionModel:客户端可能通过
|
||
// session.update 修改 session-level model(Realtime /
|
||
// Responses WS 协议允许),如果不刷新就会出现
|
||
// "首帧 model=gpt-4o(pass)→ session.update 改成 gpt-5.5
|
||
// → 不带 model 的 response.create fallback 到 gpt-4o" 的
|
||
// 绕过路径。这里只看 session.update 事件中的 session.model
|
||
// 字段,response.create 自己的 model 仍然由其本帧字段决定。
|
||
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
|
||
capturedSessionModel = updated
|
||
}
|
||
usageMeta.updateSessionRequestModel(payload)
|
||
requestModelForThisFrame := usageMeta.requestModelForFrame(payload)
|
||
// Per-frame model first; if the client omits "model" on a
|
||
// follow-up frame (legal in Realtime), fall back to the
|
||
// session-level model captured from the first frame so the
|
||
// model whitelist still resolves. An empty model would miss
|
||
// any whitelist and silently fall back to pass.
|
||
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
|
||
if model == "" {
|
||
model = capturedSessionModel
|
||
}
|
||
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
|
||
// 多轮 passthrough usage:仅在成功(non-block / non-err)
|
||
// 的 response.create 帧上更新 usageMeta,使用
|
||
// filter 处理后的 payload,与首帧 policy-after-extract 语义
|
||
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
|
||
// - 非 response.create 帧(response.cancel /
|
||
// conversation.item.create / session.update 等)不携带
|
||
// per-response metadata,不应覆盖前一轮值。
|
||
// - blocked != nil:该帧不会发送上游,usage metadata 应保持
|
||
// 上一轮值。
|
||
// - policyErr != nil:异常路径,保持上一轮值。
|
||
// - 不带 service_tier 的 response.create 会让
|
||
// extractOpenAIServiceTierFromBody 返回 nil;这里有意
|
||
// 覆盖(Store(nil)),因为 OpenAI 上游对该帧实际不传
|
||
// service_tier 时按 default 处理,billing 应如实反映。
|
||
if policyErr == nil && blocked == nil &&
|
||
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
||
usageMeta.updateFromResponseCreate(out, requestModelForThisFrame)
|
||
}
|
||
return out, blocked, policyErr
|
||
},
|
||
onBlock: func(blocked *OpenAIFastBlockedError) {
|
||
// See note above on Conn.Write being synchronous w.r.t. flush;
|
||
// no explicit flush is required to ensure the error event lands
|
||
// before the close frame.
|
||
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
||
if eventBytes == nil {
|
||
return
|
||
}
|
||
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
||
cancel()
|
||
},
|
||
}
|
||
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
|
||
Ctx: ctx,
|
||
ClientConn: policyClientConn,
|
||
UpstreamConn: upstreamFrameConn,
|
||
FirstClientMessage: firstClientMessage,
|
||
Options: openaiwsv2.RelayOptions{
|
||
WriteTimeout: s.openAIWSWriteTimeout(),
|
||
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
|
||
FirstMessageType: coderws.MessageText,
|
||
OnUsageParseFailure: func(eventType string, usageRaw string) {
|
||
logOpenAIWSV2Passthrough(
|
||
"usage_parse_failed event_type=%s usage_raw=%s",
|
||
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen),
|
||
)
|
||
},
|
||
OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) {
|
||
turnNo := int(completedTurns.Add(1))
|
||
turnResult := &OpenAIForwardResult{
|
||
RequestID: turn.RequestID,
|
||
Usage: OpenAIUsage{
|
||
InputTokens: turn.Usage.InputTokens,
|
||
OutputTokens: turn.Usage.OutputTokens,
|
||
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
|
||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||
},
|
||
Model: turn.RequestModel,
|
||
ServiceTier: usageMeta.serviceTier.Load(),
|
||
ReasoningEffort: usageMeta.reasoningEffort.Load(),
|
||
Stream: true,
|
||
OpenAIWSMode: true,
|
||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||
Duration: turn.Duration,
|
||
FirstTokenMs: turn.FirstTokenMs,
|
||
}
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d",
|
||
account.ID,
|
||
turnNo,
|
||
truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(turn.TerminalEventType, openAIWSLogValueMaxLen),
|
||
turnResult.Duration.Milliseconds(),
|
||
openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs),
|
||
turnResult.Usage.InputTokens,
|
||
turnResult.Usage.OutputTokens,
|
||
turnResult.Usage.CacheReadInputTokens,
|
||
)
|
||
if hooks != nil && hooks.AfterTurn != nil {
|
||
hooks.AfterTurn(turnNo, turnResult, nil)
|
||
}
|
||
},
|
||
OnTrace: func(event openaiwsv2.RelayTraceEvent) {
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",
|
||
account.ID,
|
||
truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen),
|
||
truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen),
|
||
event.PayloadBytes,
|
||
event.Graceful,
|
||
event.WroteDownstream,
|
||
truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen),
|
||
)
|
||
},
|
||
},
|
||
})
|
||
|
||
result := &OpenAIForwardResult{
|
||
RequestID: relayResult.RequestID,
|
||
Usage: OpenAIUsage{
|
||
InputTokens: relayResult.Usage.InputTokens,
|
||
OutputTokens: relayResult.Usage.OutputTokens,
|
||
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
|
||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||
},
|
||
Model: relayResult.RequestModel,
|
||
ServiceTier: usageMeta.serviceTier.Load(),
|
||
ReasoningEffort: usageMeta.reasoningEffort.Load(),
|
||
Stream: true,
|
||
OpenAIWSMode: true,
|
||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||
Duration: relayResult.Duration,
|
||
FirstTokenMs: relayResult.FirstTokenMs,
|
||
}
|
||
|
||
turnCount := int(completedTurns.Load())
|
||
if relayExit == nil {
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
|
||
account.ID,
|
||
truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen),
|
||
truncateOpenAIWSLogValue(relayResult.TerminalEventType, openAIWSLogValueMaxLen),
|
||
result.Duration.Milliseconds(),
|
||
relayResult.ClientToUpstreamFrames,
|
||
relayResult.UpstreamToClientFrames,
|
||
relayResult.DroppedDownstreamFrames,
|
||
turnCount,
|
||
)
|
||
// 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。
|
||
if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil {
|
||
hooks.AfterTurn(1, result, nil)
|
||
}
|
||
return nil
|
||
}
|
||
logOpenAIWSV2Passthrough(
|
||
"relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
|
||
account.ID,
|
||
truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen),
|
||
relayExit.WroteDownstream,
|
||
truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen),
|
||
result.Duration.Milliseconds(),
|
||
relayResult.ClientToUpstreamFrames,
|
||
relayResult.UpstreamToClientFrames,
|
||
relayResult.DroppedDownstreamFrames,
|
||
turnCount,
|
||
)
|
||
|
||
relayErr := relayExit.Err
|
||
if relayExit.Stage == "idle_timeout" {
|
||
relayErr = NewOpenAIWSClientCloseError(
|
||
coderws.StatusPolicyViolation,
|
||
"client websocket idle timeout",
|
||
relayErr,
|
||
)
|
||
}
|
||
turnErr := wrapOpenAIWSIngressTurnError(
|
||
relayExit.Stage,
|
||
relayErr,
|
||
relayExit.WroteDownstream,
|
||
)
|
||
if hooks != nil && hooks.AfterTurn != nil {
|
||
hooks.AfterTurn(turnCount+1, nil, turnErr)
|
||
}
|
||
return turnErr
|
||
}
|
||
|
||
func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError(
|
||
err error,
|
||
statusCode int,
|
||
handshakeHeaders http.Header,
|
||
) error {
|
||
if err == nil {
|
||
return nil
|
||
}
|
||
wrappedErr := err
|
||
var dialErr *openAIWSDialError
|
||
if !errors.As(err, &dialErr) {
|
||
wrappedErr = &openAIWSDialError{
|
||
StatusCode: statusCode,
|
||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||
Err: err,
|
||
}
|
||
}
|
||
|
||
if errors.Is(err, context.Canceled) {
|
||
return err
|
||
}
|
||
if errors.Is(err, context.DeadlineExceeded) {
|
||
return NewOpenAIWSClientCloseError(
|
||
coderws.StatusTryAgainLater,
|
||
"upstream websocket connect timeout",
|
||
wrappedErr,
|
||
)
|
||
}
|
||
if statusCode == http.StatusTooManyRequests {
|
||
return NewOpenAIWSClientCloseError(
|
||
coderws.StatusTryAgainLater,
|
||
"upstream websocket is busy, please retry later",
|
||
wrappedErr,
|
||
)
|
||
}
|
||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||
return NewOpenAIWSClientCloseError(
|
||
coderws.StatusPolicyViolation,
|
||
"upstream websocket authentication failed",
|
||
wrappedErr,
|
||
)
|
||
}
|
||
if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError {
|
||
return NewOpenAIWSClientCloseError(
|
||
coderws.StatusPolicyViolation,
|
||
"upstream websocket handshake rejected",
|
||
wrappedErr,
|
||
)
|
||
}
|
||
return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr)
|
||
}
|
||
|
||
func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string {
|
||
switch msgType {
|
||
case coderws.MessageText:
|
||
return "text"
|
||
case coderws.MessageBinary:
|
||
return "binary"
|
||
default:
|
||
return fmt.Sprintf("unknown(%d)", msgType)
|
||
}
|
||
}
|
||
|
||
func relayErrorText(err error) string {
|
||
if err == nil {
|
||
return ""
|
||
}
|
||
return err.Error()
|
||
}
|
||
|
||
func openAIWSFirstTokenMsForLog(firstTokenMs *int) int {
|
||
if firstTokenMs == nil {
|
||
return -1
|
||
}
|
||
return *firstTokenMs
|
||
}
|
||
|
||
func logOpenAIWSV2Passthrough(format string, args ...any) {
|
||
logger.LegacyPrintf(
|
||
"service.openai_ws_v2",
|
||
"[OpenAI WS v2 passthrough] %s "+format,
|
||
append([]any{openaiWSV2PassthroughModeFields}, args...)...,
|
||
)
|
||
}
|