Files
sub2api/backend/internal/service/openai_ws_protocol_forward_test.go

1220 lines
39 KiB
Go
Raw Permalink Normal View History

package service
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t *testing.T) {
gin.SetMode(gin.TestMode)
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer wsFallbackServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 1,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsFallbackServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "WS 模式下失败时不应回退 HTTP")
}
func TestOpenAIGatewayService_Forward_HTTPIngressStaysHTTPWhenWSEnabled(t *testing.T) {
gin.SetMode(gin.TestMode)
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer wsFallbackServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 101,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsFallbackServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_keep","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发")
require.NotNil(t, upstream.lastReq, "HTTP 入站应命中 HTTP 上游")
require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists(), "HTTP 路径应沿用原逻辑移除 previous_response_id")
decision, _ := c.Get("openai_ws_transport_decision")
reason, _ := c.Get("openai_ws_transport_reason")
require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision)
require.Equal(t, "client_protocol_http", reason)
}
func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer wsFallbackServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = false
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 1,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsFallbackServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists())
}
func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUpgradeRequired)
_, _ = w.Write([]byte(`upgrade required`))
}))
defer ws426Server.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"usage":{"input_tokens":8,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}`,
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 12,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": ws426Server.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_426","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "upgrade_required")
require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP")
require.Equal(t, http.StatusUpgradeRequired, rec.Code)
require.Contains(t, rec.Body.String(), "426")
}
func TestOpenAIGatewayService_Forward_WSv2FallbackCoolingSkipWS(t *testing.T) {
gin.SetMode(gin.TestMode)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"usage":{"input_tokens":2,"output_tokens":3,"input_tokens_details":{"cached_tokens":0}}}`,
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 30
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 21,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
svc.markOpenAIWSFallbackCooling(account.ID, "upgrade_required")
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_cooling","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP")
_, ok := c.Get("openai_ws_fallback_cooling")
require.False(t, ok, "已移除 fallback cooling 快捷回退路径")
}
func TestOpenAIGatewayService_Forward_ReturnErrorWhenOnlyWSv1Enabled(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(
`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 31,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "https://api.openai.com/v1/responses",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_v1","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "ws v1")
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "WSv1")
require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求")
}
func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
cfg := &config.Config{}
svc := NewOpenAIGatewayService(
nil,
nil,
nil,
nil,
nil,
nil,
cfg,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_missing", decision.Reason)
}
func TestOpenAIGatewayService_Forward_WSv2FallbackWhenResponseAlreadyWrittenReturnsWSError(t *testing.T) {
gin.SetMode(gin.TestMode)
ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUpgradeRequired)
_, _ = w.Write([]byte(`upgrade required`))
}))
defer ws426Server.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
c.String(http.StatusAccepted, "already-written")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
}
account := &Account{
ID: 41,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": ws426Server.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "ws fallback")
require.Nil(t, upstream.lastReq, "已写下游响应时,不应再回退 HTTP")
}
func TestOpenAIGatewayService_Forward_WSv2StreamEarlyCloseFallbackHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
// 仅发送 response.created非 token 事件)后立即关闭,
// 模拟线上“上游早期内部错误断连”的场景。
if err := conn.WriteJSON(map[string]any{
"type": "response.created",
"response": map[string]any{
"id": "resp_ws_created_only",
"model": "gpt-5.3-codex",
},
}); err != nil {
t.Errorf("write response.created failed: %v", err)
return
}
closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "")
_ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second))
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 88,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "WS 早期断连后不应再回退 HTTP")
require.Empty(t, rec.Body.String(), "未产出 token 前上游断连时不应写入下游半截流")
}
func TestOpenAIGatewayService_Forward_WSv2RetryFiveTimesThenFallbackHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "")
_ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second))
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_retry_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" +
"data: [DONE]\n\n",
)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 89,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "WS 重连耗尽后不应再回退 HTTP")
require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load())
}
func TestOpenAIGatewayService_Forward_WSv2PolicyViolationFastFallbackHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
closePayload := websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "")
_ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second))
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_policy_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 1
cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 2
cfg.Gateway.OpenAIWS.RetryJitterRatio = 0
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 8901,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "策略违规关闭后不应回退 HTTP")
require.Equal(t, int32(1), wsAttempts.Load(), "策略违规不应进行 WS 重试")
}
func TestOpenAIGatewayService_Forward_WSv2ConnectionLimitReachedRetryThenFallbackHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "websocket_connection_limit_reached",
"type": "server_error",
"message": "websocket connection limit reached",
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_retry_limit","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 90,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "触发 websocket_connection_limit_reached 后不应回退 HTTP")
require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load())
}
func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundRecoversByDroppingPreviousResponseID(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempt := wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
if attempt == 1 {
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "previous_response_not_found",
"type": "invalid_request_error",
"message": "previous response not found",
},
})
return
}
_ = conn.WriteJSON(map[string]any{
"type": "response.completed",
"response": map[string]any{
"id": "resp_ws_prev_recover_ok",
"model": "gpt-5.3-codex",
"usage": map[string]any{
"input_tokens": 1,
"output_tokens": 1,
"input_tokens_details": map[string]any{
"cached_tokens": 0,
},
},
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 91,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "resp_ws_prev_recover_ok", result.RequestID)
require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP")
require.Equal(t, int32(2), wsAttempts.Load(), "previous_response_not_found 应触发一次去掉 previous_response_id 的恢复重试")
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "resp_ws_prev_recover_ok", gjson.Get(rec.Body.String(), "id").String())
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 2)
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id")
require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
}
func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryForFunctionCallOutput(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "previous_response_not_found",
"type": "invalid_request_error",
"message": "previous response not found",
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 92,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP")
require.Equal(t, int32(1), wsAttempts.Load(), "function_call_output 场景应跳过 previous_response_not_found 自动恢复")
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, strings.ToLower(rec.Body.String()), "previous response not found")
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 1)
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists())
}
func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryWithoutPreviousResponseID(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "previous_response_not_found",
"type": "invalid_request_error",
"message": "previous response not found",
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 93,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP")
require.Equal(t, int32(1), wsAttempts.Load(), "缺少 previous_response_id 时应跳过自动恢复重试")
require.Equal(t, http.StatusBadRequest, rec.Code)
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 1)
require.False(t, gjson.GetBytes(requests[0], "previous_response_id").Exists())
}
func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOnce(t *testing.T) {
gin.SetMode(gin.TestMode)
var wsAttempts atomic.Int32
var wsRequestPayloads [][]byte
var wsRequestMu sync.Mutex
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsAttempts.Add(1)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
var req map[string]any
if err := conn.ReadJSON(&req); err != nil {
t.Errorf("read ws request failed: %v", err)
return
}
reqRaw, _ := json.Marshal(req)
wsRequestMu.Lock()
wsRequestPayloads = append(wsRequestPayloads, reqRaw)
wsRequestMu.Unlock()
_ = conn.WriteJSON(map[string]any{
"type": "error",
"error": map[string]any{
"code": "previous_response_not_found",
"type": "invalid_request_error",
"message": "previous response not found",
},
})
}))
defer wsServer.Close()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
account := &Account{
ID: 94,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": wsServer.URL,
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.Error(t, err)
require.Nil(t, result)
require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP")
require.Equal(t, int32(2), wsAttempts.Load(), "应只允许一次自动恢复重试")
require.Equal(t, http.StatusBadRequest, rec.Code)
wsRequestMu.Lock()
requests := append([][]byte(nil), wsRequestPayloads...)
wsRequestMu.Unlock()
require.Len(t, requests, 2)
require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应包含 previous_response_id")
require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
}