Merge pull request #2051 from DaydreamCoding/openai-fast-flex-policy

feat(openai): OpenAI Fast/Flex Policy 完整实现(HTTP + WebSocket + Admin)
This commit is contained in:
Wesley Liddick
2026-04-28 12:14:43 +08:00
committed by GitHub
23 changed files with 2820 additions and 10 deletions

View File

@@ -186,7 +186,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)

View File

@@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
@@ -222,3 +223,66 @@ func TestOpsWSHelpers(t *testing.T) {
require.True(t, isAddrInTrustedProxies(addr, prefixes))
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
}
// TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier 验证 admin
// 写入路径会把 ServiceTier 的空字符串/空白/大小写归一化为
// service.OpenAIFastTierAny ("all"),避免落盘时 "" 与 "all" 双语义。
func TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier(t *testing.T) {
t.Run("nil input returns nil", func(t *testing.T) {
require.Nil(t, openaiFastPolicySettingsFromDTO(nil))
})
t.Run("empty service_tier becomes 'all'", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{{
ServiceTier: "",
Action: "filter",
Scope: "all",
}},
}
out := openaiFastPolicySettingsFromDTO(in)
require.NotNil(t, out)
require.Len(t, out.Rules, 1)
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
require.Equal(t, "all", out.Rules[0].ServiceTier)
})
t.Run("whitespace-only service_tier becomes 'all'", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{{
ServiceTier: " ",
Action: "pass",
Scope: "all",
}},
}
out := openaiFastPolicySettingsFromDTO(in)
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
})
t.Run("uppercase service_tier is lowercased", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{{
ServiceTier: "PRIORITY",
Action: "filter",
Scope: "all",
}},
}
out := openaiFastPolicySettingsFromDTO(in)
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
})
t.Run("non-empty values pass through (lowercased)", func(t *testing.T) {
in := &dto.OpenAIFastPolicySettings{
Rules: []dto.OpenAIFastPolicyRule{
{ServiceTier: "priority", Action: "filter", Scope: "all"},
{ServiceTier: "flex", Action: "block", Scope: "oauth"},
{ServiceTier: "all", Action: "pass", Scope: "apikey"},
},
}
out := openaiFastPolicySettingsFromDTO(in)
require.Len(t, out.Rules, 3)
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
require.Equal(t, service.OpenAIFastTierFlex, out.Rules[1].ServiceTier)
require.Equal(t, service.OpenAIFastTierAny, out.Rules[2].ServiceTier)
})
}

View File

@@ -248,9 +248,51 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
AffiliateEnabled: settings.AffiliateEnabled,
}
// OpenAI fast policy (stored under a dedicated setting key)
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
} else if fastPolicy != nil {
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
}
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
// openaiFastPolicySettingsToDTO converts service -> dto for OpenAI fast policy.
func openaiFastPolicySettingsToDTO(s *service.OpenAIFastPolicySettings) *dto.OpenAIFastPolicySettings {
if s == nil {
return nil
}
rules := make([]dto.OpenAIFastPolicyRule, len(s.Rules))
for i, r := range s.Rules {
rules[i] = dto.OpenAIFastPolicyRule(r)
}
return &dto.OpenAIFastPolicySettings{Rules: rules}
}
// openaiFastPolicySettingsFromDTO converts dto -> service for OpenAI fast policy.
//
// 规范化 ServiceTier在 DTO 进入 service 层之前统一把空字符串归一为
// service.OpenAIFastTierAny ("all"),避免管理员保存时空串与 "all" 同时
// 表达"匹配任意 tier"造成数据库取值的二义性。其它非空值原样透传,由
// service.SetOpenAIFastPolicySettings 负责合法值校验。
func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.OpenAIFastPolicySettings {
if s == nil {
return nil
}
rules := make([]service.OpenAIFastPolicyRule, len(s.Rules))
for i, r := range s.Rules {
rules[i] = service.OpenAIFastPolicyRule(r)
tier := strings.ToLower(strings.TrimSpace(rules[i].ServiceTier))
if tier == "" {
tier = service.OpenAIFastTierAny
}
rules[i].ServiceTier = tier
}
return &service.OpenAIFastPolicySettings{Rules: rules}
}
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
@@ -452,6 +494,9 @@ type UpdateSettingsRequest struct {
// Affiliate (邀请返利) feature switch
AffiliateEnabled *bool `json:"affiliate_enabled"`
// OpenAI fast/flex policy (optional, only updated when provided)
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
}
// UpdateSettings 更新系统设置
@@ -1350,6 +1395,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
// Update OpenAI fast policy (stored under dedicated key, only when provided).
if req.OpenAIFastPolicySettings != nil {
if err := h.settingService.SetOpenAIFastPolicySettings(c.Request.Context(), openaiFastPolicySettingsFromDTO(req.OpenAIFastPolicySettings)); err != nil {
response.BadRequest(c, err.Error())
return
}
}
// Update payment configuration (integrated into system settings).
// Skip if no payment fields were provided (prevents accidental wipe).
if h.paymentConfigService != nil && hasPaymentFields(req) {
@@ -1555,6 +1608,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
AffiliateEnabled: updatedSettings.AffiliateEnabled,
}
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
} else if fastPolicy != nil {
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
}
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}

View File

@@ -26,7 +26,12 @@ func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.
}
func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
if s.values != nil {
if value, ok := s.values[key]; ok {
return value, nil
}
}
return "", nil
}
func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {

View File

@@ -198,6 +198,9 @@ type SystemSettings struct {
// Affiliate (邀请返利) feature switch
AffiliateEnabled bool `json:"affiliate_enabled"`
// OpenAI fast/flex policy
OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
}
type DefaultSubscriptionSetting struct {
@@ -294,6 +297,22 @@ type BetaPolicySettings struct {
Rules []BetaPolicyRule `json:"rules"`
}
// OpenAIFastPolicyRule OpenAI fast/flex 策略规则 DTO
type OpenAIFastPolicyRule struct {
ServiceTier string `json:"service_tier"`
Action string `json:"action"`
Scope string `json:"scope"`
ErrorMessage string `json:"error_message,omitempty"`
ModelWhitelist []string `json:"model_whitelist,omitempty"`
FallbackAction string `json:"fallback_action,omitempty"`
FallbackErrorMessage string `json:"fallback_error_message,omitempty"`
}
// OpenAIFastPolicySettings OpenAI fast 策略配置 DTO
type OpenAIFastPolicySettings struct {
Rules []OpenAIFastPolicyRule `json:"rules"`
}
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
// Returns empty slice on empty/invalid input.
func ParseCustomMenuItems(raw string) []CustomMenuItem {

View File

@@ -748,6 +748,16 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_alipay_enabled": true,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": true,
"openai_fast_policy_settings": {
"rules": [
{
"service_tier": "priority",
"action": "filter",
"scope": "all",
"fallback_action": "pass"
}
]
},
"custom_menu_items": [],
"custom_endpoints": [],
"payment_enabled": false,
@@ -930,6 +940,16 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_alipay_enabled": false,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": false,
"openai_fast_policy_settings": {
"rules": [
{
"service_tier": "priority",
"action": "filter",
"scope": "all",
"fallback_action": "pass"
}
]
},
"payment_enabled": false,
"payment_min_amount": 0,
"payment_max_amount": 0,

View File

@@ -306,6 +306,12 @@ const (
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
SettingKeyBetaPolicySettings = "beta_policy_settings"
// SettingKeyOpenAIFastPolicySettings stores JSON config for OpenAI
// service_tier (fast/flex) policy rules. Mirrors BetaPolicySettings but
// targets OpenAI's body-level service_tier field instead of Claude's
// anthropic-beta header.
SettingKeyOpenAIFastPolicySettings = "openai_fast_policy_settings"
// =========================
// Claude Code Version Check
// =========================

View File

@@ -0,0 +1,286 @@
package service
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type openAIFastPolicyRepoStub struct {
values map[string]string
}
func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", ErrSettingNotFound
}
func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error {
if s.values == nil {
s.values = map[string]string{}
}
s.values[key] = value
return nil
}
func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService {
t.Helper()
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
if settings != nil {
raw, err := json.Marshal(settings)
require.NoError(t, err)
repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
}
return &OpenAIGatewayService{
settingService: NewSettingService(repo, &config.Config{}),
}
}
func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// 默认策略对所有模型生效whitelist 为空),因为 codex 的 service_tier=fast
// 是用户级开关,与 model 正交。
// gpt-5.5 + priority → filter
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-5.5-turbo → filter
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-4 + priority → filter默认策略覆盖所有模型
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// gpt-5.5 + flex → pass (tier doesn't match)
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex)
require.Equal(t, BetaPolicyActionPass, action)
// empty tier → pass
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "")
require.Equal(t, BetaPolicyActionPass, action)
}
func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "fast mode is not allowed",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionBlock, action)
require.Equal(t, "fast mode is not allowed", msg)
}
func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierAny,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeOAuth,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
// OAuth account → rule matches
oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionFilter, action)
// API Key account → rule skipped → pass
apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority)
require.Equal(t, BetaPolicyActionPass, action)
}
func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// gpt-5.5 fast → service_tier stripped
body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// Client sending "fast" (alias for priority) also filtered
body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// gpt-4 priority → 默认策略对所有模型 filterservice_tier 被移除
body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`)
// No service_tier → no-op
body = []byte(`{"model":"gpt-5.5"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
}
// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后
// 客户端显式发送的 OpenAI 官方合法 tierauto/default/scale能透传到上游而不被
// 静默剥离。默认策略只针对 priority所以这些 tier 落在 fall-through pass 分支。
func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
for _, tier := range []string{"auto", "default", "scale"} {
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err, "tier %q should pass without error", tier)
require.Contains(t, string(updated), `"service_tier":"`+tier+`"`,
"tier %q should be preserved in body under default rule", tier)
}
// evaluate 层也应判定为 pass默认规则 ServiceTier=priority 与 auto/default/scale 不匹配)
for _, tier := range []string{"auto", "default", "scale"} {
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier)
require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier)
}
}
// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置
// ServiceTier=all + Action=filter 规则后auto/default/scale 等官方 tier 也会
// 被剥离。这是符合预期的——首条匹配 short-circuit"all" 覆盖任意已识别 tier。
func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierAny,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} {
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.NotContains(t, string(updated), `"service_tier"`,
"tier %q should be stripped under ServiceTier=all + filter rule", tier)
}
}
// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离
// normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段;
// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op因为字段已不可能存在
// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。
func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// normalize 阶段会将未知值剥离
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
// applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错body 透传不变
// (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离)
body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
require.Equal(t, string(body), string(updated))
}
func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) {
settings := &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "fast mode is blocked for gpt-5.5",
ModelWhitelist: []string{"gpt-5.5"},
FallbackAction: BetaPolicyActionPass,
}},
}
svc := newOpenAIGatewayServiceWithSettings(t, settings)
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.Error(t, err)
var blocked *OpenAIFastBlockedError
require.True(t, errors.As(err, &blocked))
require.Contains(t, blocked.Message, "fast mode is blocked")
require.Equal(t, string(body), string(updated)) // body not mutated on block
}
func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) {
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
svc := NewSettingService(repo, &config.Config{})
// Invalid action rejected
err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: "bogus",
Scope: BetaPolicyScopeAll,
}},
})
require.Error(t, err)
// Invalid service_tier rejected
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: "turbo",
Action: BetaPolicyActionPass,
Scope: BetaPolicyScopeAll,
}},
})
require.Error(t, err)
// Valid settings persisted
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
}},
})
require.NoError(t, err)
got, err := svc.GetOpenAIFastPolicySettings(context.Background())
require.NoError(t, err)
require.Len(t, got.Rules, 1)
require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier)
}

File diff suppressed because it is too large Load Diff

View File

@@ -171,6 +171,17 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
}
}
// 4b. Apply OpenAI fast policy (may filter service_tier or block the request).
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
if policyErr != nil {
var blocked *OpenAIFastBlockedError
if errors.As(policyErr, &blocked) {
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
}
return nil, policyErr
}
responsesBody = updatedBody
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {

View File

@@ -19,8 +19,22 @@ func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "flex", req.ServiceTier)
// OpenAI 官方合法 tier 应被透传保留。
req.ServiceTier = "auto"
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "auto", req.ServiceTier)
req.ServiceTier = "default"
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "default", req.ServiceTier)
req.ServiceTier = "scale"
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "scale", req.ServiceTier)
// 真未知值仍被剥离。
req.ServiceTier = "turbo"
normalizeResponsesRequestServiceTier(req)
require.Empty(t, req.ServiceTier)
}
@@ -37,8 +51,25 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
require.Equal(t, "flex", tier)
require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
// OpenAI 官方 tier 直接保留在 body 中(透传上游)。
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"auto"}`))
require.NoError(t, err)
require.Equal(t, "auto", tier)
require.Equal(t, "auto", gjson.GetBytes(body, "service_tier").String())
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
require.NoError(t, err)
require.Equal(t, "default", tier)
require.Equal(t, "default", gjson.GetBytes(body, "service_tier").String())
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"scale"}`))
require.NoError(t, err)
require.Equal(t, "scale", tier)
require.Equal(t, "scale", gjson.GetBytes(body, "service_tier").String())
// 真未知值才会被删除。
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"turbo"}`))
require.NoError(t, err)
require.Empty(t, tier)
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
}

View File

@@ -143,6 +143,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
}
}
// 4c. Apply OpenAI fast policy (may filter service_tier or block the request).
// Mirrors the Claude anthropic-beta "fast-mode-2026-02-01" filter, but keyed
// on the body-level service_tier field (priority/flex).
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
if policyErr != nil {
var blocked *OpenAIFastBlockedError
if errors.As(policyErr, &blocked) {
writeAnthropicError(c, http.StatusForbidden, "forbidden_error", blocked.Message)
}
return nil, policyErr
}
responsesBody = updatedBody
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {

View File

@@ -148,6 +148,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil,
nil,
nil,
nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,
@@ -826,18 +827,29 @@ func TestNormalizeOpenAIServiceTier(t *testing.T) {
require.Equal(t, "priority", *got)
})
t.Run("default ignored", func(t *testing.T) {
require.Nil(t, normalizeOpenAIServiceTier("default"))
t.Run("openai official tiers preserved", func(t *testing.T) {
// OpenAI 官方文档定义的合法 tier 值都应被透传保留,避免因白名单过窄
// 静默剥离客户端显式发送的合法字段。Codex 客户端只发 priority/flex
// 所以扩大白名单对 Codex 流量零影响(见 codex-rs/core/src/client.rs
for _, tier := range []string{"priority", "flex", "auto", "default", "scale"} {
got := normalizeOpenAIServiceTier(tier)
require.NotNil(t, got, "tier %q should not be normalized to nil", tier)
require.Equal(t, tier, *got)
}
})
t.Run("invalid ignored", func(t *testing.T) {
require.Nil(t, normalizeOpenAIServiceTier("turbo"))
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
})
}
func TestExtractOpenAIServiceTier(t *testing.T) {
require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"}))
require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"}))
require.Equal(t, "auto", *extractOpenAIServiceTier(map[string]any{"service_tier": "auto"}))
require.Equal(t, "default", *extractOpenAIServiceTier(map[string]any{"service_tier": "default"}))
require.Equal(t, "scale", *extractOpenAIServiceTier(map[string]any{"service_tier": "scale"}))
require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1}))
require.Nil(t, extractOpenAIServiceTier(nil))
}
@@ -845,7 +857,10 @@ func TestExtractOpenAIServiceTier(t *testing.T) {
func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`)))
require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
require.Equal(t, "auto", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"auto"}`)))
require.Equal(t, "default", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
require.Equal(t, "scale", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"scale"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"turbo"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody(nil))
}

View File

@@ -334,6 +334,7 @@ type OpenAIGatewayService struct {
resolver *ModelPricingResolver
channelService *ChannelService
balanceNotifyService *BalanceNotifyService
settingService *SettingService
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
@@ -372,6 +373,7 @@ func NewOpenAIGatewayService(
resolver *ModelPricingResolver,
channelService *ChannelService,
balanceNotifyService *BalanceNotifyService,
settingService *SettingService,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
@@ -402,6 +404,7 @@ func NewOpenAIGatewayService(
resolver: resolver,
channelService: channelService,
balanceNotifyService: balanceNotifyService,
settingService: settingService,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
@@ -2310,6 +2313,48 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
disablePatch()
}
// Apply OpenAI fast policy (参照 Claude BetaPolicy 的 fast-mode 过滤)
// 针对 body 的 service_tier 字段("priority" 即 fast"flex"),按策略
// 执行 filter删除字段或 block拒绝请求。对 gpt-5.5 等模型屏蔽
// fast 时在此生效。
//
// 注意:
// 1. 此处统一使用 upstreamModel已经过 GetMappedModel +
// normalizeOpenAIModelForUpstream + Codex OAuth normalize
// chat-completions / messages 入口保持一致,避免不同入口因为模型
// 维度不同而出现 whitelist 命中差异。
// 2. action=pass 时也要把 raw "fast" 归一化为 "priority" 写回 body
// 否则 native /responses 入口透传 "fast" 给上游会被拒。chat-
// completions 入口由 normalizeResponsesBodyServiceTier 完成同一
// 行为,这里手工实现等效逻辑。
if rawTier, ok := reqBody["service_tier"].(string); ok {
if normTier := normalizedOpenAIServiceTierValue(rawTier); normTier != "" {
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, upstreamModel, normTier)
switch action {
case BetaPolicyActionBlock:
msg := errMsg
if msg == "" {
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, upstreamModel)
}
blocked := &OpenAIFastBlockedError{Message: msg}
writeOpenAIFastPolicyBlockedResponse(c, blocked)
return nil, blocked
case BetaPolicyActionFilter:
delete(reqBody, "service_tier")
bodyModified = true
disablePatch()
default:
// pass若客户端传的是别名 "fast",归一化为 "priority"
// 后写回 body确保上游收到的是其能识别的规范值。
if normTier != rawTier {
reqBody["service_tier"] = normTier
bodyModified = true
markPatchSet("service_tier", normTier)
}
}
}
}
// Re-serialize body only if modified
if bodyModified {
serializedByPatch := false
@@ -2758,6 +2803,26 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
body = sanitizedBody
}
// Apply OpenAI fast policy to the passthrough body (filter/block by service_tier).
// 统一使用 upstream 视角的 model透传路径下 body 已经过 compact 映射 +
// OAuth normalizebody 中的 model 字段即上游真正会看到的 slug。
// 这样可以与 chat-completions / messages / native /responses 入口的
// upstreamModel 保持一致,避免 whitelist 命中差异。当 body 中没有
// model 字段时退回 reqModel。
policyModel := strings.TrimSpace(gjson.GetBytes(body, "model").String())
if policyModel == "" {
policyModel = reqModel
}
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, policyModel, body)
if policyErr != nil {
var blocked *OpenAIFastBlockedError
if errors.As(policyErr, &blocked) {
writeOpenAIFastPolicyBlockedResponse(c, blocked)
}
return nil, policyErr
}
body = updatedBody
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
account.ID,
@@ -5590,14 +5655,319 @@ func normalizeOpenAIServiceTier(raw string) *string {
if value == "fast" {
value = "priority"
}
// 放过 OpenAI 官方文档定义的所有合法 tier 值priority/flex/auto/default/scale。
// 对 Codex 客户端零影响Codex 只发 priority 或 flex见 codex-rs/core/src/client.rs
// 但能让直连 OpenAI SDK 的用户透传 auto/default/scale 以便抓包/调试。
// 真未知值仍返回 nil由 normalizeResponsesBodyServiceTier 从 body 中删除。
switch value {
case "priority", "flex":
case "priority", "flex", "auto", "default", "scale":
return &value
default:
return nil
}
}
// OpenAIFastBlockedError indicates a request was rejected by the OpenAI fast
// policy (action=block). Mirrors BetaBlockedError on the Claude side.
type OpenAIFastBlockedError struct {
Message string
}
func (e *OpenAIFastBlockedError) Error() string { return e.Message }
// evaluateOpenAIFastPolicy returns the action and error message that should be
// applied for a request with the given account/model/service_tier. When the
// policy service is unavailable or no rule matches, it returns
// (BetaPolicyActionPass, "") so callers can short-circuit safely.
//
// Matching rules:
// - Scope filters by account type (all / oauth / apikey / bedrock)
// - ServiceTier must be empty (= any), "all", or equal the normalized tier
// - ModelWhitelist narrows the rule to specific models; FallbackAction
// handles the non-matching case (default: pass)
//
// 与 Claude BetaPolicy 的差异(保留首条匹配 short-circuit
// - BetaPolicy 处理的是 anthropic-beta header 中的 token 集合,不同
// 规则可能针对不同 tokenfilter 需要累加成 setblock 则 first-match。
// - OpenAI fast policy 操作的是单个字段 service_tierfilter 即删字段,
// 没有可累加的对象。一次请求只携带一个 service_tier规则的 tier
// 维度天然互斥;同一 (scope, tier) 下若多条规则的 model whitelist
// 发生重叠admin 可通过规则顺序明确意图。因此采用 first-match 而
// 非 BetaPolicy 那样的"block 覆盖 filter 覆盖 pass"语义。
func (s *OpenAIGatewayService) evaluateOpenAIFastPolicy(ctx context.Context, account *Account, model, serviceTier string) (action, errMsg string) {
if s == nil || s.settingService == nil {
return BetaPolicyActionPass, ""
}
tier := strings.ToLower(strings.TrimSpace(serviceTier))
if tier == "" {
return BetaPolicyActionPass, ""
}
settings := openAIFastPolicySettingsFromContext(ctx)
if settings == nil {
fetched, err := s.settingService.GetOpenAIFastPolicySettings(ctx)
if err != nil || fetched == nil {
return BetaPolicyActionPass, ""
}
settings = fetched
}
return evaluateOpenAIFastPolicyWithSettings(settings, account, model, tier)
}
// evaluateOpenAIFastPolicyWithSettings is the pure-function core extracted so
// long-lived sessions (e.g. WS) can prefetch settings once and avoid hitting
// the settingService on every frame. See WSSession entry and
// openAIFastPolicySettingsFromContext for the caching glue.
func evaluateOpenAIFastPolicyWithSettings(settings *OpenAIFastPolicySettings, account *Account, model, tier string) (action, errMsg string) {
if settings == nil {
return BetaPolicyActionPass, ""
}
isOAuth := account != nil && account.IsOAuth()
isBedrock := account != nil && account.IsBedrock()
for _, rule := range settings.Rules {
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue
}
ruleTier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
if ruleTier != "" && ruleTier != OpenAIFastTierAny && ruleTier != tier {
continue
}
eff := BetaPolicyRule{
Action: rule.Action,
ErrorMessage: rule.ErrorMessage,
ModelWhitelist: rule.ModelWhitelist,
FallbackAction: rule.FallbackAction,
FallbackErrorMessage: rule.FallbackErrorMessage,
}
return resolveRuleAction(eff, model)
}
return BetaPolicyActionPass, ""
}
// openAIFastPolicyCtxKey 是 context 中预取的 OpenAIFastPolicySettings 缓存
// 键,仅用于 WebSocket 长会话内多帧复用同一份策略快照,避免每帧 DB 命中。
//
// Trade-off策略变更不会影响当前 WS session只影响新 session。这是
// 有意为之 —— 对长会话来说,"策略一致性"比"立刻生效"更重要,且 Claude
// BetaPolicy 的 gin.Context 缓存也是同样取舍。需要 hot-reload 时管理员
// 可以通过踢断 session 强制刷新。
type openAIFastPolicyCtxKeyType struct{}
var openAIFastPolicyCtxKey = openAIFastPolicyCtxKeyType{}
// withOpenAIFastPolicyContext 将一份 settings 快照绑定到 context供该 ctx
// 衍生 goroutine 中的 evaluateOpenAIFastPolicy 复用。
func withOpenAIFastPolicyContext(ctx context.Context, settings *OpenAIFastPolicySettings) context.Context {
if ctx == nil || settings == nil {
return ctx
}
return context.WithValue(ctx, openAIFastPolicyCtxKey, settings)
}
func openAIFastPolicySettingsFromContext(ctx context.Context) *OpenAIFastPolicySettings {
if ctx == nil {
return nil
}
if v, ok := ctx.Value(openAIFastPolicyCtxKey).(*OpenAIFastPolicySettings); ok {
return v
}
return nil
}
// applyOpenAIFastPolicyToBody applies the OpenAI fast policy to a raw request
// body. When action=filter it removes the service_tier field; when
// action=block it returns (body, *OpenAIFastBlockedError). On pass it
// normalizes the service_tier value (e.g. client alias "fast" → "priority"),
// rewriting the body so the upstream receives a slug it recognizes.
//
// Rationale for normalize-on-pass: chat-completions / messages 入口在调用本
// 函数之前已经通过 normalizeResponsesBodyServiceTier 把 service_tier 归一化
// 到了上游可识别值passthroughOpenAI 自动透传) / native /responses 等
// 入口没有这一前置步骤pass 路径下若不在此处归一化,"fast" 就会被原样
// 透传到 OpenAI 上游导致 400/拒绝。把归一化收敛到本函数,所有入口行为一致。
func (s *OpenAIGatewayService) applyOpenAIFastPolicyToBody(ctx context.Context, account *Account, model string, body []byte) ([]byte, error) {
if len(body) == 0 {
return body, nil
}
rawTier := gjson.GetBytes(body, "service_tier").String()
if rawTier == "" {
return body, nil
}
normTier := normalizedOpenAIServiceTierValue(rawTier)
if normTier == "" {
return body, nil
}
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
switch action {
case BetaPolicyActionBlock:
msg := errMsg
if msg == "" {
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
}
return body, &OpenAIFastBlockedError{Message: msg}
case BetaPolicyActionFilter:
trimmed, err := sjson.DeleteBytes(body, "service_tier")
if err != nil {
return body, fmt.Errorf("strip service_tier from body: %w", err)
}
return trimmed, nil
default:
// pass把别名如 "fast")写回为规范值("priority")。
if normTier == rawTier {
return body, nil
}
updated, err := sjson.SetBytes(body, "service_tier", normTier)
if err != nil {
return body, fmt.Errorf("normalize service_tier on pass: %w", err)
}
return updated, nil
}
}
// writeOpenAIFastPolicyBlockedResponse writes a 403 JSON response for a
// request blocked by the OpenAI fast policy.
func writeOpenAIFastPolicyBlockedResponse(c *gin.Context, err *OpenAIFastBlockedError) {
if c == nil || err == nil {
return
}
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"type": "permission_error",
"message": err.Message,
},
})
}
// applyOpenAIFastPolicyToWSResponseCreate evaluates the OpenAI fast policy
// against a single client→upstream WebSocket frame whose top-level
// "type"=="response.create". It mirrors the HTTP-side
// applyOpenAIFastPolicyToBody contract but operates on a Realtime/Responses
// WS payload:
//
// - pass: returns frame unchanged (newBytes == frame, blocked == nil)
// - filter: returns a copy with top-level service_tier removed
// - block: returns (frame, *OpenAIFastBlockedError)
//
// Only frames whose "type" field strictly equals "response.create" are
// inspected/mutated. Any other frame type — including the empty string —
// passes through untouched. The OpenAI Realtime client-event spec requires
// "type" to be set, so an empty type is treated as a malformed frame we do
// not police; the upstream is the source of truth for rejecting it.
//
// service_tier lives at the top level of response.create — same as the
// Responses HTTP body shape (see openai_gateway_chat_completions.go:304 +
// extractOpenAIServiceTierFromBody at line 5593, and the test fixture at
// openai_ws_forwarder_ingress_session_test.go:402). We therefore only need
// to inspect / strip the top-level field; there is no nested form in the
// schema today.
//
// The caller is responsible for choosing the upstream model passed in —
// this helper does not re-derive it.
func (s *OpenAIGatewayService) applyOpenAIFastPolicyToWSResponseCreate(
ctx context.Context,
account *Account,
model string,
frame []byte,
) ([]byte, *OpenAIFastBlockedError, error) {
if len(frame) == 0 {
return frame, nil, nil
}
if !gjson.ValidBytes(frame) {
return frame, nil, nil
}
frameType := strings.TrimSpace(gjson.GetBytes(frame, "type").String())
// Strict match: only response.create is policy-checked. Empty / other
// types pass through untouched so we never accidentally strip fields
// from response.cancel, conversation.item.create, or any future
// client-event the spec adds. The Realtime spec requires "type" on
// every client event, so an empty type is malformed input — let the
// upstream reject it rather than guessing at our layer.
if frameType != "response.create" {
return frame, nil, nil
}
rawTier := gjson.GetBytes(frame, "service_tier").String()
if rawTier == "" {
return frame, nil, nil
}
normTier := normalizedOpenAIServiceTierValue(rawTier)
if normTier == "" {
return frame, nil, nil
}
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
switch action {
case BetaPolicyActionBlock:
msg := errMsg
if msg == "" {
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
}
return frame, &OpenAIFastBlockedError{Message: msg}, nil
case BetaPolicyActionFilter:
trimmed, err := sjson.DeleteBytes(frame, "service_tier")
if err != nil {
return frame, nil, fmt.Errorf("strip service_tier from ws frame: %w", err)
}
return trimmed, nil, nil
default:
return frame, nil, nil
}
}
// newOpenAIFastPolicyWSEventID returns a Realtime-style event_id for a
// server-emitted error event. Matches the loose "evt_<rand>" convention used
// by upstream Realtime servers; the exact value is not load-bearing and is
// only required for client-side log correlation. We reuse the existing
// google/uuid dependency rather than pulling a new one.
func newOpenAIFastPolicyWSEventID() string {
id, err := uuid.NewRandom()
if err != nil {
// Extremely unlikely; fall back to a fixed prefix so the field is
// still non-empty and the schema stays self-consistent.
return "evt_openai_fast_policy"
}
// Strip dashes so it visually matches "evt_<hex>" rather than UUID v4
// canonical form, mirroring what real Realtime traces look like.
return "evt_" + strings.ReplaceAll(id.String(), "-", "")
}
// buildOpenAIFastPolicyBlockedWSEvent renders an OpenAI Realtime/Responses
// style "error" event payload for a request blocked by the OpenAI fast
// policy. The shape mirrors Realtime error events as observed in upstream
// traces and per the spec's server "error" event:
//
// {
// "event_id": "evt_<random>",
// "type": "error",
// "error": {
// "type": "invalid_request_error",
// "code": "policy_violation",
// "message": "..."
// }
// }
//
// event_id lets clients correlate the rejection in their logs; "code" gives
// programmatic clients a stable identifier (HTTP-side equivalent is the
// 403 permission_error JSON body).
func buildOpenAIFastPolicyBlockedWSEvent(err *OpenAIFastBlockedError) []byte {
if err == nil {
return nil
}
eventID := newOpenAIFastPolicyWSEventID()
payload, mErr := json.Marshal(map[string]any{
"event_id": eventID,
"type": "error",
"error": map[string]any{
"type": "invalid_request_error",
"code": "policy_violation",
"message": err.Message,
},
})
if mErr != nil {
// Fallback to a minimal hand-rolled payload; Marshal of the literal
// shape above should never fail in practice.
return []byte(`{"event_id":"` + eventID + `","type":"error","error":{"type":"invalid_request_error","code":"policy_violation","message":"openai fast policy blocked this request"}}`)
}
return payload
}
func sanitizeEmptyBase64InputImagesInOpenAIBody(body []byte) ([]byte, bool, error) {
if len(body) == 0 || !bytes.Contains(body, []byte(`"image_url"`)) || !bytes.Contains(body, []byte(`base64,`)) {
return body, false, nil

View File

@@ -2366,6 +2366,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return errors.New("token is empty")
}
// 预取一次 OpenAI Fast Policy settings绑定到 ctx让该 WS session
// 内所有帧的 evaluateOpenAIFastPolicy 调用复用同一份快照,避免每帧
// 进入 DB / settingRepo。Trade-off 见 withOpenAIFastPolicyContext 注释。
if s.settingService != nil {
if settings, err := s.settingService.GetOpenAIFastPolicySettings(ctx); err == nil && settings != nil {
ctx = withOpenAIFastPolicyContext(ctx, settings)
}
}
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
ingressMode := OpenAIWSIngressModeCtxPool
@@ -2524,6 +2533,44 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
normalized = next
}
// Apply OpenAI Fast Policy on the response.create frame using the same
// evaluator/normalize/scope rules as the HTTP entrypoints. This is the
// single integration point for all WS ingress turns (first + follow-up
// frames flow through here).
//
// Model fallback: parseClientPayload above rejects any frame whose
// "model" field is missing (line ~2493-2500), so by the time we
// reach this point upstreamModel is always derived from a non-empty
// per-frame model. The capturedSessionModel fallback used in the
// passthrough adapter is therefore not needed in this path.
policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized)
if policyErr != nil {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr)
}
if blocked != nil {
// Send a Realtime-style error event to the client first, then
// signal the handler to close the connection with PolicyViolation.
// We intentionally do NOT forward this frame upstream.
//
// coder/websocket@v1.8.14 Conn.Write is synchronous and flushes
// the underlying bufio writer before returning (write.go:42 →
// 307-311), and the subsequent close handshake re-acquires the
// same writeFrameMu, so the error event is guaranteed to reach
// the kernel send buffer before any close frame is queued.
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
if eventBytes != nil {
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
cancel()
}
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
blocked.Message,
blocked,
)
}
normalized = policyApplied
return openAIWSClientPayload{
payloadRaw: normalized,
rawForHash: trimmed,

View File

@@ -618,6 +618,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)

View File

@@ -21,6 +21,109 @@ type openAIWSClientFrameConn struct {
conn *coderws.Conn
}
// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs
// every client→upstream frame through the OpenAI Fast Policy. It is the
// passthrough-relay equivalent of the parseClientPayload integration in the
// ingress session path. filter returns:
// - newPayload, nil, nil: forward the (possibly mutated) payload
// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error
// event via onBlock and surfaces a transport-level error so the relay
// stops reading from the client.
// - _, _, err: a transport error other than block.
type openAIWSPolicyEnforcingFrameConn struct {
inner openaiwsv2.FrameConn
filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error)
onBlock func(blocked *OpenAIFastBlockedError)
}
var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil)
func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
if c == nil || c.inner == nil {
return coderws.MessageText, nil, errOpenAIWSConnClosed
}
msgType, payload, err := c.inner.ReadFrame(ctx)
if err != nil {
return msgType, payload, err
}
if c.filter == nil {
return msgType, payload, nil
}
updated, blocked, filterErr := c.filter(msgType, payload)
if filterErr != nil {
return msgType, payload, filterErr
}
if blocked != nil {
if c.onBlock != nil {
c.onBlock(blocked)
}
return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
}
return msgType, updated, nil
}
func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
if c == nil || c.inner == nil {
return errOpenAIWSConnClosed
}
return c.inner.WriteFrame(ctx, msgType, payload)
}
func (c *openAIWSPolicyEnforcingFrameConn) Close() error {
if c == nil || c.inner == nil {
return nil
}
return c.inner.Close()
}
// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective
// model name that should be passed to evaluateOpenAIFastPolicy for a single
// passthrough WS frame. Mirrors the HTTP-side normalization
// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path
// matches model whitelists identically.
func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string {
if account == nil || len(payload) == 0 {
return ""
}
original := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if original == "" {
return ""
}
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
}
// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model
// derived from a session.update frame's session.model field. Returns "" when
// the frame is not a session.update event or carries no session.model. Used
// by the per-frame policy filter (client→upstream direction) to keep
// capturedSessionModel in sync with the session-level model the client may
// rotate mid-session.
//
// Realtime / Responses WS lets the client change the session model after
// the WS handshake via:
//
// {"type":"session.update","session":{"model":"gpt-5.5", ...}}
//
// If we only capture the model from the very first frame, a client can ship
// gpt-4o on the first response.create (whitelisted as pass), then
// session.update to gpt-5.5, then send response.create without "model" so
// the per-frame resolver returns "" and the stale capturedSessionModel falls
// back to gpt-4o — defeating the gpt-5.5 fast-policy filter.
func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string {
if account == nil || len(payload) == 0 {
return ""
}
frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
if frameType != "session.update" {
return ""
}
original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
if original == "" {
return ""
}
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
}
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
@@ -77,7 +180,6 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
return errors.New("token is empty")
}
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage)
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
logOpenAIWSV2Passthrough(
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
@@ -88,6 +190,59 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
len(firstClientMessage),
)
// Apply OpenAI Fast Policy on the first response.create frame. Subsequent
// frames are filtered via a wrapping FrameConn below so every client→
// upstream frame goes through the same policy evaluator/normalize/scope as
// HTTP entrypoints.
//
// We capture the session-level model from the first frame here so the
// per-frame filter (below) can fall back to it when a follow-up frame
// omits "model" — Realtime clients are allowed to send response.create
// without re-stating the model, in which case the upstream uses the model
// negotiated at session.update time. Without this fallback, an empty
// model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be
// silently passed through, defeating the policy on every frame after
// the first.
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
if policyErr != nil {
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
}
if blocked != nil {
// coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires
// writeFrameMu, writes the entire frame, and Flushes the underlying
// bufio writer before returning (write.go:42 → write.go:307-311).
// The subsequent close handshake re-acquires the same writeFrameMu
// to send the close frame, so the error event is guaranteed to
// reach the kernel send buffer before any close frame is queued.
// No explicit flush hop is required here.
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
if eventBytes != nil {
writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
cancelWrite()
}
return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
}
firstClientMessage = updatedFirst
// 在 policy filter 之后再提取 service_tier 用于 billing 上报filter
// 命中时 service_tier 已经从 firstClientMessage 中删除billing 应当
// 反映上游实际处理的 tiernil = default而不是用户最初请求的
// "priority"。HTTP 入口line ~2728 extractOpenAIServiceTier(reqBody)
// 与 WS ingressopenai_ws_forwarder.go:2991 取自 payload的语义一致。
//
// 多轮 passthroughOpenAI Realtime / Responses WS 协议允许客户端在
// 同一连接的不同 response.create 帧上发送不同 service_tier参考
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
// 因此使用 atomic.Pointer[string] 在 filterrunClientToUpstream
// goroutine和 OnTurnComplete / final resultrunUpstreamToClient
// goroutine之间同步当前 turn 的 service_tier。
// extractOpenAIServiceTierFromBody 返回 *string本身是指针类型
// 可直接 Store/Load 而无需额外封装。
var requestServiceTierPtr atomic.Pointer[string]
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil {
return fmt.Errorf("build ws url: %w", err)
@@ -152,9 +307,72 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
}
completedTurns := atomic.Int32{}
policyClientConn := &openAIWSPolicyEnforcingFrameConn{
inner: &openAIWSClientFrameConn{conn: clientConn},
// 注意线程安全filter 仅在 runClientToUpstream 这一条
// goroutine 中被调用passthrough_relay.go: ReadFrame loop
// capturedSessionModel 的读写都发生在该 goroutine 内,因此无需
// 加锁/原子化。
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
if msgType != coderws.MessageText {
return payload, nil, nil
}
// 在评估策略前先刷新 capturedSessionModel客户端可能通过
// session.update 修改 session-level modelRealtime /
// Responses WS 协议允许),如果不刷新就会出现
// "首帧 model=gpt-4opass→ session.update 改成 gpt-5.5
// → 不带 model 的 response.create fallback 到 gpt-4o" 的
// 绕过路径。这里只看 session.update 事件中的 session.model
// 字段response.create 自己的 model 仍然由其本帧字段决定。
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
capturedSessionModel = updated
}
// Per-frame model first; if the client omits "model" on a
// follow-up frame (legal in Realtime), fall back to the
// session-level model captured from the first frame so the
// model whitelist still resolves. An empty model would miss
// any whitelist and silently fall back to pass.
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
if model == "" {
model = capturedSessionModel
}
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
// 多轮 passthrough billing仅在成功non-block / non-err
// 的 response.create 帧上更新 requestServiceTierPtr使用
// filter 处理后的 payload与首帧 policy-after-extract 语义
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
// - 非 response.create 帧response.cancel /
// conversation.item.create / session.update 等)不携带
// per-response service_tier不应覆盖前一轮值。
// - blocked != nil该帧不会发送上游billing tier 应保持
// 上一轮值。
// - policyErr != nil异常路径保持上一轮值。
// - 不带 service_tier 的 response.create 会让
// extractOpenAIServiceTierFromBody 返回 nil这里有意
// 覆盖Store(nil)),因为 OpenAI 上游对该帧实际不传
// service_tier 时按 default 处理billing 应如实反映。
if policyErr == nil && blocked == nil &&
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
}
return out, blocked, policyErr
},
onBlock: func(blocked *OpenAIFastBlockedError) {
// See note above on Conn.Write being synchronous w.r.t. flush;
// no explicit flush is required to ensure the error event lands
// before the close frame.
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
if eventBytes == nil {
return
}
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
cancel()
},
}
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
Ctx: ctx,
ClientConn: &openAIWSClientFrameConn{conn: clientConn},
ClientConn: policyClientConn,
UpstreamConn: upstreamFrameConn,
FirstClientMessage: firstClientMessage,
Options: openaiwsv2.RelayOptions{
@@ -179,7 +397,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
},
Model: turn.RequestModel,
ServiceTier: requestServiceTier,
ServiceTier: requestServiceTierPtr.Load(),
Stream: true,
OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders),
@@ -227,7 +445,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
},
Model: relayResult.RequestModel,
ServiceTier: requestServiceTier,
ServiceTier: requestServiceTierPtr.Load(),
Stream: true,
OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders),

View File

@@ -3259,6 +3259,84 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be
return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data))
}
// GetOpenAIFastPolicySettings 获取 OpenAI fast 策略配置
func (s *SettingService) GetOpenAIFastPolicySettings(ctx context.Context) (*OpenAIFastPolicySettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpenAIFastPolicySettings)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return DefaultOpenAIFastPolicySettings(), nil
}
return nil, fmt.Errorf("get openai fast policy settings: %w", err)
}
if value == "" {
return DefaultOpenAIFastPolicySettings(), nil
}
var settings OpenAIFastPolicySettings
if err := json.Unmarshal([]byte(value), &settings); err != nil {
// JSON 损坏时静默 fallback 到默认配置会让策略意外失效(管理员配
// 置的 block/filter 规则被忽略)。记录 Warn 让运维能在出现异常
// 行为时定位到 settings 表里的脏数据。
slog.Warn("failed to unmarshal openai fast policy settings, falling back to defaults",
"error", err,
"key", SettingKeyOpenAIFastPolicySettings)
return DefaultOpenAIFastPolicySettings(), nil
}
return &settings, nil
}
// SetOpenAIFastPolicySettings 设置 OpenAI fast 策略配置
func (s *SettingService) SetOpenAIFastPolicySettings(ctx context.Context, settings *OpenAIFastPolicySettings) error {
if settings == nil {
return fmt.Errorf("settings cannot be nil")
}
validActions := map[string]bool{
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
}
validScopes := map[string]bool{
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true,
}
validTiers := map[string]bool{
OpenAIFastTierAny: true, OpenAIFastTierPriority: true, OpenAIFastTierFlex: true,
}
for i, rule := range settings.Rules {
tier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
if tier == "" {
tier = OpenAIFastTierAny
}
if !validTiers[tier] {
return fmt.Errorf("rule[%d]: invalid service_tier %q", i, rule.ServiceTier)
}
settings.Rules[i].ServiceTier = tier
if !validActions[rule.Action] {
return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action)
}
if !validScopes[rule.Scope] {
return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope)
}
for j, pattern := range rule.ModelWhitelist {
trimmed := strings.TrimSpace(pattern)
if trimmed == "" {
return fmt.Errorf("rule[%d]: model_whitelist[%d] cannot be empty", i, j)
}
settings.Rules[i].ModelWhitelist[j] = trimmed
}
if rule.FallbackAction != "" && !validActions[rule.FallbackAction] {
return fmt.Errorf("rule[%d]: invalid fallback_action %q", i, rule.FallbackAction)
}
}
data, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("marshal openai fast policy settings: %w", err)
}
return s.settingRepo.Set(ctx, SettingKeyOpenAIFastPolicySettings, string(data))
}
// SetStreamTimeoutSettings 设置流超时处理配置
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
if settings == nil {

View File

@@ -405,3 +405,57 @@ func DefaultBetaPolicySettings() *BetaPolicySettings {
},
}
}
// OpenAI Fast Policy 策略常量
// OpenAI 的 "fast 模式" 通过请求体中的 service_tier 字段识别:
// - "priority"(客户端可传 "fast",归一化为 "priority"fast 模式
// - "flex":低优先级模式
// - 省略normal 默认
//
// 本策略复用 BetaPolicyAction*/BetaPolicyScope* 常量语义,只是匹配键从
// anthropic-beta header 换成 body 的 service_tier 字段。
const (
OpenAIFastTierAny = "all" // 匹配任意已识别的 service_tier
OpenAIFastTierPriority = "priority" // 仅匹配 fastpriority
OpenAIFastTierFlex = "flex" // 仅匹配 flex
)
// OpenAIFastPolicyRule 单条 OpenAI fast/flex 策略规则
type OpenAIFastPolicyRule struct {
ServiceTier string `json:"service_tier"` // "priority" | "flex" | "auto" | "default" | "scale" | "all"
Action string `json:"action"` // "pass" | "filter" | "block"
Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock"
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
ModelWhitelist []string `json:"model_whitelist,omitempty"` // 模型匹配模式列表(为空=对所有模型生效)
FallbackAction string `json:"fallback_action,omitempty"` // 未匹配白名单的模型的处理方式
FallbackErrorMessage string `json:"fallback_error_message,omitempty"` // 未匹配白名单时的自定义错误消息 (fallback_action=block 时生效)
}
// OpenAIFastPolicySettings OpenAI fast 策略配置
type OpenAIFastPolicySettings struct {
Rules []OpenAIFastPolicyRule `json:"rules"`
}
// DefaultOpenAIFastPolicySettings 返回默认的 OpenAI fast 策略配置。
// 默认对所有模型的 priorityfast请求执行 filter即剔除 service_tier 字段,
// 让上游按 normal 优先级处理。
//
// 为什么 ModelWhitelist 为空(=对所有模型生效):
// codex 客户端的 service_tier=fast 是用户级开关,与 model 字段正交。即使
// 用户使用 gpt-4 + fastpriority 配额仍会被消耗。如果默认规则只锁
// gpt-5.5*"用 gpt-4 + fast 透传 priority 上游" 这条路径就会绕过策略。
// 与 codex 真实语义对齐,默认对所有模型生效;管理员若需要只针对特定
// 模型,可在 admin UI 中显式配置 model_whitelist。
func DefaultOpenAIFastPolicySettings() *OpenAIFastPolicySettings {
return &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{
{
ServiceTier: OpenAIFastTierPriority,
Action: BetaPolicyActionFilter,
Scope: BetaPolicyScopeAll,
ModelWhitelist: []string{},
FallbackAction: BetaPolicyActionPass,
},
},
}
}

View File

@@ -484,6 +484,9 @@ export interface SystemSettings {
// Affiliate (邀请返利) feature switch
affiliate_enabled: boolean;
// OpenAI fast/flex policy
openai_fast_policy_settings?: OpenAIFastPolicySettings;
}
export interface UpdateSettingsRequest {
@@ -648,6 +651,9 @@ export interface UpdateSettingsRequest {
// Affiliate (邀请返利) feature switch
affiliate_enabled?: boolean;
// OpenAI fast/flex policy
openai_fast_policy_settings?: OpenAIFastPolicySettings;
}
/**
@@ -875,6 +881,29 @@ export async function updateRectifierSettings(
return data;
}
// ==================== OpenAI Fast Policy Settings ====================
/**
* OpenAI fast/flex policy rule interface.
* Matches backend dto.OpenAIFastPolicyRule.
*/
export interface OpenAIFastPolicyRule {
service_tier: "all" | "priority" | "flex";
action: "pass" | "filter" | "block";
scope: "all" | "oauth" | "apikey" | "bedrock";
error_message?: string;
model_whitelist?: string[];
fallback_action?: "pass" | "filter" | "block";
fallback_error_message?: string;
}
/**
* OpenAI fast/flex policy settings interface.
*/
export interface OpenAIFastPolicySettings {
rules: OpenAIFastPolicyRule[];
}
// ==================== Beta Policy Settings ====================
/**

View File

@@ -5535,6 +5535,38 @@ export default {
presetOpusOnlyDesc: 'Pass for Opus, filter others',
commonPatterns: 'Common patterns'
},
openaiFastPolicy: {
title: 'OpenAI Fast/Flex Policy',
description: 'Intercept, filter, or pass OpenAI fast(priority) / flex requests based on the request body service_tier field. Applies to the OpenAI gateway only.',
empty: 'No rules configured. Click the button below to add one.',
ruleHeader: 'Rule #{index}',
removeRule: 'Remove rule',
addRule: 'Add rule',
saveHint: 'Saved together with system settings (click the global Save button at the bottom of the page).',
serviceTier: 'service_tier match',
tierAll: 'All tiers',
tierPriority: 'priority (fast)',
tierFlex: 'flex',
action: 'Action',
actionPass: 'Pass (keep service_tier)',
actionFilter: 'Filter (remove service_tier)',
actionBlock: 'Block (reject request)',
scope: 'Scope',
scopeAll: 'All accounts',
scopeOAuth: 'OAuth only',
scopeAPIKey: 'API Key only',
scopeBedrock: 'Bedrock only',
errorMessage: 'Error message',
errorMessagePlaceholder: 'Custom error message when blocked',
errorMessageHint: 'Leave empty for the default message.',
modelWhitelist: 'Model whitelist',
modelWhitelistHint: 'Leave empty to apply to all models. Supports exact match and wildcard prefix (e.g., gpt-5.5*).',
modelPatternPlaceholder: 'e.g., gpt-5.5 or gpt-5.5*',
addModelPattern: 'Add model pattern',
fallbackAction: 'Fallback action',
fallbackActionHint: 'Action for models not matching the whitelist.',
fallbackErrorMessagePlaceholder: 'Custom error message when non-whitelisted models are blocked'
},
wechatConnect: {
title: 'WeChat Connect',
description: 'Third-party login configuration for WeChat Open Platform or Official Account / Mini Program.',

View File

@@ -5695,6 +5695,38 @@ export default {
presetOpusOnlyDesc: 'Opus 透传,其他模型过滤',
commonPatterns: '常用模式'
},
openaiFastPolicy: {
title: 'OpenAI Fast/Flex 策略',
description: '基于请求体 service_tier 字段拦截/过滤/透传 OpenAI fast(priority) 与 flex 请求;仅作用于 OpenAI 网关。',
empty: '尚未配置任何规则。点击下方按钮新增。',
ruleHeader: '规则 #{index}',
removeRule: '删除规则',
addRule: '新增规则',
saveHint: '保存时随系统设置一起提交(点击页面底部「保存」按钮)。',
serviceTier: 'service_tier 匹配',
tierAll: '全部 tier',
tierPriority: 'priorityfast',
tierFlex: 'flex',
action: '处理方式',
actionPass: '透传(保留 service_tier',
actionFilter: '过滤(移除 service_tier',
actionBlock: '拦截(拒绝请求)',
scope: '生效范围',
scopeAll: '全部账号',
scopeOAuth: '仅 OAuth 账号',
scopeAPIKey: '仅 API Key 账号',
scopeBedrock: '仅 Bedrock 账号',
errorMessage: '错误消息',
errorMessagePlaceholder: '拦截时返回的自定义错误消息',
errorMessageHint: '留空则使用默认错误消息。',
modelWhitelist: '模型白名单',
modelWhitelistHint: '留空表示对所有模型生效;支持精确匹配与通配符(如 gpt-5.5*)。',
modelPatternPlaceholder: '例如: gpt-5.5 或 gpt-5.5*',
addModelPattern: '添加模型规则',
fallbackAction: '未匹配模型处理方式',
fallbackActionHint: '当请求模型不在白名单中时的处理方式。',
fallbackErrorMessagePlaceholder: '未匹配模型被拦截时返回的自定义错误消息'
},
wechatConnect: {
title: '微信登录',
description: '用于微信开放平台或公众号/小程序的第三方登录配置。',

View File

@@ -949,6 +949,285 @@
</template>
</div>
</div>
<!-- OpenAI Fast/Flex Policy Settings -->
<div class="card">
<div
class="border-b border-gray-100 px-6 py-4 dark:border-dark-700"
>
<h2 class="text-lg font-semibold text-gray-900 dark:text-white">
{{ t("admin.settings.openaiFastPolicy.title") }}
</h2>
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
{{ t("admin.settings.openaiFastPolicy.description") }}
</p>
</div>
<div class="space-y-5 p-6">
<!-- Empty state -->
<div
v-if="openaiFastPolicyForm.rules.length === 0"
class="rounded-lg border border-dashed border-gray-200 p-6 text-center text-sm text-gray-500 dark:border-dark-600 dark:text-gray-400"
>
{{ t("admin.settings.openaiFastPolicy.empty") }}
</div>
<!-- Rule Cards -->
<div
v-for="(rule, ruleIndex) in openaiFastPolicyForm.rules"
:key="ruleIndex"
class="rounded-lg border border-gray-200 p-4 dark:border-dark-600"
>
<div class="mb-3 flex items-center justify-between">
<span
class="text-sm font-medium text-gray-900 dark:text-white"
>
{{
t("admin.settings.openaiFastPolicy.ruleHeader", {
index: ruleIndex + 1,
})
}}
</span>
<button
type="button"
@click="removeOpenAIFastPolicyRule(ruleIndex)"
class="rounded p-1 text-red-400 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
:title="t('admin.settings.openaiFastPolicy.removeRule')"
>
<svg
class="h-4 w-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
</div>
<div class="grid grid-cols-1 gap-4 md:grid-cols-3">
<!-- Service Tier -->
<div>
<label
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
>
{{ t("admin.settings.openaiFastPolicy.serviceTier") }}
</label>
<Select
:modelValue="rule.service_tier"
@update:modelValue="
rule.service_tier = $event as
| 'all'
| 'priority'
| 'flex'
"
:options="openaiFastPolicyTierOptions"
/>
</div>
<!-- Action -->
<div>
<label
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
>
{{ t("admin.settings.openaiFastPolicy.action") }}
</label>
<Select
:modelValue="rule.action"
@update:modelValue="
rule.action = $event as 'pass' | 'filter' | 'block'
"
:options="openaiFastPolicyActionOptions"
/>
</div>
<!-- Scope -->
<div>
<label
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
>
{{ t("admin.settings.openaiFastPolicy.scope") }}
</label>
<Select
:modelValue="rule.scope"
@update:modelValue="
rule.scope = $event as
| 'all'
| 'oauth'
| 'apikey'
| 'bedrock'
"
:options="openaiFastPolicyScopeOptions"
/>
</div>
</div>
<!-- Error Message (only when action=block) -->
<div v-if="rule.action === 'block'" class="mt-3">
<label
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
>
{{ t("admin.settings.openaiFastPolicy.errorMessage") }}
</label>
<input
v-model="rule.error_message"
type="text"
class="input"
:placeholder="
t(
'admin.settings.openaiFastPolicy.errorMessagePlaceholder',
)
"
/>
<p class="mt-1 text-xs text-gray-400 dark:text-gray-500">
{{ t("admin.settings.openaiFastPolicy.errorMessageHint") }}
</p>
</div>
<!-- Model Whitelist -->
<div class="mt-3">
<label
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
>
{{ t("admin.settings.openaiFastPolicy.modelWhitelist") }}
</label>
<p class="mb-2 text-xs text-gray-400 dark:text-gray-500">
{{
t("admin.settings.openaiFastPolicy.modelWhitelistHint")
}}
</p>
<div
v-for="(_, patternIdx) in rule.model_whitelist || []"
:key="patternIdx"
class="mb-1.5 flex items-center gap-2"
>
<input
v-model="rule.model_whitelist![patternIdx]"
type="text"
class="input input-sm flex-1"
:placeholder="
t(
'admin.settings.openaiFastPolicy.modelPatternPlaceholder',
)
"
/>
<button
type="button"
@click="
removeOpenAIFastPolicyModelPattern(rule, patternIdx)
"
class="shrink-0 rounded p-1 text-red-400 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
>
<svg
class="h-4 w-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
</div>
<button
type="button"
@click="addOpenAIFastPolicyModelPattern(rule)"
class="mb-2 inline-flex items-center gap-1 text-xs text-primary-600 transition-colors hover:text-primary-700 dark:text-primary-400 dark:hover:text-primary-300"
>
<svg
class="h-3.5 w-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 4v16m8-8H4"
/>
</svg>
{{ t("admin.settings.openaiFastPolicy.addModelPattern") }}
</button>
</div>
<!-- Fallback Action (only when model_whitelist is non-empty) -->
<div
v-if="
rule.model_whitelist && rule.model_whitelist.length > 0
"
class="mt-3"
>
<label
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
>
{{ t("admin.settings.openaiFastPolicy.fallbackAction") }}
</label>
<Select
:modelValue="rule.fallback_action || 'pass'"
@update:modelValue="
rule.fallback_action = $event as
| 'pass'
| 'filter'
| 'block'
"
:options="openaiFastPolicyActionOptions"
/>
<p class="mt-1 text-xs text-gray-400 dark:text-gray-500">
{{
t("admin.settings.openaiFastPolicy.fallbackActionHint")
}}
</p>
<div v-if="rule.fallback_action === 'block'" class="mt-2">
<input
v-model="rule.fallback_error_message"
type="text"
class="input"
:placeholder="
t(
'admin.settings.openaiFastPolicy.fallbackErrorMessagePlaceholder',
)
"
/>
</div>
</div>
</div>
<!-- Add Rule Button -->
<div>
<button
type="button"
@click="addOpenAIFastPolicyRule"
class="btn btn-secondary btn-sm inline-flex items-center gap-1"
>
<svg
class="h-4 w-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M12 4v16m8-8H4"
/>
</svg>
{{ t("admin.settings.openaiFastPolicy.addRule") }}
</button>
<p class="mt-2 text-xs text-gray-400 dark:text-gray-500">
{{ t("admin.settings.openaiFastPolicy.saveHint") }}
</p>
</div>
</div>
</div>
</div>
<!-- /Tab: Gateway -->
@@ -5199,6 +5478,7 @@ import type {
SystemSettings,
UpdateSettingsRequest,
DefaultSubscriptionSetting,
OpenAIFastPolicyRule,
WeChatConnectMode,
WebSearchEmulationConfig,
WebSearchProviderConfig,
@@ -5337,6 +5617,14 @@ const betaPolicyForm = reactive({
}>,
});
// OpenAI Fast/Flex Policy 状态
const openaiFastPolicyForm = reactive({
rules: [] as OpenAIFastPolicyRule[],
});
// 标记 openai_fast_policy_settings 是否已成功从后端加载,
// 避免后端 GET 出错或字段缺失时,保存把默认规则覆盖成空数组。
const openaiFastPolicyLoaded = ref(false);
const tablePageSizeMin = 5;
const tablePageSizeMax = 1000;
const tablePageSizeDefault = 20;
@@ -6116,6 +6404,23 @@ async function loadSettings() {
);
form.oidc_connect_client_secret = "";
// Load OpenAI fast/flex policy rules from bulk settings.
// 仅当 payload 真的包含该字段时填充并标记为已加载;否则保持表单空值,
// 让 saveSettings 在未加载时跳过该字段,防止覆盖后端默认规则。
if (
settings.openai_fast_policy_settings &&
Array.isArray(settings.openai_fast_policy_settings.rules)
) {
openaiFastPolicyForm.rules =
settings.openai_fast_policy_settings.rules.map((rule) => ({
...rule,
model_whitelist: rule.model_whitelist
? [...rule.model_whitelist]
: [],
}));
openaiFastPolicyLoaded.value = true;
}
// Load web search emulation config separately
await loadWebSearchConfig();
} catch (error: unknown) {
@@ -6460,10 +6765,39 @@ async function saveSettings() {
affiliate_enabled: form.affiliate_enabled,
};
// 仅当 openai_fast_policy_settings 已成功从后端加载时才回写,
// 否则省略整个字段,让后端保留既有规则(含默认值)。
if (openaiFastPolicyLoaded.value) {
payload.openai_fast_policy_settings = {
rules: openaiFastPolicyForm.rules.map((rule) => {
const whitelist = (rule.model_whitelist || [])
.map((p) => p.trim())
.filter((p) => p !== "");
const hasWhitelist = whitelist.length > 0;
return {
service_tier: rule.service_tier,
action: rule.action,
scope: rule.scope,
error_message:
rule.action === "block" ? rule.error_message : undefined,
model_whitelist: hasWhitelist ? whitelist : undefined,
fallback_action: hasWhitelist
? rule.fallback_action || "pass"
: undefined,
fallback_error_message:
hasWhitelist && rule.fallback_action === "block"
? rule.fallback_error_message
: undefined,
};
}),
};
}
appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults);
const updated = await adminAPI.settings.updateSettings(payload);
for (const [key, value] of Object.entries(updated)) {
if (key === "openai_fast_policy_settings") continue;
if (value !== null && value !== undefined) {
(form as Record<string, unknown>)[key] = value;
}
@@ -6507,6 +6841,20 @@ async function saveSettings() {
form.wechat_connect_mode,
);
form.oidc_connect_client_secret = "";
// Refresh OpenAI fast/flex policy from server response
if (
updated.openai_fast_policy_settings &&
Array.isArray(updated.openai_fast_policy_settings.rules)
) {
openaiFastPolicyForm.rules =
updated.openai_fast_policy_settings.rules.map((rule) => ({
...rule,
model_whitelist: rule.model_whitelist
? [...rule.model_whitelist]
: [],
}));
openaiFastPolicyLoaded.value = true;
}
// Save web search emulation config separately (errors handled internally)
const wsOk = await saveWebSearchConfig();
// Refresh cached settings so sidebar/header update immediately
@@ -6846,6 +7194,61 @@ async function loadBetaPolicySettings() {
}
}
// ==================== OpenAI Fast/Flex Policy ====================
const openaiFastPolicyTierOptions = computed(() => [
{ value: "all", label: t("admin.settings.openaiFastPolicy.tierAll") },
{
value: "priority",
label: t("admin.settings.openaiFastPolicy.tierPriority"),
},
{ value: "flex", label: t("admin.settings.openaiFastPolicy.tierFlex") },
]);
const openaiFastPolicyActionOptions = computed(() => [
{ value: "pass", label: t("admin.settings.openaiFastPolicy.actionPass") },
{ value: "filter", label: t("admin.settings.openaiFastPolicy.actionFilter") },
{ value: "block", label: t("admin.settings.openaiFastPolicy.actionBlock") },
]);
const openaiFastPolicyScopeOptions = computed(() => [
{ value: "all", label: t("admin.settings.openaiFastPolicy.scopeAll") },
{ value: "oauth", label: t("admin.settings.openaiFastPolicy.scopeOAuth") },
{ value: "apikey", label: t("admin.settings.openaiFastPolicy.scopeAPIKey") },
{
value: "bedrock",
label: t("admin.settings.openaiFastPolicy.scopeBedrock"),
},
]);
function addOpenAIFastPolicyRule() {
openaiFastPolicyForm.rules.push({
service_tier: "priority",
action: "filter",
scope: "all",
error_message: "",
model_whitelist: [],
fallback_action: "pass",
fallback_error_message: "",
});
}
function removeOpenAIFastPolicyRule(index: number) {
openaiFastPolicyForm.rules.splice(index, 1);
}
function addOpenAIFastPolicyModelPattern(rule: OpenAIFastPolicyRule) {
if (!rule.model_whitelist) rule.model_whitelist = [];
rule.model_whitelist.push("");
}
function removeOpenAIFastPolicyModelPattern(
rule: OpenAIFastPolicyRule,
idx: number,
) {
rule.model_whitelist?.splice(idx, 1);
}
async function saveBetaPolicySettings() {
betaPolicySaving.value = true;
try {