mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-04 21:20:51 +08:00
Merge pull request #2051 from DaydreamCoding/openai-fast-flex-policy
feat(openai): OpenAI Fast/Flex Policy 完整实现(HTTP + WebSocket + Admin)
This commit is contained in:
@@ -186,7 +186,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -222,3 +223,66 @@ func TestOpsWSHelpers(t *testing.T) {
|
||||
require.True(t, isAddrInTrustedProxies(addr, prefixes))
|
||||
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
|
||||
}
|
||||
|
||||
// TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier 验证 admin
|
||||
// 写入路径会把 ServiceTier 的空字符串/空白/大小写归一化为
|
||||
// service.OpenAIFastTierAny ("all"),避免落盘时 "" 与 "all" 双语义。
|
||||
func TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier(t *testing.T) {
|
||||
t.Run("nil input returns nil", func(t *testing.T) {
|
||||
require.Nil(t, openaiFastPolicySettingsFromDTO(nil))
|
||||
})
|
||||
|
||||
t.Run("empty service_tier becomes 'all'", func(t *testing.T) {
|
||||
in := &dto.OpenAIFastPolicySettings{
|
||||
Rules: []dto.OpenAIFastPolicyRule{{
|
||||
ServiceTier: "",
|
||||
Action: "filter",
|
||||
Scope: "all",
|
||||
}},
|
||||
}
|
||||
out := openaiFastPolicySettingsFromDTO(in)
|
||||
require.NotNil(t, out)
|
||||
require.Len(t, out.Rules, 1)
|
||||
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
|
||||
require.Equal(t, "all", out.Rules[0].ServiceTier)
|
||||
})
|
||||
|
||||
t.Run("whitespace-only service_tier becomes 'all'", func(t *testing.T) {
|
||||
in := &dto.OpenAIFastPolicySettings{
|
||||
Rules: []dto.OpenAIFastPolicyRule{{
|
||||
ServiceTier: " ",
|
||||
Action: "pass",
|
||||
Scope: "all",
|
||||
}},
|
||||
}
|
||||
out := openaiFastPolicySettingsFromDTO(in)
|
||||
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
|
||||
})
|
||||
|
||||
t.Run("uppercase service_tier is lowercased", func(t *testing.T) {
|
||||
in := &dto.OpenAIFastPolicySettings{
|
||||
Rules: []dto.OpenAIFastPolicyRule{{
|
||||
ServiceTier: "PRIORITY",
|
||||
Action: "filter",
|
||||
Scope: "all",
|
||||
}},
|
||||
}
|
||||
out := openaiFastPolicySettingsFromDTO(in)
|
||||
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
|
||||
})
|
||||
|
||||
t.Run("non-empty values pass through (lowercased)", func(t *testing.T) {
|
||||
in := &dto.OpenAIFastPolicySettings{
|
||||
Rules: []dto.OpenAIFastPolicyRule{
|
||||
{ServiceTier: "priority", Action: "filter", Scope: "all"},
|
||||
{ServiceTier: "flex", Action: "block", Scope: "oauth"},
|
||||
{ServiceTier: "all", Action: "pass", Scope: "apikey"},
|
||||
},
|
||||
}
|
||||
out := openaiFastPolicySettingsFromDTO(in)
|
||||
require.Len(t, out.Rules, 3)
|
||||
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
|
||||
require.Equal(t, service.OpenAIFastTierFlex, out.Rules[1].ServiceTier)
|
||||
require.Equal(t, service.OpenAIFastTierAny, out.Rules[2].ServiceTier)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -248,9 +248,51 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
|
||||
AffiliateEnabled: settings.AffiliateEnabled,
|
||||
}
|
||||
|
||||
// OpenAI fast policy (stored under a dedicated setting key)
|
||||
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
|
||||
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
|
||||
} else if fastPolicy != nil {
|
||||
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
|
||||
}
|
||||
|
||||
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
|
||||
}
|
||||
|
||||
// openaiFastPolicySettingsToDTO converts service -> dto for OpenAI fast policy.
|
||||
func openaiFastPolicySettingsToDTO(s *service.OpenAIFastPolicySettings) *dto.OpenAIFastPolicySettings {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
rules := make([]dto.OpenAIFastPolicyRule, len(s.Rules))
|
||||
for i, r := range s.Rules {
|
||||
rules[i] = dto.OpenAIFastPolicyRule(r)
|
||||
}
|
||||
return &dto.OpenAIFastPolicySettings{Rules: rules}
|
||||
}
|
||||
|
||||
// openaiFastPolicySettingsFromDTO converts dto -> service for OpenAI fast policy.
|
||||
//
|
||||
// 规范化 ServiceTier:在 DTO 进入 service 层之前统一把空字符串归一为
|
||||
// service.OpenAIFastTierAny ("all"),避免管理员保存时空串与 "all" 同时
|
||||
// 表达"匹配任意 tier"造成数据库取值的二义性。其它非空值原样透传,由
|
||||
// service.SetOpenAIFastPolicySettings 负责合法值校验。
|
||||
func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.OpenAIFastPolicySettings {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
rules := make([]service.OpenAIFastPolicyRule, len(s.Rules))
|
||||
for i, r := range s.Rules {
|
||||
rules[i] = service.OpenAIFastPolicyRule(r)
|
||||
tier := strings.ToLower(strings.TrimSpace(rules[i].ServiceTier))
|
||||
if tier == "" {
|
||||
tier = service.OpenAIFastTierAny
|
||||
}
|
||||
rules[i].ServiceTier = tier
|
||||
}
|
||||
return &service.OpenAIFastPolicySettings{Rules: rules}
|
||||
}
|
||||
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
@@ -452,6 +494,9 @@ type UpdateSettingsRequest struct {
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
AffiliateEnabled *bool `json:"affiliate_enabled"`
|
||||
|
||||
// OpenAI fast/flex policy (optional, only updated when provided)
|
||||
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -1350,6 +1395,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Update OpenAI fast policy (stored under dedicated key, only when provided).
|
||||
if req.OpenAIFastPolicySettings != nil {
|
||||
if err := h.settingService.SetOpenAIFastPolicySettings(c.Request.Context(), openaiFastPolicySettingsFromDTO(req.OpenAIFastPolicySettings)); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Update payment configuration (integrated into system settings).
|
||||
// Skip if no payment fields were provided (prevents accidental wipe).
|
||||
if h.paymentConfigService != nil && hasPaymentFields(req) {
|
||||
@@ -1555,6 +1608,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
|
||||
AffiliateEnabled: updatedSettings.AffiliateEnabled,
|
||||
}
|
||||
if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
|
||||
slog.Error("openai_fast_policy_settings_get_failed", "error", err)
|
||||
} else if fastPolicy != nil {
|
||||
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
|
||||
}
|
||||
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,12 @@ func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.
|
||||
}
|
||||
|
||||
func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
panic("unexpected GetValue call")
|
||||
if s.values != nil {
|
||||
if value, ok := s.values[key]; ok {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
|
||||
@@ -198,6 +198,9 @@ type SystemSettings struct {
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
AffiliateEnabled bool `json:"affiliate_enabled"`
|
||||
|
||||
// OpenAI fast/flex policy
|
||||
OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
@@ -294,6 +297,22 @@ type BetaPolicySettings struct {
|
||||
Rules []BetaPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// OpenAIFastPolicyRule OpenAI fast/flex 策略规则 DTO
|
||||
type OpenAIFastPolicyRule struct {
|
||||
ServiceTier string `json:"service_tier"`
|
||||
Action string `json:"action"`
|
||||
Scope string `json:"scope"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
ModelWhitelist []string `json:"model_whitelist,omitempty"`
|
||||
FallbackAction string `json:"fallback_action,omitempty"`
|
||||
FallbackErrorMessage string `json:"fallback_error_message,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIFastPolicySettings OpenAI fast 策略配置 DTO
|
||||
type OpenAIFastPolicySettings struct {
|
||||
Rules []OpenAIFastPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||
// Returns empty slice on empty/invalid input.
|
||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||
|
||||
@@ -748,6 +748,16 @@ func TestAPIContracts(t *testing.T) {
|
||||
"payment_visible_method_alipay_enabled": true,
|
||||
"payment_visible_method_wxpay_enabled": false,
|
||||
"openai_advanced_scheduler_enabled": true,
|
||||
"openai_fast_policy_settings": {
|
||||
"rules": [
|
||||
{
|
||||
"service_tier": "priority",
|
||||
"action": "filter",
|
||||
"scope": "all",
|
||||
"fallback_action": "pass"
|
||||
}
|
||||
]
|
||||
},
|
||||
"custom_menu_items": [],
|
||||
"custom_endpoints": [],
|
||||
"payment_enabled": false,
|
||||
@@ -930,6 +940,16 @@ func TestAPIContracts(t *testing.T) {
|
||||
"payment_visible_method_alipay_enabled": false,
|
||||
"payment_visible_method_wxpay_enabled": false,
|
||||
"openai_advanced_scheduler_enabled": false,
|
||||
"openai_fast_policy_settings": {
|
||||
"rules": [
|
||||
{
|
||||
"service_tier": "priority",
|
||||
"action": "filter",
|
||||
"scope": "all",
|
||||
"fallback_action": "pass"
|
||||
}
|
||||
]
|
||||
},
|
||||
"payment_enabled": false,
|
||||
"payment_min_amount": 0,
|
||||
"payment_max_amount": 0,
|
||||
|
||||
@@ -306,6 +306,12 @@ const (
|
||||
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
|
||||
SettingKeyBetaPolicySettings = "beta_policy_settings"
|
||||
|
||||
// SettingKeyOpenAIFastPolicySettings stores JSON config for OpenAI
|
||||
// service_tier (fast/flex) policy rules. Mirrors BetaPolicySettings but
|
||||
// targets OpenAI's body-level service_tier field instead of Claude's
|
||||
// anthropic-beta header.
|
||||
SettingKeyOpenAIFastPolicySettings = "openai_fast_policy_settings"
|
||||
|
||||
// =========================
|
||||
// Claude Code Version Check
|
||||
// =========================
|
||||
|
||||
286
backend/internal/service/openai_fast_policy_test.go
Normal file
286
backend/internal/service/openai_fast_policy_test.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIFastPolicyRepoStub struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
if v, ok := s.values[key]; ok {
|
||||
return v, nil
|
||||
}
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
if s.values == nil {
|
||||
s.values = map[string]string{}
|
||||
}
|
||||
s.values[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
panic("unexpected SetMultiple call")
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService {
|
||||
t.Helper()
|
||||
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
|
||||
if settings != nil {
|
||||
raw, err := json.Marshal(settings)
|
||||
require.NoError(t, err)
|
||||
repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
|
||||
}
|
||||
return &OpenAIGatewayService{
|
||||
settingService: NewSettingService(repo, &config.Config{}),
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) {
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
// 默认策略对所有模型生效(whitelist 为空),因为 codex 的 service_tier=fast
|
||||
// 是用户级开关,与 model 正交。
|
||||
// gpt-5.5 + priority → filter
|
||||
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionFilter, action)
|
||||
|
||||
// gpt-5.5-turbo → filter
|
||||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionFilter, action)
|
||||
|
||||
// gpt-4 + priority → filter(默认策略覆盖所有模型)
|
||||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionFilter, action)
|
||||
|
||||
// gpt-5.5 + flex → pass (tier doesn't match)
|
||||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex)
|
||||
require.Equal(t, BetaPolicyActionPass, action)
|
||||
|
||||
// empty tier → pass
|
||||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "")
|
||||
require.Equal(t, BetaPolicyActionPass, action)
|
||||
}
|
||||
|
||||
func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) {
|
||||
settings := &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierPriority,
|
||||
Action: BetaPolicyActionBlock,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
ErrorMessage: "fast mode is not allowed",
|
||||
ModelWhitelist: []string{"gpt-5.5"},
|
||||
FallbackAction: BetaPolicyActionPass,
|
||||
}},
|
||||
}
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionBlock, action)
|
||||
require.Equal(t, "fast mode is not allowed", msg)
|
||||
}
|
||||
|
||||
func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) {
|
||||
settings := &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierAny,
|
||||
Action: BetaPolicyActionFilter,
|
||||
Scope: BetaPolicyScopeOAuth,
|
||||
}},
|
||||
}
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||
|
||||
// OAuth account → rule matches
|
||||
oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionFilter, action)
|
||||
|
||||
// API Key account → rule skipped → pass
|
||||
apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority)
|
||||
require.Equal(t, BetaPolicyActionPass, action)
|
||||
}
|
||||
|
||||
func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
// gpt-5.5 fast → service_tier stripped
|
||||
body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
|
||||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(updated), `"service_tier"`)
|
||||
|
||||
// Client sending "fast" (alias for priority) also filtered
|
||||
body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
|
||||
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(updated), `"service_tier"`)
|
||||
|
||||
// gpt-4 priority → 默认策略对所有模型 filter,service_tier 被移除
|
||||
body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
|
||||
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(updated), `"service_tier"`)
|
||||
|
||||
// No service_tier → no-op
|
||||
body = []byte(`{"model":"gpt-5.5"}`)
|
||||
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(body), string(updated))
|
||||
}
|
||||
|
||||
// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后
|
||||
// 客户端显式发送的 OpenAI 官方合法 tier(auto/default/scale)能透传到上游而不被
|
||||
// 静默剥离。默认策略只针对 priority,所以这些 tier 落在 fall-through pass 分支。
|
||||
func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) {
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
for _, tier := range []string{"auto", "default", "scale"} {
|
||||
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
|
||||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err, "tier %q should pass without error", tier)
|
||||
require.Contains(t, string(updated), `"service_tier":"`+tier+`"`,
|
||||
"tier %q should be preserved in body under default rule", tier)
|
||||
}
|
||||
|
||||
// evaluate 层也应判定为 pass(默认规则 ServiceTier=priority 与 auto/default/scale 不匹配)
|
||||
for _, tier := range []string{"auto", "default", "scale"} {
|
||||
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier)
|
||||
require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置
|
||||
// ServiceTier=all + Action=filter 规则后,auto/default/scale 等官方 tier 也会
|
||||
// 被剥离。这是符合预期的——首条匹配 short-circuit,"all" 覆盖任意已识别 tier。
|
||||
func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) {
|
||||
settings := &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierAny,
|
||||
Action: BetaPolicyActionFilter,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
}},
|
||||
}
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} {
|
||||
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
|
||||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(updated), `"service_tier"`,
|
||||
"tier %q should be stripped under ServiceTier=all + filter rule", tier)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离
|
||||
// (normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段;
|
||||
// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op,因为字段已不可能存在
|
||||
// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。
|
||||
func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) {
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
// normalize 阶段会将未知值剥离
|
||||
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
|
||||
|
||||
// applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错,body 透传不变
|
||||
// (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离)
|
||||
body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`)
|
||||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(body), string(updated))
|
||||
}
|
||||
|
||||
func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) {
|
||||
settings := &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierPriority,
|
||||
Action: BetaPolicyActionBlock,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
ErrorMessage: "fast mode is blocked for gpt-5.5",
|
||||
ModelWhitelist: []string{"gpt-5.5"},
|
||||
FallbackAction: BetaPolicyActionPass,
|
||||
}},
|
||||
}
|
||||
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`)
|
||||
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
||||
require.Error(t, err)
|
||||
var blocked *OpenAIFastBlockedError
|
||||
require.True(t, errors.As(err, &blocked))
|
||||
require.Contains(t, blocked.Message, "fast mode is blocked")
|
||||
require.Equal(t, string(body), string(updated)) // body not mutated on block
|
||||
}
|
||||
|
||||
func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) {
|
||||
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
// Invalid action rejected
|
||||
err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierPriority,
|
||||
Action: "bogus",
|
||||
Scope: BetaPolicyScopeAll,
|
||||
}},
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
// Invalid service_tier rejected
|
||||
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: "turbo",
|
||||
Action: BetaPolicyActionPass,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
}},
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
// Valid settings persisted
|
||||
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{{
|
||||
ServiceTier: OpenAIFastTierPriority,
|
||||
Action: BetaPolicyActionFilter,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := svc.GetOpenAIFastPolicySettings(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, got.Rules, 1)
|
||||
require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier)
|
||||
}
|
||||
1018
backend/internal/service/openai_fast_policy_ws_test.go
Normal file
1018
backend/internal/service/openai_fast_policy_ws_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -171,6 +171,17 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
}
|
||||
}
|
||||
|
||||
// 4b. Apply OpenAI fast policy (may filter service_tier or block the request).
|
||||
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
|
||||
if policyErr != nil {
|
||||
var blocked *OpenAIFastBlockedError
|
||||
if errors.As(policyErr, &blocked) {
|
||||
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
|
||||
}
|
||||
return nil, policyErr
|
||||
}
|
||||
responsesBody = updatedBody
|
||||
|
||||
// 5. Get access token
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
|
||||
@@ -19,8 +19,22 @@ func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Equal(t, "flex", req.ServiceTier)
|
||||
|
||||
// OpenAI 官方合法 tier 应被透传保留。
|
||||
req.ServiceTier = "auto"
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Equal(t, "auto", req.ServiceTier)
|
||||
|
||||
req.ServiceTier = "default"
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Equal(t, "default", req.ServiceTier)
|
||||
|
||||
req.ServiceTier = "scale"
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Equal(t, "scale", req.ServiceTier)
|
||||
|
||||
// 真未知值仍被剥离。
|
||||
req.ServiceTier = "turbo"
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Empty(t, req.ServiceTier)
|
||||
}
|
||||
|
||||
@@ -37,8 +51,25 @@ func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
|
||||
require.Equal(t, "flex", tier)
|
||||
require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
|
||||
|
||||
// OpenAI 官方 tier 直接保留在 body 中(透传上游)。
|
||||
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"auto"}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "auto", tier)
|
||||
require.Equal(t, "auto", gjson.GetBytes(body, "service_tier").String())
|
||||
|
||||
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "default", tier)
|
||||
require.Equal(t, "default", gjson.GetBytes(body, "service_tier").String())
|
||||
|
||||
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"scale"}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "scale", tier)
|
||||
require.Equal(t, "scale", gjson.GetBytes(body, "service_tier").String())
|
||||
|
||||
// 真未知值才会被删除。
|
||||
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"turbo"}`))
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, tier)
|
||||
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
|
||||
}
|
||||
|
||||
@@ -143,6 +143,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
}
|
||||
}
|
||||
|
||||
// 4c. Apply OpenAI fast policy (may filter service_tier or block the request).
|
||||
// Mirrors the Claude anthropic-beta "fast-mode-2026-02-01" filter, but keyed
|
||||
// on the body-level service_tier field (priority/flex).
|
||||
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
|
||||
if policyErr != nil {
|
||||
var blocked *OpenAIFastBlockedError
|
||||
if errors.As(policyErr, &blocked) {
|
||||
writeAnthropicError(c, http.StatusForbidden, "forbidden_error", blocked.Message)
|
||||
}
|
||||
return nil, policyErr
|
||||
}
|
||||
responsesBody = updatedBody
|
||||
|
||||
// 5. Get access token
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
|
||||
@@ -148,6 +148,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
rateRepo,
|
||||
@@ -826,18 +827,29 @@ func TestNormalizeOpenAIServiceTier(t *testing.T) {
|
||||
require.Equal(t, "priority", *got)
|
||||
})
|
||||
|
||||
t.Run("default ignored", func(t *testing.T) {
|
||||
require.Nil(t, normalizeOpenAIServiceTier("default"))
|
||||
t.Run("openai official tiers preserved", func(t *testing.T) {
|
||||
// OpenAI 官方文档定义的合法 tier 值都应被透传保留,避免因白名单过窄
|
||||
// 静默剥离客户端显式发送的合法字段。Codex 客户端只发 priority/flex,
|
||||
// 所以扩大白名单对 Codex 流量零影响(见 codex-rs/core/src/client.rs)。
|
||||
for _, tier := range []string{"priority", "flex", "auto", "default", "scale"} {
|
||||
got := normalizeOpenAIServiceTier(tier)
|
||||
require.NotNil(t, got, "tier %q should not be normalized to nil", tier)
|
||||
require.Equal(t, tier, *got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid ignored", func(t *testing.T) {
|
||||
require.Nil(t, normalizeOpenAIServiceTier("turbo"))
|
||||
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractOpenAIServiceTier(t *testing.T) {
|
||||
require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"}))
|
||||
require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"}))
|
||||
require.Equal(t, "auto", *extractOpenAIServiceTier(map[string]any{"service_tier": "auto"}))
|
||||
require.Equal(t, "default", *extractOpenAIServiceTier(map[string]any{"service_tier": "default"}))
|
||||
require.Equal(t, "scale", *extractOpenAIServiceTier(map[string]any{"service_tier": "scale"}))
|
||||
require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1}))
|
||||
require.Nil(t, extractOpenAIServiceTier(nil))
|
||||
}
|
||||
@@ -845,7 +857,10 @@ func TestExtractOpenAIServiceTier(t *testing.T) {
|
||||
func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
|
||||
require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`)))
|
||||
require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`)))
|
||||
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
|
||||
require.Equal(t, "auto", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"auto"}`)))
|
||||
require.Equal(t, "default", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
|
||||
require.Equal(t, "scale", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"scale"}`)))
|
||||
require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"turbo"}`)))
|
||||
require.Nil(t, extractOpenAIServiceTierFromBody(nil))
|
||||
}
|
||||
|
||||
|
||||
@@ -334,6 +334,7 @@ type OpenAIGatewayService struct {
|
||||
resolver *ModelPricingResolver
|
||||
channelService *ChannelService
|
||||
balanceNotifyService *BalanceNotifyService
|
||||
settingService *SettingService
|
||||
|
||||
openaiWSPoolOnce sync.Once
|
||||
openaiWSStateStoreOnce sync.Once
|
||||
@@ -372,6 +373,7 @@ func NewOpenAIGatewayService(
|
||||
resolver *ModelPricingResolver,
|
||||
channelService *ChannelService,
|
||||
balanceNotifyService *BalanceNotifyService,
|
||||
settingService *SettingService,
|
||||
) *OpenAIGatewayService {
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -402,6 +404,7 @@ func NewOpenAIGatewayService(
|
||||
resolver: resolver,
|
||||
channelService: channelService,
|
||||
balanceNotifyService: balanceNotifyService,
|
||||
settingService: settingService,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||
}
|
||||
@@ -2310,6 +2313,48 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
disablePatch()
|
||||
}
|
||||
|
||||
// Apply OpenAI fast policy (参照 Claude BetaPolicy 的 fast-mode 过滤):
|
||||
// 针对 body 的 service_tier 字段("priority" 即 fast,"flex"),按策略
|
||||
// 执行 filter(删除字段)或 block(拒绝请求)。对 gpt-5.5 等模型屏蔽
|
||||
// fast 时在此生效。
|
||||
//
|
||||
// 注意:
|
||||
// 1. 此处统一使用 upstreamModel(已经过 GetMappedModel +
|
||||
// normalizeOpenAIModelForUpstream + Codex OAuth normalize),与
|
||||
// chat-completions / messages 入口保持一致,避免不同入口因为模型
|
||||
// 维度不同而出现 whitelist 命中差异。
|
||||
// 2. action=pass 时也要把 raw "fast" 归一化为 "priority" 写回 body,
|
||||
// 否则 native /responses 入口透传 "fast" 给上游会被拒。chat-
|
||||
// completions 入口由 normalizeResponsesBodyServiceTier 完成同一
|
||||
// 行为,这里手工实现等效逻辑。
|
||||
if rawTier, ok := reqBody["service_tier"].(string); ok {
|
||||
if normTier := normalizedOpenAIServiceTierValue(rawTier); normTier != "" {
|
||||
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, upstreamModel, normTier)
|
||||
switch action {
|
||||
case BetaPolicyActionBlock:
|
||||
msg := errMsg
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, upstreamModel)
|
||||
}
|
||||
blocked := &OpenAIFastBlockedError{Message: msg}
|
||||
writeOpenAIFastPolicyBlockedResponse(c, blocked)
|
||||
return nil, blocked
|
||||
case BetaPolicyActionFilter:
|
||||
delete(reqBody, "service_tier")
|
||||
bodyModified = true
|
||||
disablePatch()
|
||||
default:
|
||||
// pass:若客户端传的是别名 "fast",归一化为 "priority"
|
||||
// 后写回 body,确保上游收到的是其能识别的规范值。
|
||||
if normTier != rawTier {
|
||||
reqBody["service_tier"] = normTier
|
||||
bodyModified = true
|
||||
markPatchSet("service_tier", normTier)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-serialize body only if modified
|
||||
if bodyModified {
|
||||
serializedByPatch := false
|
||||
@@ -2758,6 +2803,26 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
body = sanitizedBody
|
||||
}
|
||||
|
||||
// Apply OpenAI fast policy to the passthrough body (filter/block by service_tier).
|
||||
// 统一使用 upstream 视角的 model:透传路径下 body 已经过 compact 映射 +
|
||||
// OAuth normalize,body 中的 model 字段即上游真正会看到的 slug。
|
||||
// 这样可以与 chat-completions / messages / native /responses 入口的
|
||||
// upstreamModel 保持一致,避免 whitelist 命中差异。当 body 中没有
|
||||
// model 字段时退回 reqModel。
|
||||
policyModel := strings.TrimSpace(gjson.GetBytes(body, "model").String())
|
||||
if policyModel == "" {
|
||||
policyModel = reqModel
|
||||
}
|
||||
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, policyModel, body)
|
||||
if policyErr != nil {
|
||||
var blocked *OpenAIFastBlockedError
|
||||
if errors.As(policyErr, &blocked) {
|
||||
writeOpenAIFastPolicyBlockedResponse(c, blocked)
|
||||
}
|
||||
return nil, policyErr
|
||||
}
|
||||
body = updatedBody
|
||||
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
|
||||
account.ID,
|
||||
@@ -5590,14 +5655,319 @@ func normalizeOpenAIServiceTier(raw string) *string {
|
||||
if value == "fast" {
|
||||
value = "priority"
|
||||
}
|
||||
// 放过 OpenAI 官方文档定义的所有合法 tier 值:priority/flex/auto/default/scale。
|
||||
// 对 Codex 客户端零影响(Codex 只发 priority 或 flex,见 codex-rs/core/src/client.rs),
|
||||
// 但能让直连 OpenAI SDK 的用户透传 auto/default/scale 以便抓包/调试。
|
||||
// 真未知值仍返回 nil,由 normalizeResponsesBodyServiceTier 从 body 中删除。
|
||||
switch value {
|
||||
case "priority", "flex":
|
||||
case "priority", "flex", "auto", "default", "scale":
|
||||
return &value
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAIFastBlockedError indicates a request was rejected by the OpenAI fast
|
||||
// policy (action=block). Mirrors BetaBlockedError on the Claude side.
|
||||
type OpenAIFastBlockedError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *OpenAIFastBlockedError) Error() string { return e.Message }
|
||||
|
||||
// evaluateOpenAIFastPolicy returns the action and error message that should be
|
||||
// applied for a request with the given account/model/service_tier. When the
|
||||
// policy service is unavailable or no rule matches, it returns
|
||||
// (BetaPolicyActionPass, "") so callers can short-circuit safely.
|
||||
//
|
||||
// Matching rules:
|
||||
// - Scope filters by account type (all / oauth / apikey / bedrock)
|
||||
// - ServiceTier must be empty (= any), "all", or equal the normalized tier
|
||||
// - ModelWhitelist narrows the rule to specific models; FallbackAction
|
||||
// handles the non-matching case (default: pass)
|
||||
//
|
||||
// 与 Claude BetaPolicy 的差异(保留首条匹配 short-circuit):
|
||||
// - BetaPolicy 处理的是 anthropic-beta header 中的 token 集合,不同
|
||||
// 规则可能针对不同 token,filter 需要累加成 set;block 则 first-match。
|
||||
// - OpenAI fast policy 操作的是单个字段 service_tier:filter 即删字段,
|
||||
// 没有可累加的对象。一次请求只携带一个 service_tier,规则的 tier
|
||||
// 维度天然互斥;同一 (scope, tier) 下若多条规则的 model whitelist
|
||||
// 发生重叠,admin 可通过规则顺序明确意图。因此采用 first-match 而
|
||||
// 非 BetaPolicy 那样的"block 覆盖 filter 覆盖 pass"语义。
|
||||
func (s *OpenAIGatewayService) evaluateOpenAIFastPolicy(ctx context.Context, account *Account, model, serviceTier string) (action, errMsg string) {
|
||||
if s == nil || s.settingService == nil {
|
||||
return BetaPolicyActionPass, ""
|
||||
}
|
||||
tier := strings.ToLower(strings.TrimSpace(serviceTier))
|
||||
if tier == "" {
|
||||
return BetaPolicyActionPass, ""
|
||||
}
|
||||
settings := openAIFastPolicySettingsFromContext(ctx)
|
||||
if settings == nil {
|
||||
fetched, err := s.settingService.GetOpenAIFastPolicySettings(ctx)
|
||||
if err != nil || fetched == nil {
|
||||
return BetaPolicyActionPass, ""
|
||||
}
|
||||
settings = fetched
|
||||
}
|
||||
return evaluateOpenAIFastPolicyWithSettings(settings, account, model, tier)
|
||||
}
|
||||
|
||||
// evaluateOpenAIFastPolicyWithSettings is the pure-function core extracted so
|
||||
// long-lived sessions (e.g. WS) can prefetch settings once and avoid hitting
|
||||
// the settingService on every frame. See WSSession entry and
|
||||
// openAIFastPolicySettingsFromContext for the caching glue.
|
||||
func evaluateOpenAIFastPolicyWithSettings(settings *OpenAIFastPolicySettings, account *Account, model, tier string) (action, errMsg string) {
|
||||
if settings == nil {
|
||||
return BetaPolicyActionPass, ""
|
||||
}
|
||||
isOAuth := account != nil && account.IsOAuth()
|
||||
isBedrock := account != nil && account.IsBedrock()
|
||||
for _, rule := range settings.Rules {
|
||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
||||
continue
|
||||
}
|
||||
ruleTier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
|
||||
if ruleTier != "" && ruleTier != OpenAIFastTierAny && ruleTier != tier {
|
||||
continue
|
||||
}
|
||||
eff := BetaPolicyRule{
|
||||
Action: rule.Action,
|
||||
ErrorMessage: rule.ErrorMessage,
|
||||
ModelWhitelist: rule.ModelWhitelist,
|
||||
FallbackAction: rule.FallbackAction,
|
||||
FallbackErrorMessage: rule.FallbackErrorMessage,
|
||||
}
|
||||
return resolveRuleAction(eff, model)
|
||||
}
|
||||
return BetaPolicyActionPass, ""
|
||||
}
|
||||
|
||||
// openAIFastPolicyCtxKey 是 context 中预取的 OpenAIFastPolicySettings 缓存
|
||||
// 键,仅用于 WebSocket 长会话内多帧复用同一份策略快照,避免每帧 DB 命中。
|
||||
//
|
||||
// Trade-off:策略变更不会影响当前 WS session(只影响新 session)。这是
|
||||
// 有意为之 —— 对长会话来说,"策略一致性"比"立刻生效"更重要,且 Claude
|
||||
// BetaPolicy 的 gin.Context 缓存也是同样取舍。需要 hot-reload 时管理员
|
||||
// 可以通过踢断 session 强制刷新。
|
||||
type openAIFastPolicyCtxKeyType struct{}
|
||||
|
||||
var openAIFastPolicyCtxKey = openAIFastPolicyCtxKeyType{}
|
||||
|
||||
// withOpenAIFastPolicyContext 将一份 settings 快照绑定到 context,供该 ctx
|
||||
// 衍生 goroutine 中的 evaluateOpenAIFastPolicy 复用。
|
||||
func withOpenAIFastPolicyContext(ctx context.Context, settings *OpenAIFastPolicySettings) context.Context {
|
||||
if ctx == nil || settings == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, openAIFastPolicyCtxKey, settings)
|
||||
}
|
||||
|
||||
func openAIFastPolicySettingsFromContext(ctx context.Context) *OpenAIFastPolicySettings {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
if v, ok := ctx.Value(openAIFastPolicyCtxKey).(*OpenAIFastPolicySettings); ok {
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyOpenAIFastPolicyToBody applies the OpenAI fast policy to a raw request
|
||||
// body. When action=filter it removes the service_tier field; when
|
||||
// action=block it returns (body, *OpenAIFastBlockedError). On pass it
|
||||
// normalizes the service_tier value (e.g. client alias "fast" → "priority"),
|
||||
// rewriting the body so the upstream receives a slug it recognizes.
|
||||
//
|
||||
// Rationale for normalize-on-pass: chat-completions / messages 入口在调用本
|
||||
// 函数之前已经通过 normalizeResponsesBodyServiceTier 把 service_tier 归一化
|
||||
// 到了上游可识别值;passthrough(OpenAI 自动透传) / native /responses 等
|
||||
// 入口没有这一前置步骤,pass 路径下若不在此处归一化,"fast" 就会被原样
|
||||
// 透传到 OpenAI 上游导致 400/拒绝。把归一化收敛到本函数,所有入口行为一致。
|
||||
func (s *OpenAIGatewayService) applyOpenAIFastPolicyToBody(ctx context.Context, account *Account, model string, body []byte) ([]byte, error) {
|
||||
if len(body) == 0 {
|
||||
return body, nil
|
||||
}
|
||||
rawTier := gjson.GetBytes(body, "service_tier").String()
|
||||
if rawTier == "" {
|
||||
return body, nil
|
||||
}
|
||||
normTier := normalizedOpenAIServiceTierValue(rawTier)
|
||||
if normTier == "" {
|
||||
return body, nil
|
||||
}
|
||||
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
|
||||
switch action {
|
||||
case BetaPolicyActionBlock:
|
||||
msg := errMsg
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
|
||||
}
|
||||
return body, &OpenAIFastBlockedError{Message: msg}
|
||||
case BetaPolicyActionFilter:
|
||||
trimmed, err := sjson.DeleteBytes(body, "service_tier")
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("strip service_tier from body: %w", err)
|
||||
}
|
||||
return trimmed, nil
|
||||
default:
|
||||
// pass:把别名(如 "fast")写回为规范值("priority")。
|
||||
if normTier == rawTier {
|
||||
return body, nil
|
||||
}
|
||||
updated, err := sjson.SetBytes(body, "service_tier", normTier)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("normalize service_tier on pass: %w", err)
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
}
|
||||
|
||||
// writeOpenAIFastPolicyBlockedResponse writes a 403 JSON response for a
|
||||
// request blocked by the OpenAI fast policy.
|
||||
func writeOpenAIFastPolicyBlockedResponse(c *gin.Context, err *OpenAIFastBlockedError) {
|
||||
if c == nil || err == nil {
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "permission_error",
|
||||
"message": err.Message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// applyOpenAIFastPolicyToWSResponseCreate evaluates the OpenAI fast policy
|
||||
// against a single client→upstream WebSocket frame whose top-level
|
||||
// "type"=="response.create". It mirrors the HTTP-side
|
||||
// applyOpenAIFastPolicyToBody contract but operates on a Realtime/Responses
|
||||
// WS payload:
|
||||
//
|
||||
// - pass: returns frame unchanged (newBytes == frame, blocked == nil)
|
||||
// - filter: returns a copy with top-level service_tier removed
|
||||
// - block: returns (frame, *OpenAIFastBlockedError)
|
||||
//
|
||||
// Only frames whose "type" field strictly equals "response.create" are
|
||||
// inspected/mutated. Any other frame type — including the empty string —
|
||||
// passes through untouched. The OpenAI Realtime client-event spec requires
|
||||
// "type" to be set, so an empty type is treated as a malformed frame we do
|
||||
// not police; the upstream is the source of truth for rejecting it.
|
||||
//
|
||||
// service_tier lives at the top level of response.create — same as the
|
||||
// Responses HTTP body shape (see openai_gateway_chat_completions.go:304 +
|
||||
// extractOpenAIServiceTierFromBody at line 5593, and the test fixture at
|
||||
// openai_ws_forwarder_ingress_session_test.go:402). We therefore only need
|
||||
// to inspect / strip the top-level field; there is no nested form in the
|
||||
// schema today.
|
||||
//
|
||||
// The caller is responsible for choosing the upstream model passed in —
|
||||
// this helper does not re-derive it.
|
||||
func (s *OpenAIGatewayService) applyOpenAIFastPolicyToWSResponseCreate(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
model string,
|
||||
frame []byte,
|
||||
) ([]byte, *OpenAIFastBlockedError, error) {
|
||||
if len(frame) == 0 {
|
||||
return frame, nil, nil
|
||||
}
|
||||
if !gjson.ValidBytes(frame) {
|
||||
return frame, nil, nil
|
||||
}
|
||||
frameType := strings.TrimSpace(gjson.GetBytes(frame, "type").String())
|
||||
// Strict match: only response.create is policy-checked. Empty / other
|
||||
// types pass through untouched so we never accidentally strip fields
|
||||
// from response.cancel, conversation.item.create, or any future
|
||||
// client-event the spec adds. The Realtime spec requires "type" on
|
||||
// every client event, so an empty type is malformed input — let the
|
||||
// upstream reject it rather than guessing at our layer.
|
||||
if frameType != "response.create" {
|
||||
return frame, nil, nil
|
||||
}
|
||||
rawTier := gjson.GetBytes(frame, "service_tier").String()
|
||||
if rawTier == "" {
|
||||
return frame, nil, nil
|
||||
}
|
||||
normTier := normalizedOpenAIServiceTierValue(rawTier)
|
||||
if normTier == "" {
|
||||
return frame, nil, nil
|
||||
}
|
||||
action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
|
||||
switch action {
|
||||
case BetaPolicyActionBlock:
|
||||
msg := errMsg
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
|
||||
}
|
||||
return frame, &OpenAIFastBlockedError{Message: msg}, nil
|
||||
case BetaPolicyActionFilter:
|
||||
trimmed, err := sjson.DeleteBytes(frame, "service_tier")
|
||||
if err != nil {
|
||||
return frame, nil, fmt.Errorf("strip service_tier from ws frame: %w", err)
|
||||
}
|
||||
return trimmed, nil, nil
|
||||
default:
|
||||
return frame, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// newOpenAIFastPolicyWSEventID returns a Realtime-style event_id for a
|
||||
// server-emitted error event. Matches the loose "evt_<rand>" convention used
|
||||
// by upstream Realtime servers; the exact value is not load-bearing and is
|
||||
// only required for client-side log correlation. We reuse the existing
|
||||
// google/uuid dependency rather than pulling a new one.
|
||||
func newOpenAIFastPolicyWSEventID() string {
|
||||
id, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
// Extremely unlikely; fall back to a fixed prefix so the field is
|
||||
// still non-empty and the schema stays self-consistent.
|
||||
return "evt_openai_fast_policy"
|
||||
}
|
||||
// Strip dashes so it visually matches "evt_<hex>" rather than UUID v4
|
||||
// canonical form, mirroring what real Realtime traces look like.
|
||||
return "evt_" + strings.ReplaceAll(id.String(), "-", "")
|
||||
}
|
||||
|
||||
// buildOpenAIFastPolicyBlockedWSEvent renders an OpenAI Realtime/Responses
|
||||
// style "error" event payload for a request blocked by the OpenAI fast
|
||||
// policy. The shape mirrors Realtime error events as observed in upstream
|
||||
// traces and per the spec's server "error" event:
|
||||
//
|
||||
// {
|
||||
// "event_id": "evt_<random>",
|
||||
// "type": "error",
|
||||
// "error": {
|
||||
// "type": "invalid_request_error",
|
||||
// "code": "policy_violation",
|
||||
// "message": "..."
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// event_id lets clients correlate the rejection in their logs; "code" gives
|
||||
// programmatic clients a stable identifier (HTTP-side equivalent is the
|
||||
// 403 permission_error JSON body).
|
||||
func buildOpenAIFastPolicyBlockedWSEvent(err *OpenAIFastBlockedError) []byte {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
eventID := newOpenAIFastPolicyWSEventID()
|
||||
payload, mErr := json.Marshal(map[string]any{
|
||||
"event_id": eventID,
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": "invalid_request_error",
|
||||
"code": "policy_violation",
|
||||
"message": err.Message,
|
||||
},
|
||||
})
|
||||
if mErr != nil {
|
||||
// Fallback to a minimal hand-rolled payload; Marshal of the literal
|
||||
// shape above should never fail in practice.
|
||||
return []byte(`{"event_id":"` + eventID + `","type":"error","error":{"type":"invalid_request_error","code":"policy_violation","message":"openai fast policy blocked this request"}}`)
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func sanitizeEmptyBase64InputImagesInOpenAIBody(body []byte) ([]byte, bool, error) {
|
||||
if len(body) == 0 || !bytes.Contains(body, []byte(`"image_url"`)) || !bytes.Contains(body, []byte(`base64,`)) {
|
||||
return body, false, nil
|
||||
|
||||
@@ -2366,6 +2366,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
return errors.New("token is empty")
|
||||
}
|
||||
|
||||
// 预取一次 OpenAI Fast Policy settings,绑定到 ctx,让该 WS session
|
||||
// 内所有帧的 evaluateOpenAIFastPolicy 调用复用同一份快照,避免每帧
|
||||
// 进入 DB / settingRepo。Trade-off 见 withOpenAIFastPolicyContext 注释。
|
||||
if s.settingService != nil {
|
||||
if settings, err := s.settingService.GetOpenAIFastPolicySettings(ctx); err == nil && settings != nil {
|
||||
ctx = withOpenAIFastPolicyContext(ctx, settings)
|
||||
}
|
||||
}
|
||||
|
||||
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
||||
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
|
||||
ingressMode := OpenAIWSIngressModeCtxPool
|
||||
@@ -2524,6 +2533,44 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
normalized = next
|
||||
}
|
||||
|
||||
// Apply OpenAI Fast Policy on the response.create frame using the same
|
||||
// evaluator/normalize/scope rules as the HTTP entrypoints. This is the
|
||||
// single integration point for all WS ingress turns (first + follow-up
|
||||
// frames flow through here).
|
||||
//
|
||||
// Model fallback: parseClientPayload above rejects any frame whose
|
||||
// "model" field is missing (line ~2493-2500), so by the time we
|
||||
// reach this point upstreamModel is always derived from a non-empty
|
||||
// per-frame model. The capturedSessionModel fallback used in the
|
||||
// passthrough adapter is therefore not needed in this path.
|
||||
policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized)
|
||||
if policyErr != nil {
|
||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr)
|
||||
}
|
||||
if blocked != nil {
|
||||
// Send a Realtime-style error event to the client first, then
|
||||
// signal the handler to close the connection with PolicyViolation.
|
||||
// We intentionally do NOT forward this frame upstream.
|
||||
//
|
||||
// coder/websocket@v1.8.14 Conn.Write is synchronous and flushes
|
||||
// the underlying bufio writer before returning (write.go:42 →
|
||||
// 307-311), and the subsequent close handshake re-acquires the
|
||||
// same writeFrameMu, so the error event is guaranteed to reach
|
||||
// the kernel send buffer before any close frame is queued.
|
||||
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
||||
if eventBytes != nil {
|
||||
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||||
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
||||
cancel()
|
||||
}
|
||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
|
||||
coderws.StatusPolicyViolation,
|
||||
blocked.Message,
|
||||
blocked,
|
||||
)
|
||||
}
|
||||
normalized = policyApplied
|
||||
|
||||
return openAIWSClientPayload{
|
||||
payloadRaw: normalized,
|
||||
rawForHash: trimmed,
|
||||
|
||||
@@ -618,6 +618,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
|
||||
|
||||
@@ -21,6 +21,109 @@ type openAIWSClientFrameConn struct {
|
||||
conn *coderws.Conn
|
||||
}
|
||||
|
||||
// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs
|
||||
// every client→upstream frame through the OpenAI Fast Policy. It is the
|
||||
// passthrough-relay equivalent of the parseClientPayload integration in the
|
||||
// ingress session path. filter returns:
|
||||
// - newPayload, nil, nil: forward the (possibly mutated) payload
|
||||
// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error
|
||||
// event via onBlock and surfaces a transport-level error so the relay
|
||||
// stops reading from the client.
|
||||
// - _, _, err: a transport error other than block.
|
||||
type openAIWSPolicyEnforcingFrameConn struct {
|
||||
inner openaiwsv2.FrameConn
|
||||
filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error)
|
||||
onBlock func(blocked *OpenAIFastBlockedError)
|
||||
}
|
||||
|
||||
var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil)
|
||||
|
||||
func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
|
||||
if c == nil || c.inner == nil {
|
||||
return coderws.MessageText, nil, errOpenAIWSConnClosed
|
||||
}
|
||||
msgType, payload, err := c.inner.ReadFrame(ctx)
|
||||
if err != nil {
|
||||
return msgType, payload, err
|
||||
}
|
||||
if c.filter == nil {
|
||||
return msgType, payload, nil
|
||||
}
|
||||
updated, blocked, filterErr := c.filter(msgType, payload)
|
||||
if filterErr != nil {
|
||||
return msgType, payload, filterErr
|
||||
}
|
||||
if blocked != nil {
|
||||
if c.onBlock != nil {
|
||||
c.onBlock(blocked)
|
||||
}
|
||||
return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
|
||||
}
|
||||
return msgType, updated, nil
|
||||
}
|
||||
|
||||
func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
|
||||
if c == nil || c.inner == nil {
|
||||
return errOpenAIWSConnClosed
|
||||
}
|
||||
return c.inner.WriteFrame(ctx, msgType, payload)
|
||||
}
|
||||
|
||||
func (c *openAIWSPolicyEnforcingFrameConn) Close() error {
|
||||
if c == nil || c.inner == nil {
|
||||
return nil
|
||||
}
|
||||
return c.inner.Close()
|
||||
}
|
||||
|
||||
// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective
|
||||
// model name that should be passed to evaluateOpenAIFastPolicy for a single
|
||||
// passthrough WS frame. Mirrors the HTTP-side normalization
|
||||
// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path
|
||||
// matches model whitelists identically.
|
||||
func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string {
|
||||
if account == nil || len(payload) == 0 {
|
||||
return ""
|
||||
}
|
||||
original := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||
if original == "" {
|
||||
return ""
|
||||
}
|
||||
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
||||
}
|
||||
|
||||
// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model
|
||||
// derived from a session.update frame's session.model field. Returns "" when
|
||||
// the frame is not a session.update event or carries no session.model. Used
|
||||
// by the per-frame policy filter (client→upstream direction) to keep
|
||||
// capturedSessionModel in sync with the session-level model the client may
|
||||
// rotate mid-session.
|
||||
//
|
||||
// Realtime / Responses WS lets the client change the session model after
|
||||
// the WS handshake via:
|
||||
//
|
||||
// {"type":"session.update","session":{"model":"gpt-5.5", ...}}
|
||||
//
|
||||
// If we only capture the model from the very first frame, a client can ship
|
||||
// gpt-4o on the first response.create (whitelisted as pass), then
|
||||
// session.update to gpt-5.5, then send response.create without "model" so
|
||||
// the per-frame resolver returns "" and the stale capturedSessionModel falls
|
||||
// back to gpt-4o — defeating the gpt-5.5 fast-policy filter.
|
||||
func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string {
|
||||
if account == nil || len(payload) == 0 {
|
||||
return ""
|
||||
}
|
||||
frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
|
||||
if frameType != "session.update" {
|
||||
return ""
|
||||
}
|
||||
original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
|
||||
if original == "" {
|
||||
return ""
|
||||
}
|
||||
return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
|
||||
}
|
||||
|
||||
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
|
||||
|
||||
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
|
||||
@@ -77,7 +180,6 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
return errors.New("token is empty")
|
||||
}
|
||||
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
|
||||
requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage)
|
||||
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
|
||||
logOpenAIWSV2Passthrough(
|
||||
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
|
||||
@@ -88,6 +190,59 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
len(firstClientMessage),
|
||||
)
|
||||
|
||||
// Apply OpenAI Fast Policy on the first response.create frame. Subsequent
|
||||
// frames are filtered via a wrapping FrameConn below so every client→
|
||||
// upstream frame goes through the same policy evaluator/normalize/scope as
|
||||
// HTTP entrypoints.
|
||||
//
|
||||
// We capture the session-level model from the first frame here so the
|
||||
// per-frame filter (below) can fall back to it when a follow-up frame
|
||||
// omits "model" — Realtime clients are allowed to send response.create
|
||||
// without re-stating the model, in which case the upstream uses the model
|
||||
// negotiated at session.update time. Without this fallback, an empty
|
||||
// model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be
|
||||
// silently passed through, defeating the policy on every frame after
|
||||
// the first.
|
||||
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
|
||||
updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
|
||||
if policyErr != nil {
|
||||
return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
|
||||
}
|
||||
if blocked != nil {
|
||||
// coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires
|
||||
// writeFrameMu, writes the entire frame, and Flushes the underlying
|
||||
// bufio writer before returning (write.go:42 → write.go:307-311).
|
||||
// The subsequent close handshake re-acquires the same writeFrameMu
|
||||
// to send the close frame, so the error event is guaranteed to
|
||||
// reach the kernel send buffer before any close frame is queued.
|
||||
// No explicit flush hop is required here.
|
||||
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
||||
if eventBytes != nil {
|
||||
writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||||
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
||||
cancelWrite()
|
||||
}
|
||||
return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
|
||||
}
|
||||
firstClientMessage = updatedFirst
|
||||
|
||||
// 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter
|
||||
// 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当
|
||||
// 反映上游实际处理的 tier(nil = default),而不是用户最初请求的
|
||||
// "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody))
|
||||
// 与 WS ingress(openai_ws_forwarder.go:2991 取自 payload)的语义一致。
|
||||
//
|
||||
// 多轮 passthrough:OpenAI Realtime / Responses WS 协议允许客户端在
|
||||
// 同一连接的不同 response.create 帧上发送不同 service_tier(参考
|
||||
// codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
|
||||
// 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream
|
||||
// goroutine)和 OnTurnComplete / final result(runUpstreamToClient
|
||||
// goroutine)之间同步当前 turn 的 service_tier。
|
||||
// extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型,
|
||||
// 可直接 Store/Load 而无需额外封装。
|
||||
var requestServiceTierPtr atomic.Pointer[string]
|
||||
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
|
||||
|
||||
wsURL, err := s.buildOpenAIResponsesWSURL(account)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build ws url: %w", err)
|
||||
@@ -152,9 +307,72 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
}
|
||||
|
||||
completedTurns := atomic.Int32{}
|
||||
policyClientConn := &openAIWSPolicyEnforcingFrameConn{
|
||||
inner: &openAIWSClientFrameConn{conn: clientConn},
|
||||
// 注意线程安全:filter 仅在 runClientToUpstream 这一条
|
||||
// goroutine 中被调用(passthrough_relay.go: ReadFrame loop),
|
||||
// capturedSessionModel 的读写都发生在该 goroutine 内,因此无需
|
||||
// 加锁/原子化。
|
||||
filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
|
||||
if msgType != coderws.MessageText {
|
||||
return payload, nil, nil
|
||||
}
|
||||
// 在评估策略前先刷新 capturedSessionModel:客户端可能通过
|
||||
// session.update 修改 session-level model(Realtime /
|
||||
// Responses WS 协议允许),如果不刷新就会出现
|
||||
// "首帧 model=gpt-4o(pass)→ session.update 改成 gpt-5.5
|
||||
// → 不带 model 的 response.create fallback 到 gpt-4o" 的
|
||||
// 绕过路径。这里只看 session.update 事件中的 session.model
|
||||
// 字段,response.create 自己的 model 仍然由其本帧字段决定。
|
||||
if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
|
||||
capturedSessionModel = updated
|
||||
}
|
||||
// Per-frame model first; if the client omits "model" on a
|
||||
// follow-up frame (legal in Realtime), fall back to the
|
||||
// session-level model captured from the first frame so the
|
||||
// model whitelist still resolves. An empty model would miss
|
||||
// any whitelist and silently fall back to pass.
|
||||
model := openAIWSPassthroughPolicyModelForFrame(account, payload)
|
||||
if model == "" {
|
||||
model = capturedSessionModel
|
||||
}
|
||||
out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
|
||||
// 多轮 passthrough billing:仅在成功(non-block / non-err)
|
||||
// 的 response.create 帧上更新 requestServiceTierPtr,使用
|
||||
// filter 处理后的 payload,与首帧 policy-after-extract 语义
|
||||
// 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
|
||||
// - 非 response.create 帧(response.cancel /
|
||||
// conversation.item.create / session.update 等)不携带
|
||||
// per-response service_tier,不应覆盖前一轮值。
|
||||
// - blocked != nil:该帧不会发送上游,billing tier 应保持
|
||||
// 上一轮值。
|
||||
// - policyErr != nil:异常路径,保持上一轮值。
|
||||
// - 不带 service_tier 的 response.create 会让
|
||||
// extractOpenAIServiceTierFromBody 返回 nil;这里有意
|
||||
// 覆盖(Store(nil)),因为 OpenAI 上游对该帧实际不传
|
||||
// service_tier 时按 default 处理,billing 应如实反映。
|
||||
if policyErr == nil && blocked == nil &&
|
||||
strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
|
||||
requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
|
||||
}
|
||||
return out, blocked, policyErr
|
||||
},
|
||||
onBlock: func(blocked *OpenAIFastBlockedError) {
|
||||
// See note above on Conn.Write being synchronous w.r.t. flush;
|
||||
// no explicit flush is required to ensure the error event lands
|
||||
// before the close frame.
|
||||
eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
|
||||
if eventBytes == nil {
|
||||
return
|
||||
}
|
||||
writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
|
||||
_ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
|
||||
cancel()
|
||||
},
|
||||
}
|
||||
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
|
||||
Ctx: ctx,
|
||||
ClientConn: &openAIWSClientFrameConn{conn: clientConn},
|
||||
ClientConn: policyClientConn,
|
||||
UpstreamConn: upstreamFrameConn,
|
||||
FirstClientMessage: firstClientMessage,
|
||||
Options: openaiwsv2.RelayOptions{
|
||||
@@ -179,7 +397,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: turn.RequestModel,
|
||||
ServiceTier: requestServiceTier,
|
||||
ServiceTier: requestServiceTierPtr.Load(),
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
@@ -227,7 +445,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: relayResult.RequestModel,
|
||||
ServiceTier: requestServiceTier,
|
||||
ServiceTier: requestServiceTierPtr.Load(),
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
|
||||
@@ -3259,6 +3259,84 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be
|
||||
return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data))
|
||||
}
|
||||
|
||||
// GetOpenAIFastPolicySettings 获取 OpenAI fast 策略配置
|
||||
func (s *SettingService) GetOpenAIFastPolicySettings(ctx context.Context) (*OpenAIFastPolicySettings, error) {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpenAIFastPolicySettings)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return DefaultOpenAIFastPolicySettings(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("get openai fast policy settings: %w", err)
|
||||
}
|
||||
if value == "" {
|
||||
return DefaultOpenAIFastPolicySettings(), nil
|
||||
}
|
||||
|
||||
var settings OpenAIFastPolicySettings
|
||||
if err := json.Unmarshal([]byte(value), &settings); err != nil {
|
||||
// JSON 损坏时静默 fallback 到默认配置会让策略意外失效(管理员配
|
||||
// 置的 block/filter 规则被忽略)。记录 Warn 让运维能在出现异常
|
||||
// 行为时定位到 settings 表里的脏数据。
|
||||
slog.Warn("failed to unmarshal openai fast policy settings, falling back to defaults",
|
||||
"error", err,
|
||||
"key", SettingKeyOpenAIFastPolicySettings)
|
||||
return DefaultOpenAIFastPolicySettings(), nil
|
||||
}
|
||||
|
||||
return &settings, nil
|
||||
}
|
||||
|
||||
// SetOpenAIFastPolicySettings 设置 OpenAI fast 策略配置
|
||||
func (s *SettingService) SetOpenAIFastPolicySettings(ctx context.Context, settings *OpenAIFastPolicySettings) error {
|
||||
if settings == nil {
|
||||
return fmt.Errorf("settings cannot be nil")
|
||||
}
|
||||
|
||||
validActions := map[string]bool{
|
||||
BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
|
||||
}
|
||||
validScopes := map[string]bool{
|
||||
BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true,
|
||||
}
|
||||
validTiers := map[string]bool{
|
||||
OpenAIFastTierAny: true, OpenAIFastTierPriority: true, OpenAIFastTierFlex: true,
|
||||
}
|
||||
|
||||
for i, rule := range settings.Rules {
|
||||
tier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
|
||||
if tier == "" {
|
||||
tier = OpenAIFastTierAny
|
||||
}
|
||||
if !validTiers[tier] {
|
||||
return fmt.Errorf("rule[%d]: invalid service_tier %q", i, rule.ServiceTier)
|
||||
}
|
||||
settings.Rules[i].ServiceTier = tier
|
||||
if !validActions[rule.Action] {
|
||||
return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action)
|
||||
}
|
||||
if !validScopes[rule.Scope] {
|
||||
return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope)
|
||||
}
|
||||
for j, pattern := range rule.ModelWhitelist {
|
||||
trimmed := strings.TrimSpace(pattern)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("rule[%d]: model_whitelist[%d] cannot be empty", i, j)
|
||||
}
|
||||
settings.Rules[i].ModelWhitelist[j] = trimmed
|
||||
}
|
||||
if rule.FallbackAction != "" && !validActions[rule.FallbackAction] {
|
||||
return fmt.Errorf("rule[%d]: invalid fallback_action %q", i, rule.FallbackAction)
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal openai fast policy settings: %w", err)
|
||||
}
|
||||
|
||||
return s.settingRepo.Set(ctx, SettingKeyOpenAIFastPolicySettings, string(data))
|
||||
}
|
||||
|
||||
// SetStreamTimeoutSettings 设置流超时处理配置
|
||||
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
|
||||
if settings == nil {
|
||||
|
||||
@@ -405,3 +405,57 @@ func DefaultBetaPolicySettings() *BetaPolicySettings {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI Fast Policy 策略常量
|
||||
// OpenAI 的 "fast 模式" 通过请求体中的 service_tier 字段识别:
|
||||
// - "priority"(客户端可传 "fast",归一化为 "priority"):fast 模式
|
||||
// - "flex":低优先级模式
|
||||
// - 省略:normal 默认
|
||||
//
|
||||
// 本策略复用 BetaPolicyAction*/BetaPolicyScope* 常量语义,只是匹配键从
|
||||
// anthropic-beta header 换成 body 的 service_tier 字段。
|
||||
const (
|
||||
OpenAIFastTierAny = "all" // 匹配任意已识别的 service_tier
|
||||
OpenAIFastTierPriority = "priority" // 仅匹配 fast(priority)
|
||||
OpenAIFastTierFlex = "flex" // 仅匹配 flex
|
||||
)
|
||||
|
||||
// OpenAIFastPolicyRule 单条 OpenAI fast/flex 策略规则
|
||||
type OpenAIFastPolicyRule struct {
|
||||
ServiceTier string `json:"service_tier"` // "priority" | "flex" | "auto" | "default" | "scale" | "all"
|
||||
Action string `json:"action"` // "pass" | "filter" | "block"
|
||||
Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock"
|
||||
ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
|
||||
ModelWhitelist []string `json:"model_whitelist,omitempty"` // 模型匹配模式列表(为空=对所有模型生效)
|
||||
FallbackAction string `json:"fallback_action,omitempty"` // 未匹配白名单的模型的处理方式
|
||||
FallbackErrorMessage string `json:"fallback_error_message,omitempty"` // 未匹配白名单时的自定义错误消息 (fallback_action=block 时生效)
|
||||
}
|
||||
|
||||
// OpenAIFastPolicySettings OpenAI fast 策略配置
|
||||
type OpenAIFastPolicySettings struct {
|
||||
Rules []OpenAIFastPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// DefaultOpenAIFastPolicySettings 返回默认的 OpenAI fast 策略配置。
|
||||
// 默认对所有模型的 priority(fast)请求执行 filter,即剔除 service_tier 字段,
|
||||
// 让上游按 normal 优先级处理。
|
||||
//
|
||||
// 为什么 ModelWhitelist 为空(=对所有模型生效):
|
||||
// codex 客户端的 service_tier=fast 是用户级开关,与 model 字段正交。即使
|
||||
// 用户使用 gpt-4 + fast,priority 配额仍会被消耗。如果默认规则只锁
|
||||
// gpt-5.5*,"用 gpt-4 + fast 透传 priority 上游" 这条路径就会绕过策略。
|
||||
// 与 codex 真实语义对齐,默认对所有模型生效;管理员若需要只针对特定
|
||||
// 模型,可在 admin UI 中显式配置 model_whitelist。
|
||||
func DefaultOpenAIFastPolicySettings() *OpenAIFastPolicySettings {
|
||||
return &OpenAIFastPolicySettings{
|
||||
Rules: []OpenAIFastPolicyRule{
|
||||
{
|
||||
ServiceTier: OpenAIFastTierPriority,
|
||||
Action: BetaPolicyActionFilter,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
ModelWhitelist: []string{},
|
||||
FallbackAction: BetaPolicyActionPass,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -484,6 +484,9 @@ export interface SystemSettings {
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
affiliate_enabled: boolean;
|
||||
|
||||
// OpenAI fast/flex policy
|
||||
openai_fast_policy_settings?: OpenAIFastPolicySettings;
|
||||
}
|
||||
|
||||
export interface UpdateSettingsRequest {
|
||||
@@ -648,6 +651,9 @@ export interface UpdateSettingsRequest {
|
||||
|
||||
// Affiliate (邀请返利) feature switch
|
||||
affiliate_enabled?: boolean;
|
||||
|
||||
// OpenAI fast/flex policy
|
||||
openai_fast_policy_settings?: OpenAIFastPolicySettings;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -875,6 +881,29 @@ export async function updateRectifierSettings(
|
||||
return data;
|
||||
}
|
||||
|
||||
// ==================== OpenAI Fast Policy Settings ====================
|
||||
|
||||
/**
|
||||
* OpenAI fast/flex policy rule interface.
|
||||
* Matches backend dto.OpenAIFastPolicyRule.
|
||||
*/
|
||||
export interface OpenAIFastPolicyRule {
|
||||
service_tier: "all" | "priority" | "flex";
|
||||
action: "pass" | "filter" | "block";
|
||||
scope: "all" | "oauth" | "apikey" | "bedrock";
|
||||
error_message?: string;
|
||||
model_whitelist?: string[];
|
||||
fallback_action?: "pass" | "filter" | "block";
|
||||
fallback_error_message?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI fast/flex policy settings interface.
|
||||
*/
|
||||
export interface OpenAIFastPolicySettings {
|
||||
rules: OpenAIFastPolicyRule[];
|
||||
}
|
||||
|
||||
// ==================== Beta Policy Settings ====================
|
||||
|
||||
/**
|
||||
|
||||
@@ -5535,6 +5535,38 @@ export default {
|
||||
presetOpusOnlyDesc: 'Pass for Opus, filter others',
|
||||
commonPatterns: 'Common patterns'
|
||||
},
|
||||
openaiFastPolicy: {
|
||||
title: 'OpenAI Fast/Flex Policy',
|
||||
description: 'Intercept, filter, or pass OpenAI fast(priority) / flex requests based on the request body service_tier field. Applies to the OpenAI gateway only.',
|
||||
empty: 'No rules configured. Click the button below to add one.',
|
||||
ruleHeader: 'Rule #{index}',
|
||||
removeRule: 'Remove rule',
|
||||
addRule: 'Add rule',
|
||||
saveHint: 'Saved together with system settings (click the global Save button at the bottom of the page).',
|
||||
serviceTier: 'service_tier match',
|
||||
tierAll: 'All tiers',
|
||||
tierPriority: 'priority (fast)',
|
||||
tierFlex: 'flex',
|
||||
action: 'Action',
|
||||
actionPass: 'Pass (keep service_tier)',
|
||||
actionFilter: 'Filter (remove service_tier)',
|
||||
actionBlock: 'Block (reject request)',
|
||||
scope: 'Scope',
|
||||
scopeAll: 'All accounts',
|
||||
scopeOAuth: 'OAuth only',
|
||||
scopeAPIKey: 'API Key only',
|
||||
scopeBedrock: 'Bedrock only',
|
||||
errorMessage: 'Error message',
|
||||
errorMessagePlaceholder: 'Custom error message when blocked',
|
||||
errorMessageHint: 'Leave empty for the default message.',
|
||||
modelWhitelist: 'Model whitelist',
|
||||
modelWhitelistHint: 'Leave empty to apply to all models. Supports exact match and wildcard prefix (e.g., gpt-5.5*).',
|
||||
modelPatternPlaceholder: 'e.g., gpt-5.5 or gpt-5.5*',
|
||||
addModelPattern: 'Add model pattern',
|
||||
fallbackAction: 'Fallback action',
|
||||
fallbackActionHint: 'Action for models not matching the whitelist.',
|
||||
fallbackErrorMessagePlaceholder: 'Custom error message when non-whitelisted models are blocked'
|
||||
},
|
||||
wechatConnect: {
|
||||
title: 'WeChat Connect',
|
||||
description: 'Third-party login configuration for WeChat Open Platform or Official Account / Mini Program.',
|
||||
|
||||
@@ -5695,6 +5695,38 @@ export default {
|
||||
presetOpusOnlyDesc: 'Opus 透传,其他模型过滤',
|
||||
commonPatterns: '常用模式'
|
||||
},
|
||||
openaiFastPolicy: {
|
||||
title: 'OpenAI Fast/Flex 策略',
|
||||
description: '基于请求体 service_tier 字段拦截/过滤/透传 OpenAI fast(priority) 与 flex 请求;仅作用于 OpenAI 网关。',
|
||||
empty: '尚未配置任何规则。点击下方按钮新增。',
|
||||
ruleHeader: '规则 #{index}',
|
||||
removeRule: '删除规则',
|
||||
addRule: '新增规则',
|
||||
saveHint: '保存时随系统设置一起提交(点击页面底部「保存」按钮)。',
|
||||
serviceTier: 'service_tier 匹配',
|
||||
tierAll: '全部 tier',
|
||||
tierPriority: 'priority(fast)',
|
||||
tierFlex: 'flex',
|
||||
action: '处理方式',
|
||||
actionPass: '透传(保留 service_tier)',
|
||||
actionFilter: '过滤(移除 service_tier)',
|
||||
actionBlock: '拦截(拒绝请求)',
|
||||
scope: '生效范围',
|
||||
scopeAll: '全部账号',
|
||||
scopeOAuth: '仅 OAuth 账号',
|
||||
scopeAPIKey: '仅 API Key 账号',
|
||||
scopeBedrock: '仅 Bedrock 账号',
|
||||
errorMessage: '错误消息',
|
||||
errorMessagePlaceholder: '拦截时返回的自定义错误消息',
|
||||
errorMessageHint: '留空则使用默认错误消息。',
|
||||
modelWhitelist: '模型白名单',
|
||||
modelWhitelistHint: '留空表示对所有模型生效;支持精确匹配与通配符(如 gpt-5.5*)。',
|
||||
modelPatternPlaceholder: '例如: gpt-5.5 或 gpt-5.5*',
|
||||
addModelPattern: '添加模型规则',
|
||||
fallbackAction: '未匹配模型处理方式',
|
||||
fallbackActionHint: '当请求模型不在白名单中时的处理方式。',
|
||||
fallbackErrorMessagePlaceholder: '未匹配模型被拦截时返回的自定义错误消息'
|
||||
},
|
||||
wechatConnect: {
|
||||
title: '微信登录',
|
||||
description: '用于微信开放平台或公众号/小程序的第三方登录配置。',
|
||||
|
||||
@@ -949,6 +949,285 @@
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
<!-- OpenAI Fast/Flex Policy Settings -->
|
||||
<div class="card">
|
||||
<div
|
||||
class="border-b border-gray-100 px-6 py-4 dark:border-dark-700"
|
||||
>
|
||||
<h2 class="text-lg font-semibold text-gray-900 dark:text-white">
|
||||
{{ t("admin.settings.openaiFastPolicy.title") }}
|
||||
</h2>
|
||||
<p class="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||
{{ t("admin.settings.openaiFastPolicy.description") }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="space-y-5 p-6">
|
||||
<!-- Empty state -->
|
||||
<div
|
||||
v-if="openaiFastPolicyForm.rules.length === 0"
|
||||
class="rounded-lg border border-dashed border-gray-200 p-6 text-center text-sm text-gray-500 dark:border-dark-600 dark:text-gray-400"
|
||||
>
|
||||
{{ t("admin.settings.openaiFastPolicy.empty") }}
|
||||
</div>
|
||||
|
||||
<!-- Rule Cards -->
|
||||
<div
|
||||
v-for="(rule, ruleIndex) in openaiFastPolicyForm.rules"
|
||||
:key="ruleIndex"
|
||||
class="rounded-lg border border-gray-200 p-4 dark:border-dark-600"
|
||||
>
|
||||
<div class="mb-3 flex items-center justify-between">
|
||||
<span
|
||||
class="text-sm font-medium text-gray-900 dark:text-white"
|
||||
>
|
||||
{{
|
||||
t("admin.settings.openaiFastPolicy.ruleHeader", {
|
||||
index: ruleIndex + 1,
|
||||
})
|
||||
}}
|
||||
</span>
|
||||
<button
|
||||
type="button"
|
||||
@click="removeOpenAIFastPolicyRule(ruleIndex)"
|
||||
class="rounded p-1 text-red-400 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
|
||||
:title="t('admin.settings.openaiFastPolicy.removeRule')"
|
||||
>
|
||||
<svg
|
||||
class="h-4 w-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="grid grid-cols-1 gap-4 md:grid-cols-3">
|
||||
<!-- Service Tier -->
|
||||
<div>
|
||||
<label
|
||||
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
|
||||
>
|
||||
{{ t("admin.settings.openaiFastPolicy.serviceTier") }}
|
||||
</label>
|
||||
<Select
|
||||
:modelValue="rule.service_tier"
|
||||
@update:modelValue="
|
||||
rule.service_tier = $event as
|
||||
| 'all'
|
||||
| 'priority'
|
||||
| 'flex'
|
||||
"
|
||||
:options="openaiFastPolicyTierOptions"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Action -->
|
||||
<div>
|
||||
<label
|
||||
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
|
||||
>
|
||||
{{ t("admin.settings.openaiFastPolicy.action") }}
|
||||
</label>
|
||||
<Select
|
||||
:modelValue="rule.action"
|
||||
@update:modelValue="
|
||||
rule.action = $event as 'pass' | 'filter' | 'block'
|
||||
"
|
||||
:options="openaiFastPolicyActionOptions"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Scope -->
|
||||
<div>
|
||||
<label
|
||||
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
|
||||
>
|
||||
{{ t("admin.settings.openaiFastPolicy.scope") }}
|
||||
</label>
|
||||
<Select
|
||||
:modelValue="rule.scope"
|
||||
@update:modelValue="
|
||||
rule.scope = $event as
|
||||
| 'all'
|
||||
| 'oauth'
|
||||
| 'apikey'
|
||||
| 'bedrock'
|
||||
"
|
||||
:options="openaiFastPolicyScopeOptions"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Error Message (only when action=block) -->
|
||||
<div v-if="rule.action === 'block'" class="mt-3">
|
||||
<label
|
||||
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
|
||||
>
|
||||
{{ t("admin.settings.openaiFastPolicy.errorMessage") }}
|
||||
</label>
|
||||
<input
|
||||
v-model="rule.error_message"
|
||||
type="text"
|
||||
class="input"
|
||||
:placeholder="
|
||||
t(
|
||||
'admin.settings.openaiFastPolicy.errorMessagePlaceholder',
|
||||
)
|
||||
"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-gray-400 dark:text-gray-500">
|
||||
{{ t("admin.settings.openaiFastPolicy.errorMessageHint") }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Model Whitelist -->
|
||||
<div class="mt-3">
|
||||
<label
|
||||
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
|
||||
>
|
||||
{{ t("admin.settings.openaiFastPolicy.modelWhitelist") }}
|
||||
</label>
|
||||
<p class="mb-2 text-xs text-gray-400 dark:text-gray-500">
|
||||
{{
|
||||
t("admin.settings.openaiFastPolicy.modelWhitelistHint")
|
||||
}}
|
||||
</p>
|
||||
<div
|
||||
v-for="(_, patternIdx) in rule.model_whitelist || []"
|
||||
:key="patternIdx"
|
||||
class="mb-1.5 flex items-center gap-2"
|
||||
>
|
||||
<input
|
||||
v-model="rule.model_whitelist![patternIdx]"
|
||||
type="text"
|
||||
class="input input-sm flex-1"
|
||||
:placeholder="
|
||||
t(
|
||||
'admin.settings.openaiFastPolicy.modelPatternPlaceholder',
|
||||
)
|
||||
"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
@click="
|
||||
removeOpenAIFastPolicyModelPattern(rule, patternIdx)
|
||||
"
|
||||
class="shrink-0 rounded p-1 text-red-400 transition-colors hover:bg-red-50 hover:text-red-600 dark:hover:bg-red-900/20"
|
||||
>
|
||||
<svg
|
||||
class="h-4 w-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="addOpenAIFastPolicyModelPattern(rule)"
|
||||
class="mb-2 inline-flex items-center gap-1 text-xs text-primary-600 transition-colors hover:text-primary-700 dark:text-primary-400 dark:hover:text-primary-300"
|
||||
>
|
||||
<svg
|
||||
class="h-3.5 w-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M12 4v16m8-8H4"
|
||||
/>
|
||||
</svg>
|
||||
{{ t("admin.settings.openaiFastPolicy.addModelPattern") }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Fallback Action (only when model_whitelist is non-empty) -->
|
||||
<div
|
||||
v-if="
|
||||
rule.model_whitelist && rule.model_whitelist.length > 0
|
||||
"
|
||||
class="mt-3"
|
||||
>
|
||||
<label
|
||||
class="mb-1 block text-xs font-medium text-gray-600 dark:text-gray-400"
|
||||
>
|
||||
{{ t("admin.settings.openaiFastPolicy.fallbackAction") }}
|
||||
</label>
|
||||
<Select
|
||||
:modelValue="rule.fallback_action || 'pass'"
|
||||
@update:modelValue="
|
||||
rule.fallback_action = $event as
|
||||
| 'pass'
|
||||
| 'filter'
|
||||
| 'block'
|
||||
"
|
||||
:options="openaiFastPolicyActionOptions"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-gray-400 dark:text-gray-500">
|
||||
{{
|
||||
t("admin.settings.openaiFastPolicy.fallbackActionHint")
|
||||
}}
|
||||
</p>
|
||||
<div v-if="rule.fallback_action === 'block'" class="mt-2">
|
||||
<input
|
||||
v-model="rule.fallback_error_message"
|
||||
type="text"
|
||||
class="input"
|
||||
:placeholder="
|
||||
t(
|
||||
'admin.settings.openaiFastPolicy.fallbackErrorMessagePlaceholder',
|
||||
)
|
||||
"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Add Rule Button -->
|
||||
<div>
|
||||
<button
|
||||
type="button"
|
||||
@click="addOpenAIFastPolicyRule"
|
||||
class="btn btn-secondary btn-sm inline-flex items-center gap-1"
|
||||
>
|
||||
<svg
|
||||
class="h-4 w-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M12 4v16m8-8H4"
|
||||
/>
|
||||
</svg>
|
||||
{{ t("admin.settings.openaiFastPolicy.addRule") }}
|
||||
</button>
|
||||
<p class="mt-2 text-xs text-gray-400 dark:text-gray-500">
|
||||
{{ t("admin.settings.openaiFastPolicy.saveHint") }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<!-- /Tab: Gateway -->
|
||||
|
||||
@@ -5199,6 +5478,7 @@ import type {
|
||||
SystemSettings,
|
||||
UpdateSettingsRequest,
|
||||
DefaultSubscriptionSetting,
|
||||
OpenAIFastPolicyRule,
|
||||
WeChatConnectMode,
|
||||
WebSearchEmulationConfig,
|
||||
WebSearchProviderConfig,
|
||||
@@ -5337,6 +5617,14 @@ const betaPolicyForm = reactive({
|
||||
}>,
|
||||
});
|
||||
|
||||
// OpenAI Fast/Flex Policy 状态
|
||||
const openaiFastPolicyForm = reactive({
|
||||
rules: [] as OpenAIFastPolicyRule[],
|
||||
});
|
||||
// 标记 openai_fast_policy_settings 是否已成功从后端加载,
|
||||
// 避免后端 GET 出错或字段缺失时,保存把默认规则覆盖成空数组。
|
||||
const openaiFastPolicyLoaded = ref(false);
|
||||
|
||||
const tablePageSizeMin = 5;
|
||||
const tablePageSizeMax = 1000;
|
||||
const tablePageSizeDefault = 20;
|
||||
@@ -6116,6 +6404,23 @@ async function loadSettings() {
|
||||
);
|
||||
form.oidc_connect_client_secret = "";
|
||||
|
||||
// Load OpenAI fast/flex policy rules from bulk settings.
|
||||
// 仅当 payload 真的包含该字段时填充并标记为已加载;否则保持表单空值,
|
||||
// 让 saveSettings 在未加载时跳过该字段,防止覆盖后端默认规则。
|
||||
if (
|
||||
settings.openai_fast_policy_settings &&
|
||||
Array.isArray(settings.openai_fast_policy_settings.rules)
|
||||
) {
|
||||
openaiFastPolicyForm.rules =
|
||||
settings.openai_fast_policy_settings.rules.map((rule) => ({
|
||||
...rule,
|
||||
model_whitelist: rule.model_whitelist
|
||||
? [...rule.model_whitelist]
|
||||
: [],
|
||||
}));
|
||||
openaiFastPolicyLoaded.value = true;
|
||||
}
|
||||
|
||||
// Load web search emulation config separately
|
||||
await loadWebSearchConfig();
|
||||
} catch (error: unknown) {
|
||||
@@ -6460,10 +6765,39 @@ async function saveSettings() {
|
||||
affiliate_enabled: form.affiliate_enabled,
|
||||
};
|
||||
|
||||
// 仅当 openai_fast_policy_settings 已成功从后端加载时才回写,
|
||||
// 否则省略整个字段,让后端保留既有规则(含默认值)。
|
||||
if (openaiFastPolicyLoaded.value) {
|
||||
payload.openai_fast_policy_settings = {
|
||||
rules: openaiFastPolicyForm.rules.map((rule) => {
|
||||
const whitelist = (rule.model_whitelist || [])
|
||||
.map((p) => p.trim())
|
||||
.filter((p) => p !== "");
|
||||
const hasWhitelist = whitelist.length > 0;
|
||||
return {
|
||||
service_tier: rule.service_tier,
|
||||
action: rule.action,
|
||||
scope: rule.scope,
|
||||
error_message:
|
||||
rule.action === "block" ? rule.error_message : undefined,
|
||||
model_whitelist: hasWhitelist ? whitelist : undefined,
|
||||
fallback_action: hasWhitelist
|
||||
? rule.fallback_action || "pass"
|
||||
: undefined,
|
||||
fallback_error_message:
|
||||
hasWhitelist && rule.fallback_action === "block"
|
||||
? rule.fallback_error_message
|
||||
: undefined,
|
||||
};
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults);
|
||||
|
||||
const updated = await adminAPI.settings.updateSettings(payload);
|
||||
for (const [key, value] of Object.entries(updated)) {
|
||||
if (key === "openai_fast_policy_settings") continue;
|
||||
if (value !== null && value !== undefined) {
|
||||
(form as Record<string, unknown>)[key] = value;
|
||||
}
|
||||
@@ -6507,6 +6841,20 @@ async function saveSettings() {
|
||||
form.wechat_connect_mode,
|
||||
);
|
||||
form.oidc_connect_client_secret = "";
|
||||
// Refresh OpenAI fast/flex policy from server response
|
||||
if (
|
||||
updated.openai_fast_policy_settings &&
|
||||
Array.isArray(updated.openai_fast_policy_settings.rules)
|
||||
) {
|
||||
openaiFastPolicyForm.rules =
|
||||
updated.openai_fast_policy_settings.rules.map((rule) => ({
|
||||
...rule,
|
||||
model_whitelist: rule.model_whitelist
|
||||
? [...rule.model_whitelist]
|
||||
: [],
|
||||
}));
|
||||
openaiFastPolicyLoaded.value = true;
|
||||
}
|
||||
// Save web search emulation config separately (errors handled internally)
|
||||
const wsOk = await saveWebSearchConfig();
|
||||
// Refresh cached settings so sidebar/header update immediately
|
||||
@@ -6846,6 +7194,61 @@ async function loadBetaPolicySettings() {
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== OpenAI Fast/Flex Policy ====================
|
||||
|
||||
const openaiFastPolicyTierOptions = computed(() => [
|
||||
{ value: "all", label: t("admin.settings.openaiFastPolicy.tierAll") },
|
||||
{
|
||||
value: "priority",
|
||||
label: t("admin.settings.openaiFastPolicy.tierPriority"),
|
||||
},
|
||||
{ value: "flex", label: t("admin.settings.openaiFastPolicy.tierFlex") },
|
||||
]);
|
||||
|
||||
const openaiFastPolicyActionOptions = computed(() => [
|
||||
{ value: "pass", label: t("admin.settings.openaiFastPolicy.actionPass") },
|
||||
{ value: "filter", label: t("admin.settings.openaiFastPolicy.actionFilter") },
|
||||
{ value: "block", label: t("admin.settings.openaiFastPolicy.actionBlock") },
|
||||
]);
|
||||
|
||||
const openaiFastPolicyScopeOptions = computed(() => [
|
||||
{ value: "all", label: t("admin.settings.openaiFastPolicy.scopeAll") },
|
||||
{ value: "oauth", label: t("admin.settings.openaiFastPolicy.scopeOAuth") },
|
||||
{ value: "apikey", label: t("admin.settings.openaiFastPolicy.scopeAPIKey") },
|
||||
{
|
||||
value: "bedrock",
|
||||
label: t("admin.settings.openaiFastPolicy.scopeBedrock"),
|
||||
},
|
||||
]);
|
||||
|
||||
function addOpenAIFastPolicyRule() {
|
||||
openaiFastPolicyForm.rules.push({
|
||||
service_tier: "priority",
|
||||
action: "filter",
|
||||
scope: "all",
|
||||
error_message: "",
|
||||
model_whitelist: [],
|
||||
fallback_action: "pass",
|
||||
fallback_error_message: "",
|
||||
});
|
||||
}
|
||||
|
||||
function removeOpenAIFastPolicyRule(index: number) {
|
||||
openaiFastPolicyForm.rules.splice(index, 1);
|
||||
}
|
||||
|
||||
function addOpenAIFastPolicyModelPattern(rule: OpenAIFastPolicyRule) {
|
||||
if (!rule.model_whitelist) rule.model_whitelist = [];
|
||||
rule.model_whitelist.push("");
|
||||
}
|
||||
|
||||
function removeOpenAIFastPolicyModelPattern(
|
||||
rule: OpenAIFastPolicyRule,
|
||||
idx: number,
|
||||
) {
|
||||
rule.model_whitelist?.splice(idx, 1);
|
||||
}
|
||||
|
||||
async function saveBetaPolicySettings() {
|
||||
betaPolicySaving.value = true;
|
||||
try {
|
||||
|
||||
Reference in New Issue
Block a user