mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 13:40:44 +08:00
287 lines
12 KiB
Go
287 lines
12 KiB
Go
|
|
package service
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"encoding/json"
|
|||
|
|
"errors"
|
|||
|
|
"testing"
|
|||
|
|
|
|||
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|||
|
|
"github.com/stretchr/testify/require"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
type openAIFastPolicyRepoStub struct {
|
|||
|
|
values map[string]string
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
|||
|
|
panic("unexpected Get call")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
|||
|
|
if v, ok := s.values[key]; ok {
|
|||
|
|
return v, nil
|
|||
|
|
}
|
|||
|
|
return "", ErrSettingNotFound
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error {
|
|||
|
|
if s.values == nil {
|
|||
|
|
s.values = map[string]string{}
|
|||
|
|
}
|
|||
|
|
s.values[key] = value
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
|||
|
|
panic("unexpected GetMultiple call")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
|||
|
|
panic("unexpected SetMultiple call")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
|||
|
|
panic("unexpected GetAll call")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error {
|
|||
|
|
panic("unexpected Delete call")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService {
|
|||
|
|
t.Helper()
|
|||
|
|
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
|
|||
|
|
if settings != nil {
|
|||
|
|
raw, err := json.Marshal(settings)
|
|||
|
|
require.NoError(t, err)
|
|||
|
|
repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
|
|||
|
|
}
|
|||
|
|
return &OpenAIGatewayService{
|
|||
|
|
settingService: NewSettingService(repo, &config.Config{}),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) {
|
|||
|
|
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
|||
|
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
|||
|
|
|
|||
|
|
// 默认策略对所有模型生效(whitelist 为空),因为 codex 的 service_tier=fast
|
|||
|
|
// 是用户级开关,与 model 正交。
|
|||
|
|
// gpt-5.5 + priority → filter
|
|||
|
|
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
|
|||
|
|
require.Equal(t, BetaPolicyActionFilter, action)
|
|||
|
|
|
|||
|
|
// gpt-5.5-turbo → filter
|
|||
|
|
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority)
|
|||
|
|
require.Equal(t, BetaPolicyActionFilter, action)
|
|||
|
|
|
|||
|
|
// gpt-4 + priority → filter(默认策略覆盖所有模型)
|
|||
|
|
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority)
|
|||
|
|
require.Equal(t, BetaPolicyActionFilter, action)
|
|||
|
|
|
|||
|
|
// gpt-5.5 + flex → pass (tier doesn't match)
|
|||
|
|
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex)
|
|||
|
|
require.Equal(t, BetaPolicyActionPass, action)
|
|||
|
|
|
|||
|
|
// empty tier → pass
|
|||
|
|
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "")
|
|||
|
|
require.Equal(t, BetaPolicyActionPass, action)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) {
|
|||
|
|
settings := &OpenAIFastPolicySettings{
|
|||
|
|
Rules: []OpenAIFastPolicyRule{{
|
|||
|
|
ServiceTier: OpenAIFastTierPriority,
|
|||
|
|
Action: BetaPolicyActionBlock,
|
|||
|
|
Scope: BetaPolicyScopeAll,
|
|||
|
|
ErrorMessage: "fast mode is not allowed",
|
|||
|
|
ModelWhitelist: []string{"gpt-5.5"},
|
|||
|
|
FallbackAction: BetaPolicyActionPass,
|
|||
|
|
}},
|
|||
|
|
}
|
|||
|
|
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
|||
|
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
|||
|
|
|
|||
|
|
action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
|
|||
|
|
require.Equal(t, BetaPolicyActionBlock, action)
|
|||
|
|
require.Equal(t, "fast mode is not allowed", msg)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) {
|
|||
|
|
settings := &OpenAIFastPolicySettings{
|
|||
|
|
Rules: []OpenAIFastPolicyRule{{
|
|||
|
|
ServiceTier: OpenAIFastTierAny,
|
|||
|
|
Action: BetaPolicyActionFilter,
|
|||
|
|
Scope: BetaPolicyScopeOAuth,
|
|||
|
|
}},
|
|||
|
|
}
|
|||
|
|
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
|||
|
|
|
|||
|
|
// OAuth account → rule matches
|
|||
|
|
oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
|||
|
|
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority)
|
|||
|
|
require.Equal(t, BetaPolicyActionFilter, action)
|
|||
|
|
|
|||
|
|
// API Key account → rule skipped → pass
|
|||
|
|
apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
|||
|
|
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority)
|
|||
|
|
require.Equal(t, BetaPolicyActionPass, action)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
|
|||
|
|
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
|||
|
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
|||
|
|
|
|||
|
|
// gpt-5.5 fast → service_tier stripped
|
|||
|
|
body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
|
|||
|
|
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
|||
|
|
require.NoError(t, err)
|
|||
|
|
require.NotContains(t, string(updated), `"service_tier"`)
|
|||
|
|
|
|||
|
|
// Client sending "fast" (alias for priority) also filtered
|
|||
|
|
body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
|
|||
|
|
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
|||
|
|
require.NoError(t, err)
|
|||
|
|
require.NotContains(t, string(updated), `"service_tier"`)
|
|||
|
|
|
|||
|
|
// gpt-4 priority → 默认策略对所有模型 filter,service_tier 被移除
|
|||
|
|
body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
|
|||
|
|
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
|
|||
|
|
require.NoError(t, err)
|
|||
|
|
require.NotContains(t, string(updated), `"service_tier"`)
|
|||
|
|
|
|||
|
|
// No service_tier → no-op
|
|||
|
|
body = []byte(`{"model":"gpt-5.5"}`)
|
|||
|
|
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
|||
|
|
require.NoError(t, err)
|
|||
|
|
require.Equal(t, string(body), string(updated))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后
|
|||
|
|
// 客户端显式发送的 OpenAI 官方合法 tier(auto/default/scale)能透传到上游而不被
|
|||
|
|
// 静默剥离。默认策略只针对 priority,所以这些 tier 落在 fall-through pass 分支。
|
|||
|
|
func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) {
|
|||
|
|
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
|||
|
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
|||
|
|
|
|||
|
|
for _, tier := range []string{"auto", "default", "scale"} {
|
|||
|
|
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
|
|||
|
|
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
|||
|
|
require.NoError(t, err, "tier %q should pass without error", tier)
|
|||
|
|
require.Contains(t, string(updated), `"service_tier":"`+tier+`"`,
|
|||
|
|
"tier %q should be preserved in body under default rule", tier)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// evaluate 层也应判定为 pass(默认规则 ServiceTier=priority 与 auto/default/scale 不匹配)
|
|||
|
|
for _, tier := range []string{"auto", "default", "scale"} {
|
|||
|
|
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier)
|
|||
|
|
require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置
|
|||
|
|
// ServiceTier=all + Action=filter 规则后,auto/default/scale 等官方 tier 也会
|
|||
|
|
// 被剥离。这是符合预期的——首条匹配 short-circuit,"all" 覆盖任意已识别 tier。
|
|||
|
|
func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) {
|
|||
|
|
settings := &OpenAIFastPolicySettings{
|
|||
|
|
Rules: []OpenAIFastPolicyRule{{
|
|||
|
|
ServiceTier: OpenAIFastTierAny,
|
|||
|
|
Action: BetaPolicyActionFilter,
|
|||
|
|
Scope: BetaPolicyScopeAll,
|
|||
|
|
}},
|
|||
|
|
}
|
|||
|
|
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
|||
|
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
|||
|
|
|
|||
|
|
for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} {
|
|||
|
|
body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
|
|||
|
|
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
|||
|
|
require.NoError(t, err)
|
|||
|
|
require.NotContains(t, string(updated), `"service_tier"`,
|
|||
|
|
"tier %q should be stripped under ServiceTier=all + filter rule", tier)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离
|
|||
|
|
// (normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段;
|
|||
|
|
// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op,因为字段已不可能存在
|
|||
|
|
// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。
|
|||
|
|
func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) {
|
|||
|
|
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
|
|||
|
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
|||
|
|
|
|||
|
|
// normalize 阶段会将未知值剥离
|
|||
|
|
require.Nil(t, normalizeOpenAIServiceTier("xxx"))
|
|||
|
|
|
|||
|
|
// applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错,body 透传不变
|
|||
|
|
// (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离)
|
|||
|
|
body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`)
|
|||
|
|
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
|||
|
|
require.NoError(t, err)
|
|||
|
|
require.Equal(t, string(body), string(updated))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) {
|
|||
|
|
settings := &OpenAIFastPolicySettings{
|
|||
|
|
Rules: []OpenAIFastPolicyRule{{
|
|||
|
|
ServiceTier: OpenAIFastTierPriority,
|
|||
|
|
Action: BetaPolicyActionBlock,
|
|||
|
|
Scope: BetaPolicyScopeAll,
|
|||
|
|
ErrorMessage: "fast mode is blocked for gpt-5.5",
|
|||
|
|
ModelWhitelist: []string{"gpt-5.5"},
|
|||
|
|
FallbackAction: BetaPolicyActionPass,
|
|||
|
|
}},
|
|||
|
|
}
|
|||
|
|
svc := newOpenAIGatewayServiceWithSettings(t, settings)
|
|||
|
|
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
|||
|
|
|
|||
|
|
body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`)
|
|||
|
|
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
|
|||
|
|
require.Error(t, err)
|
|||
|
|
var blocked *OpenAIFastBlockedError
|
|||
|
|
require.True(t, errors.As(err, &blocked))
|
|||
|
|
require.Contains(t, blocked.Message, "fast mode is blocked")
|
|||
|
|
require.Equal(t, string(body), string(updated)) // body not mutated on block
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) {
|
|||
|
|
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
|
|||
|
|
svc := NewSettingService(repo, &config.Config{})
|
|||
|
|
|
|||
|
|
// Invalid action rejected
|
|||
|
|
err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
|||
|
|
Rules: []OpenAIFastPolicyRule{{
|
|||
|
|
ServiceTier: OpenAIFastTierPriority,
|
|||
|
|
Action: "bogus",
|
|||
|
|
Scope: BetaPolicyScopeAll,
|
|||
|
|
}},
|
|||
|
|
})
|
|||
|
|
require.Error(t, err)
|
|||
|
|
|
|||
|
|
// Invalid service_tier rejected
|
|||
|
|
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
|||
|
|
Rules: []OpenAIFastPolicyRule{{
|
|||
|
|
ServiceTier: "turbo",
|
|||
|
|
Action: BetaPolicyActionPass,
|
|||
|
|
Scope: BetaPolicyScopeAll,
|
|||
|
|
}},
|
|||
|
|
})
|
|||
|
|
require.Error(t, err)
|
|||
|
|
|
|||
|
|
// Valid settings persisted
|
|||
|
|
err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
|
|||
|
|
Rules: []OpenAIFastPolicyRule{{
|
|||
|
|
ServiceTier: OpenAIFastTierPriority,
|
|||
|
|
Action: BetaPolicyActionFilter,
|
|||
|
|
Scope: BetaPolicyScopeAll,
|
|||
|
|
}},
|
|||
|
|
})
|
|||
|
|
require.NoError(t, err)
|
|||
|
|
|
|||
|
|
got, err := svc.GetOpenAIFastPolicySettings(context.Background())
|
|||
|
|
require.NoError(t, err)
|
|||
|
|
require.Len(t, got.Rules, 1)
|
|||
|
|
require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier)
|
|||
|
|
}
|