mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-05-05 05:30:44 +08:00
对称参照 Claude BetaPolicy 的 fast-mode 过滤实现,新增针对 OpenAI 上游
service_tier 字段(priority / flex,含客户端 "fast" → "priority" 归一化)的
pass / filter / block 三态策略,覆盖全部 OpenAI 入口 + admin 配置入口。
后端核心
- 新增 SettingKeyOpenAIFastPolicySettings、OpenAIFastPolicyRule、
OpenAIFastPolicySettings 配置模型,含规则的 service_tier × action × scope
× 模型白名单 × fallback action 维度。
- SettingService.Get/SetOpenAIFastPolicySettings;缺失时返回内置默认策略
(所有模型的 priority 走 filter,whitelist 为空,fallback=pass)。设计
依据:service_tier=fast 是用户级开关,与 model 字段正交,默认锁定特定
model slug 会留下"用 gpt-4 + fast 透传 priority 上游"的绕过路径。JSON
解析失败不再静默 fallback,slog.Warn 记录脏数据,便于运维定位。
- service_tier 归一化(trim + ToLower + fast→priority + 白名单 priority/flex)
与策略评估(evaluateOpenAIFastPolicy)作为唯一真实来源,HTTP / WS 共用。
抽出纯函数 evaluateOpenAIFastPolicyWithSettings,配合 ctx-bound settings
快照(withOpenAIFastPolicyContext / openAIFastPolicySettingsFromContext),
WS 长会话入口预取一次后所有帧复用,避免每帧打到 settingService。
HTTP 入口(4 个)
- Chat Completions、Anthropic 兼容(Messages,含 BetaFastMode→priority 二次
命中)、原生 Responses、Passthrough Responses 全部接入
applyOpenAIFastPolicyToBody,filter 走 sjson 顶层删除 service_tier,block
返回 403 forbidden_error JSON。
- 4 入口统一使用 upstream 视角的 model(GetMappedModel +
normalizeOpenAIModelForUpstream + Codex OAuth normalize 后的 slug),
避免 chat/messages/native /responses/passthrough 因为 model 维度不同
造成 whitelist 命中差异。
- 在 pass 路径也把客户端 "fast" 别名归一化为 "priority" 写回 body,
否则 native /responses 与 passthrough 入口会把 "fast" 原样透传给上游
导致 400/拒绝(chat-completions 入口的 normalizeResponsesBodyServiceTier
此前已具备同等行为)。
WebSocket 入口
- 新增 applyOpenAIFastPolicyToWSResponseCreate:严格匹配
type="response.create",仅处理顶层 service_tier;filter 用 sjson 删字段,
block 返回 typed *OpenAIFastBlockedError。
- ingress 路径在 parseClientPayload 内调用,block 命中先 Write Realtime
风格 error event 再返回 OpenAIWSClientCloseError(StatusPolicyViolation
=1008),依赖底层 WebSocket Conn.Write 的同步 flush 保证 error 先于
close。
- passthrough 路径在 RunEntry 前对 firstClientMessage 应用策略,并通过
openAIWSPolicyEnforcingFrameConn 包装 ReadFrame 对每个 client→upstream
帧执行策略;后续帧无 model 字段时回退到 capturedSessionModel。
filter 闭包内同时侦测 session.update / session.created 帧的 session.model
字段刷新 capturedSessionModel,封堵"首帧 model=gpt-4o(pass)→
session.update 改为 gpt-5.5 → 不带 model 的 response.create fallback
到 gpt-4o"的 mid-session 绕过路径。
- passthrough billing:requestServiceTier 在策略 filter 之后再从
firstClientMessage 提取,filter 命中时 OpenAIForwardResult.ServiceTier
上报 nil(default tier),与 HTTP 入口(reqBody 来自 post-filter map)
/ WS ingress(payload 来自 post-filter bytes)的语义一致。
- 错误事件 schema:{event_id: "evt_<32hex>", type: "error",
error: {type: "forbidden_error", code: "policy_violation", message}},
与 OpenAI codex 客户端 error event 解析兼容。
Admin / Frontend
- dto.SystemSettings / UpdateSettingsRequest 新增
openai_fast_policy_settings 字段(omitempty),bulk GET/PUT 接入。
- Settings 页 Gateway 页签新增 Fast/Flex Policy 表单卡片:
service_tier × action × scope × 模型白名单 × fallback action 全字段配置。
- 前端守门:openaiFastPolicyLoaded 标志仅在 GET 真带回字段时才允许回写,
避免 rollout/错误把默认规则覆盖成空;saveSettings 回写循环 skip 该字段,
由专用刷新逻辑处理;仅 action=block 时发送 error_message,匹配后端
omitempty 行为。
测试
- HTTP 路径:openai_fast_policy_test.go 覆盖默认配置(whitelist=[],所有
模型 priority filter)/ block 自定义错误 / scope 区分 / filter 删字段 /
block 不改 body / block 短路上游 / Anthropic BetaFastMode 触发 OpenAI
fast policy 等场景。
- WebSocket 路径:openai_fast_policy_ws_test.go 覆盖
helper 单元(filter / fast→priority 归一化 / flex 透传 / block typed
error / 无 service_tier 字节不变 / 非 response.create 帧不动 / 空 type
帧不动 / event_id+code 字段断言 / 非字符串 service_tier 容错)+
pass 路径 fast 别名归一化回归 +
ingress 端到端(filter 后上游不含 service_tier / block 后客户端先收
error event 再收 close 1008 且上游 0 写)+
passthrough capturedSessionModel fallback 用例(whitelist 策略下首帧
建立、缺 model 命中 fallback、缺少 fallback 时的 leak 文档化)+
passthrough session.update / session.created 旋转 capturedSessionModel
的 mid-session 绕过回归 +
passthrough billing post-filter ServiceTier 与 idempotent filter 回归。
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
289 lines
9.0 KiB
Go
289 lines
9.0 KiB
Go
package admin
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/netip"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestParseTimeRange(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
req := httptest.NewRequest(http.MethodGet, "/?start_date=2024-01-01&end_date=2024-01-02&timezone=UTC", nil)
|
|
c.Request = req
|
|
|
|
start, end := parseTimeRange(c)
|
|
require.Equal(t, time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), start)
|
|
require.Equal(t, time.Date(2024, 1, 3, 0, 0, 0, 0, time.UTC), end)
|
|
|
|
req = httptest.NewRequest(http.MethodGet, "/?start_date=bad&timezone=UTC", nil)
|
|
c.Request = req
|
|
start, end = parseTimeRange(c)
|
|
require.False(t, start.IsZero())
|
|
require.False(t, end.IsZero())
|
|
}
|
|
|
|
func TestParseOpsViewParam(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/?view=excluded", nil)
|
|
require.Equal(t, opsListViewExcluded, parseOpsViewParam(c))
|
|
|
|
c2, _ := gin.CreateTestContext(w)
|
|
c2.Request = httptest.NewRequest(http.MethodGet, "/?view=all", nil)
|
|
require.Equal(t, opsListViewAll, parseOpsViewParam(c2))
|
|
|
|
c3, _ := gin.CreateTestContext(w)
|
|
c3.Request = httptest.NewRequest(http.MethodGet, "/?view=unknown", nil)
|
|
require.Equal(t, opsListViewErrors, parseOpsViewParam(c3))
|
|
|
|
require.Equal(t, "", parseOpsViewParam(nil))
|
|
}
|
|
|
|
func TestParseOpsDuration(t *testing.T) {
|
|
dur, ok := parseOpsDuration("1h")
|
|
require.True(t, ok)
|
|
require.Equal(t, time.Hour, dur)
|
|
|
|
_, ok = parseOpsDuration("invalid")
|
|
require.False(t, ok)
|
|
}
|
|
|
|
func TestParseOpsOpenAITokenStatsDuration(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
want time.Duration
|
|
ok bool
|
|
}{
|
|
{input: "30m", want: 30 * time.Minute, ok: true},
|
|
{input: "1h", want: time.Hour, ok: true},
|
|
{input: "1d", want: 24 * time.Hour, ok: true},
|
|
{input: "15d", want: 15 * 24 * time.Hour, ok: true},
|
|
{input: "30d", want: 30 * 24 * time.Hour, ok: true},
|
|
{input: "7d", want: 0, ok: false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
got, ok := parseOpsOpenAITokenStatsDuration(tt.input)
|
|
require.Equal(t, tt.ok, ok, "input=%s", tt.input)
|
|
require.Equal(t, tt.want, got, "input=%s", tt.input)
|
|
}
|
|
}
|
|
|
|
func TestParseOpsOpenAITokenStatsFilter_Defaults(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
|
|
before := time.Now().UTC()
|
|
filter, err := parseOpsOpenAITokenStatsFilter(c)
|
|
after := time.Now().UTC()
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, filter)
|
|
require.Equal(t, "30d", filter.TimeRange)
|
|
require.Equal(t, 1, filter.Page)
|
|
require.Equal(t, 20, filter.PageSize)
|
|
require.Equal(t, 0, filter.TopN)
|
|
require.Nil(t, filter.GroupID)
|
|
require.Equal(t, "", filter.Platform)
|
|
require.True(t, filter.StartTime.Before(filter.EndTime))
|
|
require.WithinDuration(t, before.Add(-30*24*time.Hour), filter.StartTime, 2*time.Second)
|
|
require.WithinDuration(t, after, filter.EndTime, 2*time.Second)
|
|
}
|
|
|
|
func TestParseOpsOpenAITokenStatsFilter_WithTopN(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(
|
|
http.MethodGet,
|
|
"/?time_range=1h&platform=openai&group_id=12&top_n=50",
|
|
nil,
|
|
)
|
|
|
|
filter, err := parseOpsOpenAITokenStatsFilter(c)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "1h", filter.TimeRange)
|
|
require.Equal(t, "openai", filter.Platform)
|
|
require.NotNil(t, filter.GroupID)
|
|
require.Equal(t, int64(12), *filter.GroupID)
|
|
require.Equal(t, 50, filter.TopN)
|
|
require.Equal(t, 0, filter.Page)
|
|
require.Equal(t, 0, filter.PageSize)
|
|
}
|
|
|
|
func TestParseOpsOpenAITokenStatsFilter_InvalidParams(t *testing.T) {
|
|
tests := []string{
|
|
"/?time_range=7d",
|
|
"/?group_id=0",
|
|
"/?group_id=abc",
|
|
"/?top_n=0",
|
|
"/?top_n=101",
|
|
"/?top_n=10&page=1",
|
|
"/?top_n=10&page_size=20",
|
|
"/?page=0",
|
|
"/?page_size=0",
|
|
"/?page_size=101",
|
|
}
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
for _, rawURL := range tests {
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodGet, rawURL, nil)
|
|
|
|
_, err := parseOpsOpenAITokenStatsFilter(c)
|
|
require.Error(t, err, "url=%s", rawURL)
|
|
}
|
|
}
|
|
|
|
func TestParseOpsTimeRange(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
now := time.Now().UTC()
|
|
startStr := now.Add(-time.Hour).Format(time.RFC3339)
|
|
endStr := now.Format(time.RFC3339)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/?start_time="+startStr+"&end_time="+endStr, nil)
|
|
start, end, err := parseOpsTimeRange(c, "1h")
|
|
require.NoError(t, err)
|
|
require.True(t, start.Before(end))
|
|
|
|
c2, _ := gin.CreateTestContext(w)
|
|
c2.Request = httptest.NewRequest(http.MethodGet, "/?start_time=bad", nil)
|
|
_, _, err = parseOpsTimeRange(c2, "1h")
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestParseOpsRealtimeWindow(t *testing.T) {
|
|
dur, label, ok := parseOpsRealtimeWindow("5m")
|
|
require.True(t, ok)
|
|
require.Equal(t, 5*time.Minute, dur)
|
|
require.Equal(t, "5min", label)
|
|
|
|
_, _, ok = parseOpsRealtimeWindow("invalid")
|
|
require.False(t, ok)
|
|
}
|
|
|
|
func TestPickThroughputBucketSeconds(t *testing.T) {
|
|
require.Equal(t, 60, pickThroughputBucketSeconds(30*time.Minute))
|
|
require.Equal(t, 300, pickThroughputBucketSeconds(6*time.Hour))
|
|
require.Equal(t, 3600, pickThroughputBucketSeconds(48*time.Hour))
|
|
}
|
|
|
|
func TestParseOpsQueryMode(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
w := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(w)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/?mode=raw", nil)
|
|
require.Equal(t, service.ParseOpsQueryMode("raw"), parseOpsQueryMode(c))
|
|
require.Equal(t, service.OpsQueryMode(""), parseOpsQueryMode(nil))
|
|
}
|
|
|
|
func TestOpsAlertRuleValidation(t *testing.T) {
|
|
raw := map[string]json.RawMessage{
|
|
"name": json.RawMessage(`"High error rate"`),
|
|
"metric_type": json.RawMessage(`"error_rate"`),
|
|
"operator": json.RawMessage(`">"`),
|
|
"threshold": json.RawMessage(`90`),
|
|
}
|
|
|
|
validated, err := validateOpsAlertRulePayload(raw)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "High error rate", validated.Name)
|
|
|
|
_, err = validateOpsAlertRulePayload(map[string]json.RawMessage{})
|
|
require.Error(t, err)
|
|
|
|
require.True(t, isPercentOrRateMetric("error_rate"))
|
|
require.False(t, isPercentOrRateMetric("concurrency_queue_depth"))
|
|
}
|
|
|
|
func TestOpsWSHelpers(t *testing.T) {
|
|
prefixes, invalid := parseTrustedProxyList("10.0.0.0/8,invalid")
|
|
require.Len(t, prefixes, 1)
|
|
require.Len(t, invalid, 1)
|
|
|
|
host := hostWithoutPort("example.com:443")
|
|
require.Equal(t, "example.com", host)
|
|
|
|
addr := netip.MustParseAddr("10.0.0.1")
|
|
require.True(t, isAddrInTrustedProxies(addr, prefixes))
|
|
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
|
|
}
|
|
|
|
// TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier 验证 admin
|
|
// 写入路径会把 ServiceTier 的空字符串/空白/大小写归一化为
|
|
// service.OpenAIFastTierAny ("all"),避免落盘时 "" 与 "all" 双语义。
|
|
func TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier(t *testing.T) {
|
|
t.Run("nil input returns nil", func(t *testing.T) {
|
|
require.Nil(t, openaiFastPolicySettingsFromDTO(nil))
|
|
})
|
|
|
|
t.Run("empty service_tier becomes 'all'", func(t *testing.T) {
|
|
in := &dto.OpenAIFastPolicySettings{
|
|
Rules: []dto.OpenAIFastPolicyRule{{
|
|
ServiceTier: "",
|
|
Action: "filter",
|
|
Scope: "all",
|
|
}},
|
|
}
|
|
out := openaiFastPolicySettingsFromDTO(in)
|
|
require.NotNil(t, out)
|
|
require.Len(t, out.Rules, 1)
|
|
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
|
|
require.Equal(t, "all", out.Rules[0].ServiceTier)
|
|
})
|
|
|
|
t.Run("whitespace-only service_tier becomes 'all'", func(t *testing.T) {
|
|
in := &dto.OpenAIFastPolicySettings{
|
|
Rules: []dto.OpenAIFastPolicyRule{{
|
|
ServiceTier: " ",
|
|
Action: "pass",
|
|
Scope: "all",
|
|
}},
|
|
}
|
|
out := openaiFastPolicySettingsFromDTO(in)
|
|
require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
|
|
})
|
|
|
|
t.Run("uppercase service_tier is lowercased", func(t *testing.T) {
|
|
in := &dto.OpenAIFastPolicySettings{
|
|
Rules: []dto.OpenAIFastPolicyRule{{
|
|
ServiceTier: "PRIORITY",
|
|
Action: "filter",
|
|
Scope: "all",
|
|
}},
|
|
}
|
|
out := openaiFastPolicySettingsFromDTO(in)
|
|
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
|
|
})
|
|
|
|
t.Run("non-empty values pass through (lowercased)", func(t *testing.T) {
|
|
in := &dto.OpenAIFastPolicySettings{
|
|
Rules: []dto.OpenAIFastPolicyRule{
|
|
{ServiceTier: "priority", Action: "filter", Scope: "all"},
|
|
{ServiceTier: "flex", Action: "block", Scope: "oauth"},
|
|
{ServiceTier: "all", Action: "pass", Scope: "apikey"},
|
|
},
|
|
}
|
|
out := openaiFastPolicySettingsFromDTO(in)
|
|
require.Len(t, out.Rules, 3)
|
|
require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
|
|
require.Equal(t, service.OpenAIFastTierFlex, out.Rules[1].ServiceTier)
|
|
require.Equal(t, service.OpenAIFastTierAny, out.Rules[2].ServiceTier)
|
|
})
|
|
}
|