2026-03-05 11:50:58 +08:00
|
|
|
|
package service
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"errors"
|
|
|
|
|
|
"fmt"
|
|
|
|
|
|
"net/http"
|
|
|
|
|
|
"net/url"
|
|
|
|
|
|
"strings"
|
|
|
|
|
|
"sync/atomic"
|
|
|
|
|
|
|
|
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
|
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
|
|
|
|
|
openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
|
|
|
|
|
|
coderws "github.com/coder/websocket"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
|
"github.com/tidwall/gjson"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
type openAIWSClientFrameConn struct {
|
|
|
|
|
|
conn *coderws.Conn
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-28 00:34:23 +08:00
|
|
|
|
// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs
|
|
|
|
|
|
// every client→upstream frame through the OpenAI Fast Policy. It is the
|
|
|
|
|
|
// passthrough-relay equivalent of the parseClientPayload integration in the
|
|
|
|
|
|
// ingress session path. filter returns:
|
|
|
|
|
|
// - newPayload, nil, nil: forward the (possibly mutated) payload
|
|
|
|
|
|
// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error
|
|
|
|
|
|
// event via onBlock and surfaces a transport-level error so the relay
|
|
|
|
|
|
// stops reading from the client.
|
|
|
|
|
|
// - _, _, err: a transport error other than block.
|
|
|
|
|
|
type openAIWSPolicyEnforcingFrameConn struct {
|
|
|
|
|
|
inner openaiwsv2.FrameConn
|
|
|
|
|
|
filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error)
|
|
|
|
|
|
onBlock func(blocked *OpenAIFastBlockedError)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil)
|
|
|
|
|
|
|
|
|
|
|
|
func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
|
|
|
|
|
if c == nil || c.inner == nil {
|
|
|
|
|
|
return coderws.MessageText, nil, errOpenAIWSConnClosed
|
|
|
|
|
|
}
|
|
|
|
|
|
msgType, payload, err := c.inner.ReadFrame(ctx)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return msgType, payload, err
|
|
|
|
|
|
}
|
|
|
|
|
|
if c.filter == nil {
|
|
|
|
|
|
return msgType, payload, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
updated, blocked, filterErr := c.filter(msgType, payload)
|
|
|
|
|
|
if filterErr != nil {
|
|
|
|
|
|
return msgType, payload, filterErr
|
|
|
|
|
|
}
|
|
|
|
|
|
if blocked != nil {
|
|
|
|
|
|
if c.onBlock != nil {
|
|
|
|
|
|
c.onBlock(blocked)
|
|
|
|
|
|
}
|
|
|
|
|
|
return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
|
|
|
|
|
|
}
|
|
|
|
|
|
return msgType, updated, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
|
|
|
|
|
if c == nil || c.inner == nil {
|
|
|
|
|
|
return errOpenAIWSConnClosed
|
|
|
|
|
|
}
|
|
|
|
|
|
return c.inner.WriteFrame(ctx, msgType, payload)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *openAIWSPolicyEnforcingFrameConn) Close() error {
|
|
|
|
|
|
if c == nil || c.inner == nil {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
return c.inner.Close()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective
|
|
|
|
|
|
// model name that should be passed to evaluateOpenAIFastPolicy for a single
|
|
|
|
|
|
// passthrough WS frame. Mirrors the HTTP-side normalization
|
|
|
|
|
|
// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path
|
|
|
|
|
|
// matches model whitelists identically.
|
|
|
|
|
|
func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string {
|
|
|
|
|
|
if account == nil || len(payload) == 0 {
|
|
|
|
|
|
return ""
|
|
|
|
|
|
}
|
|
|
|
|
|
original := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
|
|
|
|
|
if original == "" {
|
|
|
|
|
|
return ""
|
|
|
|
|
|
}
|
|
|
|
|
|
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model
|
|
|
|
|
|
// derived from a session.update frame's session.model field. Returns "" when
|
|
|
|
|
|
// the frame is not a session.update event or carries no session.model. Used
|
|
|
|
|
|
// by the per-frame policy filter (client→upstream direction) to keep
|
|
|
|
|
|
// capturedSessionModel in sync with the session-level model the client may
|
|
|
|
|
|
// rotate mid-session.
|
|
|
|
|
|
//
|
|
|
|
|
|
// Realtime / Responses WS lets the client change the session model after
|
|
|
|
|
|
// the WS handshake via:
|
|
|
|
|
|
//
|
|
|
|
|
|
// {"type":"session.update","session":{"model":"gpt-5.5", ...}}
|
|
|
|
|
|
//
|
|
|
|
|
|
// If we only capture the model from the very first frame, a client can ship
|
|
|
|
|
|
// gpt-4o on the first response.create (whitelisted as pass), then
|
|
|
|
|
|
// session.update to gpt-5.5, then send response.create without "model" so
|
|
|
|
|
|
// the per-frame resolver returns "" and the stale capturedSessionModel falls
|
|
|
|
|
|
// back to gpt-4o — defeating the gpt-5.5 fast-policy filter.
|
|
|
|
|
|
func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string {
|
|
|
|
|
|
if account == nil || len(payload) == 0 {
|
|
|
|
|
|
return ""
|
|
|
|
|
|
}
|
|
|
|
|
|
frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
|
|
|
|
|
|
if frameType != "session.update" {
|
|
|
|
|
|
return ""
|
|
|
|
|
|
}
|
|
|
|
|
|
original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
|
|
|
|
|
|
if original == "" {
|
|
|
|
|
|
return ""
|
|
|
|
|
|
}
|
|
|
|
|
|
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-03-05 11:50:58 +08:00
|
|
|
|
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
|
|
|
|
|
|
|
|
|
|
|
|
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
|
|
|
|
|
|
|
|
|
|
|
|
func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
|
|
|
|
|
if c == nil || c.conn == nil {
|
|
|
|
|
|
return coderws.MessageText, nil, errOpenAIWSConnClosed
|
|
|
|
|
|
}
|
|
|
|
|
|
if ctx == nil {
|
|
|
|
|
|
ctx = context.Background()
|
|
|
|
|
|
}
|
|
|
|
|
|
return c.conn.Read(ctx)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
|
|
|
|
|
if c == nil || c.conn == nil {
|
|
|
|
|
|
return errOpenAIWSConnClosed
|
|
|
|
|
|
}
|
|
|
|
|
|
if ctx == nil {
|
|
|
|
|
|
ctx = context.Background()
|
|
|
|
|
|
}
|
|
|
|
|
|
return c.conn.Write(ctx, msgType, payload)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (c *openAIWSClientFrameConn) Close() error {
|
|
|
|
|
|
if c == nil || c.conn == nil {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
_ = c.conn.Close(coderws.StatusNormalClosure, "")
|
|
|
|
|
|
_ = c.conn.CloseNow()
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|
|
|
|
|
ctx context.Context,
|
|
|
|
|
|
c *gin.Context,
|
|
|
|
|
|
clientConn *coderws.Conn,
|
|
|
|
|
|
account *Account,
|
|
|
|
|
|
token string,
|
|
|
|
|
|
firstClientMessage []byte,
|
|
|
|
|
|
hooks *OpenAIWSIngressHooks,
|
|
|
|
|
|
wsDecision OpenAIWSProtocolDecision,
|
|
|
|
|
|
) error {
|
|
|
|
|
|
if s == nil {
|
|
|
|
|
|
return errors.New("service is nil")
|
|
|
|
|
|
}
|
|
|
|
|
|
if clientConn == nil {
|
|
|
|
|
|
return errors.New("client websocket is nil")
|
|
|
|
|
|
}
|
|
|
|
|
|
if account == nil {
|
|
|
|
|
|
return errors.New("account is nil")
|
|
|
|
|
|
}
|
|
|
|
|
|
if strings.TrimSpace(token) == "" {
|
|
|
|
|
|
return errors.New("token is empty")
|
|
|
|
|
|
}
|
|
|
|
|
|
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
|
|
|
|
|
|
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
|
|
|
|
|
|
logOpenAIWSV2Passthrough(
|
|
|
|
|
|
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
|
|
|
|
|
|
account.ID,
|
|
|
|
|
|
truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen),
|
|
|
|
|
|
truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen),
|
|
|
|
|
|
openaiwsv2RelayMessageTypeName(coderws.MessageText),
|
|
|
|
|
|
len(firstClientMessage),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-04-28 00:34:23 +08:00
|
|
|
|
// Apply OpenAI Fast Policy on the first response.create frame. Subsequent
|
|
|
|
|
|
// frames are filtered via a wrapping FrameConn below so every client→
|
|
|
|
|
|
// upstream frame goes through the same policy evaluator/normalize/scope as
|
|
|
|
|
|
// HTTP entrypoints.
|
|
|
|
|
|
//
|
|
|
|
|
|
// We capture the session-level model from the first frame here so the
|
|
|
|
|
|
// per-frame filter (below) can fall back to it when a follow-up frame
|
|
|
|
|
|
// omits "model" — Realtime clients are allowed to send response.create
|
|
|
|
|
|
// without re-stating the model, in which case the upstream uses the model
|
|
|
|
|
|
// negotiated at session.update time. Without this fallback, an empty
|
|
|
|
|
|
// model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be
|
|
|
|
|
|
// silently passed through, defeating the policy on every frame after
|
|
|
|
|
|
// the first.
|
|
|
|
|
|
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
|
|
|
|
|
|
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
|
|
|
|
|
|
if policyErr != nil {
|
|
|
|
|
|
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
|
|
|
|
|
|
}
|
|
|
|
|
|
if blocked != nil {
|
|
|
|
|
|
// coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires
|
|
|
|
|
|
// writeFrameMu, writes the entire frame, and Flushes the underlying
|
|
|
|
|
|
// bufio writer before returning (write.go:42 → write.go:307-311).
|
|
|
|
|
|
// The subsequent close handshake re-acquires the same writeFrameMu
|
|
|
|
|
|
// to send the close frame, so the error event is guaranteed to
|
|
|
|
|
|
// reach the kernel send buffer before any close frame is queued.
|
|
|
|
|
|
// No explicit flush hop is required here.
|
|
|
|
|
|
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
|
|
|
|
|
if eventBytes != nil {
|
|
|
|
|
|
writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
|
|
|
|
|
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
|
|
|
|
|
cancelWrite()
|
|
|
|
|
|
}
|
|
|
|
|
|
return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
|
|
|
|
|
|
}
|
|
|
|
|
|
firstClientMessage = updatedFirst
|
|
|
|
|
|
|
|
|
|
|
|
// 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter
|
|
|
|
|
|
// 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当
|
|
|
|
|
|
// 反映上游实际处理的 tier(nil = default),而不是用户最初请求的
|
|
|
|
|
|
// "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody))
|
|
|
|
|
|
// 与 WS ingress(openai_ws_forwarder.go:2991 取自 payload)的语义一致。
|
|
|
|
|
|
//
|
|
|
|
|
|
// 多轮 passthrough:OpenAI Realtime / Responses WS 协议允许客户端在
|
|
|
|
|
|
// 同一连接的不同 response.create 帧上发送不同 service_tier(参考
|
|
|
|
|
|
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
|
|
|
|
|
|
// 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream
|
|
|
|
|
|
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
|
|
|
|
|
// goroutine)之间同步当前 turn 的 service_tier。
|
|
|
|
|
|
// extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型,
|
|
|
|
|
|
// 可直接 Store/Load 而无需额外封装。
|
|
|
|
|
|
var requestServiceTierPtr atomic.Pointer[string]
|
|
|
|
|
|
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
|
|
|
|
|
|
|
2026-03-05 11:50:58 +08:00
|
|
|
|
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return fmt.Errorf("build ws url: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
wsHost := "-"
|
|
|
|
|
|
wsPath := "-"
|
|
|
|
|
|
if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil {
|
|
|
|
|
|
wsHost = normalizeOpenAIWSLogValue(parsedURL.Host)
|
|
|
|
|
|
wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
|
|
|
|
|
|
}
|
|
|
|
|
|
logOpenAIWSV2Passthrough(
|
|
|
|
|
|
"relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v",
|
|
|
|
|
|
account.ID,
|
|
|
|
|
|
wsHost,
|
|
|
|
|
|
wsPath,
|
|
|
|
|
|
account.ProxyID != nil && account.Proxy != nil,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
isCodexCLI := false
|
|
|
|
|
|
if c != nil {
|
2026-03-07 14:12:38 +08:00
|
|
|
|
isCodexCLI = openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator"))
|
2026-03-05 11:50:58 +08:00
|
|
|
|
}
|
|
|
|
|
|
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
|
|
|
|
|
isCodexCLI = true
|
|
|
|
|
|
}
|
|
|
|
|
|
headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
|
|
|
|
|
|
proxyURL := ""
|
|
|
|
|
|
if account.ProxyID != nil && account.Proxy != nil {
|
|
|
|
|
|
proxyURL = account.Proxy.URL()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
dialer := s.getOpenAIWSPassthroughDialer()
|
|
|
|
|
|
if dialer == nil {
|
|
|
|
|
|
return errors.New("openai ws passthrough dialer is nil")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout())
|
|
|
|
|
|
defer cancelDial()
|
|
|
|
|
|
upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
logOpenAIWSV2Passthrough(
|
|
|
|
|
|
"relay_dial_failed account_id=%d status_code=%d err=%s",
|
|
|
|
|
|
account.ID,
|
|
|
|
|
|
statusCode,
|
|
|
|
|
|
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
|
|
|
|
|
|
)
|
|
|
|
|
|
return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
|
|
|
|
|
|
}
|
|
|
|
|
|
defer func() {
|
|
|
|
|
|
_ = upstreamConn.Close()
|
|
|
|
|
|
}()
|
|
|
|
|
|
logOpenAIWSV2Passthrough(
|
|
|
|
|
|
"relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s",
|
|
|
|
|
|
account.ID,
|
|
|
|
|
|
statusCode,
|
|
|
|
|
|
openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn)
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
return errors.New("openai ws passthrough upstream connection does not support frame relay")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
completedTurns := atomic.Int32{}
|
2026-04-28 00:34:23 +08:00
|
|
|
|
policyClientConn := &openAIWSPolicyEnforcingFrameConn{
|
|
|
|
|
|
inner: &openAIWSClientFrameConn{conn: clientConn},
|
|
|
|
|
|
// 注意线程安全:filter 仅在 runClientToUpstream 这一条
|
|
|
|
|
|
// goroutine 中被调用(passthrough_relay.go: ReadFrame loop),
|
|
|
|
|
|
// capturedSessionModel 的读写都发生在该 goroutine 内,因此无需
|
|
|
|
|
|
// 加锁/原子化。
|
|
|
|
|
|
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
|
|
|
|
|
|
if msgType != coderws.MessageText {
|
|
|
|
|
|
return payload, nil, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
// 在评估策略前先刷新 capturedSessionModel:客户端可能通过
|
|
|
|
|
|
// session.update 修改 session-level model(Realtime /
|
|
|
|
|
|
// Responses WS 协议允许),如果不刷新就会出现
|
|
|
|
|
|
// "首帧 model=gpt-4o(pass)→ session.update 改成 gpt-5.5
|
|
|
|
|
|
// → 不带 model 的 response.create fallback 到 gpt-4o" 的
|
|
|
|
|
|
// 绕过路径。这里只看 session.update 事件中的 session.model
|
|
|
|
|
|
// 字段,response.create 自己的 model 仍然由其本帧字段决定。
|
|
|
|
|
|
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
|
|
|
|
|
|
capturedSessionModel = updated
|
|
|
|
|
|
}
|
|
|
|
|
|
// Per-frame model first; if the client omits "model" on a
|
|
|
|
|
|
// follow-up frame (legal in Realtime), fall back to the
|
|
|
|
|
|
// session-level model captured from the first frame so the
|
|
|
|
|
|
// model whitelist still resolves. An empty model would miss
|
|
|
|
|
|
// any whitelist and silently fall back to pass.
|
|
|
|
|
|
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
|
|
|
|
|
|
if model == "" {
|
|
|
|
|
|
model = capturedSessionModel
|
|
|
|
|
|
}
|
|
|
|
|
|
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
|
|
|
|
|
|
// 多轮 passthrough billing:仅在成功(non-block / non-err)
|
|
|
|
|
|
// 的 response.create 帧上更新 requestServiceTierPtr,使用
|
|
|
|
|
|
// filter 处理后的 payload,与首帧 policy-after-extract 语义
|
|
|
|
|
|
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
|
|
|
|
|
|
// - 非 response.create 帧(response.cancel /
|
|
|
|
|
|
// conversation.item.create / session.update 等)不携带
|
|
|
|
|
|
// per-response service_tier,不应覆盖前一轮值。
|
|
|
|
|
|
// - blocked != nil:该帧不会发送上游,billing tier 应保持
|
|
|
|
|
|
// 上一轮值。
|
|
|
|
|
|
// - policyErr != nil:异常路径,保持上一轮值。
|
|
|
|
|
|
// - 不带 service_tier 的 response.create 会让
|
|
|
|
|
|
// extractOpenAIServiceTierFromBody 返回 nil;这里有意
|
|
|
|
|
|
// 覆盖(Store(nil)),因为 OpenAI 上游对该帧实际不传
|
|
|
|
|
|
// service_tier 时按 default 处理,billing 应如实反映。
|
|
|
|
|
|
if policyErr == nil && blocked == nil &&
|
|
|
|
|
|
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
|
|
|
|
|
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
|
|
|
|
|
|
}
|
|
|
|
|
|
return out, blocked, policyErr
|
|
|
|
|
|
},
|
|
|
|
|
|
onBlock: func(blocked *OpenAIFastBlockedError) {
|
|
|
|
|
|
// See note above on Conn.Write being synchronous w.r.t. flush;
|
|
|
|
|
|
// no explicit flush is required to ensure the error event lands
|
|
|
|
|
|
// before the close frame.
|
|
|
|
|
|
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
|
|
|
|
|
if eventBytes == nil {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
|
|
|
|
|
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
|
|
|
|
|
cancel()
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
2026-03-05 11:50:58 +08:00
|
|
|
|
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
|
|
|
|
|
|
Ctx: ctx,
|
2026-04-28 00:34:23 +08:00
|
|
|
|
ClientConn: policyClientConn,
|
2026-03-05 11:50:58 +08:00
|
|
|
|
UpstreamConn: upstreamFrameConn,
|
|
|
|
|
|
FirstClientMessage: firstClientMessage,
|
|
|
|
|
|
Options: openaiwsv2.RelayOptions{
|
|
|
|
|
|
WriteTimeout: s.openAIWSWriteTimeout(),
|
|
|
|
|
|
IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
|
|
|
|
|
|
FirstMessageType: coderws.MessageText,
|
|
|
|
|
|
OnUsageParseFailure: func(eventType string, usageRaw string) {
|
|
|
|
|
|
logOpenAIWSV2Passthrough(
|
|
|
|
|
|
"usage_parse_failed event_type=%s usage_raw=%s",
|
|
|
|
|
|
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
|
|
|
|
|
|
truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen),
|
|
|
|
|
|
)
|
|
|
|
|
|
},
|
|
|
|
|
|
OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) {
|
|
|
|
|
|
turnNo := int(completedTurns.Add(1))
|
|
|
|
|
|
turnResult := &OpenAIForwardResult{
|
|
|
|
|
|
RequestID: turn.RequestID,
|
|
|
|
|
|
Usage: OpenAIUsage{
|
|
|
|
|
|
InputTokens: turn.Usage.InputTokens,
|
|
|
|
|
|
OutputTokens: turn.Usage.OutputTokens,
|
|
|
|
|
|
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
|
|
|
|
|
|
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
|
|
|
|
|
},
|
2026-03-06 20:46:10 +08:00
|
|
|
|
Model: turn.RequestModel,
|
2026-04-28 00:34:23 +08:00
|
|
|
|
ServiceTier: requestServiceTierPtr.Load(),
|
2026-03-06 20:46:10 +08:00
|
|
|
|
Stream: true,
|
|
|
|
|
|
OpenAIWSMode: true,
|
|
|
|
|
|
ResponseHeaders: cloneHeader(handshakeHeaders),
|
|
|
|
|
|
Duration: turn.Duration,
|
|
|
|
|
|
FirstTokenMs: turn.FirstTokenMs,
|
2026-03-05 11:50:58 +08:00
|
|
|
|
}
|
|
|
|
|
|
logOpenAIWSV2Passthrough(
|
|
|
|
|
|
"relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d",
|
|
|
|
|
|
account.ID,
|
|
|
|
|
|
turnNo,
|
|
|
|
|
|
truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen),
|
|
|
|
|
|
truncateOpenAIWSLogValue(turn.TerminalEventType, openAIWSLogValueMaxLen),
|
|
|
|
|
|
turnResult.Duration.Milliseconds(),
|
|
|
|
|
|
openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs),
|
|
|
|
|
|
turnResult.Usage.InputTokens,
|
|
|
|
|
|
turnResult.Usage.OutputTokens,
|
|
|
|
|
|
turnResult.Usage.CacheReadInputTokens,
|
|
|
|
|
|
)
|
|
|
|
|
|
if hooks != nil && hooks.AfterTurn != nil {
|
|
|
|
|
|
hooks.AfterTurn(turnNo, turnResult, nil)
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
OnTrace: func(event openaiwsv2.RelayTraceEvent) {
|
|
|
|
|
|
logOpenAIWSV2Passthrough(
|
|
|
|
|
|
"relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",
|
|
|
|
|
|
account.ID,
|
|
|
|
|
|
truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen),
|
|
|
|
|
|
truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen),
|
|
|
|
|
|
truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen),
|
|
|
|
|
|
event.PayloadBytes,
|
|
|
|
|
|
event.Graceful,
|
|
|
|
|
|
event.WroteDownstream,
|
|
|
|
|
|
truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen),
|
|
|
|
|
|
)
|
|
|
|
|
|
},
|
|
|
|
|
|
},
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
result := &OpenAIForwardResult{
|
|
|
|
|
|
RequestID: relayResult.RequestID,
|
|
|
|
|
|
Usage: OpenAIUsage{
|
|
|
|
|
|
InputTokens: relayResult.Usage.InputTokens,
|
|
|
|
|
|
OutputTokens: relayResult.Usage.OutputTokens,
|
|
|
|
|
|
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
|
|
|
|
|
|
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
|
|
|
|
|
},
|
2026-03-06 20:46:10 +08:00
|
|
|
|
Model: relayResult.RequestModel,
|
2026-04-28 00:34:23 +08:00
|
|
|
|
ServiceTier: requestServiceTierPtr.Load(),
|
2026-03-06 20:46:10 +08:00
|
|
|
|
Stream: true,
|
|
|
|
|
|
OpenAIWSMode: true,
|
|
|
|
|
|
ResponseHeaders: cloneHeader(handshakeHeaders),
|
|
|
|
|
|
Duration: relayResult.Duration,
|
|
|
|
|
|
FirstTokenMs: relayResult.FirstTokenMs,
|
2026-03-05 11:50:58 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
turnCount := int(completedTurns.Load())
|
|
|
|
|
|
if relayExit == nil {
|
|
|
|
|
|
logOpenAIWSV2Passthrough(
|
|
|
|
|
|
"relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
|
|
|
|
|
|
account.ID,
|
|
|
|
|
|
truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen),
|
|
|
|
|
|
truncateOpenAIWSLogValue(relayResult.TerminalEventType, openAIWSLogValueMaxLen),
|
|
|
|
|
|
result.Duration.Milliseconds(),
|
|
|
|
|
|
relayResult.ClientToUpstreamFrames,
|
|
|
|
|
|
relayResult.UpstreamToClientFrames,
|
|
|
|
|
|
relayResult.DroppedDownstreamFrames,
|
|
|
|
|
|
turnCount,
|
|
|
|
|
|
)
|
|
|
|
|
|
// 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。
|
|
|
|
|
|
if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil {
|
|
|
|
|
|
hooks.AfterTurn(1, result, nil)
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
logOpenAIWSV2Passthrough(
|
|
|
|
|
|
"relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
|
|
|
|
|
|
account.ID,
|
|
|
|
|
|
truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen),
|
|
|
|
|
|
relayExit.WroteDownstream,
|
|
|
|
|
|
truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen),
|
|
|
|
|
|
result.Duration.Milliseconds(),
|
|
|
|
|
|
relayResult.ClientToUpstreamFrames,
|
|
|
|
|
|
relayResult.UpstreamToClientFrames,
|
|
|
|
|
|
relayResult.DroppedDownstreamFrames,
|
|
|
|
|
|
turnCount,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
relayErr := relayExit.Err
|
|
|
|
|
|
if relayExit.Stage == "idle_timeout" {
|
|
|
|
|
|
relayErr = NewOpenAIWSClientCloseError(
|
|
|
|
|
|
coderws.StatusPolicyViolation,
|
|
|
|
|
|
"client websocket idle timeout",
|
|
|
|
|
|
relayErr,
|
|
|
|
|
|
)
|
|
|
|
|
|
}
|
|
|
|
|
|
turnErr := wrapOpenAIWSIngressTurnError(
|
|
|
|
|
|
relayExit.Stage,
|
|
|
|
|
|
relayErr,
|
|
|
|
|
|
relayExit.WroteDownstream,
|
|
|
|
|
|
)
|
|
|
|
|
|
if hooks != nil && hooks.AfterTurn != nil {
|
|
|
|
|
|
hooks.AfterTurn(turnCount+1, nil, turnErr)
|
|
|
|
|
|
}
|
|
|
|
|
|
return turnErr
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError(
|
|
|
|
|
|
err error,
|
|
|
|
|
|
statusCode int,
|
|
|
|
|
|
handshakeHeaders http.Header,
|
|
|
|
|
|
) error {
|
|
|
|
|
|
if err == nil {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
wrappedErr := err
|
|
|
|
|
|
var dialErr *openAIWSDialError
|
|
|
|
|
|
if !errors.As(err, &dialErr) {
|
|
|
|
|
|
wrappedErr = &openAIWSDialError{
|
|
|
|
|
|
StatusCode: statusCode,
|
|
|
|
|
|
ResponseHeaders: cloneHeader(handshakeHeaders),
|
|
|
|
|
|
Err: err,
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if errors.Is(err, context.Canceled) {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
|
|
|
|
return NewOpenAIWSClientCloseError(
|
|
|
|
|
|
coderws.StatusTryAgainLater,
|
|
|
|
|
|
"upstream websocket connect timeout",
|
|
|
|
|
|
wrappedErr,
|
|
|
|
|
|
)
|
|
|
|
|
|
}
|
|
|
|
|
|
if statusCode == http.StatusTooManyRequests {
|
|
|
|
|
|
return NewOpenAIWSClientCloseError(
|
|
|
|
|
|
coderws.StatusTryAgainLater,
|
|
|
|
|
|
"upstream websocket is busy, please retry later",
|
|
|
|
|
|
wrappedErr,
|
|
|
|
|
|
)
|
|
|
|
|
|
}
|
|
|
|
|
|
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
|
|
|
|
|
return NewOpenAIWSClientCloseError(
|
|
|
|
|
|
coderws.StatusPolicyViolation,
|
|
|
|
|
|
"upstream websocket authentication failed",
|
|
|
|
|
|
wrappedErr,
|
|
|
|
|
|
)
|
|
|
|
|
|
}
|
|
|
|
|
|
if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError {
|
|
|
|
|
|
return NewOpenAIWSClientCloseError(
|
|
|
|
|
|
coderws.StatusPolicyViolation,
|
|
|
|
|
|
"upstream websocket handshake rejected",
|
|
|
|
|
|
wrappedErr,
|
|
|
|
|
|
)
|
|
|
|
|
|
}
|
|
|
|
|
|
return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string {
|
|
|
|
|
|
switch msgType {
|
|
|
|
|
|
case coderws.MessageText:
|
|
|
|
|
|
return "text"
|
|
|
|
|
|
case coderws.MessageBinary:
|
|
|
|
|
|
return "binary"
|
|
|
|
|
|
default:
|
|
|
|
|
|
return fmt.Sprintf("unknown(%d)", msgType)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func relayErrorText(err error) string {
|
|
|
|
|
|
if err == nil {
|
|
|
|
|
|
return ""
|
|
|
|
|
|
}
|
|
|
|
|
|
return err.Error()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func openAIWSFirstTokenMsForLog(firstTokenMs *int) int {
|
|
|
|
|
|
if firstTokenMs == nil {
|
|
|
|
|
|
return -1
|
|
|
|
|
|
}
|
|
|
|
|
|
return *firstTokenMs
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func logOpenAIWSV2Passthrough(format string, args ...any) {
|
|
|
|
|
|
logger.LegacyPrintf(
|
|
|
|
|
|
"service.openai_ws_v2",
|
|
|
|
|
|
"[OpenAI WS v2 passthrough] %s "+format,
|
|
|
|
|
|
append([]any{openaiWSV2PassthroughModeFields}, args...)...,
|
|
|
|
|
|
)
|
|
|
|
|
|
}
|