diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index f767bbea..14b7db28 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -186,7 +186,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) diff --git a/backend/internal/handler/admin/admin_helpers_test.go b/backend/internal/handler/admin/admin_helpers_test.go index 3833d32e..6df49154 100644 --- a/backend/internal/handler/admin/admin_helpers_test.go +++ b/backend/internal/handler/admin/admin_helpers_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" @@ -222,3 +223,66 @@ func TestOpsWSHelpers(t *testing.T) { require.True(t, isAddrInTrustedProxies(addr, prefixes)) require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes)) } + +// TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier 验证 admin +// 写入路径会把 ServiceTier 的空字符串/空白/大小写归一化为 +// service.OpenAIFastTierAny ("all"),避免落盘时 "" 与 "all" 双语义。 +func TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier(t *testing.T) { + t.Run("nil input returns nil", func(t *testing.T) { + require.Nil(t, openaiFastPolicySettingsFromDTO(nil)) + }) + + t.Run("empty service_tier becomes 'all'", func(t *testing.T) { + in := &dto.OpenAIFastPolicySettings{ + Rules: []dto.OpenAIFastPolicyRule{{ + ServiceTier: "", + Action: "filter", + Scope: "all", + }}, + } + out := openaiFastPolicySettingsFromDTO(in) + require.NotNil(t, out) + require.Len(t, out.Rules, 1) + require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier) + require.Equal(t, "all", out.Rules[0].ServiceTier) + }) + + t.Run("whitespace-only service_tier becomes 'all'", func(t *testing.T) { + in := &dto.OpenAIFastPolicySettings{ + Rules: []dto.OpenAIFastPolicyRule{{ + ServiceTier: " ", + Action: "pass", + Scope: "all", + }}, + } + out := openaiFastPolicySettingsFromDTO(in) + require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier) + }) + + t.Run("uppercase service_tier is lowercased", func(t *testing.T) { + in := &dto.OpenAIFastPolicySettings{ + Rules: []dto.OpenAIFastPolicyRule{{ + ServiceTier: "PRIORITY", + Action: "filter", + Scope: "all", + }}, + } + out := openaiFastPolicySettingsFromDTO(in) + require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier) + }) + + t.Run("non-empty values pass through (lowercased)", func(t *testing.T) { + in := &dto.OpenAIFastPolicySettings{ + Rules: []dto.OpenAIFastPolicyRule{ + {ServiceTier: "priority", Action: "filter", Scope: "all"}, + {ServiceTier: "flex", Action: "block", Scope: "oauth"}, + {ServiceTier: "all", Action: "pass", Scope: "apikey"}, + }, + } + out := openaiFastPolicySettingsFromDTO(in) + require.Len(t, out.Rules, 3) + require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier) + require.Equal(t, service.OpenAIFastTierFlex, out.Rules[1].ServiceTier) + require.Equal(t, service.OpenAIFastTierAny, out.Rules[2].ServiceTier) + }) +} diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 320dbd6b..d6580191 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -248,9 +248,51 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { AffiliateEnabled: settings.AffiliateEnabled, } + + // OpenAI fast policy (stored under a dedicated setting key) + if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil { + slog.Error("openai_fast_policy_settings_get_failed", "error", err) + } else if fastPolicy != nil { + payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy) + } + response.Success(c, systemSettingsResponseData(payload, authSourceDefaults)) } +// openaiFastPolicySettingsToDTO converts service -> dto for OpenAI fast policy. +func openaiFastPolicySettingsToDTO(s *service.OpenAIFastPolicySettings) *dto.OpenAIFastPolicySettings { + if s == nil { + return nil + } + rules := make([]dto.OpenAIFastPolicyRule, len(s.Rules)) + for i, r := range s.Rules { + rules[i] = dto.OpenAIFastPolicyRule(r) + } + return &dto.OpenAIFastPolicySettings{Rules: rules} +} + +// openaiFastPolicySettingsFromDTO converts dto -> service for OpenAI fast policy. +// +// 规范化 ServiceTier:在 DTO 进入 service 层之前统一把空字符串归一为 +// service.OpenAIFastTierAny ("all"),避免管理员保存时空串与 "all" 同时 +// 表达"匹配任意 tier"造成数据库取值的二义性。其它非空值原样透传,由 +// service.SetOpenAIFastPolicySettings 负责合法值校验。 +func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.OpenAIFastPolicySettings { + if s == nil { + return nil + } + rules := make([]service.OpenAIFastPolicyRule, len(s.Rules)) + for i, r := range s.Rules { + rules[i] = service.OpenAIFastPolicyRule(r) + tier := strings.ToLower(strings.TrimSpace(rules[i].ServiceTier)) + if tier == "" { + tier = service.OpenAIFastTierAny + } + rules[i].ServiceTier = tier + } + return &service.OpenAIFastPolicySettings{Rules: rules} +} + // UpdateSettingsRequest 更新设置请求 type UpdateSettingsRequest struct { // 注册设置 @@ -452,6 +494,9 @@ type UpdateSettingsRequest struct { // Affiliate (邀请返利) feature switch AffiliateEnabled *bool `json:"affiliate_enabled"` + + // OpenAI fast/flex policy (optional, only updated when provided) + OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"` } // UpdateSettings 更新系统设置 @@ -1350,6 +1395,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { return } + // Update OpenAI fast policy (stored under dedicated key, only when provided). + if req.OpenAIFastPolicySettings != nil { + if err := h.settingService.SetOpenAIFastPolicySettings(c.Request.Context(), openaiFastPolicySettingsFromDTO(req.OpenAIFastPolicySettings)); err != nil { + response.BadRequest(c, err.Error()) + return + } + } + // Update payment configuration (integrated into system settings). // Skip if no payment fields were provided (prevents accidental wipe). if h.paymentConfigService != nil && hasPaymentFields(req) { @@ -1555,6 +1608,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { AffiliateEnabled: updatedSettings.AffiliateEnabled, } + if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil { + slog.Error("openai_fast_policy_settings_get_failed", "error", err) + } else if fastPolicy != nil { + payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy) + } response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults)) } diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go index 9a33a93a..085fd2ca 100644 --- a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go +++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go @@ -26,7 +26,12 @@ func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service. } func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) { - panic("unexpected GetValue call") + if s.values != nil { + if value, ok := s.values[key]; ok { + return value, nil + } + } + return "", nil } func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error { diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 92ae4dc6..b865d703 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -198,6 +198,9 @@ type SystemSettings struct { // Affiliate (邀请返利) feature switch AffiliateEnabled bool `json:"affiliate_enabled"` + + // OpenAI fast/flex policy + OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"` } type DefaultSubscriptionSetting struct { @@ -294,6 +297,22 @@ type BetaPolicySettings struct { Rules []BetaPolicyRule `json:"rules"` } +// OpenAIFastPolicyRule OpenAI fast/flex 策略规则 DTO +type OpenAIFastPolicyRule struct { + ServiceTier string `json:"service_tier"` + Action string `json:"action"` + Scope string `json:"scope"` + ErrorMessage string `json:"error_message,omitempty"` + ModelWhitelist []string `json:"model_whitelist,omitempty"` + FallbackAction string `json:"fallback_action,omitempty"` + FallbackErrorMessage string `json:"fallback_error_message,omitempty"` +} + +// OpenAIFastPolicySettings OpenAI fast 策略配置 DTO +type OpenAIFastPolicySettings struct { + Rules []OpenAIFastPolicyRule `json:"rules"` +} + // ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem. // Returns empty slice on empty/invalid input. func ParseCustomMenuItems(raw string) []CustomMenuItem { diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index ca6fd0cc..f24a1677 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -748,6 +748,16 @@ func TestAPIContracts(t *testing.T) { "payment_visible_method_alipay_enabled": true, "payment_visible_method_wxpay_enabled": false, "openai_advanced_scheduler_enabled": true, + "openai_fast_policy_settings": { + "rules": [ + { + "service_tier": "priority", + "action": "filter", + "scope": "all", + "fallback_action": "pass" + } + ] + }, "custom_menu_items": [], "custom_endpoints": [], "payment_enabled": false, @@ -930,6 +940,16 @@ func TestAPIContracts(t *testing.T) { "payment_visible_method_alipay_enabled": false, "payment_visible_method_wxpay_enabled": false, "openai_advanced_scheduler_enabled": false, + "openai_fast_policy_settings": { + "rules": [ + { + "service_tier": "priority", + "action": "filter", + "scope": "all", + "fallback_action": "pass" + } + ] + }, "payment_enabled": false, "payment_min_amount": 0, "payment_max_amount": 0, diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 0ef4a486..e1b175c3 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -306,6 +306,12 @@ const ( // SettingKeyBetaPolicySettings stores JSON config for beta policy rules. SettingKeyBetaPolicySettings = "beta_policy_settings" + // SettingKeyOpenAIFastPolicySettings stores JSON config for OpenAI + // service_tier (fast/flex) policy rules. Mirrors BetaPolicySettings but + // targets OpenAI's body-level service_tier field instead of Claude's + // anthropic-beta header. + SettingKeyOpenAIFastPolicySettings = "openai_fast_policy_settings" + // ========================= // Claude Code Version Check // ========================= diff --git a/backend/internal/service/openai_fast_policy_test.go b/backend/internal/service/openai_fast_policy_test.go new file mode 100644 index 00000000..b52da614 --- /dev/null +++ b/backend/internal/service/openai_fast_policy_test.go @@ -0,0 +1,286 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type openAIFastPolicyRepoStub struct { + values map[string]string +} + +func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) { + if v, ok := s.values[key]; ok { + return v, nil + } + return "", ErrSettingNotFound +} + +func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error { + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + return nil +} + +func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService { + t.Helper() + repo := &openAIFastPolicyRepoStub{values: map[string]string{}} + if settings != nil { + raw, err := json.Marshal(settings) + require.NoError(t, err) + repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw) + } + return &OpenAIGatewayService{ + settingService: NewSettingService(repo, &config.Config{}), + } +} + +func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // 默认策略对所有模型生效(whitelist 为空),因为 codex 的 service_tier=fast + // 是用户级开关,与 model 正交。 + // gpt-5.5 + priority → filter + action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority) + require.Equal(t, BetaPolicyActionFilter, action) + + // gpt-5.5-turbo → filter + action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority) + require.Equal(t, BetaPolicyActionFilter, action) + + // gpt-4 + priority → filter(默认策略覆盖所有模型) + action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority) + require.Equal(t, BetaPolicyActionFilter, action) + + // gpt-5.5 + flex → pass (tier doesn't match) + action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex) + require.Equal(t, BetaPolicyActionPass, action) + + // empty tier → pass + action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "") + require.Equal(t, BetaPolicyActionPass, action) +} + +func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "fast mode is not allowed", + ModelWhitelist: []string{"gpt-5.5"}, + FallbackAction: BetaPolicyActionPass, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority) + require.Equal(t, BetaPolicyActionBlock, action) + require.Equal(t, "fast mode is not allowed", msg) +} + +func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierAny, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeOAuth, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + + // OAuth account → rule matches + oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth} + action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority) + require.Equal(t, BetaPolicyActionFilter, action) + + // API Key account → rule skipped → pass + apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority) + require.Equal(t, BetaPolicyActionPass, action) +} + +func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // gpt-5.5 fast → service_tier stripped + body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`) + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err) + require.NotContains(t, string(updated), `"service_tier"`) + + // Client sending "fast" (alias for priority) also filtered + body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`) + updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err) + require.NotContains(t, string(updated), `"service_tier"`) + + // gpt-4 priority → 默认策略对所有模型 filter,service_tier 被移除 + body = []byte(`{"model":"gpt-4","service_tier":"priority"}`) + updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body) + require.NoError(t, err) + require.NotContains(t, string(updated), `"service_tier"`) + + // No service_tier → no-op + body = []byte(`{"model":"gpt-5.5"}`) + updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err) + require.Equal(t, string(body), string(updated)) +} + +// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后 +// 客户端显式发送的 OpenAI 官方合法 tier(auto/default/scale)能透传到上游而不被 +// 静默剥离。默认策略只针对 priority,所以这些 tier 落在 fall-through pass 分支。 +func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + for _, tier := range []string{"auto", "default", "scale"} { + body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`) + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err, "tier %q should pass without error", tier) + require.Contains(t, string(updated), `"service_tier":"`+tier+`"`, + "tier %q should be preserved in body under default rule", tier) + } + + // evaluate 层也应判定为 pass(默认规则 ServiceTier=priority 与 auto/default/scale 不匹配) + for _, tier := range []string{"auto", "default", "scale"} { + action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier) + require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier) + } +} + +// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置 +// ServiceTier=all + Action=filter 规则后,auto/default/scale 等官方 tier 也会 +// 被剥离。这是符合预期的——首条匹配 short-circuit,"all" 覆盖任意已识别 tier。 +func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierAny, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} { + body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`) + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err) + require.NotContains(t, string(updated), `"service_tier"`, + "tier %q should be stripped under ServiceTier=all + filter rule", tier) + } +} + +// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离 +// (normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段; +// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op,因为字段已不可能存在 +// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。 +func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // normalize 阶段会将未知值剥离 + require.Nil(t, normalizeOpenAIServiceTier("xxx")) + + // applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错,body 透传不变 + // (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离) + body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`) + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err) + require.Equal(t, string(body), string(updated)) +} + +func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "fast mode is blocked for gpt-5.5", + ModelWhitelist: []string{"gpt-5.5"}, + FallbackAction: BetaPolicyActionPass, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`) + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.Error(t, err) + var blocked *OpenAIFastBlockedError + require.True(t, errors.As(err, &blocked)) + require.Contains(t, blocked.Message, "fast mode is blocked") + require.Equal(t, string(body), string(updated)) // body not mutated on block +} + +func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) { + repo := &openAIFastPolicyRepoStub{values: map[string]string{}} + svc := NewSettingService(repo, &config.Config{}) + + // Invalid action rejected + err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: "bogus", + Scope: BetaPolicyScopeAll, + }}, + }) + require.Error(t, err) + + // Invalid service_tier rejected + err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: "turbo", + Action: BetaPolicyActionPass, + Scope: BetaPolicyScopeAll, + }}, + }) + require.Error(t, err) + + // Valid settings persisted + err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + }}, + }) + require.NoError(t, err) + + got, err := svc.GetOpenAIFastPolicySettings(context.Background()) + require.NoError(t, err) + require.Len(t, got.Rules, 1) + require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier) +} diff --git a/backend/internal/service/openai_fast_policy_ws_test.go b/backend/internal/service/openai_fast_policy_ws_test.go new file mode 100644 index 00000000..3316a242 --- /dev/null +++ b/backend/internal/service/openai_fast_policy_ws_test.go @@ -0,0 +1,1018 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +// --- Helper-level (unit) tests for applyOpenAIFastPolicyToWSResponseCreate --- + +func TestWSResponseCreate_FilterStripsServiceTier(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority","input":[{"type":"input_text","text":"hi"}]}`) + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.NotContains(t, string(updated), `"service_tier"`, "filter action should strip service_tier") + // Other fields preserved. + require.Equal(t, "response.create", gjson.GetBytes(updated, "type").String()) + require.Equal(t, "gpt-5.5", gjson.GetBytes(updated, "model").String()) + require.Equal(t, "hi", gjson.GetBytes(updated, "input.0.text").String()) +} + +func TestWSResponseCreate_FastNormalizedToPriorityThenFiltered(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Verbatim "fast" → normalized to "priority" → matches default rule → filter. + frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"fast"}`) + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.NotContains(t, string(updated), `"service_tier"`) + + // Mixed-case + whitespace variant should also normalize and filter. + frame = []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":" Fast "}`) + updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.NotContains(t, string(updated), `"service_tier"`) +} + +func TestWSResponseCreate_FlexPassThrough(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Default policy targets priority only; flex is left untouched. + frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"flex"}`) + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.Equal(t, "flex", gjson.GetBytes(updated, "service_tier").String(), "flex frames must reach upstream untouched under default policy") +} + +func TestWSResponseCreate_BlockReturnsTypedError(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "ws fast blocked", + ModelWhitelist: []string{"gpt-5.5"}, + FallbackAction: BetaPolicyActionPass, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`) + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.NotNil(t, blocked) + require.Equal(t, "ws fast blocked", blocked.Message) + // On block, payload returned unchanged so caller can inspect / log it. + require.Equal(t, string(frame), string(updated)) +} + +func TestWSResponseCreate_NoServiceTierUntouched(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + frame := []byte(`{"type":"response.create","model":"gpt-5.5","input":[]}`) + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.Equal(t, string(frame), string(updated), "no service_tier present must result in zero mutation") +} + +func TestWSResponseCreate_NonResponseCreateFrameUntouched(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + ModelWhitelist: []string{"*"}, + FallbackAction: BetaPolicyActionFilter, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // response.cancel happens to carry a service_tier-shaped field — must not be touched. + frame := []byte(`{"type":"response.cancel","service_tier":"priority"}`) + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.Equal(t, string(frame), string(updated)) +} + +// TestWSResponseCreate_EmptyTypeFrameUntouched is the A1 regression: the +// helper used to treat empty type as response.create, which risked stripping +// fields from malformed / unknown client events. After the A1 fix only a +// strict "response.create" match triggers policy. +func TestWSResponseCreate_EmptyTypeFrameUntouched(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + ModelWhitelist: []string{"*"}, + FallbackAction: BetaPolicyActionFilter, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Frame with no "type" field: must pass through completely unchanged + // even with a service_tier-shaped field present. + frame := []byte(`{"service_tier":"priority","model":"gpt-5.5"}`) + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.Equal(t, string(frame), string(updated), "empty type must NOT be policy-checked — Realtime spec requires type, malformed frames are passed through") + + // Explicit empty string also passes through. + frame = []byte(`{"type":"","service_tier":"priority","model":"gpt-5.5"}`) + updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.Equal(t, string(frame), string(updated)) +} + +// TestBuildOpenAIFastPolicyBlockedWSEvent_HasEventIDAndCode is the B1 +// regression: the rendered Realtime error event must carry a non-empty +// event_id (so clients can correlate the rejection) and a stable error.code +// ("policy_violation"). The HTTP-side equivalent is the 403 permission_error +// JSON body emitted by writeOpenAIFastPolicyBlockedResponse. +func TestBuildOpenAIFastPolicyBlockedWSEvent_HasEventIDAndCode(t *testing.T) { + bytes := buildOpenAIFastPolicyBlockedWSEvent(&OpenAIFastBlockedError{Message: "blocked because reasons"}) + require.NotNil(t, bytes) + + require.Equal(t, "error", gjson.GetBytes(bytes, "type").String()) + require.Equal(t, "invalid_request_error", gjson.GetBytes(bytes, "error.type").String()) + require.Equal(t, "policy_violation", gjson.GetBytes(bytes, "error.code").String()) + require.Equal(t, "blocked because reasons", gjson.GetBytes(bytes, "error.message").String()) + + eventID := gjson.GetBytes(bytes, "event_id").String() + require.NotEmpty(t, eventID, "event_id must be present so clients can correlate the rejection in their logs") + require.True(t, strings.HasPrefix(eventID, "evt_"), "event_id should follow the evt_ Realtime convention; got %q", eventID) + + // Sanity check: two consecutive events get distinct IDs. + other := buildOpenAIFastPolicyBlockedWSEvent(&OpenAIFastBlockedError{Message: "second"}) + otherID := gjson.GetBytes(other, "event_id").String() + require.NotEqual(t, eventID, otherID, "event_id must be random per-event") +} + +// TestBuildOpenAIFastPolicyBlockedWSEvent_NilSafe ensures the helper returns +// nil for a nil error (defensive guard for callers that always invoke it). +func TestBuildOpenAIFastPolicyBlockedWSEvent_NilSafe(t *testing.T) { + require.Nil(t, buildOpenAIFastPolicyBlockedWSEvent(nil)) +} + +// --- D5: passthrough wrapper FrameConn — capturedSessionModel fallback --- + +// fakePassthroughFrameConn replays a fixed sequence of client frames into the +// policy-enforcing wrapper, then returns io.EOF. Captures all Write attempts +// for write-side assertions (none expected in the D5 test, since the wrapper +// only filters reads). +type fakePassthroughFrameConn struct { + reads [][]byte + idx int + writes [][]byte + closeOnce bool +} + +func (f *fakePassthroughFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if f.idx >= len(f.reads) { + return coderws.MessageText, nil, errOpenAIWSConnClosed + } + payload := f.reads[f.idx] + f.idx++ + return coderws.MessageText, payload, nil +} + +func (f *fakePassthroughFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + cp := append([]byte(nil), payload...) + f.writes = append(f.writes, cp) + return nil +} + +func (f *fakePassthroughFrameConn) Close() error { + f.closeOnce = true + return nil +} + +// gpt55WhitelistFastPolicy 返回一份强制带 model whitelist 的策略,用于 +// 验证 capturedSessionModel fallback 的语义(默认策略 whitelist 为空时 +// fallback 路径无法被观察到)。 +func gpt55WhitelistFastPolicy() *OpenAIFastPolicySettings { + return &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + ModelWhitelist: []string{"gpt-5.5", "gpt-5.5*"}, + FallbackAction: BetaPolicyActionPass, + }}, + } +} + +// TestPolicyEnforcingFrameConn_FollowupFrameWithoutModelUsesCapturedModel is +// the D5 regression: in passthrough mode a follow-up response.create frame +// without a "model" field must still hit the policy via the session-level +// model captured from the first frame. Without the fallback an empty model +// would miss a model whitelist and silently leak service_tier=priority +// through to the upstream. +func TestPolicyEnforcingFrameConn_FollowupFrameWithoutModelUsesCapturedModel(t *testing.T) { + // 此处特意使用带 whitelist 的策略,以便观察 capturedSessionModel + // fallback 是否生效(默认策略 whitelist 为空,fallback 与否结果一致, + // 不能用来覆盖此回归)。 + svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Simulate the passthrough adapter capturing model from the first frame. + firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`) + capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstFrame) + require.Equal(t, "gpt-5.5", capturedSessionModel) + + // Follow-up frame deliberately omits "model" — Realtime allows this. + followupFrame := []byte(`{"type":"response.create","service_tier":"priority"}`) + + inner := &fakePassthroughFrameConn{ + reads: [][]byte{followupFrame}, + } + wrapper := &openAIWSPolicyEnforcingFrameConn{ + inner: inner, + filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) { + if msgType != coderws.MessageText { + return payload, nil, nil + } + model := openAIWSPassthroughPolicyModelForFrame(account, payload) + if model == "" { + model = capturedSessionModel + } + return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload) + }, + } + + // Read the follow-up frame through the wrapper. The policy MUST still + // trigger filter (gpt-5.5 + priority → filter), so the service_tier + // field is gone by the time the relay sees it. + _, payload, err := wrapper.ReadFrame(context.Background()) + require.NoError(t, err) + require.NotContains(t, string(payload), `"service_tier"`, + "D5 regression: empty model on follow-up frame must fall back to capturedSessionModel; whitelist policy filters service_tier=priority for gpt-5.5") + require.Equal(t, "response.create", gjson.GetBytes(payload, "type").String()) +} + +// TestPolicyEnforcingFrameConn_WithoutCapturedFallbackPolicyMisses pins the +// inverse: when the wrapper has NO capturedSessionModel fallback (model is +// empty per-frame and no fallback is wired up), the policy fails to match +// the model whitelist and the frame leaks through unchanged. This documents +// exactly the leak the D5 fix prevents. +func TestPolicyEnforcingFrameConn_WithoutCapturedFallbackPolicyMisses(t *testing.T) { + // 同样使用带 whitelist 的策略以观察 leak。 + svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + followupFrame := []byte(`{"type":"response.create","service_tier":"priority"}`) + inner := &fakePassthroughFrameConn{reads: [][]byte{followupFrame}} + wrapper := &openAIWSPolicyEnforcingFrameConn{ + inner: inner, + filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) { + // NO fallback — emulate the pre-fix behavior. + model := openAIWSPassthroughPolicyModelForFrame(account, payload) + return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload) + }, + } + + _, payload, err := wrapper.ReadFrame(context.Background()) + require.NoError(t, err) + // Pre-fix: empty model misses ["gpt-5.5","gpt-5.5*"] whitelist → fallback=pass → service_tier kept. + require.Contains(t, string(payload), `"service_tier"`, + "sanity: without capturedSessionModel fallback the leak (D5) reproduces — confirms the fix is load-bearing") +} + +// --- Ingress end-to-end test (filter path) --- + +// TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream wires up the +// real ProxyResponsesWebSocketFromClient ingress session pipeline against a +// captureConn upstream and asserts that a client frame with service_tier=fast +// is normalized + filtered out before being written upstream. This is the +// integration flavour of TestWSResponseCreate_FilterStripsServiceTier. +func TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_ws_filter_1","model":"gpt-5.5","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + repo := &openAIFastPolicyRepoStub{values: map[string]string{}} + defaultJSON, err := json.Marshal(DefaultOpenAIFastPolicySettings()) + require.NoError(t, err) + repo.values[SettingKeyOpenAIFastPolicySettings] = string(defaultJSON) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + settingService: NewSettingService(repo, cfg), + } + + account := &Account{ + ID: 901, + Name: "openai-ws-filter", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-test"}, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { _ = conn.CloseNow() }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + _, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { _ = clientConn.CloseNow() }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.5","stream":false,"service_tier":"fast"}`))) + cancelWrite() + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, event, readErr := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, readErr) + require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Len(t, captureConn.writes, 1, "上游应只收到一条 response.create") + upstream := captureConn.writes[0] + _, hasServiceTier := upstream["service_tier"] + require.False(t, hasServiceTier, "上游收到的 response.create 不应包含 service_tier 字段(已被 fast policy filter 删除)") + require.Equal(t, "response.create", upstream["type"]) + require.Equal(t, "gpt-5.5", upstream["model"]) +} + +// TestWSResponseCreate_IngressBlockSendsErrorEventAndSkipsUpstream is the +// integration flavour of TestWSResponseCreate_BlockReturnsTypedError. It +// asserts that with a custom block rule, the client receives a Realtime-style +// error event AND the upstream FrameConn never receives the offending frame. +func TestWSResponseCreate_IngressBlockSendsErrorEventAndSkipsUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + // No events queued; the upstream should never get written to anyway. + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + blockSettings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "ws priority blocked for testing", + ModelWhitelist: []string{"gpt-5.5"}, + FallbackAction: BetaPolicyActionPass, + }}, + } + repo := &openAIFastPolicyRepoStub{values: map[string]string{}} + raw, err := json.Marshal(blockSettings) + require.NoError(t, err) + repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + settingService: NewSettingService(repo, cfg), + } + + account := &Account{ + ID: 902, + Name: "openai-ws-block", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-test"}, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { _ = conn.CloseNow() }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + _, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + proxyErr := svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + // Mirror the production handler (openai_gateway_handler.go:1325-1328): + // when the proxy returns an OpenAIWSClientCloseError, surface its + // status code to the client via a graceful close handshake. Without + // this the deferred CloseNow() above would tear down the TCP + // connection without sending a close frame, and the C3 timing + // assertion (next read returns CloseStatus=1008) would see EOF + // instead. + var closeErr *OpenAIWSClientCloseError + if errors.As(proxyErr, &closeErr) { + _ = conn.Close(closeErr.StatusCode(), closeErr.Reason()) + } + serverErrCh <- proxyErr + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { _ = clientConn.CloseNow() }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.5","stream":false,"service_tier":"priority"}`))) + cancelWrite() + + // C3 timing assertion: the FIRST frame the client reads must be the + // error event — not a close frame. coder/websocket@v1.8.14 Conn.Write is + // synchronous (writeFrame Flushes the bufio writer at write.go:307-311 + // before returning) and the close handshake re-acquires the same + // writeFrameMu, so this ordering is enforced by the library itself; this + // assertion guards against future refactors that might break it. + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, event, readErr := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, readErr, "first read must succeed and return the error event before any close frame") + require.Equal(t, "error", gjson.GetBytes(event, "type").String()) + require.Equal(t, "invalid_request_error", gjson.GetBytes(event, "error.type").String()) + // B1 regression: event_id + error.code must be populated. + require.Equal(t, "policy_violation", gjson.GetBytes(event, "error.code").String()) + require.NotEmpty(t, gjson.GetBytes(event, "event_id").String(), "event_id must be present so clients can correlate") + require.Contains(t, gjson.GetBytes(event, "error.message").String(), "ws priority blocked for testing") + + // Next read must surface the close frame (as a CloseError). This + // asserts the [error event, close] ordering — i.e. the close did NOT + // race ahead of the data frame. + readCtx2, cancelRead2 := context.WithTimeout(context.Background(), 3*time.Second) + _, _, secondReadErr := clientConn.Read(readCtx2) + cancelRead2() + require.Error(t, secondReadErr, "after the error event the connection must surface a close") + require.Equal(t, coderws.StatusPolicyViolation, coderws.CloseStatus(secondReadErr), + "close status must be PolicyViolation; got %v", secondReadErr) + + select { + case serverErr := <-serverErrCh: + // Server returns an OpenAIWSClientCloseError — handler closes the WS; + // here we just assert it surfaced as the typed close error. + require.Error(t, serverErr) + var closeErr *OpenAIWSClientCloseError + require.True(t, errors.As(serverErr, &closeErr), "block 应返回 OpenAIWSClientCloseError,得到 %T: %v", serverErr, serverErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress 关闭超时") + } + + // Critical: the offending frame must NEVER reach the upstream. + // captureDialer.DialCount may legitimately be 0 or 1 depending on whether + // the lease was acquired before policy fired; either way, no writes. + require.Empty(t, captureConn.writes, "block 命中后上游不应收到 response.create") +} + +// --- HTTP-side gap-filling tests (already covered by existing tests but +// requested to be split out explicitly) --- + +// TestApplyOpenAIFastPolicyToBody_BlockShortCircuitsUpstream confirms that +// applyOpenAIFastPolicyToBody surfaces a *OpenAIFastBlockedError when the rule +// action is "block", and that the body is left untouched. The caller (chat +// completions / messages handlers) inspects this typed error and skips the +// upstream HTTP call entirely — see openai_gateway_chat_completions.go:175 and +// openai_gateway_messages.go:149. +func TestApplyOpenAIFastPolicyToBody_BlockShortCircuitsUpstream(t *testing.T) { + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "priority blocked", + ModelWhitelist: []string{"gpt-5.5"}, + FallbackAction: BetaPolicyActionPass, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + body := []byte(`{"model":"gpt-5.5","service_tier":"priority","input":[]}`) + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.Error(t, err) + var blocked *OpenAIFastBlockedError + require.True(t, errors.As(err, &blocked), "block must surface as typed error so caller can skip upstream HTTP request") + require.Equal(t, "priority blocked", blocked.Message) + require.Equal(t, string(body), string(updated), "block must not mutate body") +} + +// TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy verifies +// the Anthropic-compat entrypoint chain: anthropic-beta: fast-mode → BetaFastMode +// detection → ServiceTier="priority" injection (openai_gateway_messages.go:60) +// → applyOpenAIFastPolicyToBody filter on default policy → upstream body has +// no service_tier. We exercise the same internal pipeline (Anthropic→Responses +// + BetaFastMode + policy) without spinning up a real upstream HTTP server. +func TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Step 1: parse Anthropic request (mirrors openai_gateway_messages.go:38-50). + anthropicBody := []byte(`{"model":"gpt-5.5","max_tokens":64,"messages":[{"role":"user","content":"hi"}]}`) + var anthropicReq apicompat.AnthropicRequest + require.NoError(t, json.Unmarshal(anthropicBody, &anthropicReq)) + responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq) + require.NoError(t, err) + + // Step 2: BetaFastMode header → service_tier="priority" (mirrors line 58-61). + headers := http.Header{} + headers.Set("anthropic-beta", claude.BetaFastMode) + require.True(t, containsBetaToken(headers.Get("anthropic-beta"), claude.BetaFastMode)) + responsesReq.ServiceTier = "priority" + responsesReq.Model = "gpt-5.5" + + // Step 3: marshal & apply fast policy (mirrors line 78 + 149). + responsesBody, err := json.Marshal(responsesReq) + require.NoError(t, err) + require.Equal(t, "priority", gjson.GetBytes(responsesBody, "service_tier").String(), "前置:beta 翻译应当注入 priority") + + upstreamBody, policyErr := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", responsesBody) + require.NoError(t, policyErr) + + // Step 4: assert that policy filtered the field before the upstream HTTP request. + require.NotContains(t, string(upstreamBody), `"service_tier"`, "default policy 命中 gpt-5.5 priority 应当 filter 掉 service_tier") +} + +// --- Fix1: passthrough capturedSessionModel must follow session.update --- + +// TestPolicyEnforcingFrameConn_SessionUpdateRotatesCapturedModel covers the +// fix1 bypass: client opens with a whitelist-miss model (gpt-4o → pass under +// gpt-5.5 whitelist), rotates to gpt-5.5 via session.update, then sends +// response.create without "model". Without the session.update sniffing the +// follow-up frame would fall back to the stale gpt-4o capture and pass — the +// fix updates capturedSessionModel from session.* events so the fallback now +// resolves to gpt-5.5 and the policy filters service_tier. +func TestPolicyEnforcingFrameConn_SessionUpdateRotatesCapturedModel(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Frame 1: response.create with whitelist-miss model — under default + // rule fallback=pass, service_tier stays. + first := []byte(`{"type":"response.create","model":"gpt-4o","service_tier":"priority"}`) + // Frame 2: session.update rotates the session model to gpt-5.5. + rotate := []byte(`{"type":"session.update","session":{"model":"gpt-5.5"}}`) + // Frame 3: response.create WITHOUT model — must inherit gpt-5.5. + followup := []byte(`{"type":"response.create","service_tier":"priority"}`) + + inner := &fakePassthroughFrameConn{reads: [][]byte{first, rotate, followup}} + + // Replicate the production wiring in openai_ws_v2_passthrough_adapter.go + // so capturedSessionModel state is shared across frames. + capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, first) + require.Equal(t, "gpt-4o", capturedSessionModel) + wrapper := &openAIWSPolicyEnforcingFrameConn{ + inner: inner, + filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) { + if msgType != coderws.MessageText { + return payload, nil, nil + } + if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" { + capturedSessionModel = updated + } + model := openAIWSPassthroughPolicyModelForFrame(account, payload) + if model == "" { + model = capturedSessionModel + } + return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload) + }, + } + + // Frame 1: gpt-4o miss whitelist → pass (service_tier preserved). + _, payload1, err := wrapper.ReadFrame(context.Background()) + require.NoError(t, err) + require.Contains(t, string(payload1), `"service_tier"`, "frame1: gpt-4o miss whitelist → pass keeps service_tier") + + // Frame 2: session.update — not response.create, untouched, but its + // side effect updates capturedSessionModel to gpt-5.5. + _, payload2, err := wrapper.ReadFrame(context.Background()) + require.NoError(t, err) + require.Equal(t, string(rotate), string(payload2), "session.update frame is forwarded verbatim") + require.Equal(t, "gpt-5.5", capturedSessionModel, "fix1: session.update must rotate capturedSessionModel") + + // Frame 3: empty model + new captured gpt-5.5 → matches whitelist → filter. + _, payload3, err := wrapper.ReadFrame(context.Background()) + require.NoError(t, err) + require.NotContains(t, string(payload3), `"service_tier"`, + "fix1: post-rotate response.create without model must use refreshed capturedSessionModel and trigger filter") +} + +// TestPolicyModelFromSessionFrame_OnlySessionUpdate covers the negative +// branches of openAIWSPassthroughPolicyModelFromSessionFrame: only +// client→upstream session.update frames rotate the captured model; +// server→client events (session.created) and unrelated frames must not. +func TestPolicyModelFromSessionFrame_OnlySessionUpdate(t *testing.T) { + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // session.created is a server→client event in the OpenAI Realtime + // protocol — clients never send it, so this filter (which only runs on + // the client→upstream direction) must ignore it even if it appears. + created := []byte(`{"type":"session.created","session":{"model":"gpt-5.5"}}`) + require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, created)) + + // Non-session.* frames must NOT trigger rotation. + notSession := []byte(`{"type":"response.create","session":{"model":"gpt-9"}}`) + require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, notSession)) + + // Missing session.model returns empty — caller keeps the old captured value. + noModel := []byte(`{"type":"session.update","session":{"voice":"alloy"}}`) + require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, noModel)) +} + +// --- Fix2: native /responses normalize "fast" → "priority" on pass --- + +// TestApplyOpenAIFastPolicyToBody_PassNormalizesFastAlias is the fix2 +// regression. Before the fix, when action=pass, applyOpenAIFastPolicyToBody +// returned the body unchanged so a raw "fast" alias would leak to the +// upstream OpenAI API (which does not accept "fast"). The fix normalizes +// "fast" → "priority" on pass too. +func TestApplyOpenAIFastPolicyToBody_PassNormalizesFastAlias(t *testing.T) { + // Use a policy that deliberately misses gpt-4 so the action is pass. + settings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + ModelWhitelist: []string{"gpt-5.5"}, + FallbackAction: BetaPolicyActionPass, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, settings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // gpt-4 + "fast" → fallback pass. Body must be rewritten to "priority". + body := []byte(`{"model":"gpt-4","service_tier":"fast"}`) + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body) + require.NoError(t, err) + require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String(), + "fix2: pass action must still normalize 'fast' → 'priority' so upstream OpenAI accepts the slug") + + // Already-canonical "priority" on pass: zero mutation (byte-equal). + body = []byte(`{"model":"gpt-4","service_tier":"priority"}`) + updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body) + require.NoError(t, err) + require.Equal(t, string(body), string(updated)) + + // Mixed-case alias → normalized. + body = []byte(`{"model":"gpt-4","service_tier":" Fast "}`) + updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body) + require.NoError(t, err) + require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String()) + + // Unrecognized tier → still no-op (not normalized, since normTier == ""). + body = []byte(`{"model":"gpt-4","service_tier":"turbo"}`) + updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body) + require.NoError(t, err) + require.Equal(t, string(body), string(updated)) +} + +// --- Fix3: passthrough billing must reflect post-filter service_tier --- + +// TestPassthroughBilling_PostFilterServiceTier is the fix3 regression. The +// passthrough adapter (openai_ws_v2_passthrough_adapter.go) now extracts +// requestServiceTier from firstClientMessage AFTER applyOpenAIFastPolicy +// has rewritten it, so a filter hit causes billing to report nil (default +// tier) instead of the user-requested "priority". This test pins the +// contract those two helpers must uphold for the adapter's billing path. +func TestPassthroughBilling_PostFilterServiceTier(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + raw := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`) + + // Pre-filter sanity: extracting from the raw frame would (incorrectly, + // pre-fix) report "priority" — this is the very thing the adapter + // must NOT do anymore. + pre := extractOpenAIServiceTierFromBody(raw) + require.NotNil(t, pre) + require.Equal(t, "priority", *pre, + "sanity: raw first frame carries priority that pre-fix billing would have reported") + + // Apply policy filter (default rule: gpt-5.5 + priority → filter). + filtered, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", raw) + require.NoError(t, err) + require.Nil(t, blocked) + require.NotContains(t, string(filtered), `"service_tier"`) + + // Post-filter: extracting from the rewritten frame returns nil. This + // is the value the adapter now passes to OpenAIForwardResult.ServiceTier, + // so billing records "default" instead of "priority". + post := extractOpenAIServiceTierFromBody(filtered) + require.Nil(t, post, "fix3: post-filter extraction must return nil so passthrough billing reports default tier instead of the requested priority") + + // And the byte-level invariant the adapter relies on: filtering an + // already-filtered frame is a no-op (idempotent), so re-running the + // policy doesn't accidentally re-introduce the field. + again, blocked2, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", filtered) + require.NoError(t, err) + require.Nil(t, blocked2) + require.Equal(t, string(filtered), string(again), + "policy is idempotent: filtering an already-filtered frame leaves bytes unchanged") +} + +// TestApplyOpenAIFastPolicyToBody_NonStringServiceTier covers the test gap +// flagged in the review: when a client sends service_tier as a non-string +// (number, null, object, etc.) the policy must NOT panic and must NOT +// pretend the field was filtered. Behavior: skip policy entirely (treat as +// "no usable tier"), forward body unchanged. This mirrors the HTTP entry's +// type-assertion `reqBody["service_tier"].(string); ok` guard. +func TestApplyOpenAIFastPolicyToBody_NonStringServiceTier(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Number — gjson .String() coerces to "1" which is not a recognized + // tier alias; normalize returns "" → policy no-ops. + cases := [][]byte{ + []byte(`{"model":"gpt-5.5","service_tier":1}`), + []byte(`{"model":"gpt-5.5","service_tier":null}`), + []byte(`{"model":"gpt-5.5","service_tier":{"nested":"priority"}}`), + []byte(`{"model":"gpt-5.5","service_tier":["priority"]}`), + []byte(`{"model":"gpt-5.5","service_tier":true}`), + } + for _, body := range cases { + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err, "non-string service_tier must not error: %s", string(body)) + require.Equal(t, string(body), string(updated), + "non-string service_tier must pass through unchanged: %s", string(body)) + } + + // Same guard for the WS response.create entry. + for _, body := range cases { + frame := body + updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err, "non-string service_tier ws frame must not error: %s", string(frame)) + require.Nil(t, blocked, "non-string service_tier must not trigger block: %s", string(frame)) + require.Equal(t, string(frame), string(updated), + "non-string service_tier ws frame must pass through unchanged: %s", string(frame)) + } +} + +// TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames covers the +// multi-turn passthrough billing regression: OpenAI Realtime / Responses WS +// allows the client to ship a different service_tier on each response.create +// frame (per-response field, see codex-rs/core/src/client.rs +// build_responses_request which re-fills the field on every request). Before +// the fix the adapter only captured service_tier from firstClientMessage so +// turn 2/3 billing was wrong. After the fix the filter closure refreshes an +// atomic.Pointer[string] on every successful response.create frame. +// +// This test pins the four legs of the semantic contract: +// - turn 1: service_tier=priority hits the default whitelist filter, so +// after filter the upstream sees no tier → billing is nil. +// - turn 2: service_tier=flex passes (default rule targets priority only), +// billing should now reflect "flex". +// - turn 3: response.create without any service_tier — the upstream will +// treat it as default; we choose to mirror that and overwrite billing +// to nil rather than carry over "flex" from turn 2. +// - non-response.create frame (response.cancel here) carrying a stray +// service_tier-shaped field must NOT clobber the billing pointer. +func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + // Mirror the production filter closure (openai_ws_v2_passthrough_adapter.go + // proxyResponsesWebSocketV2Passthrough) so this test fails if the + // production code drops the per-frame Store. + var requestServiceTierPtr atomic.Pointer[string] + capturedSessionModel := "" + filter := func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) { + if msgType != coderws.MessageText { + return payload, nil, nil + } + if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" { + capturedSessionModel = updated + } + model := openAIWSPassthroughPolicyModelForFrame(account, payload) + if model == "" { + model = capturedSessionModel + } + out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload) + if policyErr == nil && blocked == nil && + strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" { + requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out)) + } + return out, blocked, policyErr + } + + // First-frame initialization mirrors the adapter: extract from the + // post-filter payload so a filter-on-first-frame zeroes billing too. + firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`) + firstOut, firstBlocked, firstErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", firstFrame) + require.NoError(t, firstErr) + require.Nil(t, firstBlocked) + requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstOut)) + capturedSessionModel = openAIWSPassthroughPolicyModelForFrame(account, firstFrame) + require.Nil(t, requestServiceTierPtr.Load(), + "turn 1: filter strips service_tier=priority, billing must reflect upstream-actual nil tier") + + // Turn 2: client switches to flex, should pass and update billing. + turn2 := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"flex"}`) + out2, blocked2, err2 := filter(coderws.MessageText, turn2) + require.NoError(t, err2) + require.Nil(t, blocked2) + require.Equal(t, "flex", gjson.GetBytes(out2, "service_tier").String(), "turn 2: flex must pass to upstream untouched") + tier2 := requestServiceTierPtr.Load() + require.NotNil(t, tier2, "turn 2: billing must update to reflect flex") + require.Equal(t, "flex", *tier2) + + // A non-response.create frame with a stray service_tier-shaped field + // must NOT overwrite the billing pointer (those frames don't carry + // per-response service_tier in the Realtime spec). + cancelFrame := []byte(`{"type":"response.cancel","service_tier":"priority"}`) + _, blockedCancel, errCancel := filter(coderws.MessageText, cancelFrame) + require.NoError(t, errCancel) + require.Nil(t, blockedCancel) + tierAfterCancel := requestServiceTierPtr.Load() + require.NotNil(t, tierAfterCancel, "response.cancel must not clobber billing tier to nil") + require.Equal(t, "flex", *tierAfterCancel, + "non-response.create frames must not update billing tier even if they carry a service_tier-shaped field") + + // Turn 3: response.create without any service_tier. We deliberately + // overwrite billing back to nil so it tracks what the upstream actually + // sees on this turn (default tier). + turn3 := []byte(`{"type":"response.create","model":"gpt-5.5"}`) + out3, blocked3, err3 := filter(coderws.MessageText, turn3) + require.NoError(t, err3) + require.Nil(t, blocked3) + require.Equal(t, string(turn3), string(out3), "turn 3 has no service_tier — filter must not mutate") + require.Nil(t, requestServiceTierPtr.Load(), + "turn 3: response.create without service_tier overwrites billing to nil to match upstream default") +} + +// TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the +// "block keeps previous" semantic: when policy returns block on a +// response.create frame, that frame is never sent upstream, so billing tier +// must keep the previous turn's value rather than getting silently zeroed. +func TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier(t *testing.T) { + blockSettings := &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "blocked", + ModelWhitelist: []string{"gpt-5.5"}, + FallbackAction: BetaPolicyActionPass, + }}, + } + svc := newOpenAIGatewayServiceWithSettings(t, blockSettings) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + var requestServiceTierPtr atomic.Pointer[string] + flexValue := "flex" + requestServiceTierPtr.Store(&flexValue) // simulate prior turn billed as flex + + filter := func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) { + if msgType != coderws.MessageText { + return payload, nil, nil + } + out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", payload) + if policyErr == nil && blocked == nil && + strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" { + requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out)) + } + return out, blocked, policyErr + } + + frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`) + _, blocked, err := filter(coderws.MessageText, frame) + require.NoError(t, err) + require.NotNil(t, blocked, "policy must block this frame") + + tier := requestServiceTierPtr.Load() + require.NotNil(t, tier, "blocked frame must not clobber prior billing tier to nil") + require.Equal(t, "flex", *tier, + "blocked frame is never sent upstream; billing must retain the previous turn's tier") +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 663066a3..5822ae4c 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -171,6 +171,17 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( } } + // 4b. Apply OpenAI fast policy (may filter service_tier or block the request). + updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody) + if policyErr != nil { + var blocked *OpenAIFastBlockedError + if errors.As(policyErr, &blocked) { + writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message) + } + return nil, policyErr + } + responsesBody = updatedBody + // 5. Get access token token, _, err := s.GetAccessToken(ctx, account) if err != nil { diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go index a00fb71c..6846e03a 100644 --- a/backend/internal/service/openai_gateway_chat_completions_test.go +++ b/backend/internal/service/openai_gateway_chat_completions_test.go @@ -19,8 +19,22 @@ func TestNormalizeResponsesRequestServiceTier(t *testing.T) { normalizeResponsesRequestServiceTier(req) require.Equal(t, "flex", req.ServiceTier) + // OpenAI 官方合法 tier 应被透传保留。 + req.ServiceTier = "auto" + normalizeResponsesRequestServiceTier(req) + require.Equal(t, "auto", req.ServiceTier) + req.ServiceTier = "default" normalizeResponsesRequestServiceTier(req) + require.Equal(t, "default", req.ServiceTier) + + req.ServiceTier = "scale" + normalizeResponsesRequestServiceTier(req) + require.Equal(t, "scale", req.ServiceTier) + + // 真未知值仍被剥离。 + req.ServiceTier = "turbo" + normalizeResponsesRequestServiceTier(req) require.Empty(t, req.ServiceTier) } @@ -37,8 +51,25 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) { require.Equal(t, "flex", tier) require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String()) + // OpenAI 官方 tier 直接保留在 body 中(透传上游)。 + body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"auto"}`)) + require.NoError(t, err) + require.Equal(t, "auto", tier) + require.Equal(t, "auto", gjson.GetBytes(body, "service_tier").String()) + body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`)) require.NoError(t, err) + require.Equal(t, "default", tier) + require.Equal(t, "default", gjson.GetBytes(body, "service_tier").String()) + + body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"scale"}`)) + require.NoError(t, err) + require.Equal(t, "scale", tier) + require.Equal(t, "scale", gjson.GetBytes(body, "service_tier").String()) + + // 真未知值才会被删除。 + body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"turbo"}`)) + require.NoError(t, err) require.Empty(t, tier) require.False(t, gjson.GetBytes(body, "service_tier").Exists()) } diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 2a0a72eb..4e0ebb2e 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -143,6 +143,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } } + // 4c. Apply OpenAI fast policy (may filter service_tier or block the request). + // Mirrors the Claude anthropic-beta "fast-mode-2026-02-01" filter, but keyed + // on the body-level service_tier field (priority/flex). + updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody) + if policyErr != nil { + var blocked *OpenAIFastBlockedError + if errors.As(policyErr, &blocked) { + writeAnthropicError(c, http.StatusForbidden, "forbidden_error", blocked.Message) + } + return nil, policyErr + } + responsesBody = updatedBody + // 5. Get access token token, _, err := s.GetAccessToken(ctx, account) if err != nil { diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 9665c4c8..47ff4e3b 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -148,6 +148,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U nil, nil, nil, + nil, ) svc.userGroupRateResolver = newUserGroupRateResolver( rateRepo, @@ -826,18 +827,29 @@ func TestNormalizeOpenAIServiceTier(t *testing.T) { require.Equal(t, "priority", *got) }) - t.Run("default ignored", func(t *testing.T) { - require.Nil(t, normalizeOpenAIServiceTier("default")) + t.Run("openai official tiers preserved", func(t *testing.T) { + // OpenAI 官方文档定义的合法 tier 值都应被透传保留,避免因白名单过窄 + // 静默剥离客户端显式发送的合法字段。Codex 客户端只发 priority/flex, + // 所以扩大白名单对 Codex 流量零影响(见 codex-rs/core/src/client.rs)。 + for _, tier := range []string{"priority", "flex", "auto", "default", "scale"} { + got := normalizeOpenAIServiceTier(tier) + require.NotNil(t, got, "tier %q should not be normalized to nil", tier) + require.Equal(t, tier, *got) + } }) t.Run("invalid ignored", func(t *testing.T) { require.Nil(t, normalizeOpenAIServiceTier("turbo")) + require.Nil(t, normalizeOpenAIServiceTier("xxx")) }) } func TestExtractOpenAIServiceTier(t *testing.T) { require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"})) require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"})) + require.Equal(t, "auto", *extractOpenAIServiceTier(map[string]any{"service_tier": "auto"})) + require.Equal(t, "default", *extractOpenAIServiceTier(map[string]any{"service_tier": "default"})) + require.Equal(t, "scale", *extractOpenAIServiceTier(map[string]any{"service_tier": "scale"})) require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1})) require.Nil(t, extractOpenAIServiceTier(nil)) } @@ -845,7 +857,10 @@ func TestExtractOpenAIServiceTier(t *testing.T) { func TestExtractOpenAIServiceTierFromBody(t *testing.T) { require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`))) require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`))) - require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`))) + require.Equal(t, "auto", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"auto"}`))) + require.Equal(t, "default", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`))) + require.Equal(t, "scale", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"scale"}`))) + require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"turbo"}`))) require.Nil(t, extractOpenAIServiceTierFromBody(nil)) } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 13e3ddab..a7407476 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -334,6 +334,7 @@ type OpenAIGatewayService struct { resolver *ModelPricingResolver channelService *ChannelService balanceNotifyService *BalanceNotifyService + settingService *SettingService openaiWSPoolOnce sync.Once openaiWSStateStoreOnce sync.Once @@ -372,6 +373,7 @@ func NewOpenAIGatewayService( resolver *ModelPricingResolver, channelService *ChannelService, balanceNotifyService *BalanceNotifyService, + settingService *SettingService, ) *OpenAIGatewayService { svc := &OpenAIGatewayService{ accountRepo: accountRepo, @@ -402,6 +404,7 @@ func NewOpenAIGatewayService( resolver: resolver, channelService: channelService, balanceNotifyService: balanceNotifyService, + settingService: settingService, responseHeaderFilter: compileResponseHeaderFilter(cfg), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } @@ -2310,6 +2313,48 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco disablePatch() } + // Apply OpenAI fast policy (参照 Claude BetaPolicy 的 fast-mode 过滤): + // 针对 body 的 service_tier 字段("priority" 即 fast,"flex"),按策略 + // 执行 filter(删除字段)或 block(拒绝请求)。对 gpt-5.5 等模型屏蔽 + // fast 时在此生效。 + // + // 注意: + // 1. 此处统一使用 upstreamModel(已经过 GetMappedModel + + // normalizeOpenAIModelForUpstream + Codex OAuth normalize),与 + // chat-completions / messages 入口保持一致,避免不同入口因为模型 + // 维度不同而出现 whitelist 命中差异。 + // 2. action=pass 时也要把 raw "fast" 归一化为 "priority" 写回 body, + // 否则 native /responses 入口透传 "fast" 给上游会被拒。chat- + // completions 入口由 normalizeResponsesBodyServiceTier 完成同一 + // 行为,这里手工实现等效逻辑。 + if rawTier, ok := reqBody["service_tier"].(string); ok { + if normTier := normalizedOpenAIServiceTierValue(rawTier); normTier != "" { + action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, upstreamModel, normTier) + switch action { + case BetaPolicyActionBlock: + msg := errMsg + if msg == "" { + msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, upstreamModel) + } + blocked := &OpenAIFastBlockedError{Message: msg} + writeOpenAIFastPolicyBlockedResponse(c, blocked) + return nil, blocked + case BetaPolicyActionFilter: + delete(reqBody, "service_tier") + bodyModified = true + disablePatch() + default: + // pass:若客户端传的是别名 "fast",归一化为 "priority" + // 后写回 body,确保上游收到的是其能识别的规范值。 + if normTier != rawTier { + reqBody["service_tier"] = normTier + bodyModified = true + markPatchSet("service_tier", normTier) + } + } + } + } + // Re-serialize body only if modified if bodyModified { serializedByPatch := false @@ -2758,6 +2803,26 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( body = sanitizedBody } + // Apply OpenAI fast policy to the passthrough body (filter/block by service_tier). + // 统一使用 upstream 视角的 model:透传路径下 body 已经过 compact 映射 + + // OAuth normalize,body 中的 model 字段即上游真正会看到的 slug。 + // 这样可以与 chat-completions / messages / native /responses 入口的 + // upstreamModel 保持一致,避免 whitelist 命中差异。当 body 中没有 + // model 字段时退回 reqModel。 + policyModel := strings.TrimSpace(gjson.GetBytes(body, "model").String()) + if policyModel == "" { + policyModel = reqModel + } + updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, policyModel, body) + if policyErr != nil { + var blocked *OpenAIFastBlockedError + if errors.As(policyErr, &blocked) { + writeOpenAIFastPolicyBlockedResponse(c, blocked) + } + return nil, policyErr + } + body = updatedBody + logger.LegacyPrintf("service.openai_gateway", "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", account.ID, @@ -5590,14 +5655,319 @@ func normalizeOpenAIServiceTier(raw string) *string { if value == "fast" { value = "priority" } + // 放过 OpenAI 官方文档定义的所有合法 tier 值:priority/flex/auto/default/scale。 + // 对 Codex 客户端零影响(Codex 只发 priority 或 flex,见 codex-rs/core/src/client.rs), + // 但能让直连 OpenAI SDK 的用户透传 auto/default/scale 以便抓包/调试。 + // 真未知值仍返回 nil,由 normalizeResponsesBodyServiceTier 从 body 中删除。 switch value { - case "priority", "flex": + case "priority", "flex", "auto", "default", "scale": return &value default: return nil } } +// OpenAIFastBlockedError indicates a request was rejected by the OpenAI fast +// policy (action=block). Mirrors BetaBlockedError on the Claude side. +type OpenAIFastBlockedError struct { + Message string +} + +func (e *OpenAIFastBlockedError) Error() string { return e.Message } + +// evaluateOpenAIFastPolicy returns the action and error message that should be +// applied for a request with the given account/model/service_tier. When the +// policy service is unavailable or no rule matches, it returns +// (BetaPolicyActionPass, "") so callers can short-circuit safely. +// +// Matching rules: +// - Scope filters by account type (all / oauth / apikey / bedrock) +// - ServiceTier must be empty (= any), "all", or equal the normalized tier +// - ModelWhitelist narrows the rule to specific models; FallbackAction +// handles the non-matching case (default: pass) +// +// 与 Claude BetaPolicy 的差异(保留首条匹配 short-circuit): +// - BetaPolicy 处理的是 anthropic-beta header 中的 token 集合,不同 +// 规则可能针对不同 token,filter 需要累加成 set;block 则 first-match。 +// - OpenAI fast policy 操作的是单个字段 service_tier:filter 即删字段, +// 没有可累加的对象。一次请求只携带一个 service_tier,规则的 tier +// 维度天然互斥;同一 (scope, tier) 下若多条规则的 model whitelist +// 发生重叠,admin 可通过规则顺序明确意图。因此采用 first-match 而 +// 非 BetaPolicy 那样的"block 覆盖 filter 覆盖 pass"语义。 +func (s *OpenAIGatewayService) evaluateOpenAIFastPolicy(ctx context.Context, account *Account, model, serviceTier string) (action, errMsg string) { + if s == nil || s.settingService == nil { + return BetaPolicyActionPass, "" + } + tier := strings.ToLower(strings.TrimSpace(serviceTier)) + if tier == "" { + return BetaPolicyActionPass, "" + } + settings := openAIFastPolicySettingsFromContext(ctx) + if settings == nil { + fetched, err := s.settingService.GetOpenAIFastPolicySettings(ctx) + if err != nil || fetched == nil { + return BetaPolicyActionPass, "" + } + settings = fetched + } + return evaluateOpenAIFastPolicyWithSettings(settings, account, model, tier) +} + +// evaluateOpenAIFastPolicyWithSettings is the pure-function core extracted so +// long-lived sessions (e.g. WS) can prefetch settings once and avoid hitting +// the settingService on every frame. See WSSession entry and +// openAIFastPolicySettingsFromContext for the caching glue. +func evaluateOpenAIFastPolicyWithSettings(settings *OpenAIFastPolicySettings, account *Account, model, tier string) (action, errMsg string) { + if settings == nil { + return BetaPolicyActionPass, "" + } + isOAuth := account != nil && account.IsOAuth() + isBedrock := account != nil && account.IsBedrock() + for _, rule := range settings.Rules { + if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { + continue + } + ruleTier := strings.ToLower(strings.TrimSpace(rule.ServiceTier)) + if ruleTier != "" && ruleTier != OpenAIFastTierAny && ruleTier != tier { + continue + } + eff := BetaPolicyRule{ + Action: rule.Action, + ErrorMessage: rule.ErrorMessage, + ModelWhitelist: rule.ModelWhitelist, + FallbackAction: rule.FallbackAction, + FallbackErrorMessage: rule.FallbackErrorMessage, + } + return resolveRuleAction(eff, model) + } + return BetaPolicyActionPass, "" +} + +// openAIFastPolicyCtxKey 是 context 中预取的 OpenAIFastPolicySettings 缓存 +// 键,仅用于 WebSocket 长会话内多帧复用同一份策略快照,避免每帧 DB 命中。 +// +// Trade-off:策略变更不会影响当前 WS session(只影响新 session)。这是 +// 有意为之 —— 对长会话来说,"策略一致性"比"立刻生效"更重要,且 Claude +// BetaPolicy 的 gin.Context 缓存也是同样取舍。需要 hot-reload 时管理员 +// 可以通过踢断 session 强制刷新。 +type openAIFastPolicyCtxKeyType struct{} + +var openAIFastPolicyCtxKey = openAIFastPolicyCtxKeyType{} + +// withOpenAIFastPolicyContext 将一份 settings 快照绑定到 context,供该 ctx +// 衍生 goroutine 中的 evaluateOpenAIFastPolicy 复用。 +func withOpenAIFastPolicyContext(ctx context.Context, settings *OpenAIFastPolicySettings) context.Context { + if ctx == nil || settings == nil { + return ctx + } + return context.WithValue(ctx, openAIFastPolicyCtxKey, settings) +} + +func openAIFastPolicySettingsFromContext(ctx context.Context) *OpenAIFastPolicySettings { + if ctx == nil { + return nil + } + if v, ok := ctx.Value(openAIFastPolicyCtxKey).(*OpenAIFastPolicySettings); ok { + return v + } + return nil +} + +// applyOpenAIFastPolicyToBody applies the OpenAI fast policy to a raw request +// body. When action=filter it removes the service_tier field; when +// action=block it returns (body, *OpenAIFastBlockedError). On pass it +// normalizes the service_tier value (e.g. client alias "fast" → "priority"), +// rewriting the body so the upstream receives a slug it recognizes. +// +// Rationale for normalize-on-pass: chat-completions / messages 入口在调用本 +// 函数之前已经通过 normalizeResponsesBodyServiceTier 把 service_tier 归一化 +// 到了上游可识别值;passthrough(OpenAI 自动透传) / native /responses 等 +// 入口没有这一前置步骤,pass 路径下若不在此处归一化,"fast" 就会被原样 +// 透传到 OpenAI 上游导致 400/拒绝。把归一化收敛到本函数,所有入口行为一致。 +func (s *OpenAIGatewayService) applyOpenAIFastPolicyToBody(ctx context.Context, account *Account, model string, body []byte) ([]byte, error) { + if len(body) == 0 { + return body, nil + } + rawTier := gjson.GetBytes(body, "service_tier").String() + if rawTier == "" { + return body, nil + } + normTier := normalizedOpenAIServiceTierValue(rawTier) + if normTier == "" { + return body, nil + } + action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier) + switch action { + case BetaPolicyActionBlock: + msg := errMsg + if msg == "" { + msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model) + } + return body, &OpenAIFastBlockedError{Message: msg} + case BetaPolicyActionFilter: + trimmed, err := sjson.DeleteBytes(body, "service_tier") + if err != nil { + return body, fmt.Errorf("strip service_tier from body: %w", err) + } + return trimmed, nil + default: + // pass:把别名(如 "fast")写回为规范值("priority")。 + if normTier == rawTier { + return body, nil + } + updated, err := sjson.SetBytes(body, "service_tier", normTier) + if err != nil { + return body, fmt.Errorf("normalize service_tier on pass: %w", err) + } + return updated, nil + } +} + +// writeOpenAIFastPolicyBlockedResponse writes a 403 JSON response for a +// request blocked by the OpenAI fast policy. +func writeOpenAIFastPolicyBlockedResponse(c *gin.Context, err *OpenAIFastBlockedError) { + if c == nil || err == nil { + return + } + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "permission_error", + "message": err.Message, + }, + }) +} + +// applyOpenAIFastPolicyToWSResponseCreate evaluates the OpenAI fast policy +// against a single client→upstream WebSocket frame whose top-level +// "type"=="response.create". It mirrors the HTTP-side +// applyOpenAIFastPolicyToBody contract but operates on a Realtime/Responses +// WS payload: +// +// - pass: returns frame unchanged (newBytes == frame, blocked == nil) +// - filter: returns a copy with top-level service_tier removed +// - block: returns (frame, *OpenAIFastBlockedError) +// +// Only frames whose "type" field strictly equals "response.create" are +// inspected/mutated. Any other frame type — including the empty string — +// passes through untouched. The OpenAI Realtime client-event spec requires +// "type" to be set, so an empty type is treated as a malformed frame we do +// not police; the upstream is the source of truth for rejecting it. +// +// service_tier lives at the top level of response.create — same as the +// Responses HTTP body shape (see openai_gateway_chat_completions.go:304 + +// extractOpenAIServiceTierFromBody at line 5593, and the test fixture at +// openai_ws_forwarder_ingress_session_test.go:402). We therefore only need +// to inspect / strip the top-level field; there is no nested form in the +// schema today. +// +// The caller is responsible for choosing the upstream model passed in — +// this helper does not re-derive it. +func (s *OpenAIGatewayService) applyOpenAIFastPolicyToWSResponseCreate( + ctx context.Context, + account *Account, + model string, + frame []byte, +) ([]byte, *OpenAIFastBlockedError, error) { + if len(frame) == 0 { + return frame, nil, nil + } + if !gjson.ValidBytes(frame) { + return frame, nil, nil + } + frameType := strings.TrimSpace(gjson.GetBytes(frame, "type").String()) + // Strict match: only response.create is policy-checked. Empty / other + // types pass through untouched so we never accidentally strip fields + // from response.cancel, conversation.item.create, or any future + // client-event the spec adds. The Realtime spec requires "type" on + // every client event, so an empty type is malformed input — let the + // upstream reject it rather than guessing at our layer. + if frameType != "response.create" { + return frame, nil, nil + } + rawTier := gjson.GetBytes(frame, "service_tier").String() + if rawTier == "" { + return frame, nil, nil + } + normTier := normalizedOpenAIServiceTierValue(rawTier) + if normTier == "" { + return frame, nil, nil + } + action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier) + switch action { + case BetaPolicyActionBlock: + msg := errMsg + if msg == "" { + msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model) + } + return frame, &OpenAIFastBlockedError{Message: msg}, nil + case BetaPolicyActionFilter: + trimmed, err := sjson.DeleteBytes(frame, "service_tier") + if err != nil { + return frame, nil, fmt.Errorf("strip service_tier from ws frame: %w", err) + } + return trimmed, nil, nil + default: + return frame, nil, nil + } +} + +// newOpenAIFastPolicyWSEventID returns a Realtime-style event_id for a +// server-emitted error event. Matches the loose "evt_" convention used +// by upstream Realtime servers; the exact value is not load-bearing and is +// only required for client-side log correlation. We reuse the existing +// google/uuid dependency rather than pulling a new one. +func newOpenAIFastPolicyWSEventID() string { + id, err := uuid.NewRandom() + if err != nil { + // Extremely unlikely; fall back to a fixed prefix so the field is + // still non-empty and the schema stays self-consistent. + return "evt_openai_fast_policy" + } + // Strip dashes so it visually matches "evt_" rather than UUID v4 + // canonical form, mirroring what real Realtime traces look like. + return "evt_" + strings.ReplaceAll(id.String(), "-", "") +} + +// buildOpenAIFastPolicyBlockedWSEvent renders an OpenAI Realtime/Responses +// style "error" event payload for a request blocked by the OpenAI fast +// policy. The shape mirrors Realtime error events as observed in upstream +// traces and per the spec's server "error" event: +// +// { +// "event_id": "evt_", +// "type": "error", +// "error": { +// "type": "invalid_request_error", +// "code": "policy_violation", +// "message": "..." +// } +// } +// +// event_id lets clients correlate the rejection in their logs; "code" gives +// programmatic clients a stable identifier (HTTP-side equivalent is the +// 403 permission_error JSON body). +func buildOpenAIFastPolicyBlockedWSEvent(err *OpenAIFastBlockedError) []byte { + if err == nil { + return nil + } + eventID := newOpenAIFastPolicyWSEventID() + payload, mErr := json.Marshal(map[string]any{ + "event_id": eventID, + "type": "error", + "error": map[string]any{ + "type": "invalid_request_error", + "code": "policy_violation", + "message": err.Message, + }, + }) + if mErr != nil { + // Fallback to a minimal hand-rolled payload; Marshal of the literal + // shape above should never fail in practice. + return []byte(`{"event_id":"` + eventID + `","type":"error","error":{"type":"invalid_request_error","code":"policy_violation","message":"openai fast policy blocked this request"}}`) + } + return payload +} + func sanitizeEmptyBase64InputImagesInOpenAIBody(body []byte) ([]byte, bool, error) { if len(body) == 0 || !bytes.Contains(body, []byte(`"image_url"`)) || !bytes.Contains(body, []byte(`base64,`)) { return body, false, nil diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 8c0222e2..dedbce1e 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -2366,6 +2366,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( return errors.New("token is empty") } + // 预取一次 OpenAI Fast Policy settings,绑定到 ctx,让该 WS session + // 内所有帧的 evaluateOpenAIFastPolicy 调用复用同一份快照,避免每帧 + // 进入 DB / settingRepo。Trade-off 见 withOpenAIFastPolicyContext 注释。 + if s.settingService != nil { + if settings, err := s.settingService.GetOpenAIFastPolicySettings(ctx); err == nil && settings != nil { + ctx = withOpenAIFastPolicyContext(ctx, settings) + } + } + wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled ingressMode := OpenAIWSIngressModeCtxPool @@ -2524,6 +2533,44 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( normalized = next } + // Apply OpenAI Fast Policy on the response.create frame using the same + // evaluator/normalize/scope rules as the HTTP entrypoints. This is the + // single integration point for all WS ingress turns (first + follow-up + // frames flow through here). + // + // Model fallback: parseClientPayload above rejects any frame whose + // "model" field is missing (line ~2493-2500), so by the time we + // reach this point upstreamModel is always derived from a non-empty + // per-frame model. The capturedSessionModel fallback used in the + // passthrough adapter is therefore not needed in this path. + policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized) + if policyErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr) + } + if blocked != nil { + // Send a Realtime-style error event to the client first, then + // signal the handler to close the connection with PolicyViolation. + // We intentionally do NOT forward this frame upstream. + // + // coder/websocket@v1.8.14 Conn.Write is synchronous and flushes + // the underlying bufio writer before returning (write.go:42 → + // 307-311), and the subsequent close handshake re-acquires the + // same writeFrameMu, so the error event is guaranteed to reach + // the kernel send buffer before any close frame is queued. + eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked) + if eventBytes != nil { + writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + _ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes) + cancel() + } + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + blocked.Message, + blocked, + ) + } + normalized = policyApplied + return openAIWSClientPayload{ payloadRaw: normalized, rawForHash: trimmed, diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index 66e5db93..f3936de1 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -618,6 +618,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { nil, nil, nil, + nil, ) decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go index cda2e351..3dbb199a 100644 --- a/backend/internal/service/openai_ws_v2_passthrough_adapter.go +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -21,6 +21,109 @@ type openAIWSClientFrameConn struct { conn *coderws.Conn } +// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs +// every client→upstream frame through the OpenAI Fast Policy. It is the +// passthrough-relay equivalent of the parseClientPayload integration in the +// ingress session path. filter returns: +// - newPayload, nil, nil: forward the (possibly mutated) payload +// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error +// event via onBlock and surfaces a transport-level error so the relay +// stops reading from the client. +// - _, _, err: a transport error other than block. +type openAIWSPolicyEnforcingFrameConn struct { + inner openaiwsv2.FrameConn + filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) + onBlock func(blocked *OpenAIFastBlockedError) +} + +var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil) + +func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if c == nil || c.inner == nil { + return coderws.MessageText, nil, errOpenAIWSConnClosed + } + msgType, payload, err := c.inner.ReadFrame(ctx) + if err != nil { + return msgType, payload, err + } + if c.filter == nil { + return msgType, payload, nil + } + updated, blocked, filterErr := c.filter(msgType, payload) + if filterErr != nil { + return msgType, payload, filterErr + } + if blocked != nil { + if c.onBlock != nil { + c.onBlock(blocked) + } + return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked) + } + return msgType, updated, nil +} + +func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if c == nil || c.inner == nil { + return errOpenAIWSConnClosed + } + return c.inner.WriteFrame(ctx, msgType, payload) +} + +func (c *openAIWSPolicyEnforcingFrameConn) Close() error { + if c == nil || c.inner == nil { + return nil + } + return c.inner.Close() +} + +// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective +// model name that should be passed to evaluateOpenAIFastPolicy for a single +// passthrough WS frame. Mirrors the HTTP-side normalization +// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path +// matches model whitelists identically. +func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string { + if account == nil || len(payload) == 0 { + return "" + } + original := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) + if original == "" { + return "" + } + return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original)) +} + +// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model +// derived from a session.update frame's session.model field. Returns "" when +// the frame is not a session.update event or carries no session.model. Used +// by the per-frame policy filter (client→upstream direction) to keep +// capturedSessionModel in sync with the session-level model the client may +// rotate mid-session. +// +// Realtime / Responses WS lets the client change the session model after +// the WS handshake via: +// +// {"type":"session.update","session":{"model":"gpt-5.5", ...}} +// +// If we only capture the model from the very first frame, a client can ship +// gpt-4o on the first response.create (whitelisted as pass), then +// session.update to gpt-5.5, then send response.create without "model" so +// the per-frame resolver returns "" and the stale capturedSessionModel falls +// back to gpt-4o — defeating the gpt-5.5 fast-policy filter. +func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string { + if account == nil || len(payload) == 0 { + return "" + } + frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String()) + if frameType != "session.update" { + return "" + } + original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String()) + if original == "" { + return "" + } + return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original)) +} + const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2" var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil) @@ -77,7 +180,6 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( return errors.New("token is empty") } requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String()) - requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage) requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String()) logOpenAIWSV2Passthrough( "relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d", @@ -88,6 +190,59 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( len(firstClientMessage), ) + // Apply OpenAI Fast Policy on the first response.create frame. Subsequent + // frames are filtered via a wrapping FrameConn below so every client→ + // upstream frame goes through the same policy evaluator/normalize/scope as + // HTTP entrypoints. + // + // We capture the session-level model from the first frame here so the + // per-frame filter (below) can fall back to it when a follow-up frame + // omits "model" — Realtime clients are allowed to send response.create + // without re-stating the model, in which case the upstream uses the model + // negotiated at session.update time. Without this fallback, an empty + // model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be + // silently passed through, defeating the policy on every frame after + // the first. + capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage) + updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage) + if policyErr != nil { + return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr) + } + if blocked != nil { + // coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires + // writeFrameMu, writes the entire frame, and Flushes the underlying + // bufio writer before returning (write.go:42 → write.go:307-311). + // The subsequent close handshake re-acquires the same writeFrameMu + // to send the close frame, so the error event is guaranteed to + // reach the kernel send buffer before any close frame is queued. + // No explicit flush hop is required here. + eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked) + if eventBytes != nil { + writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + _ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes) + cancelWrite() + } + return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked) + } + firstClientMessage = updatedFirst + + // 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter + // 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当 + // 反映上游实际处理的 tier(nil = default),而不是用户最初请求的 + // "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody)) + // 与 WS ingress(openai_ws_forwarder.go:2991 取自 payload)的语义一致。 + // + // 多轮 passthrough:OpenAI Realtime / Responses WS 协议允许客户端在 + // 同一连接的不同 response.create 帧上发送不同 service_tier(参考 + // codex-rs/core/src/client.rs build_responses_request 每次重新填值)。 + // 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream + // goroutine)和 OnTurnComplete / final result(runUpstreamToClient + // goroutine)之间同步当前 turn 的 service_tier。 + // extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型, + // 可直接 Store/Load 而无需额外封装。 + var requestServiceTierPtr atomic.Pointer[string] + requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage)) + wsURL, err := s.buildOpenAIResponsesWSURL(account) if err != nil { return fmt.Errorf("build ws url: %w", err) @@ -152,9 +307,72 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( } completedTurns := atomic.Int32{} + policyClientConn := &openAIWSPolicyEnforcingFrameConn{ + inner: &openAIWSClientFrameConn{conn: clientConn}, + // 注意线程安全:filter 仅在 runClientToUpstream 这一条 + // goroutine 中被调用(passthrough_relay.go: ReadFrame loop), + // capturedSessionModel 的读写都发生在该 goroutine 内,因此无需 + // 加锁/原子化。 + filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) { + if msgType != coderws.MessageText { + return payload, nil, nil + } + // 在评估策略前先刷新 capturedSessionModel:客户端可能通过 + // session.update 修改 session-level model(Realtime / + // Responses WS 协议允许),如果不刷新就会出现 + // "首帧 model=gpt-4o(pass)→ session.update 改成 gpt-5.5 + // → 不带 model 的 response.create fallback 到 gpt-4o" 的 + // 绕过路径。这里只看 session.update 事件中的 session.model + // 字段,response.create 自己的 model 仍然由其本帧字段决定。 + if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" { + capturedSessionModel = updated + } + // Per-frame model first; if the client omits "model" on a + // follow-up frame (legal in Realtime), fall back to the + // session-level model captured from the first frame so the + // model whitelist still resolves. An empty model would miss + // any whitelist and silently fall back to pass. + model := openAIWSPassthroughPolicyModelForFrame(account, payload) + if model == "" { + model = capturedSessionModel + } + out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload) + // 多轮 passthrough billing:仅在成功(non-block / non-err) + // 的 response.create 帧上更新 requestServiceTierPtr,使用 + // filter 处理后的 payload,与首帧 policy-after-extract 语义 + // 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。 + // - 非 response.create 帧(response.cancel / + // conversation.item.create / session.update 等)不携带 + // per-response service_tier,不应覆盖前一轮值。 + // - blocked != nil:该帧不会发送上游,billing tier 应保持 + // 上一轮值。 + // - policyErr != nil:异常路径,保持上一轮值。 + // - 不带 service_tier 的 response.create 会让 + // extractOpenAIServiceTierFromBody 返回 nil;这里有意 + // 覆盖(Store(nil)),因为 OpenAI 上游对该帧实际不传 + // service_tier 时按 default 处理,billing 应如实反映。 + if policyErr == nil && blocked == nil && + strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" { + requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out)) + } + return out, blocked, policyErr + }, + onBlock: func(blocked *OpenAIFastBlockedError) { + // See note above on Conn.Write being synchronous w.r.t. flush; + // no explicit flush is required to ensure the error event lands + // before the close frame. + eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked) + if eventBytes == nil { + return + } + writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + _ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes) + cancel() + }, + } relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{ Ctx: ctx, - ClientConn: &openAIWSClientFrameConn{conn: clientConn}, + ClientConn: policyClientConn, UpstreamConn: upstreamFrameConn, FirstClientMessage: firstClientMessage, Options: openaiwsv2.RelayOptions{ @@ -179,7 +397,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( CacheReadInputTokens: turn.Usage.CacheReadInputTokens, }, Model: turn.RequestModel, - ServiceTier: requestServiceTier, + ServiceTier: requestServiceTierPtr.Load(), Stream: true, OpenAIWSMode: true, ResponseHeaders: cloneHeader(handshakeHeaders), @@ -227,7 +445,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens, }, Model: relayResult.RequestModel, - ServiceTier: requestServiceTier, + ServiceTier: requestServiceTierPtr.Load(), Stream: true, OpenAIWSMode: true, ResponseHeaders: cloneHeader(handshakeHeaders), diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 33316031..966b4b84 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -3259,6 +3259,84 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data)) } +// GetOpenAIFastPolicySettings 获取 OpenAI fast 策略配置 +func (s *SettingService) GetOpenAIFastPolicySettings(ctx context.Context) (*OpenAIFastPolicySettings, error) { + value, err := s.settingRepo.GetValue(ctx, SettingKeyOpenAIFastPolicySettings) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + return DefaultOpenAIFastPolicySettings(), nil + } + return nil, fmt.Errorf("get openai fast policy settings: %w", err) + } + if value == "" { + return DefaultOpenAIFastPolicySettings(), nil + } + + var settings OpenAIFastPolicySettings + if err := json.Unmarshal([]byte(value), &settings); err != nil { + // JSON 损坏时静默 fallback 到默认配置会让策略意外失效(管理员配 + // 置的 block/filter 规则被忽略)。记录 Warn 让运维能在出现异常 + // 行为时定位到 settings 表里的脏数据。 + slog.Warn("failed to unmarshal openai fast policy settings, falling back to defaults", + "error", err, + "key", SettingKeyOpenAIFastPolicySettings) + return DefaultOpenAIFastPolicySettings(), nil + } + + return &settings, nil +} + +// SetOpenAIFastPolicySettings 设置 OpenAI fast 策略配置 +func (s *SettingService) SetOpenAIFastPolicySettings(ctx context.Context, settings *OpenAIFastPolicySettings) error { + if settings == nil { + return fmt.Errorf("settings cannot be nil") + } + + validActions := map[string]bool{ + BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true, + } + validScopes := map[string]bool{ + BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true, + } + validTiers := map[string]bool{ + OpenAIFastTierAny: true, OpenAIFastTierPriority: true, OpenAIFastTierFlex: true, + } + + for i, rule := range settings.Rules { + tier := strings.ToLower(strings.TrimSpace(rule.ServiceTier)) + if tier == "" { + tier = OpenAIFastTierAny + } + if !validTiers[tier] { + return fmt.Errorf("rule[%d]: invalid service_tier %q", i, rule.ServiceTier) + } + settings.Rules[i].ServiceTier = tier + if !validActions[rule.Action] { + return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action) + } + if !validScopes[rule.Scope] { + return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope) + } + for j, pattern := range rule.ModelWhitelist { + trimmed := strings.TrimSpace(pattern) + if trimmed == "" { + return fmt.Errorf("rule[%d]: model_whitelist[%d] cannot be empty", i, j) + } + settings.Rules[i].ModelWhitelist[j] = trimmed + } + if rule.FallbackAction != "" && !validActions[rule.FallbackAction] { + return fmt.Errorf("rule[%d]: invalid fallback_action %q", i, rule.FallbackAction) + } + } + + data, err := json.Marshal(settings) + if err != nil { + return fmt.Errorf("marshal openai fast policy settings: %w", err) + } + + return s.settingRepo.Set(ctx, SettingKeyOpenAIFastPolicySettings, string(data)) +} + // SetStreamTimeoutSettings 设置流超时处理配置 func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error { if settings == nil { diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 5ec7d313..c0962ff0 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -405,3 +405,57 @@ func DefaultBetaPolicySettings() *BetaPolicySettings { }, } } + +// OpenAI Fast Policy 策略常量 +// OpenAI 的 "fast 模式" 通过请求体中的 service_tier 字段识别: +// - "priority"(客户端可传 "fast",归一化为 "priority"):fast 模式 +// - "flex":低优先级模式 +// - 省略:normal 默认 +// +// 本策略复用 BetaPolicyAction*/BetaPolicyScope* 常量语义,只是匹配键从 +// anthropic-beta header 换成 body 的 service_tier 字段。 +const ( + OpenAIFastTierAny = "all" // 匹配任意已识别的 service_tier + OpenAIFastTierPriority = "priority" // 仅匹配 fast(priority) + OpenAIFastTierFlex = "flex" // 仅匹配 flex +) + +// OpenAIFastPolicyRule 单条 OpenAI fast/flex 策略规则 +type OpenAIFastPolicyRule struct { + ServiceTier string `json:"service_tier"` // "priority" | "flex" | "auto" | "default" | "scale" | "all" + Action string `json:"action"` // "pass" | "filter" | "block" + Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock" + ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效) + ModelWhitelist []string `json:"model_whitelist,omitempty"` // 模型匹配模式列表(为空=对所有模型生效) + FallbackAction string `json:"fallback_action,omitempty"` // 未匹配白名单的模型的处理方式 + FallbackErrorMessage string `json:"fallback_error_message,omitempty"` // 未匹配白名单时的自定义错误消息 (fallback_action=block 时生效) +} + +// OpenAIFastPolicySettings OpenAI fast 策略配置 +type OpenAIFastPolicySettings struct { + Rules []OpenAIFastPolicyRule `json:"rules"` +} + +// DefaultOpenAIFastPolicySettings 返回默认的 OpenAI fast 策略配置。 +// 默认对所有模型的 priority(fast)请求执行 filter,即剔除 service_tier 字段, +// 让上游按 normal 优先级处理。 +// +// 为什么 ModelWhitelist 为空(=对所有模型生效): +// codex 客户端的 service_tier=fast 是用户级开关,与 model 字段正交。即使 +// 用户使用 gpt-4 + fast,priority 配额仍会被消耗。如果默认规则只锁 +// gpt-5.5*,"用 gpt-4 + fast 透传 priority 上游" 这条路径就会绕过策略。 +// 与 codex 真实语义对齐,默认对所有模型生效;管理员若需要只针对特定 +// 模型,可在 admin UI 中显式配置 model_whitelist。 +func DefaultOpenAIFastPolicySettings() *OpenAIFastPolicySettings { + return &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{ + { + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + ModelWhitelist: []string{}, + FallbackAction: BetaPolicyActionPass, + }, + }, + } +} diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index defbab43..e8ab6af5 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -484,6 +484,9 @@ export interface SystemSettings { // Affiliate (邀请返利) feature switch affiliate_enabled: boolean; + + // OpenAI fast/flex policy + openai_fast_policy_settings?: OpenAIFastPolicySettings; } export interface UpdateSettingsRequest { @@ -648,6 +651,9 @@ export interface UpdateSettingsRequest { // Affiliate (邀请返利) feature switch affiliate_enabled?: boolean; + + // OpenAI fast/flex policy + openai_fast_policy_settings?: OpenAIFastPolicySettings; } /** @@ -875,6 +881,29 @@ export async function updateRectifierSettings( return data; } +// ==================== OpenAI Fast Policy Settings ==================== + +/** + * OpenAI fast/flex policy rule interface. + * Matches backend dto.OpenAIFastPolicyRule. + */ +export interface OpenAIFastPolicyRule { + service_tier: "all" | "priority" | "flex"; + action: "pass" | "filter" | "block"; + scope: "all" | "oauth" | "apikey" | "bedrock"; + error_message?: string; + model_whitelist?: string[]; + fallback_action?: "pass" | "filter" | "block"; + fallback_error_message?: string; +} + +/** + * OpenAI fast/flex policy settings interface. + */ +export interface OpenAIFastPolicySettings { + rules: OpenAIFastPolicyRule[]; +} + // ==================== Beta Policy Settings ==================== /** diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 6f445986..c66ca55b 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -5535,6 +5535,38 @@ export default { presetOpusOnlyDesc: 'Pass for Opus, filter others', commonPatterns: 'Common patterns' }, + openaiFastPolicy: { + title: 'OpenAI Fast/Flex Policy', + description: 'Intercept, filter, or pass OpenAI fast(priority) / flex requests based on the request body service_tier field. Applies to the OpenAI gateway only.', + empty: 'No rules configured. Click the button below to add one.', + ruleHeader: 'Rule #{index}', + removeRule: 'Remove rule', + addRule: 'Add rule', + saveHint: 'Saved together with system settings (click the global Save button at the bottom of the page).', + serviceTier: 'service_tier match', + tierAll: 'All tiers', + tierPriority: 'priority (fast)', + tierFlex: 'flex', + action: 'Action', + actionPass: 'Pass (keep service_tier)', + actionFilter: 'Filter (remove service_tier)', + actionBlock: 'Block (reject request)', + scope: 'Scope', + scopeAll: 'All accounts', + scopeOAuth: 'OAuth only', + scopeAPIKey: 'API Key only', + scopeBedrock: 'Bedrock only', + errorMessage: 'Error message', + errorMessagePlaceholder: 'Custom error message when blocked', + errorMessageHint: 'Leave empty for the default message.', + modelWhitelist: 'Model whitelist', + modelWhitelistHint: 'Leave empty to apply to all models. Supports exact match and wildcard prefix (e.g., gpt-5.5*).', + modelPatternPlaceholder: 'e.g., gpt-5.5 or gpt-5.5*', + addModelPattern: 'Add model pattern', + fallbackAction: 'Fallback action', + fallbackActionHint: 'Action for models not matching the whitelist.', + fallbackErrorMessagePlaceholder: 'Custom error message when non-whitelisted models are blocked' + }, wechatConnect: { title: 'WeChat Connect', description: 'Third-party login configuration for WeChat Open Platform or Official Account / Mini Program.', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index e399530b..77d1c93c 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -5695,6 +5695,38 @@ export default { presetOpusOnlyDesc: 'Opus 透传,其他模型过滤', commonPatterns: '常用模式' }, + openaiFastPolicy: { + title: 'OpenAI Fast/Flex 策略', + description: '基于请求体 service_tier 字段拦截/过滤/透传 OpenAI fast(priority) 与 flex 请求;仅作用于 OpenAI 网关。', + empty: '尚未配置任何规则。点击下方按钮新增。', + ruleHeader: '规则 #{index}', + removeRule: '删除规则', + addRule: '新增规则', + saveHint: '保存时随系统设置一起提交(点击页面底部「保存」按钮)。', + serviceTier: 'service_tier 匹配', + tierAll: '全部 tier', + tierPriority: 'priority(fast)', + tierFlex: 'flex', + action: '处理方式', + actionPass: '透传(保留 service_tier)', + actionFilter: '过滤(移除 service_tier)', + actionBlock: '拦截(拒绝请求)', + scope: '生效范围', + scopeAll: '全部账号', + scopeOAuth: '仅 OAuth 账号', + scopeAPIKey: '仅 API Key 账号', + scopeBedrock: '仅 Bedrock 账号', + errorMessage: '错误消息', + errorMessagePlaceholder: '拦截时返回的自定义错误消息', + errorMessageHint: '留空则使用默认错误消息。', + modelWhitelist: '模型白名单', + modelWhitelistHint: '留空表示对所有模型生效;支持精确匹配与通配符(如 gpt-5.5*)。', + modelPatternPlaceholder: '例如: gpt-5.5 或 gpt-5.5*', + addModelPattern: '添加模型规则', + fallbackAction: '未匹配模型处理方式', + fallbackActionHint: '当请求模型不在白名单中时的处理方式。', + fallbackErrorMessagePlaceholder: '未匹配模型被拦截时返回的自定义错误消息' + }, wechatConnect: { title: '微信登录', description: '用于微信开放平台或公众号/小程序的第三方登录配置。', diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 90d10b9a..ad0587b8 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -949,6 +949,285 @@ + +
+
+

+ {{ t("admin.settings.openaiFastPolicy.title") }} +

+

+ {{ t("admin.settings.openaiFastPolicy.description") }} +

+
+
+ +
+ {{ t("admin.settings.openaiFastPolicy.empty") }} +
+ + +
+
+ + {{ + t("admin.settings.openaiFastPolicy.ruleHeader", { + index: ruleIndex + 1, + }) + }} + + +
+ +
+ +
+ + +
+ + +
+ + +

+ {{ t("admin.settings.openaiFastPolicy.errorMessageHint") }} +

+
+ + +
+ +

+ {{ + t("admin.settings.openaiFastPolicy.modelWhitelistHint") + }} +

+
+ + +
+ +
+ + +
+ + +
+
+
+ + +
+ +

+ {{ t("admin.settings.openaiFastPolicy.saveHint") }} +

+
+
+
@@ -5199,6 +5478,7 @@ import type { SystemSettings, UpdateSettingsRequest, DefaultSubscriptionSetting, + OpenAIFastPolicyRule, WeChatConnectMode, WebSearchEmulationConfig, WebSearchProviderConfig, @@ -5337,6 +5617,14 @@ const betaPolicyForm = reactive({ }>, }); +// OpenAI Fast/Flex Policy 状态 +const openaiFastPolicyForm = reactive({ + rules: [] as OpenAIFastPolicyRule[], +}); +// 标记 openai_fast_policy_settings 是否已成功从后端加载, +// 避免后端 GET 出错或字段缺失时,保存把默认规则覆盖成空数组。 +const openaiFastPolicyLoaded = ref(false); + const tablePageSizeMin = 5; const tablePageSizeMax = 1000; const tablePageSizeDefault = 20; @@ -6116,6 +6404,23 @@ async function loadSettings() { ); form.oidc_connect_client_secret = ""; + // Load OpenAI fast/flex policy rules from bulk settings. + // 仅当 payload 真的包含该字段时填充并标记为已加载;否则保持表单空值, + // 让 saveSettings 在未加载时跳过该字段,防止覆盖后端默认规则。 + if ( + settings.openai_fast_policy_settings && + Array.isArray(settings.openai_fast_policy_settings.rules) + ) { + openaiFastPolicyForm.rules = + settings.openai_fast_policy_settings.rules.map((rule) => ({ + ...rule, + model_whitelist: rule.model_whitelist + ? [...rule.model_whitelist] + : [], + })); + openaiFastPolicyLoaded.value = true; + } + // Load web search emulation config separately await loadWebSearchConfig(); } catch (error: unknown) { @@ -6460,10 +6765,39 @@ async function saveSettings() { affiliate_enabled: form.affiliate_enabled, }; + // 仅当 openai_fast_policy_settings 已成功从后端加载时才回写, + // 否则省略整个字段,让后端保留既有规则(含默认值)。 + if (openaiFastPolicyLoaded.value) { + payload.openai_fast_policy_settings = { + rules: openaiFastPolicyForm.rules.map((rule) => { + const whitelist = (rule.model_whitelist || []) + .map((p) => p.trim()) + .filter((p) => p !== ""); + const hasWhitelist = whitelist.length > 0; + return { + service_tier: rule.service_tier, + action: rule.action, + scope: rule.scope, + error_message: + rule.action === "block" ? rule.error_message : undefined, + model_whitelist: hasWhitelist ? whitelist : undefined, + fallback_action: hasWhitelist + ? rule.fallback_action || "pass" + : undefined, + fallback_error_message: + hasWhitelist && rule.fallback_action === "block" + ? rule.fallback_error_message + : undefined, + }; + }), + }; + } + appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults); const updated = await adminAPI.settings.updateSettings(payload); for (const [key, value] of Object.entries(updated)) { + if (key === "openai_fast_policy_settings") continue; if (value !== null && value !== undefined) { (form as Record)[key] = value; } @@ -6507,6 +6841,20 @@ async function saveSettings() { form.wechat_connect_mode, ); form.oidc_connect_client_secret = ""; + // Refresh OpenAI fast/flex policy from server response + if ( + updated.openai_fast_policy_settings && + Array.isArray(updated.openai_fast_policy_settings.rules) + ) { + openaiFastPolicyForm.rules = + updated.openai_fast_policy_settings.rules.map((rule) => ({ + ...rule, + model_whitelist: rule.model_whitelist + ? [...rule.model_whitelist] + : [], + })); + openaiFastPolicyLoaded.value = true; + } // Save web search emulation config separately (errors handled internally) const wsOk = await saveWebSearchConfig(); // Refresh cached settings so sidebar/header update immediately @@ -6846,6 +7194,61 @@ async function loadBetaPolicySettings() { } } +// ==================== OpenAI Fast/Flex Policy ==================== + +const openaiFastPolicyTierOptions = computed(() => [ + { value: "all", label: t("admin.settings.openaiFastPolicy.tierAll") }, + { + value: "priority", + label: t("admin.settings.openaiFastPolicy.tierPriority"), + }, + { value: "flex", label: t("admin.settings.openaiFastPolicy.tierFlex") }, +]); + +const openaiFastPolicyActionOptions = computed(() => [ + { value: "pass", label: t("admin.settings.openaiFastPolicy.actionPass") }, + { value: "filter", label: t("admin.settings.openaiFastPolicy.actionFilter") }, + { value: "block", label: t("admin.settings.openaiFastPolicy.actionBlock") }, +]); + +const openaiFastPolicyScopeOptions = computed(() => [ + { value: "all", label: t("admin.settings.openaiFastPolicy.scopeAll") }, + { value: "oauth", label: t("admin.settings.openaiFastPolicy.scopeOAuth") }, + { value: "apikey", label: t("admin.settings.openaiFastPolicy.scopeAPIKey") }, + { + value: "bedrock", + label: t("admin.settings.openaiFastPolicy.scopeBedrock"), + }, +]); + +function addOpenAIFastPolicyRule() { + openaiFastPolicyForm.rules.push({ + service_tier: "priority", + action: "filter", + scope: "all", + error_message: "", + model_whitelist: [], + fallback_action: "pass", + fallback_error_message: "", + }); +} + +function removeOpenAIFastPolicyRule(index: number) { + openaiFastPolicyForm.rules.splice(index, 1); +} + +function addOpenAIFastPolicyModelPattern(rule: OpenAIFastPolicyRule) { + if (!rule.model_whitelist) rule.model_whitelist = []; + rule.model_whitelist.push(""); +} + +function removeOpenAIFastPolicyModelPattern( + rule: OpenAIFastPolicyRule, + idx: number, +) { + rule.model_whitelist?.splice(idx, 1); +} + async function saveBetaPolicySettings() { betaPolicySaving.value = true; try {