Files
sub2api/backend/internal/service/openai_fast_policy_test.go

287 lines
12 KiB
Go
Raw Normal View History

feat(openai): OpenAI Fast/Flex Policy 完整实现(HTTP + WebSocket + Admin) 对称参照 Claude BetaPolicy 的 fast-mode 过滤实现,新增针对 OpenAI 上游 service_tier 字段(priority / flex,含客户端 "fast" → "priority" 归一化)的 pass / filter / block 三态策略,覆盖全部 OpenAI 入口 + admin 配置入口。 后端核心 - 新增 SettingKeyOpenAIFastPolicySettings、OpenAIFastPolicyRule、 OpenAIFastPolicySettings 配置模型,含规则的 service_tier × action × scope × 模型白名单 × fallback action 维度。 - SettingService.Get/SetOpenAIFastPolicySettings;缺失时返回内置默认策略 (所有模型的 priority 走 filter,whitelist 为空,fallback=pass)。设计 依据:service_tier=fast 是用户级开关,与 model 字段正交,默认锁定特定 model slug 会留下"用 gpt-4 + fast 透传 priority 上游"的绕过路径。JSON 解析失败不再静默 fallback,slog.Warn 记录脏数据,便于运维定位。 - service_tier 归一化(trim + ToLower + fast→priority + 白名单 priority/flex) 与策略评估(evaluateOpenAIFastPolicy)作为唯一真实来源,HTTP / WS 共用。 抽出纯函数 evaluateOpenAIFastPolicyWithSettings,配合 ctx-bound settings 快照(withOpenAIFastPolicyContext / openAIFastPolicySettingsFromContext), WS 长会话入口预取一次后所有帧复用,避免每帧打到 settingService。 HTTP 入口(4 个) - Chat Completions、Anthropic 兼容(Messages,含 BetaFastMode→priority 二次 命中)、原生 Responses、Passthrough Responses 全部接入 applyOpenAIFastPolicyToBody,filter 走 sjson 顶层删除 service_tier,block 返回 403 forbidden_error JSON。 - 4 入口统一使用 upstream 视角的 model(GetMappedModel + normalizeOpenAIModelForUpstream + Codex OAuth normalize 后的 slug), 避免 chat/messages/native /responses/passthrough 因为 model 维度不同 造成 whitelist 命中差异。 - 在 pass 路径也把客户端 "fast" 别名归一化为 "priority" 写回 body, 否则 native /responses 与 passthrough 入口会把 "fast" 原样透传给上游 导致 400/拒绝(chat-completions 入口的 normalizeResponsesBodyServiceTier 此前已具备同等行为)。 WebSocket 入口 - 新增 applyOpenAIFastPolicyToWSResponseCreate:严格匹配 type="response.create",仅处理顶层 service_tier;filter 用 sjson 删字段, block 返回 typed *OpenAIFastBlockedError。 - ingress 路径在 parseClientPayload 内调用,block 命中先 Write Realtime 风格 error event 再返回 OpenAIWSClientCloseError(StatusPolicyViolation =1008),依赖底层 WebSocket Conn.Write 的同步 flush 保证 error 先于 close。 - passthrough 路径在 RunEntry 前对 firstClientMessage 应用策略,并通过 openAIWSPolicyEnforcingFrameConn 包装 ReadFrame 对每个 client→upstream 帧执行策略;后续帧无 model 字段时回退到 capturedSessionModel。 filter 闭包内同时侦测 session.update / session.created 帧的 session.model 字段刷新 capturedSessionModel,封堵"首帧 model=gpt-4o(pass)→ session.update 改为 gpt-5.5 → 不带 model 的 response.create fallback 到 gpt-4o"的 mid-session 绕过路径。 - passthrough billing:requestServiceTier 在策略 filter 之后再从 firstClientMessage 提取,filter 命中时 OpenAIForwardResult.ServiceTier 上报 nil(default tier),与 HTTP 入口(reqBody 来自 post-filter map) / WS ingress(payload 来自 post-filter bytes)的语义一致。 - 错误事件 schema:{event_id: "evt_<32hex>", type: "error", error: {type: "forbidden_error", code: "policy_violation", message}}, 与 OpenAI codex 客户端 error event 解析兼容。 Admin / Frontend - dto.SystemSettings / UpdateSettingsRequest 新增 openai_fast_policy_settings 字段(omitempty),bulk GET/PUT 接入。 - Settings 页 Gateway 页签新增 Fast/Flex Policy 表单卡片: service_tier × action × scope × 模型白名单 × fallback action 全字段配置。 - 前端守门:openaiFastPolicyLoaded 标志仅在 GET 真带回字段时才允许回写, 避免 rollout/错误把默认规则覆盖成空;saveSettings 回写循环 skip 该字段, 由专用刷新逻辑处理;仅 action=block 时发送 error_message,匹配后端 omitempty 行为。 测试 - HTTP 路径:openai_fast_policy_test.go 覆盖默认配置(whitelist=[],所有 模型 priority filter)/ block 自定义错误 / scope 区分 / filter 删字段 / block 不改 body / block 短路上游 / Anthropic BetaFastMode 触发 OpenAI fast policy 等场景。 - WebSocket 路径:openai_fast_policy_ws_test.go 覆盖 helper 单元(filter / fast→priority 归一化 / flex 透传 / block typed error / 无 service_tier 字节不变 / 非 response.create 帧不动 / 空 type 帧不动 / event_id+code 字段断言 / 非字符串 service_tier 容错)+ pass 路径 fast 别名归一化回归 + ingress 端到端(filter 后上游不含 service_tier / block 后客户端先收 error event 再收 close 1008 且上游 0 写)+ passthrough capturedSessionModel fallback 用例(whitelist 策略下首帧 建立、缺 model 命中 fallback、缺少 fallback 时的 leak 文档化)+ passthrough session.update / session.created 旋转 capturedSessionModel 的 mid-session 绕过回归 + passthrough billing post-filter ServiceTier 与 idempotent filter 回归。 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 00:34:23 +08:00
package service
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type openAIFastPolicyRepoStub struct {
values map[string]string
}
func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", ErrSettingNotFound
}
func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error {
if s.values == nil {
s.values = map[string]string{}
}
s.values[key] = value
return nil
}
func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService {
t.Helper()
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
if settings != nil {
raw, err := json.Marshal(settings)
require.NoError(t, err)
repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
}
return &OpenAIGatewayService{
settingService: NewSettingService(repo, &config.Config{}),
}
}
func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// 默认策略对所有模型生效whitelist 为空),因为 codex 的 service_tier=fast
// 是用户级开关,与 model 正交。
// gpt-5.5 + priority → filter
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-5.5-turbo → filter
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-4 + priority → filter默认策略覆盖所有模型
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-5.5 + flex → pass (tier doesn't match)
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex)
require.Equal(t, BetaPolicyActionPass, action)
// empty tier → pass
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "")
require.Equal(t, BetaPolicyActionPass, action)
}
func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "fast mode is not allowed",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionBlock, action)
require.Equal(t, "fast mode is not allowed", msg)
}
func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierAny,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeOAuth,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
// OAuth account → rule matches
oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// API Key account → rule skipped → pass
apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionPass, action)
}
func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// gpt-5.5 fast → service_tier stripped
body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// Client sending "fast" (alias for priority) also filtered
body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// gpt-4 priority → 默认策略对所有模型 filterservice_tier 被移除
body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// No service_tier → no-op
body = []byte(`{"model":"gpt-5.5"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
}
// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后
// 客户端显式发送的 OpenAI 官方合法 tierauto/default/scale能透传到上游而不被
// 静默剥离。默认策略只针对 priority所以这些 tier 落在 fall-through pass 分支。
func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
for _, tier := range []string{"auto", "default", "scale"} {
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err, "tier %q should pass without error", tier)
require.Contains(t, string(updated), `"service_tier":"`+tier+`"`,
"tier %q should be preserved in body under default rule", tier)
}
// evaluate 层也应判定为 pass默认规则 ServiceTier=priority 与 auto/default/scale 不匹配)
for _, tier := range []string{"auto", "default", "scale"} {
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier)
require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier)
}
}
// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置
// ServiceTier=all + Action=filter 规则后auto/default/scale 等官方 tier 也会
// 被剥离。这是符合预期的——首条匹配 short-circuit"all" 覆盖任意已识别 tier。
func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierAny,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} {
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`,
"tier %q should be stripped under ServiceTier=all + filter rule", tier)
}
}
// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离
// normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段;
// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op因为字段已不可能存在
// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。
func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// normalize 阶段会将未知值剥离
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
// applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错body 透传不变
// (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离)
body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
}
func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "fast mode is blocked for gpt-5.5",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.Error(t, err)
var blocked *OpenAIFastBlockedError
require.True(t, errors.As(err, &blocked))
require.Contains(t, blocked.Message, "fast mode is blocked")
require.Equal(t, string(body), string(updated)) // body not mutated on block
}
func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) {
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
svc := NewSettingService(repo, &config.Config{})
// Invalid action rejected
err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: "bogus",
Scope: BetaPolicyScopeAll,
}},
})
require.Error(t, err)
// Invalid service_tier rejected
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: "turbo",
Action: BetaPolicyActionPass,
Scope: BetaPolicyScopeAll,
}},
})
require.Error(t, err)
// Valid settings persisted
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
}},
})
require.NoError(t, err)
got, err := svc.GetOpenAIFastPolicySettings(context.Background())
require.NoError(t, err)
require.Len(t, got.Rules, 1)
require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier)
}