mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
对称参照 Claude BetaPolicy 的 fast-mode 过滤实现,新增针对 OpenAI 上游
service_tier 字段(priority / flex,含客户端 "fast" → "priority" 归一化)的
pass / filter / block 三态策略,覆盖全部 OpenAI 入口 + admin 配置入口。
后端核心
- 新增 SettingKeyOpenAIFastPolicySettings、OpenAIFastPolicyRule、
OpenAIFastPolicySettings 配置模型,含规则的 service_tier × action × scope
× 模型白名单 × fallback action 维度。
- SettingService.Get/SetOpenAIFastPolicySettings;缺失时返回内置默认策略
(所有模型的 priority 走 filter,whitelist 为空,fallback=pass)。设计
依据:service_tier=fast 是用户级开关,与 model 字段正交,默认锁定特定
model slug 会留下"用 gpt-4 + fast 透传 priority 上游"的绕过路径。JSON
解析失败不再静默 fallback,slog.Warn 记录脏数据,便于运维定位。
- service_tier 归一化(trim + ToLower + fast→priority + 白名单 priority/flex)
与策略评估(evaluateOpenAIFastPolicy)作为唯一真实来源,HTTP / WS 共用。
抽出纯函数 evaluateOpenAIFastPolicyWithSettings,配合 ctx-bound settings
快照(withOpenAIFastPolicyContext / openAIFastPolicySettingsFromContext),
WS 长会话入口预取一次后所有帧复用,避免每帧打到 settingService。
HTTP 入口(4 个)
- Chat Completions、Anthropic 兼容(Messages,含 BetaFastMode→priority 二次
命中)、原生 Responses、Passthrough Responses 全部接入
applyOpenAIFastPolicyToBody,filter 走 sjson 顶层删除 service_tier,block
返回 403 forbidden_error JSON。
- 4 入口统一使用 upstream 视角的 model(GetMappedModel +
normalizeOpenAIModelForUpstream + Codex OAuth normalize 后的 slug),
避免 chat/messages/native /responses/passthrough 因为 model 维度不同
造成 whitelist 命中差异。
- 在 pass 路径也把客户端 "fast" 别名归一化为 "priority" 写回 body,
否则 native /responses 与 passthrough 入口会把 "fast" 原样透传给上游
导致 400/拒绝(chat-completions 入口的 normalizeResponsesBodyServiceTier
此前已具备同等行为)。
WebSocket 入口
- 新增 applyOpenAIFastPolicyToWSResponseCreate:严格匹配
type="response.create",仅处理顶层 service_tier;filter 用 sjson 删字段,
block 返回 typed *OpenAIFastBlockedError。
- ingress 路径在 parseClientPayload 内调用,block 命中先 Write Realtime
风格 error event 再返回 OpenAIWSClientCloseError(StatusPolicyViolation
=1008),依赖底层 WebSocket Conn.Write 的同步 flush 保证 error 先于
close。
- passthrough 路径在 RunEntry 前对 firstClientMessage 应用策略,并通过
openAIWSPolicyEnforcingFrameConn 包装 ReadFrame 对每个 client→upstream
帧执行策略;后续帧无 model 字段时回退到 capturedSessionModel。
filter 闭包内同时侦测 session.update / session.created 帧的 session.model
字段刷新 capturedSessionModel,封堵"首帧 model=gpt-4o(pass)→
session.update 改为 gpt-5.5 → 不带 model 的 response.create fallback
到 gpt-4o"的 mid-session 绕过路径。
- passthrough billing:requestServiceTier 在策略 filter 之后再从
firstClientMessage 提取,filter 命中时 OpenAIForwardResult.ServiceTier
上报 nil(default tier),与 HTTP 入口(reqBody 来自 post-filter map)
/ WS ingress(payload 来自 post-filter bytes)的语义一致。
- 错误事件 schema:{event_id: "evt_<32hex>", type: "error",
error: {type: "forbidden_error", code: "policy_violation", message}},
与 OpenAI codex 客户端 error event 解析兼容。
Admin / Frontend
- dto.SystemSettings / UpdateSettingsRequest 新增
openai_fast_policy_settings 字段(omitempty),bulk GET/PUT 接入。
- Settings 页 Gateway 页签新增 Fast/Flex Policy 表单卡片:
service_tier × action × scope × 模型白名单 × fallback action 全字段配置。
- 前端守门:openaiFastPolicyLoaded 标志仅在 GET 真带回字段时才允许回写,
避免 rollout/错误把默认规则覆盖成空;saveSettings 回写循环 skip 该字段,
由专用刷新逻辑处理;仅 action=block 时发送 error_message,匹配后端
omitempty 行为。
测试
- HTTP 路径:openai_fast_policy_test.go 覆盖默认配置(whitelist=[],所有
模型 priority filter)/ block 自定义错误 / scope 区分 / filter 删字段 /
block 不改 body / block 短路上游 / Anthropic BetaFastMode 触发 OpenAI
fast policy 等场景。
- WebSocket 路径:openai_fast_policy_ws_test.go 覆盖
helper 单元(filter / fast→priority 归一化 / flex 透传 / block typed
error / 无 service_tier 字节不变 / 非 response.create 帧不动 / 空 type
帧不动 / event_id+code 字段断言 / 非字符串 service_tier 容错)+
pass 路径 fast 别名归一化回归 +
ingress 端到端(filter 后上游不含 service_tier / block 后客户端先收
error event 再收 close 1008 且上游 0 写)+
passthrough capturedSessionModel fallback 用例(whitelist 策略下首帧
建立、缺 model 命中 fallback、缺少 fallback 时的 leak 文档化)+
passthrough session.update / session.created 旋转 capturedSessionModel
的 mid-session 绕过回归 +
passthrough billing post-filter ServiceTier 与 idempotent filter 回归。
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
287 lines
12 KiB
Go
287 lines
12 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"testing"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
type openAIFastPolicyRepoStub struct {
|
||
values map[string]string
|
||
}
|
||
|
||
func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||
panic("unexpected Get call")
|
||
}
|
||
|
||
func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||
if v, ok := s.values[key]; ok {
|
||
return v, nil
|
||
}
|
||
return "", ErrSettingNotFound
|
||
}
|
||
|
||
func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error {
|
||
if s.values == nil {
|
||
s.values = map[string]string{}
|
||
}
|
||
s.values[key] = value
|
||
return nil
|
||
}
|
||
|
||
func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||
panic("unexpected GetMultiple call")
|
||
}
|
||
|
||
func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||
panic("unexpected SetMultiple call")
|
||
}
|
||
|
||
func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||
panic("unexpected GetAll call")
|
||
}
|
||
|
||
func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error {
|
||
panic("unexpected Delete call")
|
||
}
|
||
|
||
func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService {
|
||
t.Helper()
|
||
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
|
||
if settings != nil {
|
||
raw, err := json.Marshal(settings)
|
||
require.NoError(t, err)
|
||
repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
|
||
}
|
||
return &OpenAIGatewayService{
|
||
settingService: NewSettingService(repo, &config.Config{}),
|
||
}
|
||
}
|
||
|
||
func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) {
|
||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
|
||
// 默认策略对所有模型生效(whitelist 为空),因为 codex 的 service_tier=fast
|
||
// 是用户级开关,与 model 正交。
|
||
// gpt-5.5 + priority → filter
|
||
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
|
||
require.Equal(t, BetaPolicyActionFilter, action)
|
||
|
||
// gpt-5.5-turbo → filter
|
||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority)
|
||
require.Equal(t, BetaPolicyActionFilter, action)
|
||
|
||
// gpt-4 + priority → filter(默认策略覆盖所有模型)
|
||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority)
|
||
require.Equal(t, BetaPolicyActionFilter, action)
|
||
|
||
// gpt-5.5 + flex → pass (tier doesn't match)
|
||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex)
|
||
require.Equal(t, BetaPolicyActionPass, action)
|
||
|
||
// empty tier → pass
|
||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "")
|
||
require.Equal(t, BetaPolicyActionPass, action)
|
||
}
|
||
|
||
func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) {
|
||
settings := &OpenAIFastPolicySettings{
|
||
Rules: []OpenAIFastPolicyRule{{
|
||
ServiceTier: OpenAIFastTierPriority,
|
||
Action: BetaPolicyActionBlock,
|
||
Scope: BetaPolicyScopeAll,
|
||
ErrorMessage: "fast mode is not allowed",
|
||
ModelWhitelist: []string{"gpt-5.5"},
|
||
FallbackAction: BetaPolicyActionPass,
|
||
}},
|
||
}
|
||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
|
||
action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
|
||
require.Equal(t, BetaPolicyActionBlock, action)
|
||
require.Equal(t, "fast mode is not allowed", msg)
|
||
}
|
||
|
||
func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) {
|
||
settings := &OpenAIFastPolicySettings{
|
||
Rules: []OpenAIFastPolicyRule{{
|
||
ServiceTier: OpenAIFastTierAny,
|
||
Action: BetaPolicyActionFilter,
|
||
Scope: BetaPolicyScopeOAuth,
|
||
}},
|
||
}
|
||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||
|
||
// OAuth account → rule matches
|
||
oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority)
|
||
require.Equal(t, BetaPolicyActionFilter, action)
|
||
|
||
// API Key account → rule skipped → pass
|
||
apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority)
|
||
require.Equal(t, BetaPolicyActionPass, action)
|
||
}
|
||
|
||
func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
|
||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
|
||
// gpt-5.5 fast → service_tier stripped
|
||
body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
|
||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||
require.NoError(t, err)
|
||
require.NotContains(t, string(updated), `"service_tier"`)
|
||
|
||
// Client sending "fast" (alias for priority) also filtered
|
||
body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
|
||
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||
require.NoError(t, err)
|
||
require.NotContains(t, string(updated), `"service_tier"`)
|
||
|
||
// gpt-4 priority → 默认策略对所有模型 filter,service_tier 被移除
|
||
body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
|
||
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
|
||
require.NoError(t, err)
|
||
require.NotContains(t, string(updated), `"service_tier"`)
|
||
|
||
// No service_tier → no-op
|
||
body = []byte(`{"model":"gpt-5.5"}`)
|
||
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||
require.NoError(t, err)
|
||
require.Equal(t, string(body), string(updated))
|
||
}
|
||
|
||
// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后
|
||
// 客户端显式发送的 OpenAI 官方合法 tier(auto/default/scale)能透传到上游而不被
|
||
// 静默剥离。默认策略只针对 priority,所以这些 tier 落在 fall-through pass 分支。
|
||
func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) {
|
||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
|
||
for _, tier := range []string{"auto", "default", "scale"} {
|
||
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
|
||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||
require.NoError(t, err, "tier %q should pass without error", tier)
|
||
require.Contains(t, string(updated), `"service_tier":"`+tier+`"`,
|
||
"tier %q should be preserved in body under default rule", tier)
|
||
}
|
||
|
||
// evaluate 层也应判定为 pass(默认规则 ServiceTier=priority 与 auto/default/scale 不匹配)
|
||
for _, tier := range []string{"auto", "default", "scale"} {
|
||
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier)
|
||
require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier)
|
||
}
|
||
}
|
||
|
||
// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置
|
||
// ServiceTier=all + Action=filter 规则后,auto/default/scale 等官方 tier 也会
|
||
// 被剥离。这是符合预期的——首条匹配 short-circuit,"all" 覆盖任意已识别 tier。
|
||
func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) {
|
||
settings := &OpenAIFastPolicySettings{
|
||
Rules: []OpenAIFastPolicyRule{{
|
||
ServiceTier: OpenAIFastTierAny,
|
||
Action: BetaPolicyActionFilter,
|
||
Scope: BetaPolicyScopeAll,
|
||
}},
|
||
}
|
||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
|
||
for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} {
|
||
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
|
||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||
require.NoError(t, err)
|
||
require.NotContains(t, string(updated), `"service_tier"`,
|
||
"tier %q should be stripped under ServiceTier=all + filter rule", tier)
|
||
}
|
||
}
|
||
|
||
// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离
|
||
// (normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段;
|
||
// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op,因为字段已不可能存在
|
||
// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。
|
||
func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) {
|
||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
|
||
// normalize 阶段会将未知值剥离
|
||
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
|
||
|
||
// applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错,body 透传不变
|
||
// (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离)
|
||
body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`)
|
||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||
require.NoError(t, err)
|
||
require.Equal(t, string(body), string(updated))
|
||
}
|
||
|
||
func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) {
|
||
settings := &OpenAIFastPolicySettings{
|
||
Rules: []OpenAIFastPolicyRule{{
|
||
ServiceTier: OpenAIFastTierPriority,
|
||
Action: BetaPolicyActionBlock,
|
||
Scope: BetaPolicyScopeAll,
|
||
ErrorMessage: "fast mode is blocked for gpt-5.5",
|
||
ModelWhitelist: []string{"gpt-5.5"},
|
||
FallbackAction: BetaPolicyActionPass,
|
||
}},
|
||
}
|
||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||
|
||
body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`)
|
||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||
require.Error(t, err)
|
||
var blocked *OpenAIFastBlockedError
|
||
require.True(t, errors.As(err, &blocked))
|
||
require.Contains(t, blocked.Message, "fast mode is blocked")
|
||
require.Equal(t, string(body), string(updated)) // body not mutated on block
|
||
}
|
||
|
||
func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) {
|
||
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
|
||
svc := NewSettingService(repo, &config.Config{})
|
||
|
||
// Invalid action rejected
|
||
err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
||
Rules: []OpenAIFastPolicyRule{{
|
||
ServiceTier: OpenAIFastTierPriority,
|
||
Action: "bogus",
|
||
Scope: BetaPolicyScopeAll,
|
||
}},
|
||
})
|
||
require.Error(t, err)
|
||
|
||
// Invalid service_tier rejected
|
||
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
||
Rules: []OpenAIFastPolicyRule{{
|
||
ServiceTier: "turbo",
|
||
Action: BetaPolicyActionPass,
|
||
Scope: BetaPolicyScopeAll,
|
||
}},
|
||
})
|
||
require.Error(t, err)
|
||
|
||
// Valid settings persisted
|
||
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
||
Rules: []OpenAIFastPolicyRule{{
|
||
ServiceTier: OpenAIFastTierPriority,
|
||
Action: BetaPolicyActionFilter,
|
||
Scope: BetaPolicyScopeAll,
|
||
}},
|
||
})
|
||
require.NoError(t, err)
|
||
|
||
got, err := svc.GetOpenAIFastPolicySettings(context.Background())
|
||
require.NoError(t, err)
|
||
require.Len(t, got.Rules, 1)
|
||
require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier)
|
||
}
|