Files
sub2api/backend/internal/service/openai_ws_forwarder.go

4043 lines
129 KiB
Go
Raw Permalink Normal View History

package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/url"
"sort"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
)
const (
openAIWSBetaV1Value = "responses_websockets=2026-02-04"
openAIWSBetaV2Value = "responses_websockets=2026-02-06"
openAIWSTurnStateHeader = "x-codex-turn-state"
openAIWSTurnMetadataHeader = "x-codex-turn-metadata"
openAIWSLogValueMaxLen = 160
openAIWSHeaderValueMaxLen = 120
openAIWSIDValueMaxLen = 64
openAIWSEventLogHeadLimit = 20
openAIWSEventLogEveryN = 50
openAIWSBufferLogHeadLimit = 8
openAIWSBufferLogEveryN = 20
openAIWSPrewarmEventLogHead = 10
openAIWSPayloadKeySizeTopN = 6
openAIWSPayloadSizeEstimateDepth = 3
openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024
openAIWSPayloadSizeEstimateMaxItems = 16
openAIWSEventFlushBatchSizeDefault = 4
openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
openAIWSPayloadLogSampleDefault = 0.2
openAIWSPassthroughIdleTimeoutDefault = time.Hour
openAIWSStoreDisabledConnModeStrict = "strict"
openAIWSStoreDisabledConnModeAdaptive = "adaptive"
openAIWSStoreDisabledConnModeOff = "off"
openAIWSIngressStagePreviousResponseNotFound = "previous_response_not_found"
openAIWSMaxPrevResponseIDDeletePasses = 8
)
var openAIWSLogValueReplacer = strings.NewReplacer(
"error", "err",
"fallback", "fb",
"warning", "warnx",
"failed", "fail",
)
var openAIWSIngressPreflightPingIdle = 20 * time.Second
// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。
type openAIWSFallbackError struct {
Reason string
Err error
}
func (e *openAIWSFallbackError) Error() string {
if e == nil {
return ""
}
if e.Err == nil {
return fmt.Sprintf("openai ws fallback: %s", strings.TrimSpace(e.Reason))
}
return fmt.Sprintf("openai ws fallback: %s: %v", strings.TrimSpace(e.Reason), e.Err)
}
func (e *openAIWSFallbackError) Unwrap() error {
if e == nil {
return nil
}
return e.Err
}
func wrapOpenAIWSFallback(reason string, err error) error {
return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err}
}
// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。
type OpenAIWSClientCloseError struct {
statusCode coderws.StatusCode
reason string
err error
}
type openAIWSIngressTurnError struct {
stage string
cause error
wroteDownstream bool
}
func (e *openAIWSIngressTurnError) Error() string {
if e == nil {
return ""
}
if e.cause == nil {
return strings.TrimSpace(e.stage)
}
return e.cause.Error()
}
func (e *openAIWSIngressTurnError) Unwrap() error {
if e == nil {
return nil
}
return e.cause
}
func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error {
if cause == nil {
return nil
}
return &openAIWSIngressTurnError{
stage: strings.TrimSpace(stage),
cause: cause,
wroteDownstream: wroteDownstream,
}
}
func isOpenAIWSIngressTurnRetryable(err error) bool {
var turnErr *openAIWSIngressTurnError
if !errors.As(err, &turnErr) || turnErr == nil {
return false
}
if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) {
return false
}
if turnErr.wroteDownstream {
return false
}
switch turnErr.stage {
case "write_upstream", "read_upstream":
return true
default:
return false
}
}
func openAIWSIngressTurnRetryReason(err error) string {
var turnErr *openAIWSIngressTurnError
if !errors.As(err, &turnErr) || turnErr == nil {
return "unknown"
}
if turnErr.stage == "" {
return "unknown"
}
return turnErr.stage
}
func isOpenAIWSIngressPreviousResponseNotFound(err error) bool {
var turnErr *openAIWSIngressTurnError
if !errors.As(err, &turnErr) || turnErr == nil {
return false
}
if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound {
return false
}
return !turnErr.wroteDownstream
}
// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。
func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error {
return &OpenAIWSClientCloseError{
statusCode: statusCode,
reason: strings.TrimSpace(reason),
err: err,
}
}
func (e *OpenAIWSClientCloseError) Error() string {
if e == nil {
return ""
}
if e.err == nil {
return fmt.Sprintf("openai ws client close: %d %s", int(e.statusCode), strings.TrimSpace(e.reason))
}
return fmt.Sprintf("openai ws client close: %d %s: %v", int(e.statusCode), strings.TrimSpace(e.reason), e.err)
}
func (e *OpenAIWSClientCloseError) Unwrap() error {
if e == nil {
return nil
}
return e.err
}
func (e *OpenAIWSClientCloseError) StatusCode() coderws.StatusCode {
if e == nil {
return coderws.StatusInternalError
}
return e.statusCode
}
func (e *OpenAIWSClientCloseError) Reason() string {
if e == nil {
return ""
}
return strings.TrimSpace(e.reason)
}
// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。
type OpenAIWSIngressHooks struct {
BeforeTurn func(turn int) error
AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
}
func normalizeOpenAIWSLogValue(value string) string {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return "-"
}
return openAIWSLogValueReplacer.Replace(trimmed)
}
func truncateOpenAIWSLogValue(value string, maxLen int) string {
normalized := normalizeOpenAIWSLogValue(value)
if normalized == "-" || maxLen <= 0 {
return normalized
}
if len(normalized) <= maxLen {
return normalized
}
return normalized[:maxLen] + "..."
}
func openAIWSHeaderValueForLog(headers http.Header, key string) string {
if headers == nil {
return "-"
}
return truncateOpenAIWSLogValue(headers.Get(key), openAIWSHeaderValueMaxLen)
}
func hasOpenAIWSHeader(headers http.Header, key string) bool {
if headers == nil {
return false
}
return strings.TrimSpace(headers.Get(key)) != ""
}
type openAIWSSessionHeaderResolution struct {
SessionID string
ConversationID string
SessionSource string
ConversationSource string
}
func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution {
resolution := openAIWSSessionHeaderResolution{
SessionSource: "none",
ConversationSource: "none",
}
if c != nil && c.Request != nil {
if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" {
resolution.SessionID = sessionID
resolution.SessionSource = "header_session_id"
}
if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" {
resolution.ConversationID = conversationID
resolution.ConversationSource = "header_conversation_id"
if resolution.SessionID == "" {
resolution.SessionID = conversationID
resolution.SessionSource = "header_conversation_id"
}
}
}
cacheKey := strings.TrimSpace(promptCacheKey)
if cacheKey != "" {
if resolution.SessionID == "" {
resolution.SessionID = cacheKey
resolution.SessionSource = "prompt_cache_key"
}
}
return resolution
}
func shouldLogOpenAIWSEvent(idx int, eventType string) bool {
if idx <= openAIWSEventLogHeadLimit {
return true
}
if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 {
return true
}
if eventType == "error" || isOpenAIWSTerminalEvent(eventType) {
return true
}
return false
}
func shouldLogOpenAIWSBufferedEvent(idx int) bool {
if idx <= openAIWSBufferLogHeadLimit {
return true
}
if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 {
return true
}
return false
}
func openAIWSEventMayContainModel(eventType string) bool {
switch eventType {
case "response.created",
"response.in_progress",
"response.completed",
"response.done",
"response.failed",
"response.incomplete",
"response.cancelled",
"response.canceled":
return true
default:
trimmed := strings.TrimSpace(eventType)
if trimmed == eventType {
return false
}
switch trimmed {
case "response.created",
"response.in_progress",
"response.completed",
"response.done",
"response.failed",
"response.incomplete",
"response.cancelled",
"response.canceled":
return true
default:
return false
}
}
}
func openAIWSEventMayContainToolCalls(eventType string) bool {
eventType = strings.TrimSpace(eventType)
if eventType == "" {
return false
}
if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") {
return true
}
switch eventType {
case "response.output_item.added", "response.output_item.done", "response.completed", "response.done":
return true
default:
return false
}
}
func openAIWSEventShouldParseUsage(eventType string) bool {
return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed"
}
func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) {
if len(message) == 0 {
return "", "", gjson.Result{}
}
values := gjson.GetManyBytes(message, "type", "response.id", "id", "response")
eventType = strings.TrimSpace(values[0].String())
if id := strings.TrimSpace(values[1].String()); id != "" {
responseID = id
} else {
responseID = strings.TrimSpace(values[2].String())
}
return eventType, responseID, values[3]
}
func openAIWSMessageLikelyContainsToolCalls(message []byte) bool {
if len(message) == 0 {
return false
}
return bytes.Contains(message, []byte(`"tool_calls"`)) ||
bytes.Contains(message, []byte(`"tool_call"`)) ||
bytes.Contains(message, []byte(`"function_call"`))
}
func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) {
if usage == nil || len(message) == 0 {
return
}
values := gjson.GetManyBytes(
message,
"response.usage.input_tokens",
"response.usage.output_tokens",
"response.usage.input_tokens_details.cached_tokens",
)
usage.InputTokens = int(values[0].Int())
usage.OutputTokens = int(values[1].Int())
usage.CacheReadInputTokens = int(values[2].Int())
}
func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) {
if len(message) == 0 {
return "", "", ""
}
values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message")
return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String())
}
func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) {
code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen)
errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen)
errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen)
return code, errType, errMessage
}
func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) {
if len(message) == 0 {
return "-", "-", "-"
}
return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message))
}
func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string {
if len(payload) == 0 {
return "-"
}
type keySize struct {
Key string
Size int
}
sizes := make([]keySize, 0, len(payload))
for key, value := range payload {
size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth)
sizes = append(sizes, keySize{Key: key, Size: size})
}
sort.Slice(sizes, func(i, j int) bool {
if sizes[i].Size == sizes[j].Size {
return sizes[i].Key < sizes[j].Key
}
return sizes[i].Size > sizes[j].Size
})
if topN <= 0 || topN > len(sizes) {
topN = len(sizes)
}
parts := make([]string, 0, topN)
for idx := 0; idx < topN; idx++ {
item := sizes[idx]
parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size))
}
return strings.Join(parts, ",")
}
func estimateOpenAIWSPayloadValueSize(value any, depth int) int {
if depth <= 0 {
return -1
}
switch v := value.(type) {
case nil:
return 0
case string:
return len(v)
case []byte:
return len(v)
case bool:
return 1
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return 8
case float32, float64:
return 8
case map[string]any:
if len(v) == 0 {
return 2
}
total := 2
count := 0
for key, item := range v {
count++
if count > openAIWSPayloadSizeEstimateMaxItems {
return -1
}
itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1)
if itemSize < 0 {
return -1
}
total += len(key) + itemSize + 3
if total > openAIWSPayloadSizeEstimateMaxBytes {
return -1
}
}
return total
case []any:
if len(v) == 0 {
return 2
}
total := 2
limit := len(v)
if limit > openAIWSPayloadSizeEstimateMaxItems {
return -1
}
for i := 0; i < limit; i++ {
itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1)
if itemSize < 0 {
return -1
}
total += itemSize + 1
if total > openAIWSPayloadSizeEstimateMaxBytes {
return -1
}
}
return total
default:
raw, err := json.Marshal(v)
if err != nil {
return -1
}
if len(raw) > openAIWSPayloadSizeEstimateMaxBytes {
return -1
}
return len(raw)
}
}
func openAIWSPayloadString(payload map[string]any, key string) string {
if len(payload) == 0 {
return ""
}
raw, ok := payload[key]
if !ok {
return ""
}
switch v := raw.(type) {
case nil:
return ""
case string:
return strings.TrimSpace(v)
case []byte:
return strings.TrimSpace(string(v))
default:
return ""
}
}
func openAIWSPayloadStringFromRaw(payload []byte, key string) string {
if len(payload) == 0 || strings.TrimSpace(key) == "" {
return ""
}
return strings.TrimSpace(gjson.GetBytes(payload, key).String())
}
func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool {
if len(payload) == 0 || strings.TrimSpace(key) == "" {
return defaultValue
}
value := gjson.GetBytes(payload, key)
if !value.Exists() {
return defaultValue
}
if value.Type != gjson.True && value.Type != gjson.False {
return defaultValue
}
return value.Bool()
}
func openAIWSSessionHashesFromID(sessionID string) (string, string) {
return deriveOpenAISessionHashes(sessionID)
}
func extractOpenAIWSImageURL(value any) string {
switch v := value.(type) {
case string:
return strings.TrimSpace(v)
case map[string]any:
if raw, ok := v["url"].(string); ok {
return strings.TrimSpace(raw)
}
}
return ""
}
func summarizeOpenAIWSInput(input any) string {
items, ok := input.([]any)
if !ok || len(items) == 0 {
return "-"
}
itemCount := len(items)
textChars := 0
imageDataURLs := 0
imageDataURLChars := 0
imageRemoteURLs := 0
handleContentItem := func(contentItem map[string]any) {
contentType, _ := contentItem["type"].(string)
switch strings.TrimSpace(contentType) {
case "input_text", "output_text", "text":
if text, ok := contentItem["text"].(string); ok {
textChars += len(text)
}
case "input_image":
imageURL := extractOpenAIWSImageURL(contentItem["image_url"])
if imageURL == "" {
return
}
if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") {
imageDataURLs++
imageDataURLChars += len(imageURL)
return
}
imageRemoteURLs++
}
}
handleInputItem := func(inputItem map[string]any) {
if content, ok := inputItem["content"].([]any); ok {
for _, rawContent := range content {
contentItem, ok := rawContent.(map[string]any)
if !ok {
continue
}
handleContentItem(contentItem)
}
return
}
itemType, _ := inputItem["type"].(string)
switch strings.TrimSpace(itemType) {
case "input_text", "output_text", "text":
if text, ok := inputItem["text"].(string); ok {
textChars += len(text)
}
case "input_image":
imageURL := extractOpenAIWSImageURL(inputItem["image_url"])
if imageURL == "" {
return
}
if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") {
imageDataURLs++
imageDataURLChars += len(imageURL)
return
}
imageRemoteURLs++
}
}
for _, rawItem := range items {
inputItem, ok := rawItem.(map[string]any)
if !ok {
continue
}
handleInputItem(inputItem)
}
return fmt.Sprintf(
"items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d",
itemCount,
textChars,
imageDataURLs,
imageDataURLChars,
imageRemoteURLs,
)
}
func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) {
if len(payload) == 0 || strings.TrimSpace(key) == "" {
return
}
if _, exists := payload[key]; !exists {
return
}
delete(payload, key)
*removed = append(*removed, key)
}
// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段,
// 避免重试成功却改变原始请求语义。
// 注意prompt_cache_key 不应在重试中移除它常用于会话稳定标识session_id 兜底)。
func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) {
if len(payload) == 0 {
return "empty", nil
}
if attempt <= 1 {
return "full", nil
}
removed := make([]string, 0, 2)
if attempt >= 2 {
dropOpenAIWSPayloadKey(payload, "include", &removed)
}
if len(removed) == 0 {
return "full", nil
}
sort.Strings(removed)
return "trim_optional_fields", removed
}
func logOpenAIWSModeInfo(format string, args ...any) {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...)
}
func isOpenAIWSModeDebugEnabled() bool {
return logger.L().Core().Enabled(zap.DebugLevel)
}
func logOpenAIWSModeDebug(format string, args ...any) {
if !isOpenAIWSModeDebugEnabled() {
return
}
logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...)
}
func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) {
if err == nil {
return
}
logger.L().Warn(
"openai.ws_bind_response_account_failed",
zap.Int64("group_id", groupID),
zap.Int64("account_id", accountID),
zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)),
zap.Error(err),
)
}
func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) {
if err == nil {
return "-", "-"
}
statusCode := coderws.CloseStatus(err)
if statusCode == -1 {
return "-", "-"
}
closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
closeReason := "-"
var closeErr coderws.CloseError
if errors.As(err, &closeErr) {
reasonText := strings.TrimSpace(closeErr.Reason)
if reasonText != "" {
closeReason = normalizeOpenAIWSLogValue(reasonText)
}
}
return normalizeOpenAIWSLogValue(closeStatus), closeReason
}
func unwrapOpenAIWSDialBaseError(err error) error {
if err == nil {
return nil
}
var dialErr *openAIWSDialError
if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil {
return dialErr.Err
}
return err
}
func openAIWSDialRespHeaderForLog(err error, key string) string {
var dialErr *openAIWSDialError
if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil {
return "-"
}
return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen)
}
func classifyOpenAIWSDialError(err error) string {
if err == nil {
return "-"
}
baseErr := unwrapOpenAIWSDialBaseError(err)
if baseErr == nil {
return "-"
}
if errors.Is(baseErr, context.DeadlineExceeded) {
return "ctx_deadline_exceeded"
}
if errors.Is(baseErr, context.Canceled) {
return "ctx_canceled"
}
var netErr net.Error
if errors.As(baseErr, &netErr) && netErr.Timeout() {
return "net_timeout"
}
if status := coderws.CloseStatus(baseErr); status != -1 {
return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status)))
}
message := strings.ToLower(strings.TrimSpace(baseErr.Error()))
switch {
case strings.Contains(message, "handshake not finished"):
return "handshake_not_finished"
case strings.Contains(message, "bad handshake"):
return "bad_handshake"
case strings.Contains(message, "connection refused"):
return "connection_refused"
case strings.Contains(message, "no such host"):
return "dns_not_found"
case strings.Contains(message, "tls"):
return "tls_error"
case strings.Contains(message, "i/o timeout"):
return "io_timeout"
case strings.Contains(message, "context deadline exceeded"):
return "ctx_deadline_exceeded"
default:
return "dial_error"
}
}
func summarizeOpenAIWSDialError(err error) (
statusCode int,
dialClass string,
closeStatus string,
closeReason string,
respServer string,
respVia string,
respCFRay string,
respRequestID string,
) {
dialClass = "-"
closeStatus = "-"
closeReason = "-"
respServer = "-"
respVia = "-"
respCFRay = "-"
respRequestID = "-"
if err == nil {
return
}
var dialErr *openAIWSDialError
if errors.As(err, &dialErr) && dialErr != nil {
statusCode = dialErr.StatusCode
respServer = openAIWSDialRespHeaderForLog(err, "server")
respVia = openAIWSDialRespHeaderForLog(err, "via")
respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray")
respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id")
}
dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err))
closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err))
return
}
func isOpenAIWSClientDisconnectError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
return true
}
switch coderws.CloseStatus(err) {
case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
return true
}
message := strings.ToLower(strings.TrimSpace(err.Error()))
if message == "" {
return false
}
return strings.Contains(message, "failed to read frame header: eof") ||
strings.Contains(message, "unexpected eof") ||
strings.Contains(message, "use of closed network connection") ||
strings.Contains(message, "connection reset by peer") ||
strings.Contains(message, "broken pipe") ||
strings.Contains(message, "an established connection was aborted")
}
func classifyOpenAIWSReadFallbackReason(err error) string {
if err == nil {
return "read_event"
}
switch coderws.CloseStatus(err) {
case coderws.StatusPolicyViolation:
return "policy_violation"
case coderws.StatusMessageTooBig:
return "message_too_big"
default:
return "read_event"
}
}
func sortedKeys(m map[string]any) []string {
if len(m) == 0 {
return nil
}
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool {
if s == nil {
return nil
}
s.openaiWSPoolOnce.Do(func() {
if s.openaiWSPool == nil {
s.openaiWSPool = newOpenAIWSConnPool(s.cfg)
}
})
return s.openaiWSPool
}
func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer {
if s == nil {
return nil
}
s.openaiWSPassthroughDialerOnce.Do(func() {
if s.openaiWSPassthroughDialer == nil {
s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer()
}
})
return s.openaiWSPassthroughDialer
}
func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot {
pool := s.getOpenAIWSConnPool()
if pool == nil {
return OpenAIWSPoolMetricsSnapshot{}
}
return pool.SnapshotMetrics()
}
type OpenAIWSPerformanceMetricsSnapshot struct {
Pool OpenAIWSPoolMetricsSnapshot `json:"pool"`
Retry OpenAIWSRetryMetricsSnapshot `json:"retry"`
Transport OpenAIWSTransportMetricsSnapshot `json:"transport"`
}
func (s *OpenAIGatewayService) SnapshotOpenAIWSPerformanceMetrics() OpenAIWSPerformanceMetricsSnapshot {
pool := s.getOpenAIWSConnPool()
snapshot := OpenAIWSPerformanceMetricsSnapshot{
Retry: s.SnapshotOpenAIWSRetryMetrics(),
}
if pool == nil {
return snapshot
}
snapshot.Pool = pool.SnapshotMetrics()
snapshot.Transport = pool.SnapshotTransportMetrics()
return snapshot
}
func (s *OpenAIGatewayService) getOpenAIWSStateStore() OpenAIWSStateStore {
if s == nil {
return nil
}
s.openaiWSStateStoreOnce.Do(func() {
if s.openaiWSStateStore == nil {
s.openaiWSStateStore = NewOpenAIWSStateStore(s.cache)
}
})
return s.openaiWSStateStore
}
func (s *OpenAIGatewayService) openAIWSResponseStickyTTL() time.Duration {
if s != nil && s.cfg != nil {
seconds := s.cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds
if seconds > 0 {
return time.Duration(seconds) * time.Second
}
}
return time.Hour
}
func (s *OpenAIGatewayService) openAIWSIngressPreviousResponseRecoveryEnabled() bool {
if s != nil && s.cfg != nil {
return s.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled
}
return true
}
func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds > 0 {
return time.Duration(s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds) * time.Second
}
return 15 * time.Minute
}
func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration {
if timeout := s.openAIWSReadTimeout(); timeout > 0 {
return timeout
}
return openAIWSPassthroughIdleTimeoutDefault
}
func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 {
return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second
}
return 2 * time.Minute
}
func (s *OpenAIGatewayService) openAIWSEventFlushBatchSize() int {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushBatchSize > 0 {
return s.cfg.Gateway.OpenAIWS.EventFlushBatchSize
}
return openAIWSEventFlushBatchSizeDefault
}
func (s *OpenAIGatewayService) openAIWSEventFlushInterval() time.Duration {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS >= 0 {
if s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS == 0 {
return 0
}
return time.Duration(s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS) * time.Millisecond
}
return openAIWSEventFlushIntervalDefault
}
func (s *OpenAIGatewayService) openAIWSPayloadLogSampleRate() float64 {
if s != nil && s.cfg != nil {
rate := s.cfg.Gateway.OpenAIWS.PayloadLogSampleRate
if rate < 0 {
return 0
}
if rate > 1 {
return 1
}
return rate
}
return openAIWSPayloadLogSampleDefault
}
func (s *OpenAIGatewayService) shouldLogOpenAIWSPayloadSchema(attempt int) bool {
// 首次尝试保留一条完整 payload_schema 便于排障。
if attempt <= 1 {
return true
}
rate := s.openAIWSPayloadLogSampleRate()
if rate <= 0 {
return false
}
if rate >= 1 {
return true
}
return rand.Float64() < rate
}
func (s *OpenAIGatewayService) shouldEmitOpenAIWSPayloadSchema(attempt int) bool {
if !s.shouldLogOpenAIWSPayloadSchema(attempt) {
return false
}
return logger.L().Core().Enabled(zap.DebugLevel)
}
func (s *OpenAIGatewayService) openAIWSDialTimeout() time.Duration {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 {
return time.Duration(s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second
}
return 10 * time.Second
}
func (s *OpenAIGatewayService) openAIWSAcquireTimeout() time.Duration {
// Acquire 覆盖“连接复用命中/排队/新建连接”三个阶段。
// 这里不再叠加 write_timeout避免高并发排队时把 TTFT 长尾拉到分钟级。
dial := s.openAIWSDialTimeout()
if dial <= 0 {
dial = 10 * time.Second
}
return dial + 2*time.Second
}
func (s *OpenAIGatewayService) buildOpenAIResponsesWSURL(account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
var targetURL string
switch account.Type {
case AccountTypeOAuth:
targetURL = chatgptCodexURL
case AccountTypeAPIKey:
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
targetURL = openaiPlatformAPIURL
} else {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return "", err
}
targetURL = buildOpenAIResponsesURL(validatedURL)
}
default:
targetURL = openaiPlatformAPIURL
}
parsed, err := url.Parse(strings.TrimSpace(targetURL))
if err != nil {
return "", fmt.Errorf("invalid target url: %w", err)
}
switch strings.ToLower(parsed.Scheme) {
case "https":
parsed.Scheme = "wss"
case "http":
parsed.Scheme = "ws"
case "wss", "ws":
// 保持不变
default:
return "", fmt.Errorf("unsupported scheme for ws: %s", parsed.Scheme)
}
return parsed.String(), nil
}
func (s *OpenAIGatewayService) buildOpenAIWSHeaders(
c *gin.Context,
account *Account,
token string,
decision OpenAIWSProtocolDecision,
isCodexCLI bool,
turnState string,
turnMetadata string,
promptCacheKey string,
) (http.Header, openAIWSSessionHeaderResolution) {
headers := make(http.Header)
headers.Set("authorization", "Bearer "+token)
sessionResolution := resolveOpenAIWSSessionHeaders(c, promptCacheKey)
if c != nil && c.Request != nil {
if v := strings.TrimSpace(c.Request.Header.Get("accept-language")); v != "" {
headers.Set("accept-language", v)
}
}
if sessionResolution.SessionID != "" {
headers.Set("session_id", sessionResolution.SessionID)
}
if sessionResolution.ConversationID != "" {
headers.Set("conversation_id", sessionResolution.ConversationID)
}
if state := strings.TrimSpace(turnState); state != "" {
headers.Set(openAIWSTurnStateHeader, state)
}
if metadata := strings.TrimSpace(turnMetadata); metadata != "" {
headers.Set(openAIWSTurnMetadataHeader, metadata)
}
if account != nil && account.Type == AccountTypeOAuth {
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
headers.Set("chatgpt-account-id", chatgptAccountID)
}
headers.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI))
}
betaValue := openAIWSBetaV2Value
if decision.Transport == OpenAIUpstreamTransportResponsesWebsocket {
betaValue = openAIWSBetaV1Value
}
headers.Set("OpenAI-Beta", betaValue)
customUA := ""
if account != nil {
customUA = account.GetOpenAIUserAgent()
}
if strings.TrimSpace(customUA) != "" {
headers.Set("user-agent", customUA)
} else if c != nil {
if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" {
headers.Set("user-agent", ua)
}
}
if s != nil && s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
headers.Set("user-agent", codexCLIUserAgent)
}
if account != nil && account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(headers.Get("user-agent")) {
headers.Set("user-agent", codexCLIUserAgent)
}
return headers, sessionResolution
}
func (s *OpenAIGatewayService) buildOpenAIWSCreatePayload(reqBody map[string]any, account *Account) map[string]any {
// OpenAI WS Mode 协议response.create 字段与 HTTP /responses 基本一致。
// 保留 stream 字段(与 Codex CLI 一致),仅移除 background。
payload := make(map[string]any, len(reqBody)+1)
for k, v := range reqBody {
payload[k] = v
}
delete(payload, "background")
if _, exists := payload["stream"]; !exists {
payload["stream"] = true
}
payload["type"] = "response.create"
// OAuth 默认保持 store=false避免误依赖服务端历史。
if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
payload["store"] = false
}
return payload
}
func setOpenAIWSTurnMetadata(payload map[string]any, turnMetadata string) {
if len(payload) == 0 {
return
}
metadata := strings.TrimSpace(turnMetadata)
if metadata == "" {
return
}
switch existing := payload["client_metadata"].(type) {
case map[string]any:
existing[openAIWSTurnMetadataHeader] = metadata
payload["client_metadata"] = existing
case map[string]string:
next := make(map[string]any, len(existing)+1)
for k, v := range existing {
next[k] = v
}
next[openAIWSTurnMetadataHeader] = metadata
payload["client_metadata"] = next
default:
payload["client_metadata"] = map[string]any{
openAIWSTurnMetadataHeader: metadata,
}
}
}
func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool {
if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() {
return true
}
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery {
return true
}
return false
}
func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool {
if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
return true
}
if len(reqBody) == 0 {
return false
}
rawStore, ok := reqBody["store"]
if !ok {
return false
}
storeEnabled, ok := rawStore.(bool)
if !ok {
return false
}
return !storeEnabled
}
func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool {
if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
return true
}
if len(reqBody) == 0 {
return false
}
storeValue := gjson.GetBytes(reqBody, "store")
if !storeValue.Exists() {
return false
}
if storeValue.Type != gjson.True && storeValue.Type != gjson.False {
return false
}
return !storeValue.Bool()
}
func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string {
if s == nil || s.cfg == nil {
return openAIWSStoreDisabledConnModeStrict
}
mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode))
switch mode {
case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff:
return mode
case "":
// 兼容旧配置:仅配置了布尔开关时按旧语义推导。
if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn {
return openAIWSStoreDisabledConnModeStrict
}
return openAIWSStoreDisabledConnModeOff
default:
return openAIWSStoreDisabledConnModeStrict
}
}
func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool {
switch mode {
case openAIWSStoreDisabledConnModeOff:
return false
case openAIWSStoreDisabledConnModeAdaptive:
reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_")
switch reason {
case "policy_violation", "message_too_big", "auth_failed", "write_request", "write":
return true
default:
return false
}
default:
return true
}
}
func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) {
return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes)
}
func dropPreviousResponseIDFromRawPayloadWithDeleteFn(
payload []byte,
deleteFn func([]byte, string) ([]byte, error),
) ([]byte, bool, error) {
if len(payload) == 0 {
return payload, false, nil
}
if !gjson.GetBytes(payload, "previous_response_id").Exists() {
return payload, false, nil
}
if deleteFn == nil {
deleteFn = sjson.DeleteBytes
}
updated := payload
for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses &&
gjson.GetBytes(updated, "previous_response_id").Exists(); i++ {
next, err := deleteFn(updated, "previous_response_id")
if err != nil {
return payload, false, err
}
updated = next
}
return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil
}
func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) {
normalizedPrevID := strings.TrimSpace(previousResponseID)
if len(payload) == 0 || normalizedPrevID == "" {
return payload, nil
}
updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID)
if err == nil {
return updated, nil
}
var reqBody map[string]any
if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil {
return nil, err
}
reqBody["previous_response_id"] = normalizedPrevID
rebuilt, marshalErr := json.Marshal(reqBody)
if marshalErr != nil {
return nil, marshalErr
}
return rebuilt, nil
}
func shouldInferIngressFunctionCallOutputPreviousResponseID(
storeDisabled bool,
turn int,
hasFunctionCallOutput bool,
currentPreviousResponseID string,
expectedPreviousResponseID string,
) bool {
if !storeDisabled || turn <= 1 || !hasFunctionCallOutput {
return false
}
if strings.TrimSpace(currentPreviousResponseID) != "" {
return false
}
return strings.TrimSpace(expectedPreviousResponseID) != ""
}
func alignStoreDisabledPreviousResponseID(
payload []byte,
expectedPreviousResponseID string,
) ([]byte, bool, error) {
if len(payload) == 0 {
return payload, false, nil
}
expected := strings.TrimSpace(expectedPreviousResponseID)
if expected == "" {
return payload, false, nil
}
current := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
if current == "" || current == expected {
return payload, false, nil
}
withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
if dropErr != nil {
return payload, false, dropErr
}
if !removed {
return payload, false, nil
}
updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected)
if setErr != nil {
return payload, false, setErr
}
return updated, true, nil
}
func cloneOpenAIWSPayloadBytes(payload []byte) []byte {
if len(payload) == 0 {
return nil
}
cloned := make([]byte, len(payload))
copy(cloned, payload)
return cloned
}
func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage {
if items == nil {
return nil
}
cloned := make([]json.RawMessage, 0, len(items))
for idx := range items {
cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx])))
}
return cloned
}
func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) {
trimmed := bytes.TrimSpace(raw)
if len(trimmed) == 0 {
return nil, errors.New("json is empty")
}
var decoded any
if err := json.Unmarshal(trimmed, &decoded); err != nil {
return nil, err
}
return json.Marshal(decoded)
}
func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte {
normalized, err := normalizeOpenAIWSJSONForCompare(raw)
if err != nil {
return bytes.TrimSpace(raw)
}
return normalized
}
func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) {
if len(payload) == 0 {
return nil, errors.New("payload is empty")
}
var decoded map[string]any
if err := json.Unmarshal(payload, &decoded); err != nil {
return nil, err
}
delete(decoded, "input")
delete(decoded, "previous_response_id")
return json.Marshal(decoded)
}
func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) {
if len(payload) == 0 {
return nil, false, nil
}
inputValue := gjson.GetBytes(payload, "input")
if !inputValue.Exists() {
return nil, false, nil
}
if inputValue.Type == gjson.JSON {
raw := strings.TrimSpace(inputValue.Raw)
if strings.HasPrefix(raw, "[") {
var items []json.RawMessage
if err := json.Unmarshal([]byte(raw), &items); err != nil {
return nil, true, err
}
return items, true, nil
}
return []json.RawMessage{json.RawMessage(raw)}, true, nil
}
if inputValue.Type == gjson.String {
encoded, _ := json.Marshal(inputValue.String())
return []json.RawMessage{encoded}, true, nil
}
return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil
}
func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) {
previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload)
if prevErr != nil {
return false, prevErr
}
currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
if currentErr != nil {
return false, currentErr
}
if !previousExists && !currentExists {
return true, nil
}
if !previousExists {
return len(currentItems) == 0, nil
}
if !currentExists {
return len(previousItems) == 0, nil
}
if len(currentItems) < len(previousItems) {
return false, nil
}
for idx := range previousItems {
previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx])
currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx])
if !bytes.Equal(previousNormalized, currentNormalized) {
return false, nil
}
}
return true, nil
}
func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool {
if len(prefix) == 0 {
return true
}
if len(items) < len(prefix) {
return false
}
for idx := range prefix {
previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx])
currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx])
if !bytes.Equal(previousNormalized, currentNormalized) {
return false
}
}
return true
}
func buildOpenAIWSReplayInputSequence(
previousFullInput []json.RawMessage,
previousFullInputExists bool,
currentPayload []byte,
hasPreviousResponseID bool,
) ([]json.RawMessage, bool, error) {
currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
if currentErr != nil {
return nil, false, currentErr
}
if !hasPreviousResponseID {
return cloneOpenAIWSRawMessages(currentItems), currentExists, nil
}
if !previousFullInputExists {
return cloneOpenAIWSRawMessages(currentItems), currentExists, nil
}
if !currentExists || len(currentItems) == 0 {
return cloneOpenAIWSRawMessages(previousFullInput), true, nil
}
if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) {
return cloneOpenAIWSRawMessages(currentItems), true, nil
}
merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems))
merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...)
merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...)
return merged, true, nil
}
func setOpenAIWSPayloadInputSequence(
payload []byte,
fullInput []json.RawMessage,
fullInputExists bool,
) ([]byte, error) {
if !fullInputExists {
return payload, nil
}
// Preserve [] vs null semantics when input exists but is empty.
inputForMarshal := fullInput
if inputForMarshal == nil {
inputForMarshal = []json.RawMessage{}
}
inputRaw, marshalErr := json.Marshal(inputForMarshal)
if marshalErr != nil {
return nil, marshalErr
}
return sjson.SetRawBytes(payload, "input", inputRaw)
}
func shouldKeepIngressPreviousResponseID(
previousPayload []byte,
currentPayload []byte,
lastTurnResponseID string,
hasFunctionCallOutput bool,
) (bool, string, error) {
if hasFunctionCallOutput {
return true, "has_function_call_output", nil
}
currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
if currentPreviousResponseID == "" {
return false, "missing_previous_response_id", nil
}
expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID)
if expectedPreviousResponseID == "" {
return false, "missing_last_turn_response_id", nil
}
if currentPreviousResponseID != expectedPreviousResponseID {
return false, "previous_response_id_mismatch", nil
}
if len(previousPayload) == 0 {
return false, "missing_previous_turn_payload", nil
}
previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload)
if previousComparableErr != nil {
return false, "non_input_compare_error", previousComparableErr
}
currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload)
if currentComparableErr != nil {
return false, "non_input_compare_error", currentComparableErr
}
if !bytes.Equal(previousComparable, currentComparable) {
return false, "non_input_changed", nil
}
return true, "strict_incremental_ok", nil
}
type openAIWSIngressPreviousTurnStrictState struct {
nonInputComparable []byte
}
func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) {
if len(payload) == 0 {
return nil, nil
}
nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload)
if nonInputErr != nil {
return nil, nonInputErr
}
return &openAIWSIngressPreviousTurnStrictState{
nonInputComparable: nonInputComparable,
}, nil
}
func shouldKeepIngressPreviousResponseIDWithStrictState(
previousState *openAIWSIngressPreviousTurnStrictState,
currentPayload []byte,
lastTurnResponseID string,
hasFunctionCallOutput bool,
) (bool, string, error) {
if hasFunctionCallOutput {
return true, "has_function_call_output", nil
}
currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
if currentPreviousResponseID == "" {
return false, "missing_previous_response_id", nil
}
expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID)
if expectedPreviousResponseID == "" {
return false, "missing_last_turn_response_id", nil
}
if currentPreviousResponseID != expectedPreviousResponseID {
return false, "previous_response_id_mismatch", nil
}
if previousState == nil {
return false, "missing_previous_turn_payload", nil
}
currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload)
if currentComparableErr != nil {
return false, "non_input_compare_error", currentComparableErr
}
if !bytes.Equal(previousState.nonInputComparable, currentComparable) {
return false, "non_input_changed", nil
}
return true, "strict_incremental_ok", nil
}
func (s *OpenAIGatewayService) forwardOpenAIWSV2(
ctx context.Context,
c *gin.Context,
account *Account,
reqBody map[string]any,
token string,
decision OpenAIWSProtocolDecision,
isCodexCLI bool,
reqStream bool,
originalModel string,
mappedModel string,
startTime time.Time,
attempt int,
lastFailureReason string,
) (*OpenAIForwardResult, error) {
if s == nil || account == nil {
return nil, wrapOpenAIWSFallback("invalid_state", errors.New("service or account is nil"))
}
wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil {
return nil, wrapOpenAIWSFallback("build_ws_url", err)
}
wsHost := "-"
wsPath := "-"
if parsed, parseErr := url.Parse(wsURL); parseErr == nil && parsed != nil {
if h := strings.TrimSpace(parsed.Host); h != "" {
wsHost = normalizeOpenAIWSLogValue(h)
}
if p := strings.TrimSpace(parsed.Path); p != "" {
wsPath = normalizeOpenAIWSLogValue(p)
}
}
logOpenAIWSModeDebug(
"dial_target account_id=%d account_type=%s ws_host=%s ws_path=%s",
account.ID,
account.Type,
wsHost,
wsPath,
)
payload := s.buildOpenAIWSCreatePayload(reqBody, account)
payloadStrategy, removedKeys := applyOpenAIWSRetryPayloadStrategy(payload, attempt)
previousResponseID := openAIWSPayloadString(payload, "previous_response_id")
previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
promptCacheKey := openAIWSPayloadString(payload, "prompt_cache_key")
_, hasTools := payload["tools"]
debugEnabled := isOpenAIWSModeDebugEnabled()
payloadBytes := -1
resolvePayloadBytes := func() int {
if payloadBytes >= 0 {
return payloadBytes
}
payloadBytes = len(payloadAsJSONBytes(payload))
return payloadBytes
}
streamValue := "-"
if raw, ok := payload["stream"]; ok {
streamValue = normalizeOpenAIWSLogValue(strings.TrimSpace(fmt.Sprintf("%v", raw)))
}
turnState := ""
turnMetadata := ""
if c != nil && c.Request != nil {
turnState = strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader))
turnMetadata = strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader))
}
setOpenAIWSTurnMetadata(payload, turnMetadata)
payloadEventType := openAIWSPayloadString(payload, "type")
if payloadEventType == "" {
payloadEventType = "response.create"
}
if s.shouldEmitOpenAIWSPayloadSchema(attempt) {
logOpenAIWSModeInfo(
"[debug] payload_schema account_id=%d attempt=%d event=%s payload_keys=%s payload_bytes=%d payload_key_sizes=%s input_summary=%s stream=%s payload_strategy=%s removed_keys=%s has_previous_response_id=%v has_prompt_cache_key=%v has_tools=%v",
account.ID,
attempt,
payloadEventType,
normalizeOpenAIWSLogValue(strings.Join(sortedKeys(payload), ",")),
resolvePayloadBytes(),
normalizeOpenAIWSLogValue(summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN)),
normalizeOpenAIWSLogValue(summarizeOpenAIWSInput(payload["input"])),
streamValue,
normalizeOpenAIWSLogValue(payloadStrategy),
normalizeOpenAIWSLogValue(strings.Join(removedKeys, ",")),
previousResponseID != "",
promptCacheKey != "",
hasTools,
)
}
stateStore := s.getOpenAIWSStateStore()
groupID := getOpenAIGroupIDFromContext(c)
sessionHash := s.GenerateSessionHash(c, nil)
if sessionHash == "" {
var legacySessionHash string
sessionHash, legacySessionHash = openAIWSSessionHashesFromID(promptCacheKey)
attachOpenAILegacySessionHashToGin(c, legacySessionHash)
}
if turnState == "" && stateStore != nil && sessionHash != "" {
if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok {
turnState = savedTurnState
}
}
preferredConnID := ""
if stateStore != nil && previousResponseID != "" {
if connID, ok := stateStore.GetResponseConn(previousResponseID); ok {
preferredConnID = connID
}
}
storeDisabled := s.isOpenAIWSStoreDisabledInRequest(reqBody, account)
if stateStore != nil && storeDisabled && previousResponseID == "" && sessionHash != "" {
if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok {
preferredConnID = connID
}
}
storeDisabledConnMode := s.openAIWSStoreDisabledConnMode()
forceNewConnByPolicy := shouldForceNewConnOnStoreDisabled(storeDisabledConnMode, lastFailureReason)
forceNewConn := forceNewConnByPolicy && storeDisabled && previousResponseID == "" && sessionHash != "" && preferredConnID == ""
wsHeaders, sessionResolution := s.buildOpenAIWSHeaders(c, account, token, decision, isCodexCLI, turnState, turnMetadata, promptCacheKey)
logOpenAIWSModeDebug(
"acquire_start account_id=%d account_type=%s transport=%s preferred_conn_id=%s has_previous_response_id=%v session_hash=%s has_turn_state=%v turn_state_len=%d has_turn_metadata=%v turn_metadata_len=%d store_disabled=%v store_disabled_conn_mode=%s retry_last_reason=%s force_new_conn=%v header_user_agent=%s header_openai_beta=%s header_originator=%s header_accept_language=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_prompt_cache_key=%v has_chatgpt_account_id=%v has_authorization=%v has_session_id=%v has_conversation_id=%v proxy_enabled=%v",
account.ID,
account.Type,
normalizeOpenAIWSLogValue(string(decision.Transport)),
truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
previousResponseID != "",
truncateOpenAIWSLogValue(sessionHash, 12),
turnState != "",
len(turnState),
turnMetadata != "",
len(turnMetadata),
storeDisabled,
normalizeOpenAIWSLogValue(storeDisabledConnMode),
truncateOpenAIWSLogValue(lastFailureReason, openAIWSLogValueMaxLen),
forceNewConn,
openAIWSHeaderValueForLog(wsHeaders, "user-agent"),
openAIWSHeaderValueForLog(wsHeaders, "openai-beta"),
openAIWSHeaderValueForLog(wsHeaders, "originator"),
openAIWSHeaderValueForLog(wsHeaders, "accept-language"),
openAIWSHeaderValueForLog(wsHeaders, "session_id"),
openAIWSHeaderValueForLog(wsHeaders, "conversation_id"),
normalizeOpenAIWSLogValue(sessionResolution.SessionSource),
normalizeOpenAIWSLogValue(sessionResolution.ConversationSource),
promptCacheKey != "",
hasOpenAIWSHeader(wsHeaders, "chatgpt-account-id"),
hasOpenAIWSHeader(wsHeaders, "authorization"),
hasOpenAIWSHeader(wsHeaders, "session_id"),
hasOpenAIWSHeader(wsHeaders, "conversation_id"),
account.ProxyID != nil && account.Proxy != nil,
)
acquireCtx, acquireCancel := context.WithTimeout(ctx, s.openAIWSAcquireTimeout())
defer acquireCancel()
lease, err := s.getOpenAIWSConnPool().Acquire(acquireCtx, openAIWSAcquireRequest{
Account: account,
WSURL: wsURL,
Headers: wsHeaders,
PreferredConnID: preferredConnID,
ForceNewConn: forceNewConn,
ProxyURL: func() string {
if account.ProxyID != nil && account.Proxy != nil {
return account.Proxy.URL()
}
return ""
}(),
})
if err != nil {
dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(err)
logOpenAIWSModeInfo(
"acquire_fail account_id=%d account_type=%s transport=%s reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_new_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v",
account.ID,
account.Type,
normalizeOpenAIWSLogValue(string(decision.Transport)),
normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(err)),
dialStatus,
dialClass,
dialCloseStatus,
truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen),
dialRespServer,
dialRespVia,
dialRespCFRay,
dialRespReqID,
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
forceNewConn,
wsHost,
wsPath,
account.ProxyID != nil && account.Proxy != nil,
)
var dialErr *openAIWSDialError
if errors.As(err, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error()))
}
return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err)
}
defer lease.Release()
connID := strings.TrimSpace(lease.ConnID())
logOpenAIWSModeDebug(
"connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v",
account.ID,
account.Type,
normalizeOpenAIWSLogValue(string(decision.Transport)),
connID,
lease.Reused(),
lease.ConnPickDuration().Milliseconds(),
lease.QueueWaitDuration().Milliseconds(),
previousResponseID != "",
)
if previousResponseID != "" {
logOpenAIWSModeInfo(
"continuation_probe account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s conn_reused=%v store_disabled=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v",
account.ID,
account.Type,
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(previousResponseIDKind),
truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
lease.Reused(),
storeDisabled,
truncateOpenAIWSLogValue(sessionHash, 12),
openAIWSHeaderValueForLog(wsHeaders, "session_id"),
openAIWSHeaderValueForLog(wsHeaders, "conversation_id"),
normalizeOpenAIWSLogValue(sessionResolution.SessionSource),
normalizeOpenAIWSLogValue(sessionResolution.ConversationSource),
turnState != "",
len(turnState),
promptCacheKey != "",
)
}
if c != nil {
SetOpsLatencyMs(c, OpsOpenAIWSConnPickMsKey, lease.ConnPickDuration().Milliseconds())
SetOpsLatencyMs(c, OpsOpenAIWSQueueWaitMsKey, lease.QueueWaitDuration().Milliseconds())
c.Set(OpsOpenAIWSConnReusedKey, lease.Reused())
if connID != "" {
c.Set(OpsOpenAIWSConnIDKey, connID)
}
}
handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader))
logOpenAIWSModeDebug(
"handshake account_id=%d conn_id=%s has_turn_state=%v turn_state_len=%d",
account.ID,
connID,
handshakeTurnState != "",
len(handshakeTurnState),
)
if handshakeTurnState != "" {
if stateStore != nil && sessionHash != "" {
stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL())
}
if c != nil {
c.Header(http.CanonicalHeaderKey(openAIWSTurnStateHeader), handshakeTurnState)
}
}
if err := s.performOpenAIWSGeneratePrewarm(
ctx,
lease,
decision,
payload,
previousResponseID,
reqBody,
account,
stateStore,
groupID,
); err != nil {
return nil, err
}
if err := lease.WriteJSONWithContextTimeout(ctx, payload, s.openAIWSWriteTimeout()); err != nil {
lease.MarkBroken()
logOpenAIWSModeInfo(
"write_request_fail account_id=%d conn_id=%s cause=%s payload_bytes=%d",
account.ID,
connID,
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
resolvePayloadBytes(),
)
return nil, wrapOpenAIWSFallback("write_request", err)
}
if debugEnabled {
logOpenAIWSModeDebug(
"write_request_sent account_id=%d conn_id=%s stream=%v payload_bytes=%d previous_response_id=%s",
account.ID,
connID,
reqStream,
resolvePayloadBytes(),
truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
)
}
usage := &OpenAIUsage{}
var firstTokenMs *int
responseID := ""
var finalResponse []byte
wroteDownstream := false
needModelReplace := originalModel != mappedModel
var mappedModelBytes []byte
if needModelReplace && mappedModel != "" {
mappedModelBytes = []byte(mappedModel)
}
bufferedStreamEvents := make([][]byte, 0, 4)
eventCount := 0
tokenEventCount := 0
terminalEventCount := 0
bufferedEventCount := 0
flushedBufferedEventCount := 0
firstEventType := ""
lastEventType := ""
var flusher http.Flusher
if reqStream {
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), http.Header{}, s.responseHeaderFilter)
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
f, ok := c.Writer.(http.Flusher)
if !ok {
lease.MarkBroken()
return nil, wrapOpenAIWSFallback("streaming_not_supported", errors.New("streaming not supported"))
}
flusher = f
}
clientDisconnected := false
flushBatchSize := s.openAIWSEventFlushBatchSize()
flushInterval := s.openAIWSEventFlushInterval()
pendingFlushEvents := 0
lastFlushAt := time.Now()
flushStreamWriter := func(force bool) {
if clientDisconnected || flusher == nil || pendingFlushEvents <= 0 {
return
}
if !force && flushBatchSize > 1 && pendingFlushEvents < flushBatchSize {
if flushInterval <= 0 || time.Since(lastFlushAt) < flushInterval {
return
}
}
flusher.Flush()
pendingFlushEvents = 0
lastFlushAt = time.Now()
}
emitStreamMessage := func(message []byte, forceFlush bool) {
if clientDisconnected {
return
}
frame := make([]byte, 0, len(message)+8)
frame = append(frame, "data: "...)
frame = append(frame, message...)
frame = append(frame, '\n', '\n')
_, wErr := c.Writer.Write(frame)
if wErr == nil {
wroteDownstream = true
pendingFlushEvents++
flushStreamWriter(forceFlush)
return
}
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode] client disconnected, continue draining upstream: account=%d", account.ID)
}
flushBufferedStreamEvents := func(reason string) {
if len(bufferedStreamEvents) == 0 {
return
}
flushed := len(bufferedStreamEvents)
for _, buffered := range bufferedStreamEvents {
emitStreamMessage(buffered, false)
}
bufferedStreamEvents = bufferedStreamEvents[:0]
flushStreamWriter(true)
flushedBufferedEventCount += flushed
if debugEnabled {
logOpenAIWSModeDebug(
"buffer_flush account_id=%d conn_id=%s reason=%s flushed=%d total_flushed=%d client_disconnected=%v",
account.ID,
connID,
truncateOpenAIWSLogValue(reason, openAIWSLogValueMaxLen),
flushed,
flushedBufferedEventCount,
clientDisconnected,
)
}
}
readTimeout := s.openAIWSReadTimeout()
for {
message, readErr := lease.ReadMessageWithContextTimeout(ctx, readTimeout)
if readErr != nil {
lease.MarkBroken()
closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
logOpenAIWSModeInfo(
"read_fail account_id=%d conn_id=%s wrote_downstream=%v close_status=%s close_reason=%s cause=%s events=%d token_events=%d terminal_events=%d buffered_pending=%d buffered_flushed=%d first_event=%s last_event=%s",
account.ID,
connID,
wroteDownstream,
closeStatus,
closeReason,
truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen),
eventCount,
tokenEventCount,
terminalEventCount,
len(bufferedStreamEvents),
flushedBufferedEventCount,
truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen),
)
if !wroteDownstream {
return nil, wrapOpenAIWSFallback(classifyOpenAIWSReadFallbackReason(readErr), readErr)
}
if clientDisconnected {
break
}
setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(readErr.Error()), "")
return nil, fmt.Errorf("openai ws read event: %w", readErr)
}
eventType, eventResponseID, responseField := parseOpenAIWSEventEnvelope(message)
if eventType == "" {
continue
}
eventCount++
if firstEventType == "" {
firstEventType = eventType
}
lastEventType = eventType
if responseID == "" && eventResponseID != "" {
responseID = eventResponseID
}
isTokenEvent := isOpenAIWSTokenEvent(eventType)
if isTokenEvent {
tokenEventCount++
}
isTerminalEvent := isOpenAIWSTerminalEvent(eventType)
if isTerminalEvent {
terminalEventCount++
}
if firstTokenMs == nil && isTokenEvent {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if debugEnabled && shouldLogOpenAIWSEvent(eventCount, eventType) {
logOpenAIWSModeDebug(
"event_received account_id=%d conn_id=%s idx=%d type=%s bytes=%d token=%v terminal=%v buffered_pending=%d",
account.ID,
connID,
eventCount,
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
len(message),
isTokenEvent,
isTerminalEvent,
len(bufferedStreamEvents),
)
}
if !clientDisconnected {
if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(message, mappedModelBytes) {
message = replaceOpenAIWSMessageModel(message, mappedModel, originalModel)
}
if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(message) {
if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(message); changed {
message = corrected
}
}
}
if openAIWSEventShouldParseUsage(eventType) {
parseOpenAIWSResponseUsageFromCompletedEvent(message, usage)
}
if eventType == "error" {
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw)
errMsg := strings.TrimSpace(errMsgRaw)
if errMsg == "" {
errMsg = "Upstream websocket error"
}
fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
logOpenAIWSModeInfo(
"error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s",
account.ID,
connID,
eventCount,
truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen),
canFallback,
errCode,
errType,
errMessage,
)
if fallbackReason == "previous_response_not_found" {
logOpenAIWSModeInfo(
"previous_response_not_found_diag account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s event_idx=%d req_stream=%v store_disabled=%v conn_reused=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v err_code=%s err_type=%s err_message=%s",
account.ID,
account.Type,
connID,
truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(previousResponseIDKind),
truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
eventCount,
reqStream,
storeDisabled,
lease.Reused(),
truncateOpenAIWSLogValue(sessionHash, 12),
openAIWSHeaderValueForLog(wsHeaders, "session_id"),
openAIWSHeaderValueForLog(wsHeaders, "conversation_id"),
normalizeOpenAIWSLogValue(sessionResolution.SessionSource),
normalizeOpenAIWSLogValue(sessionResolution.ConversationSource),
turnState != "",
len(turnState),
promptCacheKey != "",
errCode,
errType,
errMessage,
)
}
// error 事件后连接不再可复用,避免回池后污染下一请求。
lease.MarkBroken()
if !wroteDownstream && canFallback {
return nil, wrapOpenAIWSFallback(fallbackReason, errors.New(errMsg))
}
statusCode := openAIWSErrorHTTPStatusFromRaw(errCodeRaw, errTypeRaw)
setOpsUpstreamError(c, statusCode, errMsg, "")
if reqStream && !clientDisconnected {
flushBufferedStreamEvents("error_event")
emitStreamMessage(message, true)
}
if !reqStream {
c.JSON(statusCode, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": errMsg,
},
})
}
return nil, fmt.Errorf("openai ws error event: %s", errMsg)
}
if reqStream {
// 在首个 token 前先缓冲事件(如 response.created
// 以便上游早期断连时仍可安全回退到 HTTP不给下游发送半截流。
shouldBuffer := firstTokenMs == nil && !isTokenEvent && !isTerminalEvent
if shouldBuffer {
buffered := make([]byte, len(message))
copy(buffered, message)
bufferedStreamEvents = append(bufferedStreamEvents, buffered)
bufferedEventCount++
if debugEnabled && shouldLogOpenAIWSBufferedEvent(bufferedEventCount) {
logOpenAIWSModeDebug(
"buffer_enqueue account_id=%d conn_id=%s idx=%d event_idx=%d event_type=%s buffer_size=%d",
account.ID,
connID,
bufferedEventCount,
eventCount,
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
len(bufferedStreamEvents),
)
}
} else {
flushBufferedStreamEvents(eventType)
emitStreamMessage(message, isTerminalEvent)
}
} else {
if responseField.Exists() && responseField.Type == gjson.JSON {
finalResponse = []byte(responseField.Raw)
}
}
if isTerminalEvent {
break
}
}
if !reqStream {
if len(finalResponse) == 0 {
logOpenAIWSModeInfo(
"missing_final_response account_id=%d conn_id=%s events=%d token_events=%d terminal_events=%d wrote_downstream=%v",
account.ID,
connID,
eventCount,
tokenEventCount,
terminalEventCount,
wroteDownstream,
)
if !wroteDownstream {
return nil, wrapOpenAIWSFallback("missing_final_response", errors.New("no terminal response payload"))
}
return nil, errors.New("ws finished without final response")
}
if needModelReplace {
finalResponse = s.replaceModelInResponseBody(finalResponse, mappedModel, originalModel)
}
finalResponse = s.correctToolCallsInResponseBody(finalResponse)
populateOpenAIUsageFromResponseJSON(finalResponse, usage)
if responseID == "" {
responseID = strings.TrimSpace(gjson.GetBytes(finalResponse, "id").String())
}
c.Data(http.StatusOK, "application/json", finalResponse)
} else {
flushStreamWriter(true)
}
if responseID != "" && stateStore != nil {
ttl := s.openAIWSResponseStickyTTL()
logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl))
stateStore.BindResponseConn(responseID, lease.ConnID(), ttl)
}
if stateStore != nil && storeDisabled && sessionHash != "" {
stateStore.BindSessionConn(groupID, sessionHash, lease.ConnID(), s.openAIWSSessionStickyTTL())
}
firstTokenMsValue := -1
if firstTokenMs != nil {
firstTokenMsValue = *firstTokenMs
}
logOpenAIWSModeDebug(
"completed account_id=%d conn_id=%s response_id=%s stream=%v duration_ms=%d events=%d token_events=%d terminal_events=%d buffered_events=%d buffered_flushed=%d first_event=%s last_event=%s first_token_ms=%d wrote_downstream=%v client_disconnected=%v",
account.ID,
connID,
truncateOpenAIWSLogValue(strings.TrimSpace(responseID), openAIWSIDValueMaxLen),
reqStream,
time.Since(startTime).Milliseconds(),
eventCount,
tokenEventCount,
terminalEventCount,
bufferedEventCount,
flushedBufferedEventCount,
truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen),
firstTokenMsValue,
wroteDownstream,
clientDisconnected,
)
return &OpenAIForwardResult{
RequestID: responseID,
Usage: *usage,
Model: originalModel,
ServiceTier: extractOpenAIServiceTier(reqBody),
ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
Stream: reqStream,
OpenAIWSMode: true,
ResponseHeaders: lease.HandshakeHeaders(),
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
// ProxyResponsesWebSocketFromClient 处理客户端入站 WebSocketOpenAI Responses WS Mode并转发到上游。
// 当前实现按“单请求 -> 终止事件 -> 下一请求”的顺序代理,适配 Codex CLI 的 turn 模式。
func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
ctx context.Context,
c *gin.Context,
clientConn *coderws.Conn,
account *Account,
token string,
firstClientMessage []byte,
hooks *OpenAIWSIngressHooks,
) error {
if s == nil {
return errors.New("service is nil")
}
if c == nil {
return errors.New("gin context is nil")
}
if clientConn == nil {
return errors.New("client websocket is nil")
}
if account == nil {
return errors.New("account is nil")
}
if strings.TrimSpace(token) == "" {
return errors.New("token is empty")
}
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
ingressMode := OpenAIWSIngressModeCtxPool
if modeRouterV2Enabled {
ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
if ingressMode == OpenAIWSIngressModeOff {
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"websocket mode is disabled for this account",
nil,
)
}
switch ingressMode {
case OpenAIWSIngressModePassthrough:
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
}
return s.proxyResponsesWebSocketV2Passthrough(
ctx,
c,
clientConn,
account,
token,
firstClientMessage,
hooks,
wsDecision,
)
case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
// continue
default:
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"websocket mode only supports ctx_pool/passthrough",
nil,
)
}
}
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
}
dedicatedMode := modeRouterV2Enabled && ingressMode == OpenAIWSIngressModeDedicated
wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil {
return fmt.Errorf("build ws url: %w", err)
}
wsHost := "-"
wsPath := "-"
if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil {
wsHost = normalizeOpenAIWSLogValue(parsedURL.Host)
wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
}
debugEnabled := isOpenAIWSModeDebugEnabled()
type openAIWSClientPayload struct {
payloadRaw []byte
rawForHash []byte
promptCacheKey string
previousResponseID string
originalModel string
payloadBytes int
}
applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) {
next, err := sjson.SetBytes(current, path, value)
if err == nil {
return next, nil
}
// 仅在确实需要修改 payload 且 sjson 失败时,退回 map 路径确保兼容性。
payload := make(map[string]any)
if unmarshalErr := json.Unmarshal(current, &payload); unmarshalErr != nil {
return nil, err
}
switch path {
case "type", "model":
payload[path] = value
case "client_metadata." + openAIWSTurnMetadataHeader:
setOpenAIWSTurnMetadata(payload, fmt.Sprintf("%v", value))
default:
return nil, err
}
rebuilt, marshalErr := json.Marshal(payload)
if marshalErr != nil {
return nil, marshalErr
}
return rebuilt, nil
}
parseClientPayload := func(raw []byte) (openAIWSClientPayload, error) {
trimmed := bytes.TrimSpace(raw)
if len(trimmed) == 0 {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "empty websocket request payload", nil)
}
if !gjson.ValidBytes(trimmed) {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
}
values := gjson.GetManyBytes(trimmed, "type", "model", "prompt_cache_key", "previous_response_id")
eventType := strings.TrimSpace(values[0].String())
normalized := trimmed
switch eventType {
case "":
eventType = "response.create"
next, setErr := applyPayloadMutation(normalized, "type", eventType)
if setErr != nil {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
}
normalized = next
case "response.create":
case "response.append":
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"response.append is not supported in ws v2; use response.create with previous_response_id",
nil,
)
default:
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
fmt.Sprintf("unsupported websocket request type: %s", eventType),
nil,
)
}
originalModel := strings.TrimSpace(values[1].String())
if originalModel == "" {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"model is required in response.create payload",
nil,
)
}
promptCacheKey := strings.TrimSpace(values[2].String())
previousResponseID := strings.TrimSpace(values[3].String())
previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
if previousResponseID != "" && previousResponseIDKind == OpenAIPreviousResponseIDKindMessageID {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"previous_response_id must be a response.id (resp_*), not a message id",
nil,
)
}
if turnMetadata := strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)); turnMetadata != "" {
next, setErr := applyPayloadMutation(normalized, "client_metadata."+openAIWSTurnMetadataHeader, turnMetadata)
if setErr != nil {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
}
normalized = next
}
mappedModel := account.GetMappedModel(originalModel)
if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" {
mappedModel = normalizedModel
}
if mappedModel != originalModel {
next, setErr := applyPayloadMutation(normalized, "model", mappedModel)
if setErr != nil {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
}
normalized = next
}
return openAIWSClientPayload{
payloadRaw: normalized,
rawForHash: trimmed,
promptCacheKey: promptCacheKey,
previousResponseID: previousResponseID,
originalModel: originalModel,
payloadBytes: len(normalized),
}, nil
}
firstPayload, err := parseClientPayload(firstClientMessage)
if err != nil {
return err
}
turnState := strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader))
stateStore := s.getOpenAIWSStateStore()
groupID := getOpenAIGroupIDFromContext(c)
sessionHash := s.GenerateSessionHash(c, firstPayload.rawForHash)
if turnState == "" && stateStore != nil && sessionHash != "" {
if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok {
turnState = savedTurnState
}
}
preferredConnID := ""
if stateStore != nil && firstPayload.previousResponseID != "" {
if connID, ok := stateStore.GetResponseConn(firstPayload.previousResponseID); ok {
preferredConnID = connID
}
}
storeDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(firstPayload.payloadRaw, account)
storeDisabledConnMode := s.openAIWSStoreDisabledConnMode()
if stateStore != nil && storeDisabled && firstPayload.previousResponseID == "" && sessionHash != "" {
if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok {
preferredConnID = connID
}
}
isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey)
baseAcquireReq := openAIWSAcquireRequest{
Account: account,
WSURL: wsURL,
Headers: wsHeaders,
ProxyURL: func() string {
if account.ProxyID != nil && account.Proxy != nil {
return account.Proxy.URL()
}
return ""
}(),
ForceNewConn: false,
}
pool := s.getOpenAIWSConnPool()
if pool == nil {
return errors.New("openai ws conn pool is nil")
}
logOpenAIWSModeInfo(
"ingress_ws_protocol_confirm account_id=%d account_type=%s transport=%s ws_host=%s ws_path=%s ws_mode=%s store_disabled=%v has_session_hash=%v has_previous_response_id=%v",
account.ID,
account.Type,
normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
wsHost,
wsPath,
normalizeOpenAIWSLogValue(ingressMode),
storeDisabled,
sessionHash != "",
firstPayload.previousResponseID != "",
)
if debugEnabled {
logOpenAIWSModeDebug(
"ingress_ws_start account_id=%d account_type=%s transport=%s ws_host=%s preferred_conn_id=%s has_session_hash=%v has_previous_response_id=%v store_disabled=%v",
account.ID,
account.Type,
normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
wsHost,
truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
sessionHash != "",
firstPayload.previousResponseID != "",
storeDisabled,
)
}
if firstPayload.previousResponseID != "" {
firstPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(firstPayload.previousResponseID)
logOpenAIWSModeInfo(
"ingress_ws_continuation_probe account_id=%d turn=%d previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s session_hash=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v",
account.ID,
1,
truncateOpenAIWSLogValue(firstPayload.previousResponseID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(firstPreviousResponseIDKind),
truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(sessionHash, 12),
openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"),
openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"),
turnState != "",
len(turnState),
firstPayload.promptCacheKey != "",
storeDisabled,
)
}
acquireTimeout := s.openAIWSAcquireTimeout()
if acquireTimeout <= 0 {
acquireTimeout = 30 * time.Second
}
acquireTurnLease := func(turn int, preferred string, forcePreferredConn bool) (*openAIWSConnLease, error) {
req := cloneOpenAIWSAcquireRequest(baseAcquireReq)
req.PreferredConnID = strings.TrimSpace(preferred)
req.ForcePreferredConn = forcePreferredConn
// dedicated 模式下每次获取均新建连接,避免跨会话复用残留上下文。
req.ForceNewConn = dedicatedMode
acquireCtx, acquireCancel := context.WithTimeout(ctx, acquireTimeout)
lease, acquireErr := pool.Acquire(acquireCtx, req)
acquireCancel()
if acquireErr != nil {
dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(acquireErr)
logOpenAIWSModeInfo(
"ingress_ws_upstream_acquire_fail account_id=%d turn=%d reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_preferred_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v",
account.ID,
turn,
normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(acquireErr)),
dialStatus,
dialClass,
dialCloseStatus,
truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen),
dialRespServer,
dialRespVia,
dialRespCFRay,
dialRespReqID,
truncateOpenAIWSLogValue(acquireErr.Error(), openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen),
forcePreferredConn,
wsHost,
wsPath,
account.ProxyID != nil && account.Proxy != nil,
)
var dialErr *openAIWSDialError
if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error()))
}
if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
return nil, NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"upstream continuation connection is unavailable; please restart the conversation",
acquireErr,
)
}
if errors.Is(acquireErr, context.DeadlineExceeded) || errors.Is(acquireErr, errOpenAIWSConnQueueFull) {
return nil, NewOpenAIWSClientCloseError(
coderws.StatusTryAgainLater,
"upstream websocket is busy, please retry later",
acquireErr,
)
}
return nil, acquireErr
}
connID := strings.TrimSpace(lease.ConnID())
if handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader)); handshakeTurnState != "" {
turnState = handshakeTurnState
if stateStore != nil && sessionHash != "" {
stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL())
}
updatedHeaders := cloneHeader(baseAcquireReq.Headers)
if updatedHeaders == nil {
updatedHeaders = make(http.Header)
}
updatedHeaders.Set(openAIWSTurnStateHeader, handshakeTurnState)
baseAcquireReq.Headers = updatedHeaders
}
logOpenAIWSModeInfo(
"ingress_ws_upstream_connected account_id=%d turn=%d conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d preferred_conn_id=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
lease.Reused(),
lease.ConnPickDuration().Milliseconds(),
lease.QueueWaitDuration().Milliseconds(),
truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen),
)
return lease, nil
}
writeClientMessage := func(message []byte) error {
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
defer cancel()
return clientConn.Write(writeCtx, coderws.MessageText, message)
}
readClientMessage := func() ([]byte, error) {
msgType, payload, readErr := clientConn.Read(ctx)
if readErr != nil {
return nil, readErr
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
return nil, NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
fmt.Sprintf("unsupported websocket client message type: %s", msgType.String()),
nil,
)
}
return payload, nil
}
sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) {
if lease == nil {
return nil, errors.New("upstream websocket lease is nil")
}
turnStart := time.Now()
wroteDownstream := false
if err := lease.WriteJSONWithContextTimeout(ctx, json.RawMessage(payload), s.openAIWSWriteTimeout()); err != nil {
return nil, wrapOpenAIWSIngressTurnError(
"write_upstream",
fmt.Errorf("write upstream websocket request: %w", err),
false,
)
}
if debugEnabled {
logOpenAIWSModeDebug(
"ingress_ws_turn_request_sent account_id=%d turn=%d conn_id=%s payload_bytes=%d",
account.ID,
turn,
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
payloadBytes,
)
}
responseID := ""
usage := OpenAIUsage{}
var firstTokenMs *int
reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true)
turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID)
turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key")
turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account)
turnHasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()
eventCount := 0
tokenEventCount := 0
terminalEventCount := 0
firstEventType := ""
lastEventType := ""
needModelReplace := false
clientDisconnected := false
mappedModel := ""
var mappedModelBytes []byte
if originalModel != "" {
mappedModel = account.GetMappedModel(originalModel)
if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" {
mappedModel = normalizedModel
}
needModelReplace = mappedModel != "" && mappedModel != originalModel
if needModelReplace {
mappedModelBytes = []byte(mappedModel)
}
}
for {
upstreamMessage, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout())
if readErr != nil {
lease.MarkBroken()
return nil, wrapOpenAIWSIngressTurnError(
"read_upstream",
fmt.Errorf("read upstream websocket event: %w", readErr),
wroteDownstream,
)
}
eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(upstreamMessage)
if responseID == "" && eventResponseID != "" {
responseID = eventResponseID
}
if eventType != "" {
eventCount++
if firstEventType == "" {
firstEventType = eventType
}
lastEventType = eventType
}
if eventType == "error" {
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage)
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), upstreamMessage, errCodeRaw, errTypeRaw, errMsgRaw)
fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound &&
turnPreviousResponseID != "" &&
!turnHasFunctionCallOutput &&
s.openAIWSIngressPreviousResponseRecoveryEnabled() &&
!wroteDownstream
if recoverablePrevNotFound {
// 可恢复场景使用非 error 关键字日志,避免被 LegacyPrintf 误判为 ERROR 级别。
logOpenAIWSModeInfo(
"ingress_ws_prev_response_recoverable account_id=%d turn=%d conn_id=%s idx=%d reason=%s code=%s type=%s message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
eventCount,
truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen),
errCode,
errType,
errMessage,
truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(turnPreviousResponseIDKind),
truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
turnStoreDisabled,
turnPromptCacheKey != "",
)
} else {
logOpenAIWSModeInfo(
"ingress_ws_error_event account_id=%d turn=%d conn_id=%s idx=%d fallback_reason=%s err_code=%s err_type=%s err_message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
eventCount,
truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen),
errCode,
errType,
errMessage,
truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(turnPreviousResponseIDKind),
truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
turnStoreDisabled,
turnPromptCacheKey != "",
)
}
// previous_response_not_found 在 ingress 模式支持单次恢复重试:
// 不把该 error 直接下发客户端,而是由上层去掉 previous_response_id 后重放当前 turn。
if recoverablePrevNotFound {
lease.MarkBroken()
errMsg := strings.TrimSpace(errMsgRaw)
if errMsg == "" {
errMsg = "previous response not found"
}
return nil, wrapOpenAIWSIngressTurnError(
openAIWSIngressStagePreviousResponseNotFound,
errors.New(errMsg),
false,
)
}
}
isTokenEvent := isOpenAIWSTokenEvent(eventType)
if isTokenEvent {
tokenEventCount++
}
isTerminalEvent := isOpenAIWSTerminalEvent(eventType)
if isTerminalEvent {
terminalEventCount++
}
if firstTokenMs == nil && isTokenEvent {
ms := int(time.Since(turnStart).Milliseconds())
firstTokenMs = &ms
}
if openAIWSEventShouldParseUsage(eventType) {
parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage)
}
if !clientDisconnected {
if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) {
upstreamMessage = replaceOpenAIWSMessageModel(upstreamMessage, mappedModel, originalModel)
}
if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(upstreamMessage) {
if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(upstreamMessage); changed {
upstreamMessage = corrected
}
}
if err := writeClientMessage(upstreamMessage); err != nil {
if isOpenAIWSClientDisconnectError(err) {
clientDisconnected = true
closeStatus, closeReason := summarizeOpenAIWSReadCloseError(err)
logOpenAIWSModeInfo(
"ingress_ws_client_disconnected_drain account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
closeStatus,
truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen),
)
} else {
return nil, wrapOpenAIWSIngressTurnError(
"write_client",
fmt.Errorf("write client websocket event: %w", err),
wroteDownstream,
)
}
} else {
wroteDownstream = true
}
}
if isTerminalEvent {
// 客户端已断连时,上游连接的 session 状态不可信,标记 broken 避免回池复用。
if clientDisconnected {
lease.MarkBroken()
}
firstTokenMsValue := -1
if firstTokenMs != nil {
firstTokenMsValue = *firstTokenMs
}
if debugEnabled {
logOpenAIWSModeDebug(
"ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
time.Since(turnStart).Milliseconds(),
eventCount,
tokenEventCount,
terminalEventCount,
truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen),
firstTokenMsValue,
clientDisconnected,
)
}
return &OpenAIForwardResult{
RequestID: responseID,
Usage: usage,
Model: originalModel,
ServiceTier: extractOpenAIServiceTierFromBody(payload),
ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
Stream: reqStream,
OpenAIWSMode: true,
ResponseHeaders: lease.HandshakeHeaders(),
Duration: time.Since(turnStart),
FirstTokenMs: firstTokenMs,
}, nil
}
}
}
currentPayload := firstPayload.payloadRaw
currentOriginalModel := firstPayload.originalModel
currentPayloadBytes := firstPayload.payloadBytes
isStrictAffinityTurn := func(payload []byte) bool {
if !storeDisabled {
return false
}
return strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) != ""
}
var sessionLease *openAIWSConnLease
sessionConnID := ""
pinnedSessionConnID := ""
unpinSessionConn := func(connID string) {
connID = strings.TrimSpace(connID)
if connID == "" || pinnedSessionConnID != connID {
return
}
pool.UnpinConn(account.ID, connID)
pinnedSessionConnID = ""
}
pinSessionConn := func(connID string) {
if !storeDisabled {
return
}
connID = strings.TrimSpace(connID)
if connID == "" || pinnedSessionConnID == connID {
return
}
if pinnedSessionConnID != "" {
pool.UnpinConn(account.ID, pinnedSessionConnID)
pinnedSessionConnID = ""
}
if pool.PinConn(account.ID, connID) {
pinnedSessionConnID = connID
}
}
releaseSessionLease := func() {
if sessionLease == nil {
return
}
if dedicatedMode {
// dedicated 会话结束后主动标记损坏,确保连接不会跨会话复用。
sessionLease.MarkBroken()
}
unpinSessionConn(sessionConnID)
sessionLease.Release()
if debugEnabled {
logOpenAIWSModeDebug(
"ingress_ws_upstream_released account_id=%d conn_id=%s",
account.ID,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
)
}
}
defer releaseSessionLease()
turn := 1
turnRetry := 0
turnPrevRecoveryTried := false
lastTurnFinishedAt := time.Time{}
lastTurnResponseID := ""
lastTurnPayload := []byte(nil)
var lastTurnStrictState *openAIWSIngressPreviousTurnStrictState
lastTurnReplayInput := []json.RawMessage(nil)
lastTurnReplayInputExists := false
currentTurnReplayInput := []json.RawMessage(nil)
currentTurnReplayInputExists := false
skipBeforeTurn := false
resetSessionLease := func(markBroken bool) {
if sessionLease == nil {
return
}
if markBroken {
sessionLease.MarkBroken()
}
releaseSessionLease()
sessionLease = nil
sessionConnID = ""
preferredConnID = ""
}
recoverIngressPrevResponseNotFound := func(relayErr error, turn int, connID string) bool {
if !isOpenAIWSIngressPreviousResponseNotFound(relayErr) {
return false
}
if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() {
return false
}
if isStrictAffinityTurn(currentPayload) {
// Layer 2严格亲和链路命中 previous_response_not_found 时,降级为“去掉 previous_response_id 后重放一次”。
// 该错误说明续链锚点已失效,继续 strict fail-close 只会直接中断本轮请求。
logOpenAIWSModeInfo(
"ingress_ws_prev_response_recovery_layer2 account_id=%d turn=%d conn_id=%s store_disabled_conn_mode=%s action=drop_previous_response_id_retry",
account.ID,
turn,
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(storeDisabledConnMode),
)
}
turnPrevRecoveryTried = true
updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
if dropErr != nil || !removed {
reason := "not_removed"
if dropErr != nil {
reason = "drop_error"
}
logOpenAIWSModeInfo(
"ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(reason),
)
return false
}
updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
updatedPayload,
currentTurnReplayInput,
currentTurnReplayInputExists,
)
if setInputErr != nil {
logOpenAIWSModeInfo(
"ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
)
return false
}
logOpenAIWSModeInfo(
"ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id retry=1",
account.ID,
turn,
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
)
currentPayload = updatedWithInput
currentPayloadBytes = len(updatedWithInput)
resetSessionLease(true)
skipBeforeTurn = true
return true
}
retryIngressTurn := func(relayErr error, turn int, connID string) bool {
if !isOpenAIWSIngressTurnRetryable(relayErr) || turnRetry >= 1 {
return false
}
if isStrictAffinityTurn(currentPayload) {
logOpenAIWSModeInfo(
"ingress_ws_turn_retry_skip account_id=%d turn=%d conn_id=%s reason=strict_affinity",
account.ID,
turn,
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
)
return false
}
turnRetry++
logOpenAIWSModeInfo(
"ingress_ws_turn_retry account_id=%d turn=%d retry=%d reason=%s conn_id=%s",
account.ID,
turn,
turnRetry,
truncateOpenAIWSLogValue(openAIWSIngressTurnRetryReason(relayErr), openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
)
resetSessionLease(true)
skipBeforeTurn = true
return true
}
for {
if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil {
if err := hooks.BeforeTurn(turn); err != nil {
return err
}
}
skipBeforeTurn = false
currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
expectedPrev := strings.TrimSpace(lastTurnResponseID)
hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
// store=false + function_call_output 场景必须有续链锚点。
// 若客户端未传 previous_response_id优先回填上一轮响应 ID避免上游报 call_id 无法关联。
if shouldInferIngressFunctionCallOutputPreviousResponseID(
storeDisabled,
turn,
hasFunctionCallOutput,
currentPreviousResponseID,
expectedPrev,
) {
updatedPayload, setPrevErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev)
if setPrevErr != nil {
logOpenAIWSModeInfo(
"ingress_ws_function_call_output_prev_infer_skip account_id=%d turn=%d conn_id=%s reason=set_previous_response_id_error cause=%s expected_previous_response_id=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(setPrevErr.Error(), openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
)
} else {
currentPayload = updatedPayload
currentPayloadBytes = len(updatedPayload)
currentPreviousResponseID = expectedPrev
logOpenAIWSModeInfo(
"ingress_ws_function_call_output_prev_infer account_id=%d turn=%d conn_id=%s action=set_previous_response_id previous_response_id=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
)
}
}
nextReplayInput, nextReplayInputExists, replayInputErr := buildOpenAIWSReplayInputSequence(
lastTurnReplayInput,
lastTurnReplayInputExists,
currentPayload,
currentPreviousResponseID != "",
)
if replayInputErr != nil {
logOpenAIWSModeInfo(
"ingress_ws_replay_input_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(replayInputErr.Error(), openAIWSLogValueMaxLen),
)
currentTurnReplayInput = nil
currentTurnReplayInputExists = false
} else {
currentTurnReplayInput = nextReplayInput
currentTurnReplayInputExists = nextReplayInputExists
}
if storeDisabled && turn > 1 && currentPreviousResponseID != "" {
shouldKeepPreviousResponseID := false
strictReason := ""
var strictErr error
if lastTurnStrictState != nil {
shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseIDWithStrictState(
lastTurnStrictState,
currentPayload,
lastTurnResponseID,
hasFunctionCallOutput,
)
} else {
shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseID(
lastTurnPayload,
currentPayload,
lastTurnResponseID,
hasFunctionCallOutput,
)
}
if strictErr != nil {
logOpenAIWSModeInfo(
"ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s cause=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(strictReason),
truncateOpenAIWSLogValue(strictErr.Error(), openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
hasFunctionCallOutput,
)
} else if !shouldKeepPreviousResponseID {
updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
if dropErr != nil || !removed {
dropReason := "not_removed"
if dropErr != nil {
dropReason = "drop_error"
}
logOpenAIWSModeInfo(
"ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(strictReason),
normalizeOpenAIWSLogValue(dropReason),
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
hasFunctionCallOutput,
)
} else {
updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
updatedPayload,
currentTurnReplayInput,
currentTurnReplayInputExists,
)
if setInputErr != nil {
logOpenAIWSModeInfo(
"ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=set_full_input_error previous_response_id=%s expected_previous_response_id=%s cause=%s has_function_call_output=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(strictReason),
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
hasFunctionCallOutput,
)
} else {
currentPayload = updatedWithInput
currentPayloadBytes = len(updatedWithInput)
logOpenAIWSModeInfo(
"ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_full_create reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(strictReason),
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
hasFunctionCallOutput,
)
currentPreviousResponseID = ""
}
}
}
}
forcePreferredConn := isStrictAffinityTurn(currentPayload)
if sessionLease == nil {
acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn)
if acquireErr != nil {
return fmt.Errorf("acquire upstream websocket: %w", acquireErr)
}
sessionLease = acquiredLease
sessionConnID = strings.TrimSpace(sessionLease.ConnID())
if storeDisabled {
pinSessionConn(sessionConnID)
} else {
unpinSessionConn(sessionConnID)
}
}
shouldPreflightPing := turn > 1 && sessionLease != nil && turnRetry == 0
if shouldPreflightPing && openAIWSIngressPreflightPingIdle > 0 && !lastTurnFinishedAt.IsZero() {
if time.Since(lastTurnFinishedAt) < openAIWSIngressPreflightPingIdle {
shouldPreflightPing = false
}
}
if shouldPreflightPing {
if pingErr := sessionLease.PingWithTimeout(openAIWSConnHealthCheckTO); pingErr != nil {
logOpenAIWSModeInfo(
"ingress_ws_upstream_preflight_ping_fail account_id=%d turn=%d conn_id=%s cause=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen),
)
if forcePreferredConn {
if !turnPrevRecoveryTried && currentPreviousResponseID != "" {
updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
if dropErr != nil || !removed {
reason := "not_removed"
if dropErr != nil {
reason = "drop_error"
}
logOpenAIWSModeInfo(
"ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s previous_response_id=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(reason),
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
)
} else {
updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
updatedPayload,
currentTurnReplayInput,
currentTurnReplayInputExists,
)
if setInputErr != nil {
logOpenAIWSModeInfo(
"ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error previous_response_id=%s cause=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
)
} else {
logOpenAIWSModeInfo(
"ingress_ws_preflight_ping_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
)
turnPrevRecoveryTried = true
currentPayload = updatedWithInput
currentPayloadBytes = len(updatedWithInput)
resetSessionLease(true)
skipBeforeTurn = true
continue
}
}
}
resetSessionLease(true)
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
"upstream continuation connection is unavailable; please restart the conversation",
pingErr,
)
}
resetSessionLease(true)
acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn)
if acquireErr != nil {
return fmt.Errorf("acquire upstream websocket after preflight ping fail: %w", acquireErr)
}
sessionLease = acquiredLease
sessionConnID = strings.TrimSpace(sessionLease.ConnID())
if storeDisabled {
pinSessionConn(sessionConnID)
}
}
}
connID := sessionConnID
if currentPreviousResponseID != "" {
chainedFromLast := expectedPrev != "" && currentPreviousResponseID == expectedPrev
currentPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(currentPreviousResponseID)
logOpenAIWSModeInfo(
"ingress_ws_turn_chain account_id=%d turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v preferred_conn_id=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(currentPreviousResponseIDKind),
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
chainedFromLast,
truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"),
openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"),
turnState != "",
len(turnState),
openAIWSPayloadStringFromRaw(currentPayload, "prompt_cache_key") != "",
storeDisabled,
)
}
result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel)
if relayErr != nil {
if recoverIngressPrevResponseNotFound(relayErr, turn, connID) {
continue
}
if retryIngressTurn(relayErr, turn, connID) {
continue
}
finalErr := relayErr
if unwrapped := errors.Unwrap(relayErr); unwrapped != nil {
finalErr = unwrapped
}
if hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(turn, nil, finalErr)
}
sessionLease.MarkBroken()
return finalErr
}
turnRetry = 0
turnPrevRecoveryTried = false
lastTurnFinishedAt = time.Now()
if hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(turn, result, nil)
}
if result == nil {
return errors.New("websocket turn result is nil")
}
responseID := strings.TrimSpace(result.RequestID)
lastTurnResponseID = responseID
lastTurnPayload = cloneOpenAIWSPayloadBytes(currentPayload)
lastTurnReplayInput = cloneOpenAIWSRawMessages(currentTurnReplayInput)
lastTurnReplayInputExists = currentTurnReplayInputExists
nextStrictState, strictStateErr := buildOpenAIWSIngressPreviousTurnStrictState(currentPayload)
if strictStateErr != nil {
lastTurnStrictState = nil
logOpenAIWSModeInfo(
"ingress_ws_prev_response_strict_state_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(strictStateErr.Error(), openAIWSLogValueMaxLen),
)
} else {
lastTurnStrictState = nextStrictState
}
if responseID != "" && stateStore != nil {
ttl := s.openAIWSResponseStickyTTL()
logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl))
stateStore.BindResponseConn(responseID, connID, ttl)
}
if stateStore != nil && storeDisabled && sessionHash != "" {
stateStore.BindSessionConn(groupID, sessionHash, connID, s.openAIWSSessionStickyTTL())
}
if connID != "" {
preferredConnID = connID
}
nextClientMessage, readErr := readClientMessage()
if readErr != nil {
if isOpenAIWSClientDisconnectError(readErr) {
closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
logOpenAIWSModeInfo(
"ingress_ws_client_closed account_id=%d conn_id=%s close_status=%s close_reason=%s",
account.ID,
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
closeStatus,
truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen),
)
return nil
}
return fmt.Errorf("read client websocket request: %w", readErr)
}
nextPayload, parseErr := parseClientPayload(nextClientMessage)
if parseErr != nil {
return parseErr
}
if nextPayload.promptCacheKey != "" {
// ingress 会话在整个客户端 WS 生命周期内复用同一上游连接;
// prompt_cache_key 对握手头的更新仅在未来需要重新建连时生效。
updatedHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), nextPayload.promptCacheKey)
baseAcquireReq.Headers = updatedHeaders
}
if nextPayload.previousResponseID != "" {
expectedPrev := strings.TrimSpace(lastTurnResponseID)
chainedFromLast := expectedPrev != "" && nextPayload.previousResponseID == expectedPrev
nextPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(nextPayload.previousResponseID)
logOpenAIWSModeInfo(
"ingress_ws_next_turn_chain account_id=%d turn=%d next_turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v has_prompt_cache_key=%v store_disabled=%v",
account.ID,
turn,
turn+1,
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(nextPreviousResponseIDKind),
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
chainedFromLast,
nextPayload.promptCacheKey != "",
storeDisabled,
)
}
if stateStore != nil && nextPayload.previousResponseID != "" {
if stickyConnID, ok := stateStore.GetResponseConn(nextPayload.previousResponseID); ok {
if sessionConnID != "" && stickyConnID != "" && stickyConnID != sessionConnID {
logOpenAIWSModeInfo(
"ingress_ws_keep_session_conn account_id=%d turn=%d conn_id=%s sticky_conn_id=%s previous_response_id=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(stickyConnID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen),
)
} else {
preferredConnID = stickyConnID
}
}
}
currentPayload = nextPayload.payloadRaw
currentOriginalModel = nextPayload.originalModel
currentPayloadBytes = nextPayload.payloadBytes
storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account)
if !storeDisabled {
unpinSessionConn(sessionConnID)
}
turn++
}
}
func (s *OpenAIGatewayService) isOpenAIWSGeneratePrewarmEnabled() bool {
return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled
}
// performOpenAIWSGeneratePrewarm 在 WSv2 下执行可选的 generate=false 预热。
// 预热默认关闭,仅在配置开启后生效;失败时按可恢复错误回退到 HTTP。
func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
ctx context.Context,
lease *openAIWSConnLease,
decision OpenAIWSProtocolDecision,
payload map[string]any,
previousResponseID string,
reqBody map[string]any,
account *Account,
stateStore OpenAIWSStateStore,
groupID int64,
) error {
if s == nil {
return nil
}
if lease == nil || account == nil {
logOpenAIWSModeInfo("prewarm_skip reason=invalid_state has_lease=%v has_account=%v", lease != nil, account != nil)
return nil
}
connID := strings.TrimSpace(lease.ConnID())
if !s.isOpenAIWSGeneratePrewarmEnabled() {
return nil
}
if decision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
logOpenAIWSModeInfo(
"prewarm_skip account_id=%d conn_id=%s reason=transport_not_v2 transport=%s",
account.ID,
connID,
normalizeOpenAIWSLogValue(string(decision.Transport)),
)
return nil
}
if strings.TrimSpace(previousResponseID) != "" {
logOpenAIWSModeInfo(
"prewarm_skip account_id=%d conn_id=%s reason=has_previous_response_id previous_response_id=%s",
account.ID,
connID,
truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
)
return nil
}
if lease.IsPrewarmed() {
logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=already_prewarmed", account.ID, connID)
return nil
}
if NeedsToolContinuation(reqBody) {
logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=tool_continuation", account.ID, connID)
return nil
}
prewarmStart := time.Now()
logOpenAIWSModeInfo("prewarm_start account_id=%d conn_id=%s", account.ID, connID)
prewarmPayload := make(map[string]any, len(payload)+1)
for k, v := range payload {
prewarmPayload[k] = v
}
prewarmPayload["generate"] = false
prewarmPayloadJSON := payloadAsJSONBytes(prewarmPayload)
if err := lease.WriteJSONWithContextTimeout(ctx, prewarmPayload, s.openAIWSWriteTimeout()); err != nil {
lease.MarkBroken()
logOpenAIWSModeInfo(
"prewarm_write_fail account_id=%d conn_id=%s cause=%s",
account.ID,
connID,
truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
)
return wrapOpenAIWSFallback("prewarm_write", err)
}
logOpenAIWSModeInfo("prewarm_write_sent account_id=%d conn_id=%s payload_bytes=%d", account.ID, connID, len(prewarmPayloadJSON))
prewarmResponseID := ""
prewarmEventCount := 0
prewarmTerminalCount := 0
for {
message, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout())
if readErr != nil {
lease.MarkBroken()
closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
logOpenAIWSModeInfo(
"prewarm_read_fail account_id=%d conn_id=%s close_status=%s close_reason=%s cause=%s events=%d",
account.ID,
connID,
closeStatus,
closeReason,
truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen),
prewarmEventCount,
)
return wrapOpenAIWSFallback("prewarm_"+classifyOpenAIWSReadFallbackReason(readErr), readErr)
}
eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(message)
if eventType == "" {
continue
}
prewarmEventCount++
if prewarmResponseID == "" && eventResponseID != "" {
prewarmResponseID = eventResponseID
}
if prewarmEventCount <= openAIWSPrewarmEventLogHead || eventType == "error" || isOpenAIWSTerminalEvent(eventType) {
logOpenAIWSModeInfo(
"prewarm_event account_id=%d conn_id=%s idx=%d type=%s bytes=%d",
account.ID,
connID,
prewarmEventCount,
truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
len(message),
)
}
if eventType == "error" {
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw)
errMsg := strings.TrimSpace(errMsgRaw)
if errMsg == "" {
errMsg = "OpenAI websocket prewarm error"
}
fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
logOpenAIWSModeInfo(
"prewarm_error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s",
account.ID,
connID,
prewarmEventCount,
truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen),
canFallback,
errCode,
errType,
errMessage,
)
lease.MarkBroken()
if canFallback {
return wrapOpenAIWSFallback("prewarm_"+fallbackReason, errors.New(errMsg))
}
return wrapOpenAIWSFallback("prewarm_error_event", errors.New(errMsg))
}
if isOpenAIWSTerminalEvent(eventType) {
prewarmTerminalCount++
break
}
}
lease.MarkPrewarmed()
if prewarmResponseID != "" && stateStore != nil {
ttl := s.openAIWSResponseStickyTTL()
logOpenAIWSBindResponseAccountWarn(groupID, account.ID, prewarmResponseID, stateStore.BindResponseAccount(ctx, groupID, prewarmResponseID, account.ID, ttl))
stateStore.BindResponseConn(prewarmResponseID, lease.ConnID(), ttl)
}
logOpenAIWSModeInfo(
"prewarm_done account_id=%d conn_id=%s response_id=%s events=%d terminal_events=%d duration_ms=%d",
account.ID,
connID,
truncateOpenAIWSLogValue(prewarmResponseID, openAIWSIDValueMaxLen),
prewarmEventCount,
prewarmTerminalCount,
time.Since(prewarmStart).Milliseconds(),
)
return nil
}
func payloadAsJSON(payload map[string]any) string {
return string(payloadAsJSONBytes(payload))
}
func payloadAsJSONBytes(payload map[string]any) []byte {
if len(payload) == 0 {
return []byte("{}")
}
body, err := json.Marshal(payload)
if err != nil {
return []byte("{}")
}
return body
}
func isOpenAIWSTerminalEvent(eventType string) bool {
switch strings.TrimSpace(eventType) {
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return true
default:
return false
}
}
func isOpenAIWSTokenEvent(eventType string) bool {
eventType = strings.TrimSpace(eventType)
if eventType == "" {
return false
}
switch eventType {
case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
return false
}
if strings.Contains(eventType, ".delta") {
return true
}
if strings.HasPrefix(eventType, "response.output_text") {
return true
}
if strings.HasPrefix(eventType, "response.output") {
return true
}
return eventType == "response.completed" || eventType == "response.done"
}
func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte {
if len(message) == 0 {
return message
}
if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel {
return message
}
if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) {
return message
}
modelValues := gjson.GetManyBytes(message, "model", "response.model")
replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel
replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel
if !replaceModel && !replaceResponseModel {
return message
}
updated := message
if replaceModel {
if next, err := sjson.SetBytes(updated, "model", toModel); err == nil {
updated = next
}
}
if replaceResponseModel {
if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil {
updated = next
}
}
return updated
}
func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) {
if usage == nil || len(body) == 0 {
return
}
values := gjson.GetManyBytes(
body,
"usage.input_tokens",
"usage.output_tokens",
"usage.input_tokens_details.cached_tokens",
)
usage.InputTokens = int(values[0].Int())
usage.OutputTokens = int(values[1].Int())
usage.CacheReadInputTokens = int(values[2].Int())
}
func getOpenAIGroupIDFromContext(c *gin.Context) int64 {
if c == nil {
return 0
}
value, exists := c.Get("api_key")
if !exists {
return 0
}
apiKey, ok := value.(*APIKey)
if !ok || apiKey == nil || apiKey.GroupID == nil {
return 0
}
return *apiKey.GroupID
}
// SelectAccountByPreviousResponseID 按 previous_response_id 命中账号粘连。
// 未命中或账号不可用时返回 (nil, nil),由调用方继续走常规调度。
func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
ctx context.Context,
groupID *int64,
previousResponseID string,
requestedModel string,
excludedIDs map[int64]struct{},
) (*AccountSelectionResult, error) {
if s == nil {
return nil, nil
}
responseID := strings.TrimSpace(previousResponseID)
if responseID == "" {
return nil, nil
}
store := s.getOpenAIWSStateStore()
if store == nil {
return nil, nil
}
accountID, err := store.GetResponseAccount(ctx, derefGroupID(groupID), responseID)
if err != nil || accountID <= 0 {
return nil, nil
}
if excludedIDs != nil {
if _, excluded := excludedIDs[accountID]; excluded {
return nil, nil
}
}
account, err := s.getSchedulableAccount(ctx, accountID)
if err != nil || account == nil {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
// 非 WSv2 场景(如 force_http/全局关闭)不应使用 previous_response_id 粘连,
// 以保持“回滚到 HTTP”后的历史行为一致性。
if s.getOpenAIWSProtocolResolver().Resolve(account).Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
return nil, nil
}
if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() || !account.IsSchedulable() {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil, nil
}
result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {
logOpenAIWSBindResponseAccountWarn(
derefGroupID(groupID),
accountID,
responseID,
store.BindResponseAccount(ctx, derefGroupID(groupID), responseID, accountID, s.openAIWSResponseStickyTTL()),
)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
cfg := s.schedulingConfig()
if s.concurrencyService != nil {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
return nil, nil
}
func classifyOpenAIWSAcquireError(err error) string {
if err == nil {
return "acquire_conn"
}
var dialErr *openAIWSDialError
if errors.As(err, &dialErr) {
switch dialErr.StatusCode {
case 426:
return "upgrade_required"
case 401, 403:
return "auth_failed"
case 429:
return "upstream_rate_limited"
}
if dialErr.StatusCode >= 500 {
return "upstream_5xx"
}
return "dial_failed"
}
if errors.Is(err, errOpenAIWSConnQueueFull) {
return "conn_queue_full"
}
if errors.Is(err, errOpenAIWSPreferredConnUnavailable) {
return "preferred_conn_unavailable"
}
if errors.Is(err, context.DeadlineExceeded) {
return "acquire_timeout"
}
return "acquire_conn"
}
func isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw string) bool {
code := strings.ToLower(strings.TrimSpace(codeRaw))
errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
msg := strings.ToLower(strings.TrimSpace(msgRaw))
if strings.Contains(errType, "rate_limit") || strings.Contains(errType, "usage_limit") {
return true
}
if strings.Contains(code, "rate_limit") || strings.Contains(code, "usage_limit") || strings.Contains(code, "insufficient_quota") {
return true
}
if strings.Contains(msg, "usage limit") && strings.Contains(msg, "reached") {
return true
}
if strings.Contains(msg, "rate limit") && (strings.Contains(msg, "reached") || strings.Contains(msg, "exceeded")) {
return true
}
return false
}
func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Context, account *Account, headers http.Header, responseBody []byte, codeRaw, errTypeRaw, msgRaw string) {
if s == nil || s.rateLimitService == nil || account == nil || account.Platform != PlatformOpenAI {
return
}
if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
return
}
s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
}
func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
code := strings.ToLower(strings.TrimSpace(codeRaw))
errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
msg := strings.ToLower(strings.TrimSpace(msgRaw))
switch code {
case "upgrade_required":
return "upgrade_required", true
case "websocket_not_supported", "websocket_unsupported":
return "ws_unsupported", true
case "websocket_connection_limit_reached":
return "ws_connection_limit_reached", true
case "previous_response_not_found":
return "previous_response_not_found", true
}
if isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
return "upstream_rate_limited", false
}
if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") {
return "upgrade_required", true
}
if strings.Contains(errType, "upgrade") {
return "upgrade_required", true
}
if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") {
return "ws_unsupported", true
}
if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") {
return "ws_connection_limit_reached", true
}
if strings.Contains(msg, "previous_response_not_found") ||
(strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) {
return "previous_response_not_found", true
}
if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") {
return "upstream_error_event", true
}
return "event_error", false
}
func classifyOpenAIWSErrorEvent(message []byte) (string, bool) {
if len(message) == 0 {
return "event_error", false
}
return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message))
}
func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
code := strings.ToLower(strings.TrimSpace(codeRaw))
errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
switch {
case strings.Contains(errType, "invalid_request"),
strings.Contains(code, "invalid_request"),
strings.Contains(code, "bad_request"),
code == "previous_response_not_found":
return http.StatusBadRequest
case strings.Contains(errType, "authentication"),
strings.Contains(code, "invalid_api_key"),
strings.Contains(code, "unauthorized"):
return http.StatusUnauthorized
case strings.Contains(errType, "permission"),
strings.Contains(code, "forbidden"):
return http.StatusForbidden
case isOpenAIWSRateLimitError(codeRaw, errTypeRaw, ""):
return http.StatusTooManyRequests
default:
return http.StatusBadGateway
}
}
func openAIWSErrorHTTPStatus(message []byte) int {
if len(message) == 0 {
return http.StatusBadGateway
}
codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message)
return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw)
}
func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration {
if s == nil || s.cfg == nil {
return 30 * time.Second
}
seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds
if seconds <= 0 {
return 0
}
return time.Duration(seconds) * time.Second
}
func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool {
if s == nil || accountID <= 0 {
return false
}
cooldown := s.openAIWSFallbackCooldown()
if cooldown <= 0 {
return false
}
rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID)
if !ok || rawUntil == nil {
return false
}
until, ok := rawUntil.(time.Time)
if !ok || until.IsZero() {
s.openaiWSFallbackUntil.Delete(accountID)
return false
}
if time.Now().Before(until) {
return true
}
s.openaiWSFallbackUntil.Delete(accountID)
return false
}
func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) {
if s == nil || accountID <= 0 {
return
}
cooldown := s.openAIWSFallbackCooldown()
if cooldown <= 0 {
return
}
s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown))
}
func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) {
if s == nil || accountID <= 0 {
return
}
s.openaiWSFallbackUntil.Delete(accountID)
}