Files
sub2api/backend/internal/service/openai_fast_policy_ws_test.go

1019 lines
47 KiB
Go
Raw Normal View History

feat(openai): OpenAI Fast/Flex Policy 完整实现(HTTP + WebSocket + Admin) 对称参照 Claude BetaPolicy 的 fast-mode 过滤实现,新增针对 OpenAI 上游 service_tier 字段(priority / flex,含客户端 "fast" → "priority" 归一化)的 pass / filter / block 三态策略,覆盖全部 OpenAI 入口 + admin 配置入口。 后端核心 - 新增 SettingKeyOpenAIFastPolicySettings、OpenAIFastPolicyRule、 OpenAIFastPolicySettings 配置模型,含规则的 service_tier × action × scope × 模型白名单 × fallback action 维度。 - SettingService.Get/SetOpenAIFastPolicySettings;缺失时返回内置默认策略 (所有模型的 priority 走 filter,whitelist 为空,fallback=pass)。设计 依据:service_tier=fast 是用户级开关,与 model 字段正交,默认锁定特定 model slug 会留下"用 gpt-4 + fast 透传 priority 上游"的绕过路径。JSON 解析失败不再静默 fallback,slog.Warn 记录脏数据,便于运维定位。 - service_tier 归一化(trim + ToLower + fast→priority + 白名单 priority/flex) 与策略评估(evaluateOpenAIFastPolicy)作为唯一真实来源,HTTP / WS 共用。 抽出纯函数 evaluateOpenAIFastPolicyWithSettings,配合 ctx-bound settings 快照(withOpenAIFastPolicyContext / openAIFastPolicySettingsFromContext), WS 长会话入口预取一次后所有帧复用,避免每帧打到 settingService。 HTTP 入口(4 个) - Chat Completions、Anthropic 兼容(Messages,含 BetaFastMode→priority 二次 命中)、原生 Responses、Passthrough Responses 全部接入 applyOpenAIFastPolicyToBody,filter 走 sjson 顶层删除 service_tier,block 返回 403 forbidden_error JSON。 - 4 入口统一使用 upstream 视角的 model(GetMappedModel + normalizeOpenAIModelForUpstream + Codex OAuth normalize 后的 slug), 避免 chat/messages/native /responses/passthrough 因为 model 维度不同 造成 whitelist 命中差异。 - 在 pass 路径也把客户端 "fast" 别名归一化为 "priority" 写回 body, 否则 native /responses 与 passthrough 入口会把 "fast" 原样透传给上游 导致 400/拒绝(chat-completions 入口的 normalizeResponsesBodyServiceTier 此前已具备同等行为)。 WebSocket 入口 - 新增 applyOpenAIFastPolicyToWSResponseCreate:严格匹配 type="response.create",仅处理顶层 service_tier;filter 用 sjson 删字段, block 返回 typed *OpenAIFastBlockedError。 - ingress 路径在 parseClientPayload 内调用,block 命中先 Write Realtime 风格 error event 再返回 OpenAIWSClientCloseError(StatusPolicyViolation =1008),依赖底层 WebSocket Conn.Write 的同步 flush 保证 error 先于 close。 - passthrough 路径在 RunEntry 前对 firstClientMessage 应用策略,并通过 openAIWSPolicyEnforcingFrameConn 包装 ReadFrame 对每个 client→upstream 帧执行策略;后续帧无 model 字段时回退到 capturedSessionModel。 filter 闭包内同时侦测 session.update / session.created 帧的 session.model 字段刷新 capturedSessionModel,封堵"首帧 model=gpt-4o(pass)→ session.update 改为 gpt-5.5 → 不带 model 的 response.create fallback 到 gpt-4o"的 mid-session 绕过路径。 - passthrough billing:requestServiceTier 在策略 filter 之后再从 firstClientMessage 提取,filter 命中时 OpenAIForwardResult.ServiceTier 上报 nil(default tier),与 HTTP 入口(reqBody 来自 post-filter map) / WS ingress(payload 来自 post-filter bytes)的语义一致。 - 错误事件 schema:{event_id: "evt_<32hex>", type: "error", error: {type: "forbidden_error", code: "policy_violation", message}}, 与 OpenAI codex 客户端 error event 解析兼容。 Admin / Frontend - dto.SystemSettings / UpdateSettingsRequest 新增 openai_fast_policy_settings 字段(omitempty),bulk GET/PUT 接入。 - Settings 页 Gateway 页签新增 Fast/Flex Policy 表单卡片: service_tier × action × scope × 模型白名单 × fallback action 全字段配置。 - 前端守门:openaiFastPolicyLoaded 标志仅在 GET 真带回字段时才允许回写, 避免 rollout/错误把默认规则覆盖成空;saveSettings 回写循环 skip 该字段, 由专用刷新逻辑处理;仅 action=block 时发送 error_message,匹配后端 omitempty 行为。 测试 - HTTP 路径:openai_fast_policy_test.go 覆盖默认配置(whitelist=[],所有 模型 priority filter)/ block 自定义错误 / scope 区分 / filter 删字段 / block 不改 body / block 短路上游 / Anthropic BetaFastMode 触发 OpenAI fast policy 等场景。 - WebSocket 路径:openai_fast_policy_ws_test.go 覆盖 helper 单元(filter / fast→priority 归一化 / flex 透传 / block typed error / 无 service_tier 字节不变 / 非 response.create 帧不动 / 空 type 帧不动 / event_id+code 字段断言 / 非字符串 service_tier 容错)+ pass 路径 fast 别名归一化回归 + ingress 端到端(filter 后上游不含 service_tier / block 后客户端先收 error event 再收 close 1008 且上游 0 写)+ passthrough capturedSessionModel fallback 用例(whitelist 策略下首帧 建立、缺 model 命中 fallback、缺少 fallback 时的 leak 文档化)+ passthrough session.update / session.created 旋转 capturedSessionModel 的 mid-session 绕过回归 + passthrough billing post-filter ServiceTier 与 idempotent filter 回归。 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 00:34:23 +08:00
package service
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
// --- Helper-level (unit) tests for applyOpenAIFastPolicyToWSResponseCreate ---
func TestWSResponseCreate_FilterStripsServiceTier(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority","input":[{"type":"input_text","text":"hi"}]}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.NotContains(t, string(updated), `"service_tier"`, "filter action should strip service_tier")
// Other fields preserved.
require.Equal(t, "response.create", gjson.GetBytes(updated, "type").String())
require.Equal(t, "gpt-5.5", gjson.GetBytes(updated, "model").String())
require.Equal(t, "hi", gjson.GetBytes(updated, "input.0.text").String())
}
func TestWSResponseCreate_FastNormalizedToPriorityThenFiltered(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Verbatim "fast" → normalized to "priority" → matches default rule → filter.
frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"fast"}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.NotContains(t, string(updated), `"service_tier"`)
// Mixed-case + whitespace variant should also normalize and filter.
frame = []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":" Fast "}`)
updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.NotContains(t, string(updated), `"service_tier"`)
}
func TestWSResponseCreate_FlexPassThrough(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Default policy targets priority only; flex is left untouched.
frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"flex"}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.Equal(t, "flex", gjson.GetBytes(updated, "service_tier").String(), "flex frames must reach upstream untouched under default policy")
}
func TestWSResponseCreate_BlockReturnsTypedError(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "ws fast blocked",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.NotNil(t, blocked)
require.Equal(t, "ws fast blocked", blocked.Message)
// On block, payload returned unchanged so caller can inspect / log it.
require.Equal(t, string(frame), string(updated))
}
func TestWSResponseCreate_NoServiceTierUntouched(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
frame := []byte(`{"type":"response.create","model":"gpt-5.5","input":[]}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.Equal(t, string(frame), string(updated), "no service_tier present must result in zero mutation")
}
func TestWSResponseCreate_NonResponseCreateFrameUntouched(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
ModelWhitelist: []string{"*"},
FallbackAction: BetaPolicyActionFilter,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// response.cancel happens to carry a service_tier-shaped field — must not be touched.
frame := []byte(`{"type":"response.cancel","service_tier":"priority"}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.Equal(t, string(frame), string(updated))
}
// TestWSResponseCreate_EmptyTypeFrameUntouched is the A1 regression: the
// helper used to treat empty type as response.create, which risked stripping
// fields from malformed / unknown client events. After the A1 fix only a
// strict "response.create" match triggers policy.
func TestWSResponseCreate_EmptyTypeFrameUntouched(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
ModelWhitelist: []string{"*"},
FallbackAction: BetaPolicyActionFilter,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Frame with no "type" field: must pass through completely unchanged
// even with a service_tier-shaped field present.
frame := []byte(`{"service_tier":"priority","model":"gpt-5.5"}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.Equal(t, string(frame), string(updated), "empty type must NOT be policy-checked — Realtime spec requires type, malformed frames are passed through")
// Explicit empty string also passes through.
frame = []byte(`{"type":"","service_tier":"priority","model":"gpt-5.5"}`)
updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
require.Equal(t, string(frame), string(updated))
}
// TestBuildOpenAIFastPolicyBlockedWSEvent_HasEventIDAndCode is the B1
// regression: the rendered Realtime error event must carry a non-empty
// event_id (so clients can correlate the rejection) and a stable error.code
// ("policy_violation"). The HTTP-side equivalent is the 403 permission_error
// JSON body emitted by writeOpenAIFastPolicyBlockedResponse.
func TestBuildOpenAIFastPolicyBlockedWSEvent_HasEventIDAndCode(t *testing.T) {
bytes := buildOpenAIFastPolicyBlockedWSEvent(&OpenAIFastBlockedError{Message: "blocked because reasons"})
require.NotNil(t, bytes)
require.Equal(t, "error", gjson.GetBytes(bytes, "type").String())
require.Equal(t, "invalid_request_error", gjson.GetBytes(bytes, "error.type").String())
require.Equal(t, "policy_violation", gjson.GetBytes(bytes, "error.code").String())
require.Equal(t, "blocked because reasons", gjson.GetBytes(bytes, "error.message").String())
eventID := gjson.GetBytes(bytes, "event_id").String()
require.NotEmpty(t, eventID, "event_id must be present so clients can correlate the rejection in their logs")
require.True(t, strings.HasPrefix(eventID, "evt_"), "event_id should follow the evt_<rand> Realtime convention; got %q", eventID)
// Sanity check: two consecutive events get distinct IDs.
other := buildOpenAIFastPolicyBlockedWSEvent(&OpenAIFastBlockedError{Message: "second"})
otherID := gjson.GetBytes(other, "event_id").String()
require.NotEqual(t, eventID, otherID, "event_id must be random per-event")
}
// TestBuildOpenAIFastPolicyBlockedWSEvent_NilSafe ensures the helper returns
// nil for a nil error (defensive guard for callers that always invoke it).
func TestBuildOpenAIFastPolicyBlockedWSEvent_NilSafe(t *testing.T) {
require.Nil(t, buildOpenAIFastPolicyBlockedWSEvent(nil))
}
// --- D5: passthrough wrapper FrameConn — capturedSessionModel fallback ---
// fakePassthroughFrameConn replays a fixed sequence of client frames into the
// policy-enforcing wrapper, then returns io.EOF. Captures all Write attempts
// for write-side assertions (none expected in the D5 test, since the wrapper
// only filters reads).
type fakePassthroughFrameConn struct {
reads [][]byte
idx int
writes [][]byte
closeOnce bool
}
func (f *fakePassthroughFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if f.idx >= len(f.reads) {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
payload := f.reads[f.idx]
f.idx++
return coderws.MessageText, payload, nil
}
func (f *fakePassthroughFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
cp := append([]byte(nil), payload...)
f.writes = append(f.writes, cp)
return nil
}
func (f *fakePassthroughFrameConn) Close() error {
f.closeOnce = true
return nil
}
// gpt55WhitelistFastPolicy 返回一份强制带 model whitelist 的策略,用于
// 验证 capturedSessionModel fallback 的语义(默认策略 whitelist 为空时
// fallback 路径无法被观察到)。
func gpt55WhitelistFastPolicy() *OpenAIFastPolicySettings {
return &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
ModelWhitelist: []string{"gpt-5.5", "gpt-5.5*"},
FallbackAction: BetaPolicyActionPass,
}},
}
}
// TestPolicyEnforcingFrameConn_FollowupFrameWithoutModelUsesCapturedModel is
// the D5 regression: in passthrough mode a follow-up response.create frame
// without a "model" field must still hit the policy via the session-level
// model captured from the first frame. Without the fallback an empty model
// would miss a model whitelist and silently leak service_tier=priority
// through to the upstream.
func TestPolicyEnforcingFrameConn_FollowupFrameWithoutModelUsesCapturedModel(t *testing.T) {
// 此处特意使用带 whitelist 的策略,以便观察 capturedSessionModel
// fallback 是否生效(默认策略 whitelist 为空fallback 与否结果一致,
// 不能用来覆盖此回归)。
svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Simulate the passthrough adapter capturing model from the first frame.
firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstFrame)
require.Equal(t, "gpt-5.5", capturedSessionModel)
// Follow-up frame deliberately omits "model" — Realtime allows this.
followupFrame := []byte(`{"type":"response.create","service_tier":"priority"}`)
inner := &fakePassthroughFrameConn{
reads: [][]byte{followupFrame},
}
wrapper := &openAIWSPolicyEnforcingFrameConn{
inner: inner,
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if msgType != coderws.MessageText {
return payload, nil, nil
}
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
if model == "" {
model = capturedSessionModel
}
return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
},
}
// Read the follow-up frame through the wrapper. The policy MUST still
// trigger filter (gpt-5.5 + priority → filter), so the service_tier
// field is gone by the time the relay sees it.
_, payload, err := wrapper.ReadFrame(context.Background())
require.NoError(t, err)
require.NotContains(t, string(payload), `"service_tier"`,
"D5 regression: empty model on follow-up frame must fall back to capturedSessionModel; whitelist policy filters service_tier=priority for gpt-5.5")
require.Equal(t, "response.create", gjson.GetBytes(payload, "type").String())
}
// TestPolicyEnforcingFrameConn_WithoutCapturedFallbackPolicyMisses pins the
// inverse: when the wrapper has NO capturedSessionModel fallback (model is
// empty per-frame and no fallback is wired up), the policy fails to match
// the model whitelist and the frame leaks through unchanged. This documents
// exactly the leak the D5 fix prevents.
func TestPolicyEnforcingFrameConn_WithoutCapturedFallbackPolicyMisses(t *testing.T) {
// 同样使用带 whitelist 的策略以观察 leak。
svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
followupFrame := []byte(`{"type":"response.create","service_tier":"priority"}`)
inner := &fakePassthroughFrameConn{reads: [][]byte{followupFrame}}
wrapper := &openAIWSPolicyEnforcingFrameConn{
inner: inner,
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
// NO fallback — emulate the pre-fix behavior.
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
},
}
_, payload, err := wrapper.ReadFrame(context.Background())
require.NoError(t, err)
// Pre-fix: empty model misses ["gpt-5.5","gpt-5.5*"] whitelist → fallback=pass → service_tier kept.
require.Contains(t, string(payload), `"service_tier"`,
"sanity: without capturedSessionModel fallback the leak (D5) reproduces — confirms the fix is load-bearing")
}
// --- Ingress end-to-end test (filter path) ---
// TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream wires up the
// real ProxyResponsesWebSocketFromClient ingress session pipeline against a
// captureConn upstream and asserts that a client frame with service_tier=fast
// is normalized + filtered out before being written upstream. This is the
// integration flavour of TestWSResponseCreate_FilterStripsServiceTier.
func TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_ws_filter_1","model":"gpt-5.5","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
defaultJSON, err := json.Marshal(DefaultOpenAIFastPolicySettings())
require.NoError(t, err)
repo.values[SettingKeyOpenAIFastPolicySettings] = string(defaultJSON)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
settingService: NewSettingService(repo, cfg),
}
account := &Account{
ID: 901,
Name: "openai-ws-filter",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{"api_key": "sk-test"},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() { _ = conn.CloseNow() }()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
_, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() { _ = clientConn.CloseNow() }()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.5","stream":false,"service_tier":"fast"}`)))
cancelWrite()
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, event, readErr := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, readErr)
require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
select {
case serverErr := <-serverErrCh:
require.NoError(t, serverErr)
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}
require.Len(t, captureConn.writes, 1, "上游应只收到一条 response.create")
upstream := captureConn.writes[0]
_, hasServiceTier := upstream["service_tier"]
require.False(t, hasServiceTier, "上游收到的 response.create 不应包含 service_tier 字段(已被 fast policy filter 删除)")
require.Equal(t, "response.create", upstream["type"])
require.Equal(t, "gpt-5.5", upstream["model"])
}
// TestWSResponseCreate_IngressBlockSendsErrorEventAndSkipsUpstream is the
// integration flavour of TestWSResponseCreate_BlockReturnsTypedError. It
// asserts that with a custom block rule, the client receives a Realtime-style
// error event AND the upstream FrameConn never receives the offending frame.
func TestWSResponseCreate_IngressBlockSendsErrorEventAndSkipsUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
captureConn := &openAIWSCaptureConn{
// No events queued; the upstream should never get written to anyway.
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
blockSettings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "ws priority blocked for testing",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
raw, err := json.Marshal(blockSettings)
require.NoError(t, err)
repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
settingService: NewSettingService(repo, cfg),
}
account := &Account{
ID: 902,
Name: "openai-ws-block",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{"api_key": "sk-test"},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() { _ = conn.CloseNow() }()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
_, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
proxyErr := svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
// Mirror the production handler (openai_gateway_handler.go:1325-1328):
// when the proxy returns an OpenAIWSClientCloseError, surface its
// status code to the client via a graceful close handshake. Without
// this the deferred CloseNow() above would tear down the TCP
// connection without sending a close frame, and the C3 timing
// assertion (next read returns CloseStatus=1008) would see EOF
// instead.
var closeErr *OpenAIWSClientCloseError
if errors.As(proxyErr, &closeErr) {
_ = conn.Close(closeErr.StatusCode(), closeErr.Reason())
}
serverErrCh <- proxyErr
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() { _ = clientConn.CloseNow() }()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.5","stream":false,"service_tier":"priority"}`)))
cancelWrite()
// C3 timing assertion: the FIRST frame the client reads must be the
// error event — not a close frame. coder/websocket@v1.8.14 Conn.Write is
// synchronous (writeFrame Flushes the bufio writer at write.go:307-311
// before returning) and the close handshake re-acquires the same
// writeFrameMu, so this ordering is enforced by the library itself; this
// assertion guards against future refactors that might break it.
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, event, readErr := clientConn.Read(readCtx)
cancelRead()
require.NoError(t, readErr, "first read must succeed and return the error event before any close frame")
require.Equal(t, "error", gjson.GetBytes(event, "type").String())
require.Equal(t, "invalid_request_error", gjson.GetBytes(event, "error.type").String())
// B1 regression: event_id + error.code must be populated.
require.Equal(t, "policy_violation", gjson.GetBytes(event, "error.code").String())
require.NotEmpty(t, gjson.GetBytes(event, "event_id").String(), "event_id must be present so clients can correlate")
require.Contains(t, gjson.GetBytes(event, "error.message").String(), "ws priority blocked for testing")
// Next read must surface the close frame (as a CloseError). This
// asserts the [error event, close] ordering — i.e. the close did NOT
// race ahead of the data frame.
readCtx2, cancelRead2 := context.WithTimeout(context.Background(), 3*time.Second)
_, _, secondReadErr := clientConn.Read(readCtx2)
cancelRead2()
require.Error(t, secondReadErr, "after the error event the connection must surface a close")
require.Equal(t, coderws.StatusPolicyViolation, coderws.CloseStatus(secondReadErr),
"close status must be PolicyViolation; got %v", secondReadErr)
select {
case serverErr := <-serverErrCh:
// Server returns an OpenAIWSClientCloseError — handler closes the WS;
// here we just assert it surfaced as the typed close error.
require.Error(t, serverErr)
var closeErr *OpenAIWSClientCloseError
require.True(t, errors.As(serverErr, &closeErr), "block 应返回 OpenAIWSClientCloseError得到 %T: %v", serverErr, serverErr)
require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress 关闭超时")
}
// Critical: the offending frame must NEVER reach the upstream.
// captureDialer.DialCount may legitimately be 0 or 1 depending on whether
// the lease was acquired before policy fired; either way, no writes.
require.Empty(t, captureConn.writes, "block 命中后上游不应收到 response.create")
}
// --- HTTP-side gap-filling tests (already covered by existing tests but
// requested to be split out explicitly) ---
// TestApplyOpenAIFastPolicyToBody_BlockShortCircuitsUpstream confirms that
// applyOpenAIFastPolicyToBody surfaces a *OpenAIFastBlockedError when the rule
// action is "block", and that the body is left untouched. The caller (chat
// completions / messages handlers) inspects this typed error and skips the
// upstream HTTP call entirely — see openai_gateway_chat_completions.go:175 and
// openai_gateway_messages.go:149.
func TestApplyOpenAIFastPolicyToBody_BlockShortCircuitsUpstream(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "priority blocked",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
body := []byte(`{"model":"gpt-5.5","service_tier":"priority","input":[]}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.Error(t, err)
var blocked *OpenAIFastBlockedError
require.True(t, errors.As(err, &blocked), "block must surface as typed error so caller can skip upstream HTTP request")
require.Equal(t, "priority blocked", blocked.Message)
require.Equal(t, string(body), string(updated), "block must not mutate body")
}
// TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy verifies
// the Anthropic-compat entrypoint chain: anthropic-beta: fast-mode → BetaFastMode
// detection → ServiceTier="priority" injection (openai_gateway_messages.go:60)
// → applyOpenAIFastPolicyToBody filter on default policy → upstream body has
// no service_tier. We exercise the same internal pipeline (Anthropic→Responses
// + BetaFastMode + policy) without spinning up a real upstream HTTP server.
func TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Step 1: parse Anthropic request (mirrors openai_gateway_messages.go:38-50).
anthropicBody := []byte(`{"model":"gpt-5.5","max_tokens":64,"messages":[{"role":"user","content":"hi"}]}`)
var anthropicReq apicompat.AnthropicRequest
require.NoError(t, json.Unmarshal(anthropicBody, &anthropicReq))
responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq)
require.NoError(t, err)
// Step 2: BetaFastMode header → service_tier="priority" (mirrors line 58-61).
headers := http.Header{}
headers.Set("anthropic-beta", claude.BetaFastMode)
require.True(t, containsBetaToken(headers.Get("anthropic-beta"), claude.BetaFastMode))
responsesReq.ServiceTier = "priority"
responsesReq.Model = "gpt-5.5"
// Step 3: marshal & apply fast policy (mirrors line 78 + 149).
responsesBody, err := json.Marshal(responsesReq)
require.NoError(t, err)
require.Equal(t, "priority", gjson.GetBytes(responsesBody, "service_tier").String(), "前置beta 翻译应当注入 priority")
upstreamBody, policyErr := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", responsesBody)
require.NoError(t, policyErr)
// Step 4: assert that policy filtered the field before the upstream HTTP request.
require.NotContains(t, string(upstreamBody), `"service_tier"`, "default policy 命中 gpt-5.5 priority 应当 filter 掉 service_tier")
}
// --- Fix1: passthrough capturedSessionModel must follow session.update ---
// TestPolicyEnforcingFrameConn_SessionUpdateRotatesCapturedModel covers the
// fix1 bypass: client opens with a whitelist-miss model (gpt-4o → pass under
// gpt-5.5 whitelist), rotates to gpt-5.5 via session.update, then sends
// response.create without "model". Without the session.update sniffing the
// follow-up frame would fall back to the stale gpt-4o capture and pass — the
// fix updates capturedSessionModel from session.* events so the fallback now
// resolves to gpt-5.5 and the policy filters service_tier.
func TestPolicyEnforcingFrameConn_SessionUpdateRotatesCapturedModel(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Frame 1: response.create with whitelist-miss model — under default
// rule fallback=pass, service_tier stays.
first := []byte(`{"type":"response.create","model":"gpt-4o","service_tier":"priority"}`)
// Frame 2: session.update rotates the session model to gpt-5.5.
rotate := []byte(`{"type":"session.update","session":{"model":"gpt-5.5"}}`)
// Frame 3: response.create WITHOUT model — must inherit gpt-5.5.
followup := []byte(`{"type":"response.create","service_tier":"priority"}`)
inner := &fakePassthroughFrameConn{reads: [][]byte{first, rotate, followup}}
// Replicate the production wiring in openai_ws_v2_passthrough_adapter.go
// so capturedSessionModel state is shared across frames.
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, first)
require.Equal(t, "gpt-4o", capturedSessionModel)
wrapper := &openAIWSPolicyEnforcingFrameConn{
inner: inner,
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if msgType != coderws.MessageText {
return payload, nil, nil
}
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
capturedSessionModel = updated
}
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
if model == "" {
model = capturedSessionModel
}
return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
},
}
// Frame 1: gpt-4o miss whitelist → pass (service_tier preserved).
_, payload1, err := wrapper.ReadFrame(context.Background())
require.NoError(t, err)
require.Contains(t, string(payload1), `"service_tier"`, "frame1: gpt-4o miss whitelist → pass keeps service_tier")
// Frame 2: session.update — not response.create, untouched, but its
// side effect updates capturedSessionModel to gpt-5.5.
_, payload2, err := wrapper.ReadFrame(context.Background())
require.NoError(t, err)
require.Equal(t, string(rotate), string(payload2), "session.update frame is forwarded verbatim")
require.Equal(t, "gpt-5.5", capturedSessionModel, "fix1: session.update must rotate capturedSessionModel")
// Frame 3: empty model + new captured gpt-5.5 → matches whitelist → filter.
_, payload3, err := wrapper.ReadFrame(context.Background())
require.NoError(t, err)
require.NotContains(t, string(payload3), `"service_tier"`,
"fix1: post-rotate response.create without model must use refreshed capturedSessionModel and trigger filter")
}
// TestPolicyModelFromSessionFrame_OnlySessionUpdate covers the negative
// branches of openAIWSPassthroughPolicyModelFromSessionFrame: only
// client→upstream session.update frames rotate the captured model;
// server→client events (session.created) and unrelated frames must not.
func TestPolicyModelFromSessionFrame_OnlySessionUpdate(t *testing.T) {
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// session.created is a server→client event in the OpenAI Realtime
// protocol — clients never send it, so this filter (which only runs on
// the client→upstream direction) must ignore it even if it appears.
created := []byte(`{"type":"session.created","session":{"model":"gpt-5.5"}}`)
require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, created))
// Non-session.* frames must NOT trigger rotation.
notSession := []byte(`{"type":"response.create","session":{"model":"gpt-9"}}`)
require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, notSession))
// Missing session.model returns empty — caller keeps the old captured value.
noModel := []byte(`{"type":"session.update","session":{"voice":"alloy"}}`)
require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, noModel))
}
// --- Fix2: native /responses normalize "fast" → "priority" on pass ---
// TestApplyOpenAIFastPolicyToBody_PassNormalizesFastAlias is the fix2
// regression. Before the fix, when action=pass, applyOpenAIFastPolicyToBody
// returned the body unchanged so a raw "fast" alias would leak to the
// upstream OpenAI API (which does not accept "fast"). The fix normalizes
// "fast" → "priority" on pass too.
func TestApplyOpenAIFastPolicyToBody_PassNormalizesFastAlias(t *testing.T) {
// Use a policy that deliberately misses gpt-4 so the action is pass.
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// gpt-4 + "fast" → fallback pass. Body must be rewritten to "priority".
body := []byte(`{"model":"gpt-4","service_tier":"fast"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
require.NoError(t, err)
require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String(),
"fix2: pass action must still normalize 'fast' → 'priority' so upstream OpenAI accepts the slug")
// Already-canonical "priority" on pass: zero mutation (byte-equal).
body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
// Mixed-case alias → normalized.
body = []byte(`{"model":"gpt-4","service_tier":" Fast "}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
require.NoError(t, err)
require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String())
// Unrecognized tier → still no-op (not normalized, since normTier == "").
body = []byte(`{"model":"gpt-4","service_tier":"turbo"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
}
// --- Fix3: passthrough billing must reflect post-filter service_tier ---
// TestPassthroughBilling_PostFilterServiceTier is the fix3 regression. The
// passthrough adapter (openai_ws_v2_passthrough_adapter.go) now extracts
// requestServiceTier from firstClientMessage AFTER applyOpenAIFastPolicy
// has rewritten it, so a filter hit causes billing to report nil (default
// tier) instead of the user-requested "priority". This test pins the
// contract those two helpers must uphold for the adapter's billing path.
func TestPassthroughBilling_PostFilterServiceTier(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
raw := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
// Pre-filter sanity: extracting from the raw frame would (incorrectly,
// pre-fix) report "priority" — this is the very thing the adapter
// must NOT do anymore.
pre := extractOpenAIServiceTierFromBody(raw)
require.NotNil(t, pre)
require.Equal(t, "priority", *pre,
"sanity: raw first frame carries priority that pre-fix billing would have reported")
// Apply policy filter (default rule: gpt-5.5 + priority → filter).
filtered, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", raw)
require.NoError(t, err)
require.Nil(t, blocked)
require.NotContains(t, string(filtered), `"service_tier"`)
// Post-filter: extracting from the rewritten frame returns nil. This
// is the value the adapter now passes to OpenAIForwardResult.ServiceTier,
// so billing records "default" instead of "priority".
post := extractOpenAIServiceTierFromBody(filtered)
require.Nil(t, post, "fix3: post-filter extraction must return nil so passthrough billing reports default tier instead of the requested priority")
// And the byte-level invariant the adapter relies on: filtering an
// already-filtered frame is a no-op (idempotent), so re-running the
// policy doesn't accidentally re-introduce the field.
again, blocked2, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", filtered)
require.NoError(t, err)
require.Nil(t, blocked2)
require.Equal(t, string(filtered), string(again),
"policy is idempotent: filtering an already-filtered frame leaves bytes unchanged")
}
// TestApplyOpenAIFastPolicyToBody_NonStringServiceTier covers the test gap
// flagged in the review: when a client sends service_tier as a non-string
// (number, null, object, etc.) the policy must NOT panic and must NOT
// pretend the field was filtered. Behavior: skip policy entirely (treat as
// "no usable tier"), forward body unchanged. This mirrors the HTTP entry's
// type-assertion `reqBody["service_tier"].(string); ok` guard.
func TestApplyOpenAIFastPolicyToBody_NonStringServiceTier(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Number — gjson .String() coerces to "1" which is not a recognized
// tier alias; normalize returns "" → policy no-ops.
cases := [][]byte{
[]byte(`{"model":"gpt-5.5","service_tier":1}`),
[]byte(`{"model":"gpt-5.5","service_tier":null}`),
[]byte(`{"model":"gpt-5.5","service_tier":{"nested":"priority"}}`),
[]byte(`{"model":"gpt-5.5","service_tier":["priority"]}`),
[]byte(`{"model":"gpt-5.5","service_tier":true}`),
}
for _, body := range cases {
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err, "non-string service_tier must not error: %s", string(body))
require.Equal(t, string(body), string(updated),
"non-string service_tier must pass through unchanged: %s", string(body))
}
// Same guard for the WS response.create entry.
for _, body := range cases {
frame := body
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err, "non-string service_tier ws frame must not error: %s", string(frame))
require.Nil(t, blocked, "non-string service_tier must not trigger block: %s", string(frame))
require.Equal(t, string(frame), string(updated),
"non-string service_tier ws frame must pass through unchanged: %s", string(frame))
}
}
// TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames covers the
// multi-turn passthrough billing regression: OpenAI Realtime / Responses WS
// allows the client to ship a different service_tier on each response.create
// frame (per-response field, see codex-rs/core/src/client.rs
// build_responses_request which re-fills the field on every request). Before
// the fix the adapter only captured service_tier from firstClientMessage so
// turn 2/3 billing was wrong. After the fix the filter closure refreshes an
// atomic.Pointer[string] on every successful response.create frame.
//
// This test pins the four legs of the semantic contract:
// - turn 1: service_tier=priority hits the default whitelist filter, so
// after filter the upstream sees no tier → billing is nil.
// - turn 2: service_tier=flex passes (default rule targets priority only),
// billing should now reflect "flex".
// - turn 3: response.create without any service_tier — the upstream will
// treat it as default; we choose to mirror that and overwrite billing
// to nil rather than carry over "flex" from turn 2.
// - non-response.create frame (response.cancel here) carrying a stray
// service_tier-shaped field must NOT clobber the billing pointer.
func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Mirror the production filter closure (openai_ws_v2_passthrough_adapter.go
// proxyResponsesWebSocketV2Passthrough) so this test fails if the
// production code drops the per-frame Store.
var requestServiceTierPtr atomic.Pointer[string]
capturedSessionModel := ""
filter := func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if msgType != coderws.MessageText {
return payload, nil, nil
}
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
capturedSessionModel = updated
}
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
if model == "" {
model = capturedSessionModel
}
out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
if policyErr == nil && blocked == nil &&
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
}
return out, blocked, policyErr
}
// First-frame initialization mirrors the adapter: extract from the
// post-filter payload so a filter-on-first-frame zeroes billing too.
firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
firstOut, firstBlocked, firstErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", firstFrame)
require.NoError(t, firstErr)
require.Nil(t, firstBlocked)
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstOut))
capturedSessionModel = openAIWSPassthroughPolicyModelForFrame(account, firstFrame)
require.Nil(t, requestServiceTierPtr.Load(),
"turn 1: filter strips service_tier=priority, billing must reflect upstream-actual nil tier")
// Turn 2: client switches to flex, should pass and update billing.
turn2 := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"flex"}`)
out2, blocked2, err2 := filter(coderws.MessageText, turn2)
require.NoError(t, err2)
require.Nil(t, blocked2)
require.Equal(t, "flex", gjson.GetBytes(out2, "service_tier").String(), "turn 2: flex must pass to upstream untouched")
tier2 := requestServiceTierPtr.Load()
require.NotNil(t, tier2, "turn 2: billing must update to reflect flex")
require.Equal(t, "flex", *tier2)
// A non-response.create frame with a stray service_tier-shaped field
// must NOT overwrite the billing pointer (those frames don't carry
// per-response service_tier in the Realtime spec).
cancelFrame := []byte(`{"type":"response.cancel","service_tier":"priority"}`)
_, blockedCancel, errCancel := filter(coderws.MessageText, cancelFrame)
require.NoError(t, errCancel)
require.Nil(t, blockedCancel)
tierAfterCancel := requestServiceTierPtr.Load()
require.NotNil(t, tierAfterCancel, "response.cancel must not clobber billing tier to nil")
require.Equal(t, "flex", *tierAfterCancel,
"non-response.create frames must not update billing tier even if they carry a service_tier-shaped field")
// Turn 3: response.create without any service_tier. We deliberately
// overwrite billing back to nil so it tracks what the upstream actually
// sees on this turn (default tier).
turn3 := []byte(`{"type":"response.create","model":"gpt-5.5"}`)
out3, blocked3, err3 := filter(coderws.MessageText, turn3)
require.NoError(t, err3)
require.Nil(t, blocked3)
require.Equal(t, string(turn3), string(out3), "turn 3 has no service_tier — filter must not mutate")
require.Nil(t, requestServiceTierPtr.Load(),
"turn 3: response.create without service_tier overwrites billing to nil to match upstream default")
}
// TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the
// "block keeps previous" semantic: when policy returns block on a
// response.create frame, that frame is never sent upstream, so billing tier
// must keep the previous turn's value rather than getting silently zeroed.
func TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier(t *testing.T) {
blockSettings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "blocked",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, blockSettings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
var requestServiceTierPtr atomic.Pointer[string]
flexValue := "flex"
requestServiceTierPtr.Store(&flexValue) // simulate prior turn billed as flex
filter := func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if msgType != coderws.MessageText {
return payload, nil, nil
}
out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", payload)
if policyErr == nil && blocked == nil &&
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
}
return out, blocked, policyErr
}
frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
_, blocked, err := filter(coderws.MessageText, frame)
require.NoError(t, err)
require.NotNil(t, blocked, "policy must block this frame")
tier := requestServiceTierPtr.Load()
require.NotNil(t, tier, "blocked frame must not clobber prior billing tier to nil")
require.Equal(t, "flex", *tier,
"blocked frame is never sent upstream; billing must retain the previous turn's tier")
}