Files
sub2api/backend/internal/service/openai_fast_policy_test.go
DaydreamCoding 30f55a1f72 feat(openai): OpenAI Fast/Flex Policy 完整实现(HTTP + WebSocket + Admin)
对称参照 Claude BetaPolicy 的 fast-mode 过滤实现,新增针对 OpenAI 上游
service_tier 字段(priority / flex,含客户端 "fast" → "priority" 归一化)的
pass / filter / block 三态策略,覆盖全部 OpenAI 入口 + admin 配置入口。

后端核心
- 新增 SettingKeyOpenAIFastPolicySettings、OpenAIFastPolicyRule、
  OpenAIFastPolicySettings 配置模型,含规则的 service_tier × action × scope
  × 模型白名单 × fallback action 维度。
- SettingService.Get/SetOpenAIFastPolicySettings;缺失时返回内置默认策略
  (所有模型的 priority 走 filter,whitelist 为空,fallback=pass)。设计
  依据:service_tier=fast 是用户级开关,与 model 字段正交,默认锁定特定
  model slug 会留下"用 gpt-4 + fast 透传 priority 上游"的绕过路径。JSON
  解析失败不再静默 fallback,slog.Warn 记录脏数据,便于运维定位。
- service_tier 归一化(trim + ToLower + fast→priority + 白名单 priority/flex)
  与策略评估(evaluateOpenAIFastPolicy)作为唯一真实来源,HTTP / WS 共用。
  抽出纯函数 evaluateOpenAIFastPolicyWithSettings,配合 ctx-bound settings
  快照(withOpenAIFastPolicyContext / openAIFastPolicySettingsFromContext),
  WS 长会话入口预取一次后所有帧复用,避免每帧打到 settingService。

HTTP 入口(4 个)
- Chat Completions、Anthropic 兼容(Messages,含 BetaFastMode→priority 二次
  命中)、原生 Responses、Passthrough Responses 全部接入
  applyOpenAIFastPolicyToBody,filter 走 sjson 顶层删除 service_tier,block
  返回 403 forbidden_error JSON。
- 4 入口统一使用 upstream 视角的 model(GetMappedModel +
  normalizeOpenAIModelForUpstream + Codex OAuth normalize 后的 slug),
  避免 chat/messages/native /responses/passthrough 因为 model 维度不同
  造成 whitelist 命中差异。
- 在 pass 路径也把客户端 "fast" 别名归一化为 "priority" 写回 body,
  否则 native /responses 与 passthrough 入口会把 "fast" 原样透传给上游
  导致 400/拒绝(chat-completions 入口的 normalizeResponsesBodyServiceTier
  此前已具备同等行为)。

WebSocket 入口
- 新增 applyOpenAIFastPolicyToWSResponseCreate:严格匹配
  type="response.create",仅处理顶层 service_tier;filter 用 sjson 删字段,
  block 返回 typed *OpenAIFastBlockedError。
- ingress 路径在 parseClientPayload 内调用,block 命中先 Write Realtime
  风格 error event 再返回 OpenAIWSClientCloseError(StatusPolicyViolation
  =1008),依赖底层 WebSocket Conn.Write 的同步 flush 保证 error 先于
  close。
- passthrough 路径在 RunEntry 前对 firstClientMessage 应用策略,并通过
  openAIWSPolicyEnforcingFrameConn 包装 ReadFrame 对每个 client→upstream
  帧执行策略;后续帧无 model 字段时回退到 capturedSessionModel。
  filter 闭包内同时侦测 session.update / session.created 帧的 session.model
  字段刷新 capturedSessionModel,封堵"首帧 model=gpt-4o(pass)→
  session.update 改为 gpt-5.5 → 不带 model 的 response.create fallback
  到 gpt-4o"的 mid-session 绕过路径。
- passthrough billing:requestServiceTier 在策略 filter 之后再从
  firstClientMessage 提取,filter 命中时 OpenAIForwardResult.ServiceTier
  上报 nil(default tier),与 HTTP 入口(reqBody 来自 post-filter map)
  / WS ingress(payload 来自 post-filter bytes)的语义一致。
- 错误事件 schema:{event_id: "evt_<32hex>", type: "error",
  error: {type: "forbidden_error", code: "policy_violation", message}},
  与 OpenAI codex 客户端 error event 解析兼容。

Admin / Frontend
- dto.SystemSettings / UpdateSettingsRequest 新增
  openai_fast_policy_settings 字段(omitempty),bulk GET/PUT 接入。
- Settings 页 Gateway 页签新增 Fast/Flex Policy 表单卡片:
  service_tier × action × scope × 模型白名单 × fallback action 全字段配置。
- 前端守门:openaiFastPolicyLoaded 标志仅在 GET 真带回字段时才允许回写,
  避免 rollout/错误把默认规则覆盖成空;saveSettings 回写循环 skip 该字段,
  由专用刷新逻辑处理;仅 action=block 时发送 error_message,匹配后端
  omitempty 行为。

测试
- HTTP 路径:openai_fast_policy_test.go 覆盖默认配置(whitelist=[],所有
  模型 priority filter)/ block 自定义错误 / scope 区分 / filter 删字段 /
  block 不改 body / block 短路上游 / Anthropic BetaFastMode 触发 OpenAI
  fast policy 等场景。
- WebSocket 路径:openai_fast_policy_ws_test.go 覆盖
    helper 单元(filter / fast→priority 归一化 / flex 透传 / block typed
    error / 无 service_tier 字节不变 / 非 response.create 帧不动 / 空 type
    帧不动 / event_id+code 字段断言 / 非字符串 service_tier 容错)+
    pass 路径 fast 别名归一化回归 +
    ingress 端到端(filter 后上游不含 service_tier / block 后客户端先收
    error event 再收 close 1008 且上游 0 写)+
    passthrough capturedSessionModel fallback 用例(whitelist 策略下首帧
    建立、缺 model 命中 fallback、缺少 fallback 时的 leak 文档化)+
    passthrough session.update / session.created 旋转 capturedSessionModel
    的 mid-session 绕过回归 +
    passthrough billing post-filter ServiceTier 与 idempotent filter 回归。

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 11:15:09 +08:00

287 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type openAIFastPolicyRepoStub struct {
values map[string]string
}
func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", ErrSettingNotFound
}
func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error {
if s.values == nil {
s.values = map[string]string{}
}
s.values[key] = value
return nil
}
func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService {
t.Helper()
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
if settings != nil {
raw, err := json.Marshal(settings)
require.NoError(t, err)
repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
}
return &OpenAIGatewayService{
settingService: NewSettingService(repo, &config.Config{}),
}
}
func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// 默认策略对所有模型生效whitelist 为空),因为 codex 的 service_tier=fast
// 是用户级开关,与 model 正交。
// gpt-5.5 + priority → filter
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-5.5-turbo → filter
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-4 + priority → filter默认策略覆盖所有模型
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-5.5 + flex → pass (tier doesn't match)
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex)
require.Equal(t, BetaPolicyActionPass, action)
// empty tier → pass
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "")
require.Equal(t, BetaPolicyActionPass, action)
}
func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "fast mode is not allowed",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionBlock, action)
require.Equal(t, "fast mode is not allowed", msg)
}
func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierAny,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeOAuth,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
// OAuth account → rule matches
oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// API Key account → rule skipped → pass
apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionPass, action)
}
func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// gpt-5.5 fast → service_tier stripped
body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// Client sending "fast" (alias for priority) also filtered
body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// gpt-4 priority → 默认策略对所有模型 filterservice_tier 被移除
body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// No service_tier → no-op
body = []byte(`{"model":"gpt-5.5"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
}
// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后
// 客户端显式发送的 OpenAI 官方合法 tierauto/default/scale能透传到上游而不被
// 静默剥离。默认策略只针对 priority所以这些 tier 落在 fall-through pass 分支。
func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
for _, tier := range []string{"auto", "default", "scale"} {
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err, "tier %q should pass without error", tier)
require.Contains(t, string(updated), `"service_tier":"`+tier+`"`,
"tier %q should be preserved in body under default rule", tier)
}
// evaluate 层也应判定为 pass默认规则 ServiceTier=priority 与 auto/default/scale 不匹配)
for _, tier := range []string{"auto", "default", "scale"} {
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier)
require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier)
}
}
// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置
// ServiceTier=all + Action=filter 规则后auto/default/scale 等官方 tier 也会
// 被剥离。这是符合预期的——首条匹配 short-circuit"all" 覆盖任意已识别 tier。
func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierAny,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} {
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`,
"tier %q should be stripped under ServiceTier=all + filter rule", tier)
}
}
// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离
// normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段;
// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op因为字段已不可能存在
// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。
func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// normalize 阶段会将未知值剥离
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
// applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错body 透传不变
// (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离)
body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
}
func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "fast mode is blocked for gpt-5.5",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.Error(t, err)
var blocked *OpenAIFastBlockedError
require.True(t, errors.As(err, &blocked))
require.Contains(t, blocked.Message, "fast mode is blocked")
require.Equal(t, string(body), string(updated)) // body not mutated on block
}
func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) {
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
svc := NewSettingService(repo, &config.Config{})
// Invalid action rejected
err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: "bogus",
Scope: BetaPolicyScopeAll,
}},
})
require.Error(t, err)
// Invalid service_tier rejected
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: "turbo",
Action: BetaPolicyActionPass,
Scope: BetaPolicyScopeAll,
}},
})
require.Error(t, err)
// Valid settings persisted
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
}},
})
require.NoError(t, err)
got, err := svc.GetOpenAIFastPolicySettings(context.Background())
require.NoError(t, err)
require.Len(t, got.Rules, 1)
require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier)
}