Merge tag 'v0.1.90' into merge/upstream-v0.1.90

注册邮箱域名白名单策略上线,后台大数据场景性能大幅优化。

- 注册邮箱域名白名单:支持管理员配置允许注册的邮箱域名策略
- Keys 页面表单筛选:用户 /keys 页面支持按条件筛选 API Key
- Settings 页面分 Tab 拆分:管理后台设置页面按功能模块分 Tab 展示

- 后台大数据场景加载性能优化:仪表盘/用户/账号/Ops 页面大数据集加载显著提速
- Usage 大表分页优化:默认避免全量 COUNT(*),大幅降低分页查询耗时
- 消除重复的 normalizeAccountIDList,补充新增组件的单元测试
- 清理无用文件和过时文档,精简项目结构
- EmailVerifyView 硬编码英文字符串替换为 i18n 调用

- 修复 Anthropic 平台无限流重置时间的 429 误标记账号限流问题
- 修复自定义菜单页面管理员视角菜单不生效问题
- 修复 Ops 错误详情弹窗未展示真实上游 payload 的问题
- 修复充值/订阅菜单 icon 显示问题

# Conflicts:
#	.gitignore
#	backend/cmd/server/VERSION
#	backend/ent/group.go
#	backend/ent/runtime/runtime.go
#	backend/ent/schema/group.go
#	backend/go.sum
#	backend/internal/handler/admin/account_handler.go
#	backend/internal/handler/admin/dashboard_handler.go
#	backend/internal/pkg/usagestats/usage_log_types.go
#	backend/internal/repository/group_repo.go
#	backend/internal/repository/usage_log_repo.go
#	backend/internal/server/middleware/security_headers.go
#	backend/internal/server/router.go
#	backend/internal/service/account_usage_service.go
#	backend/internal/service/admin_service_bulk_update_test.go
#	backend/internal/service/dashboard_service.go
#	backend/internal/service/gateway_service.go
#	frontend/src/api/admin/dashboard.ts
#	frontend/src/components/account/BulkEditAccountModal.vue
#	frontend/src/components/charts/GroupDistributionChart.vue
#	frontend/src/components/layout/AppSidebar.vue
#	frontend/src/i18n/locales/en.ts
#	frontend/src/i18n/locales/zh.ts
#	frontend/src/views/admin/GroupsView.vue
#	frontend/src/views/admin/SettingsView.vue
#	frontend/src/views/admin/UsageView.vue
#	frontend/src/views/user/PurchaseSubscriptionView.vue
This commit is contained in:
erio
2026-03-04 19:58:38 +08:00
461 changed files with 63392 additions and 6617 deletions

View File

@@ -3,11 +3,14 @@ package service
import (
"encoding/json"
"hash/fnv"
"reflect"
"sort"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
@@ -50,6 +53,14 @@ type Account struct {
AccountGroups []AccountGroup
GroupIDs []int64
Groups []*Group
// model_mapping 热路径缓存(非持久化字段)
modelMappingCache map[string]string
modelMappingCacheReady bool
modelMappingCacheCredentialsPtr uintptr
modelMappingCacheRawPtr uintptr
modelMappingCacheRawLen int
modelMappingCacheRawSig uint64
}
type TempUnschedulableRule struct {
@@ -349,6 +360,39 @@ func parseTempUnschedInt(value any) int {
}
func (a *Account) GetModelMapping() map[string]string {
credentialsPtr := mapPtr(a.Credentials)
rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
rawPtr := mapPtr(rawMapping)
rawLen := len(rawMapping)
rawSig := uint64(0)
rawSigReady := false
if a.modelMappingCacheReady &&
a.modelMappingCacheCredentialsPtr == credentialsPtr &&
a.modelMappingCacheRawPtr == rawPtr &&
a.modelMappingCacheRawLen == rawLen {
rawSig = modelMappingSignature(rawMapping)
rawSigReady = true
if a.modelMappingCacheRawSig == rawSig {
return a.modelMappingCache
}
}
mapping := a.resolveModelMapping(rawMapping)
if !rawSigReady {
rawSig = modelMappingSignature(rawMapping)
}
a.modelMappingCache = mapping
a.modelMappingCacheReady = true
a.modelMappingCacheCredentialsPtr = credentialsPtr
a.modelMappingCacheRawPtr = rawPtr
a.modelMappingCacheRawLen = rawLen
a.modelMappingCacheRawSig = rawSig
return mapping
}
func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
if a.Credentials == nil {
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
@@ -356,32 +400,31 @@ func (a *Account) GetModelMapping() map[string]string {
}
return nil
}
raw, ok := a.Credentials["model_mapping"]
if !ok || raw == nil {
if len(rawMapping) == 0 {
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
}
return nil
}
if m, ok := raw.(map[string]any); ok {
result := make(map[string]string)
for k, v := range m {
if s, ok := v.(string); ok {
result[k] = s
}
}
if len(result) > 0 {
if a.Platform == domain.PlatformAntigravity {
ensureAntigravityDefaultPassthroughs(result, []string{
"gemini-3-flash",
"gemini-3.1-pro-high",
"gemini-3.1-pro-low",
})
}
return result
result := make(map[string]string)
for k, v := range rawMapping {
if s, ok := v.(string); ok {
result[k] = s
}
}
if len(result) > 0 {
if a.Platform == domain.PlatformAntigravity {
ensureAntigravityDefaultPassthroughs(result, []string{
"gemini-3-flash",
"gemini-3.1-pro-high",
"gemini-3.1-pro-low",
})
}
return result
}
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
@@ -389,6 +432,37 @@ func (a *Account) GetModelMapping() map[string]string {
return nil
}
func mapPtr(m map[string]any) uintptr {
if m == nil {
return 0
}
return reflect.ValueOf(m).Pointer()
}
func modelMappingSignature(rawMapping map[string]any) uint64 {
if len(rawMapping) == 0 {
return 0
}
keys := make([]string, 0, len(rawMapping))
for k := range rawMapping {
keys = append(keys, k)
}
sort.Strings(keys)
h := fnv.New64a()
for _, k := range keys {
_, _ = h.Write([]byte(k))
_, _ = h.Write([]byte{0})
if v, ok := rawMapping[k].(string); ok {
_, _ = h.Write([]byte(v))
} else {
_, _ = h.Write([]byte{1})
}
_, _ = h.Write([]byte{0xff})
}
return h.Sum64()
}
func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) {
if mapping == nil || model == "" {
return
@@ -742,6 +816,159 @@ func (a *Account) IsOpenAIPassthroughEnabled() bool {
return false
}
// IsOpenAIResponsesWebSocketV2Enabled 返回 OpenAI 账号是否开启 Responses WebSocket v2。
//
// 分类型新字段:
// - OAuth 账号accounts.extra.openai_oauth_responses_websockets_v2_enabled
// - API Key 账号accounts.extra.openai_apikey_responses_websockets_v2_enabled
//
// 兼容字段:
// - accounts.extra.responses_websockets_v2_enabled
// - accounts.extra.openai_ws_enabled历史开关
//
// 优先级:
// 1. 按账号类型读取分类型字段
// 2. 分类型字段缺失时,回退兼容字段
func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
if a.IsOpenAIOAuth() {
if enabled, ok := a.Extra["openai_oauth_responses_websockets_v2_enabled"].(bool); ok {
return enabled
}
}
if a.IsOpenAIApiKey() {
if enabled, ok := a.Extra["openai_apikey_responses_websockets_v2_enabled"].(bool); ok {
return enabled
}
}
if enabled, ok := a.Extra["responses_websockets_v2_enabled"].(bool); ok {
return enabled
}
if enabled, ok := a.Extra["openai_ws_enabled"].(bool); ok {
return enabled
}
return false
}
const (
OpenAIWSIngressModeOff = "off"
OpenAIWSIngressModeShared = "shared"
OpenAIWSIngressModeDedicated = "dedicated"
)
func normalizeOpenAIWSIngressMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case OpenAIWSIngressModeOff:
return OpenAIWSIngressModeOff
case OpenAIWSIngressModeShared:
return OpenAIWSIngressModeShared
case OpenAIWSIngressModeDedicated:
return OpenAIWSIngressModeDedicated
default:
return ""
}
}
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
return normalized
}
return OpenAIWSIngressModeShared
}
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式off/shared/dedicated
//
// 优先级:
// 1. 分类型 mode 新字段string
// 2. 分类型 enabled 旧字段bool
// 3. 兼容 enabled 旧字段bool
// 4. defaultMode非法时回退 shared
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
if a == nil || !a.IsOpenAI() {
return OpenAIWSIngressModeOff
}
if a.Extra == nil {
return resolvedDefault
}
resolveModeString := func(key string) (string, bool) {
raw, ok := a.Extra[key]
if !ok {
return "", false
}
mode, ok := raw.(string)
if !ok {
return "", false
}
normalized := normalizeOpenAIWSIngressMode(mode)
if normalized == "" {
return "", false
}
return normalized, true
}
resolveBoolMode := func(key string) (string, bool) {
raw, ok := a.Extra[key]
if !ok {
return "", false
}
enabled, ok := raw.(bool)
if !ok {
return "", false
}
if enabled {
return OpenAIWSIngressModeShared, true
}
return OpenAIWSIngressModeOff, true
}
if a.IsOpenAIOAuth() {
if mode, ok := resolveModeString("openai_oauth_responses_websockets_v2_mode"); ok {
return mode
}
if mode, ok := resolveBoolMode("openai_oauth_responses_websockets_v2_enabled"); ok {
return mode
}
}
if a.IsOpenAIApiKey() {
if mode, ok := resolveModeString("openai_apikey_responses_websockets_v2_mode"); ok {
return mode
}
if mode, ok := resolveBoolMode("openai_apikey_responses_websockets_v2_enabled"); ok {
return mode
}
}
if mode, ok := resolveBoolMode("responses_websockets_v2_enabled"); ok {
return mode
}
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
return mode
}
return resolvedDefault
}
// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
// 字段accounts.extra.openai_ws_force_http。
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
enabled, ok := a.Extra["openai_ws_force_http"].(bool)
return ok && enabled
}
// IsOpenAIWSAllowStoreRecoveryEnabled 返回账号级 store 恢复开关。
// 字段accounts.extra.openai_ws_allow_store_recovery。
func (a *Account) IsOpenAIWSAllowStoreRecoveryEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
enabled, ok := a.Extra["openai_ws_allow_store_recovery"].(bool)
return ok && enabled
}
// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。
func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
@@ -806,6 +1033,26 @@ func (a *Account) IsTLSFingerprintEnabled() bool {
return false
}
// GetUserMsgQueueMode 获取用户消息队列模式
// "serialize" = 串行队列, "throttle" = 软性限速, "" = 未设置(使用全局配置)
func (a *Account) GetUserMsgQueueMode() string {
if a.Extra == nil {
return ""
}
// 优先读取新字段 user_msg_queue_mode白名单校验非法值视为未设置
if mode, ok := a.Extra["user_msg_queue_mode"].(string); ok && mode != "" {
if mode == config.UMQModeSerialize || mode == config.UMQModeThrottle {
return mode
}
return "" // 非法值 fallback 到全局配置
}
// 向后兼容: user_msg_queue_enabled: true → "serialize"
if enabled, ok := a.Extra["user_msg_queue_enabled"].(bool); ok && enabled {
return config.UMQModeSerialize
}
return ""
}
// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将在一段时间内15分钟固定 metadata.user_id 中的 session ID
@@ -911,6 +1158,80 @@ func (a *Account) GetSessionIdleTimeoutMinutes() int {
return 5
}
// GetBaseRPM 获取基础 RPM 限制
// 返回 0 表示未启用(负数视为无效配置,按 0 处理)
func (a *Account) GetBaseRPM() int {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["base_rpm"]; ok {
val := parseExtraInt(v)
if val > 0 {
return val
}
}
return 0
}
// GetRPMStrategy 获取 RPM 策略
// "tiered" = 三区模型(默认), "sticky_exempt" = 粘性豁免
func (a *Account) GetRPMStrategy() string {
if a.Extra == nil {
return "tiered"
}
if v, ok := a.Extra["rpm_strategy"]; ok {
if s, ok := v.(string); ok && s == "sticky_exempt" {
return "sticky_exempt"
}
}
return "tiered"
}
// GetRPMStickyBuffer 获取 RPM 粘性缓冲数量
// tiered 模式下的黄区大小,默认为 base_rpm 的 20%(至少 1
func (a *Account) GetRPMStickyBuffer() int {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["rpm_sticky_buffer"]; ok {
val := parseExtraInt(v)
if val > 0 {
return val
}
}
base := a.GetBaseRPM()
buffer := base / 5
if buffer < 1 && base > 0 {
buffer = 1
}
return buffer
}
// CheckRPMSchedulability 根据当前 RPM 计数检查调度状态
// 复用 WindowCostSchedulability 三态Schedulable / StickyOnly / NotSchedulable
func (a *Account) CheckRPMSchedulability(currentRPM int) WindowCostSchedulability {
baseRPM := a.GetBaseRPM()
if baseRPM <= 0 {
return WindowCostSchedulable
}
if currentRPM < baseRPM {
return WindowCostSchedulable
}
strategy := a.GetRPMStrategy()
if strategy == "sticky_exempt" {
return WindowCostStickyOnly // 粘性豁免无红区
}
// tiered: 黄区 + 红区
buffer := a.GetRPMStickyBuffer()
if currentRPM < baseRPM+buffer {
return WindowCostStickyOnly
}
return WindowCostNotSchedulable
}
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
// - 费用 < 阈值: WindowCostSchedulable可正常调度
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly仅粘性会话
@@ -974,6 +1295,12 @@ func parseExtraFloat64(value any) float64 {
}
// parseExtraInt 从 extra 字段解析 int 值
// ParseExtraInt 从 extra 字段的 any 值解析为 int。
// 支持 int, int64, float64, json.Number, string 类型,无法解析时返回 0。
func ParseExtraInt(value any) int {
return parseExtraInt(value)
}
func parseExtraInt(value any) int {
switch v := value.(type) {
case int:

View File

@@ -134,3 +134,161 @@ func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) {
require.False(t, otherPlatform.IsCodexCLIOnlyEnabled())
})
}
func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
t.Run("OAuth使用OAuth专用开关", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
},
}
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("API Key使用API Key专用开关", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("分类型新键优先于兼容键", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": true,
"openai_ws_enabled": true,
},
}
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("分类型键缺失时回退兼容键", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("非OpenAI账号默认关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
}
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
t.Run("default fallback to shared", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{},
}
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
})
t.Run("oauth mode field has highest priority", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
"openai_oauth_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": false,
},
}
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
})
t.Run("legacy enabled maps to shared", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
})
t.Run("legacy disabled maps to off", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
})
t.Run("non openai always off", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated))
})
}
func TestAccount_OpenAIWSExtraFlags(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_ws_force_http": true,
"openai_ws_allow_store_recovery": true,
},
}
require.True(t, account.IsOpenAIWSForceHTTPEnabled())
require.True(t, account.IsOpenAIWSAllowStoreRecoveryEnabled())
off := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}}
require.False(t, off.IsOpenAIWSForceHTTPEnabled())
require.False(t, off.IsOpenAIWSAllowStoreRecoveryEnabled())
var nilAccount *Account
require.False(t, nilAccount.IsOpenAIWSAllowStoreRecoveryEnabled())
nonOpenAI := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_ws_allow_store_recovery": true,
},
}
require.False(t, nonOpenAI.IsOpenAIWSAllowStoreRecoveryEnabled())
}

View File

@@ -0,0 +1,120 @@
package service
import (
"encoding/json"
"testing"
)
func TestGetBaseRPM(t *testing.T) {
tests := []struct {
name string
extra map[string]any
expected int
}{
{"nil extra", nil, 0},
{"no key", map[string]any{}, 0},
{"zero", map[string]any{"base_rpm": 0}, 0},
{"int value", map[string]any{"base_rpm": 15}, 15},
{"float value", map[string]any{"base_rpm": 15.0}, 15},
{"string value", map[string]any{"base_rpm": "15"}, 15},
{"negative value", map[string]any{"base_rpm": -5}, 0},
{"int64 value", map[string]any{"base_rpm": int64(20)}, 20},
{"json.Number value", map[string]any{"base_rpm": json.Number("25")}, 25},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Extra: tt.extra}
if got := a.GetBaseRPM(); got != tt.expected {
t.Errorf("GetBaseRPM() = %d, want %d", got, tt.expected)
}
})
}
}
func TestGetRPMStrategy(t *testing.T) {
tests := []struct {
name string
extra map[string]any
expected string
}{
{"nil extra", nil, "tiered"},
{"no key", map[string]any{}, "tiered"},
{"tiered", map[string]any{"rpm_strategy": "tiered"}, "tiered"},
{"sticky_exempt", map[string]any{"rpm_strategy": "sticky_exempt"}, "sticky_exempt"},
{"invalid", map[string]any{"rpm_strategy": "foobar"}, "tiered"},
{"empty string fallback", map[string]any{"rpm_strategy": ""}, "tiered"},
{"numeric value fallback", map[string]any{"rpm_strategy": 123}, "tiered"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Extra: tt.extra}
if got := a.GetRPMStrategy(); got != tt.expected {
t.Errorf("GetRPMStrategy() = %q, want %q", got, tt.expected)
}
})
}
}
func TestCheckRPMSchedulability(t *testing.T) {
tests := []struct {
name string
extra map[string]any
currentRPM int
expected WindowCostSchedulability
}{
{"disabled", map[string]any{}, 100, WindowCostSchedulable},
{"green zone", map[string]any{"base_rpm": 15}, 10, WindowCostSchedulable},
{"yellow zone tiered", map[string]any{"base_rpm": 15}, 15, WindowCostStickyOnly},
{"red zone tiered", map[string]any{"base_rpm": 15}, 18, WindowCostNotSchedulable},
{"sticky_exempt at limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 15, WindowCostStickyOnly},
{"sticky_exempt over limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 100, WindowCostStickyOnly},
{"custom buffer", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 14, WindowCostStickyOnly},
{"custom buffer red", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 15, WindowCostNotSchedulable},
{"base_rpm=1 green", map[string]any{"base_rpm": 1}, 0, WindowCostSchedulable},
{"base_rpm=1 yellow (at limit)", map[string]any{"base_rpm": 1}, 1, WindowCostStickyOnly},
{"base_rpm=1 red (at limit+buffer)", map[string]any{"base_rpm": 1}, 2, WindowCostNotSchedulable},
{"negative currentRPM", map[string]any{"base_rpm": 15}, -1, WindowCostSchedulable},
{"base_rpm negative disabled", map[string]any{"base_rpm": -5}, 10, WindowCostSchedulable},
{"very high currentRPM", map[string]any{"base_rpm": 10}, 9999, WindowCostNotSchedulable},
{"sticky_exempt very high currentRPM", map[string]any{"base_rpm": 10, "rpm_strategy": "sticky_exempt"}, 9999, WindowCostStickyOnly},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Extra: tt.extra}
if got := a.CheckRPMSchedulability(tt.currentRPM); got != tt.expected {
t.Errorf("CheckRPMSchedulability(%d) = %d, want %d", tt.currentRPM, got, tt.expected)
}
})
}
}
func TestGetRPMStickyBuffer(t *testing.T) {
tests := []struct {
name string
extra map[string]any
expected int
}{
{"nil extra", nil, 0},
{"no keys", map[string]any{}, 0},
{"base_rpm=0", map[string]any{"base_rpm": 0}, 0},
{"base_rpm=1 min buffer 1", map[string]any{"base_rpm": 1}, 1},
{"base_rpm=4 min buffer 1", map[string]any{"base_rpm": 4}, 1},
{"base_rpm=5 buffer 1", map[string]any{"base_rpm": 5}, 1},
{"base_rpm=10 buffer 2", map[string]any{"base_rpm": 10}, 2},
{"base_rpm=15 buffer 3", map[string]any{"base_rpm": 15}, 3},
{"base_rpm=100 buffer 20", map[string]any{"base_rpm": 100}, 20},
{"custom buffer=5", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 5},
{"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2},
{"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2},
{"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7},
{"json.Number base_rpm", map[string]any{"base_rpm": json.Number("10")}, 2},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Extra: tt.extra}
if got := a.GetRPMStickyBuffer(); got != tt.expected {
t.Errorf("GetRPMStickyBuffer() = %d, want %d", got, tt.expected)
}
})
}
}

View File

@@ -54,6 +54,8 @@ type AccountRepository interface {
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error)
ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
@@ -119,6 +121,10 @@ type AccountService struct {
groupRepo GroupRepository
}
type groupExistenceBatchChecker interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAccountService 创建账号服务实例
func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService {
return &AccountService{
@@ -131,11 +137,8 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository)
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
// 验证分组是否存在(如果指定了分组)
if len(req.GroupIDs) > 0 {
for _, groupID := range req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
if err := s.validateGroupIDsExist(ctx, req.GroupIDs); err != nil {
return nil, err
}
}
@@ -256,11 +259,8 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
// 先验证分组是否存在(在任何写操作之前)
if req.GroupIDs != nil {
for _, groupID := range *req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
if err := s.validateGroupIDsExist(ctx, *req.GroupIDs); err != nil {
return nil, err
}
}
@@ -300,6 +300,39 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
return nil
}
func (s *AccountService) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
if s.groupRepo == nil {
return fmt.Errorf("group repository not configured")
}
if batchChecker, ok := s.groupRepo.(groupExistenceBatchChecker); ok {
existsByID, err := batchChecker.ExistsByIDs(ctx, groupIDs)
if err != nil {
return fmt.Errorf("check groups exists: %w", err)
}
for _, groupID := range groupIDs {
if groupID <= 0 {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
if !existsByID[groupID] {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
}
return nil
}
for _, groupID := range groupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
return fmt.Errorf("get group: %w", err)
}
}
return nil
}
// UpdateStatus 更新账号状态
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
account, err := s.accountRepo.GetByID(ctx, id)

View File

@@ -147,6 +147,14 @@ func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte
panic("unexpected ListSchedulableByGroupIDAndPlatforms call")
}
func (s *accountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
panic("unexpected ListSchedulableUngroupedByPlatform call")
}
func (s *accountRepoStub) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
panic("unexpected ListSchedulableUngroupedByPlatforms call")
}
func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
panic("unexpected SetRateLimited call")
}

View File

@@ -598,9 +598,102 @@ func ceilSeconds(d time.Duration) int {
return sec
}
// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。
// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。
func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error {
ctx := c.Request.Context()
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证")
}
baseURL := account.GetBaseURL()
if baseURL == "" {
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url")
}
// 验证 base_url 格式
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error()))
}
upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions"
// 设置 SSE 头
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
return s.sendErrorAndEnd(c, msg)
}
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"})
// 构建轻量级 prompt-enhance 请求作为连通性测试
testPayload := map[string]any{
"model": "prompt-enhance-short-10s",
"messages": []map[string]string{{"role": "user", "content": "test"}},
"stream": false,
}
payloadBytes, _ := json.Marshal(testPayload)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "构建测试请求失败")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
// 获取代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
if resp.StatusCode == http.StatusOK {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode))
}
// 其他错误但能连通(如 400 参数错误)也算连通性测试通过
if resp.StatusCode == http.StatusBadRequest {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d参数校验错误属正常", resp.StatusCode)})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256)))
}
// testSoraAccountConnection 测试 Sora 账号的连接
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token
// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性
// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
// apikey 类型走独立测试流程
if account.Type == AccountTypeAPIKey {
return s.testSoraAPIKeyAccountConnection(c, account)
}
ctx := c.Request.Context()
recorder := &soraProbeRecorder{}

View File

@@ -9,7 +9,9 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"golang.org/x/sync/errgroup"
)
type UsageLogRepository interface {
@@ -33,9 +35,9 @@ type UsageLogRepository interface {
// Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
@@ -63,6 +65,10 @@ type UsageLogRepository interface {
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
}
type accountWindowStatsBatchReader interface {
GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error)
}
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据utilization, resets_at
type apiUsageCache struct {
response *ClaudeUsageResponse
@@ -298,7 +304,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
}
dayStart := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
@@ -320,7 +326,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
}
@@ -441,6 +447,78 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
}, nil
}
// GetTodayStatsBatch 批量获取账号今日统计,优先走批量 SQL失败时回退单账号查询。
func (s *AccountUsageService) GetTodayStatsBatch(ctx context.Context, accountIDs []int64) (map[int64]*WindowStats, error) {
uniqueIDs := make([]int64, 0, len(accountIDs))
seen := make(map[int64]struct{}, len(accountIDs))
for _, accountID := range accountIDs {
if accountID <= 0 {
continue
}
if _, exists := seen[accountID]; exists {
continue
}
seen[accountID] = struct{}{}
uniqueIDs = append(uniqueIDs, accountID)
}
result := make(map[int64]*WindowStats, len(uniqueIDs))
if len(uniqueIDs) == 0 {
return result, nil
}
startTime := timezone.Today()
if batchReader, ok := s.usageLogRepo.(accountWindowStatsBatchReader); ok {
statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, uniqueIDs, startTime)
if err == nil {
for _, accountID := range uniqueIDs {
result[accountID] = windowStatsFromAccountStats(statsByAccount[accountID])
}
return result, nil
}
}
var mu sync.Mutex
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(8)
for _, accountID := range uniqueIDs {
id := accountID
g.Go(func() error {
stats, err := s.usageLogRepo.GetAccountWindowStats(gctx, id, startTime)
if err != nil {
return nil
}
mu.Lock()
result[id] = windowStatsFromAccountStats(stats)
mu.Unlock()
return nil
})
}
_ = g.Wait()
for _, accountID := range uniqueIDs {
if _, ok := result[accountID]; !ok {
result[accountID] = &WindowStats{}
}
}
return result, nil
}
func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
if stats == nil {
return &WindowStats{}
}
return &WindowStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
StandardCost: stats.StandardCost,
UserCost: stats.UserCost,
}
}
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
if err != nil {

View File

@@ -314,3 +314,72 @@ func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T
t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped)
}
}
func TestAccountGetModelMapping_CacheInvalidatesOnCredentialsReplace(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "upstream-a",
},
},
}
first := account.GetModelMapping()
if first["claude-3-5-sonnet"] != "upstream-a" {
t.Fatalf("unexpected first mapping: %v", first)
}
account.Credentials = map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "upstream-b",
},
}
second := account.GetModelMapping()
if second["claude-3-5-sonnet"] != "upstream-b" {
t.Fatalf("expected cache invalidated after credentials replace, got: %v", second)
}
}
func TestAccountGetModelMapping_CacheInvalidatesOnMappingLenChange(t *testing.T) {
rawMapping := map[string]any{
"claude-sonnet": "sonnet-a",
}
account := &Account{
Credentials: map[string]any{
"model_mapping": rawMapping,
},
}
first := account.GetModelMapping()
if len(first) != 1 {
t.Fatalf("unexpected first mapping length: %d", len(first))
}
rawMapping["claude-opus"] = "opus-b"
second := account.GetModelMapping()
if second["claude-opus"] != "opus-b" {
t.Fatalf("expected cache invalidated after mapping len change, got: %v", second)
}
}
func TestAccountGetModelMapping_CacheInvalidatesOnInPlaceValueChange(t *testing.T) {
rawMapping := map[string]any{
"claude-sonnet": "sonnet-a",
}
account := &Account{
Credentials: map[string]any{
"model_mapping": rawMapping,
},
}
first := account.GetModelMapping()
if first["claude-sonnet"] != "sonnet-a" {
t.Fatalf("unexpected first mapping: %v", first)
}
rawMapping["claude-sonnet"] = "sonnet-b"
second := account.GetModelMapping()
if second["claude-sonnet"] != "sonnet-b" {
t.Fatalf("expected cache invalidated after in-place value change, got: %v", second)
}
}

View File

@@ -9,6 +9,8 @@ import (
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -42,6 +44,9 @@ type AdminService interface {
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// API Key management (admin)
AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error)
@@ -83,13 +88,14 @@ type AdminService interface {
// CreateUserInput represents input for creating a new user via admin operations.
type CreateUserInput struct {
Email string
Password string
Username string
Notes string
Balance float64
Concurrency int
AllowedGroups []int64
Email string
Password string
Username string
Notes string
Balance float64
Concurrency int
AllowedGroups []int64
SoraStorageQuotaBytes int64
}
type UpdateUserInput struct {
@@ -103,7 +109,8 @@ type UpdateUserInput struct {
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
// map[groupID]*ratenil 表示删除该分组的专属倍率
GroupRates map[int64]*float64
GroupRates map[int64]*float64
SoraStorageQuotaBytes *int64
}
type CreateGroupInput struct {
@@ -136,6 +143,8 @@ type CreateGroupInput struct {
SimulateClaudeMaxEnabled *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
// Sora 存储配额
SoraStorageQuotaBytes int64
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
@@ -171,6 +180,8 @@ type UpdateGroupInput struct {
SimulateClaudeMaxEnabled *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
// Sora 存储配额
SoraStorageQuotaBytes *int64
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
@@ -238,6 +249,14 @@ type BulkUpdateAccountResult struct {
Error string `json:"error,omitempty"`
}
// AdminUpdateAPIKeyGroupIDResult is the result of AdminUpdateAPIKeyGroupID.
type AdminUpdateAPIKeyGroupIDResult struct {
APIKey *APIKey
AutoGrantedGroupAccess bool // true if a new exclusive group permission was auto-added
GrantedGroupID *int64 // the group ID that was auto-granted
GrantedGroupName string // the group name that was auto-granted
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
@@ -406,6 +425,17 @@ type adminServiceImpl struct {
proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
authCacheInvalidator APIKeyAuthCacheInvalidator
entClient *dbent.Client // 用于开启数据库事务
settingService *SettingService
defaultSubAssigner DefaultSubscriptionAssigner
}
type userGroupRateBatchReader interface {
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
}
type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAdminService creates a new AdminService
@@ -422,6 +452,9 @@ func NewAdminService(
proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
authCacheInvalidator APIKeyAuthCacheInvalidator,
entClient *dbent.Client,
settingService *SettingService,
defaultSubAssigner DefaultSubscriptionAssigner,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
@@ -436,6 +469,9 @@ func NewAdminService(
proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
authCacheInvalidator: authCacheInvalidator,
entClient: entClient,
settingService: settingService,
defaultSubAssigner: defaultSubAssigner,
}
}
@@ -448,18 +484,43 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
}
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok {
userIDs := make([]int64, 0, len(users))
for i := range users {
userIDs = append(userIDs, users[i].ID)
}
users[i].GroupRates = rates
ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs)
if err != nil {
logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err)
s.loadUserGroupRatesOneByOne(ctx, users)
} else {
for i := range users {
if rates, ok := ratesByUser[users[i].ID]; ok {
users[i].GroupRates = rates
}
}
}
} else {
s.loadUserGroupRatesOneByOne(ctx, users)
}
}
return users, result.Total, nil
}
func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) {
if s.userGroupRateRepo == nil {
return
}
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
}
users[i].GroupRates = rates
}
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
@@ -479,14 +540,15 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
user := &User{
Email: input.Email,
Username: input.Username,
Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
Email: input.Email,
Username: input.Username,
Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
@@ -494,9 +556,27 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
s.assignDefaultSubscriptions(ctx, user.ID)
return user, nil
}
func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userID int64) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return
}
items := s.settingService.GetDefaultSubscriptions(ctx)
for _, item := range items {
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
Notes: "auto assigned by default user subscriptions setting",
}); err != nil {
logger.LegacyPrintf("service.admin", "failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
}
}
}
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
@@ -540,6 +620,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.AllowedGroups = *input.AllowedGroups
}
if input.SoraStorageQuotaBytes != nil {
user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
@@ -667,7 +751,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{})
if err != nil {
return nil, 0, err
}
@@ -834,6 +918,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
MCPXMLInject: mcpXMLInject,
SimulateClaudeMaxEnabled: simulateClaudeMaxEnabled,
SupportedModelScopes: input.SupportedModelScopes,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
@@ -996,6 +1081,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.SoraVideoPricePerRequestHD != nil {
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
}
if input.SoraStorageQuotaBytes != nil {
group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
// Claude Code 客户端限制
if input.ClaudeCodeOnly != nil {
@@ -1160,6 +1248,103 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []
return s.groupRepo.UpdateSortOrders(ctx, updates)
}
// AdminUpdateAPIKeyGroupID 管理员修改 API Key 分组绑定
// groupID: nil=不修改, 指向0=解绑, 指向正整数=绑定到目标分组
func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
if err != nil {
return nil, err
}
if groupID == nil {
// nil 表示不修改,直接返回
return &AdminUpdateAPIKeyGroupIDResult{APIKey: apiKey}, nil
}
if *groupID < 0 {
return nil, infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative")
}
result := &AdminUpdateAPIKeyGroupIDResult{}
if *groupID == 0 {
// 0 表示解绑分组(不修改 user_allowed_groups避免影响用户其他 Key
apiKey.GroupID = nil
apiKey.Group = nil
} else {
// 验证目标分组存在且状态为 active
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, err
}
if group.Status != StatusActive {
return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active")
}
// 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程
if group.IsSubscriptionType() {
return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow")
}
gid := *groupID
apiKey.GroupID = &gid
apiKey.Group = group
// 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性
if group.IsExclusive {
opCtx := ctx
var tx *dbent.Tx
if s.entClient == nil {
logger.LegacyPrintf("service.admin", "Warning: entClient is nil, skipping transaction protection for exclusive group binding")
} else {
var txErr error
tx, txErr = s.entClient.Tx(ctx)
if txErr != nil {
return nil, fmt.Errorf("begin transaction: %w", txErr)
}
defer func() { _ = tx.Rollback() }()
opCtx = dbent.NewTxContext(ctx, tx)
}
if addErr := s.userRepo.AddGroupToAllowedGroups(opCtx, apiKey.UserID, gid); addErr != nil {
return nil, fmt.Errorf("add group to user allowed groups: %w", addErr)
}
if err := s.apiKeyRepo.Update(opCtx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}
if tx != nil {
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit transaction: %w", err)
}
}
result.AutoGrantedGroupAccess = true
result.GrantedGroupID = &gid
result.GrantedGroupName = group.Name
// 失效认证缓存(在事务提交后执行)
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
}
result.APIKey = apiKey
return result, nil
}
}
// 非专属分组 / 解绑:无需事务,单步更新即可
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}
// 失效认证缓存
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
}
result.APIKey = apiKey
return result, nil
}
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
@@ -1211,6 +1396,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
}
// Sora apikey 账号的 base_url 必填校验
if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
baseURL, _ := input.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
account := &Account{
Name: input.Name,
Notes: normalizeAccountNotes(input.Notes),
@@ -1324,12 +1521,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
}
// Sora apikey 账号的 base_url 必填校验
if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
baseURL, _ := account.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
// 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil {
for _, groupID := range *input.GroupIDs {
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
return nil, err
}
// 检查混合渠道风险(除非用户已确认)
@@ -1371,6 +1578,11 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if len(input.AccountIDs) == 0 {
return result, nil
}
if input.GroupIDs != nil {
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
return nil, err
}
}
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
@@ -1839,7 +2051,6 @@ func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*Pr
ProxyURL: proxyURL,
Timeout: proxyQualityRequestTimeout,
ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout,
ProxyStrict: true,
})
if err != nil {
result.Items = append(result.Items, ProxyQualityCheckItem{

View File

@@ -0,0 +1,429 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Stubs
// ---------------------------------------------------------------------------
// userRepoStubForGroupUpdate implements UserRepository for AdminUpdateAPIKeyGroupID tests.
type userRepoStubForGroupUpdate struct {
addGroupErr error
addGroupCalled bool
addedUserID int64
addedGroupID int64
}
func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context, userID int64, groupID int64) error {
s.addGroupCalled = true
s.addedUserID = userID
s.addedGroupID = groupID
return s.addGroupErr
}
func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
type apiKeyRepoStubForGroupUpdate struct {
key *APIKey
getErr error
updateErr error
updated *APIKey // captures what was passed to Update
}
func (s *apiKeyRepoStubForGroupUpdate) GetByID(_ context.Context, _ int64) (*APIKey, error) {
if s.getErr != nil {
return nil, s.getErr
}
clone := *s.key
return &clone, nil
}
func (s *apiKeyRepoStubForGroupUpdate) Update(_ context.Context, key *APIKey) error {
if s.updateErr != nil {
return s.updateErr
}
clone := *key
s.updated = &clone
return nil
}
// Unused methods panic on unexpected call.
func (s *apiKeyRepoStubForGroupUpdate) Create(context.Context, *APIKey) error { panic("unexpected") }
func (s *apiKeyRepoStubForGroupUpdate) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) GetByKey(context.Context, string) (*APIKey, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) CountByUserID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ExistsByKey(context.Context, string) (bool, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ListKeysByUserID(context.Context, int64) ([]string, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ListKeysByGroupID(context.Context, int64) ([]string, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) UpdateLastUsed(context.Context, int64, time.Time) error {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) IncrementRateLimitUsage(context.Context, int64, float64) error {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, int64) error {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
panic("unexpected")
}
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
type groupRepoStubForGroupUpdate struct {
group *Group
getErr error
lastGetByIDArg int64
}
func (s *groupRepoStubForGroupUpdate) GetByID(_ context.Context, id int64) (*Group, error) {
s.lastGetByIDArg = id
if s.getErr != nil {
return nil, s.getErr
}
clone := *s.group
return &clone, nil
}
// Unused methods panic on unexpected call.
func (s *groupRepoStubForGroupUpdate) Create(context.Context, *Group) error { panic("unexpected") }
func (s *groupRepoStubForGroupUpdate) GetByIDLite(context.Context, int64) (*Group, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) Update(context.Context, *Group) error { panic("unexpected") }
func (s *groupRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *groupRepoStubForGroupUpdate) DeleteCascade(context.Context, int64) ([]int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ListActive(context.Context) ([]Group, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, string) ([]Group, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) BindAccountsToGroup(context.Context, int64, []int64) error {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
panic("unexpected")
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
func TestAdminService_AdminUpdateAPIKeyGroupID_KeyNotFound(t *testing.T) {
repo := &apiKeyRepoStubForGroupUpdate{getErr: ErrAPIKeyNotFound}
svc := &adminServiceImpl{apiKeyRepo: repo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 999, int64Ptr(1))
require.ErrorIs(t, err, ErrAPIKeyNotFound)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NilGroupID_NoOp(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5)}
repo := &apiKeyRepoStubForGroupUpdate{key: existing}
svc := &adminServiceImpl{apiKeyRepo: repo}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, nil)
require.NoError(t, err)
require.Equal(t, int64(1), got.APIKey.ID)
// Update should NOT have been called (updated stays nil)
require.Nil(t, repo.updated)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5), Group: &Group{ID: 5, Name: "Old"}}
repo := &apiKeyRepoStubForGroupUpdate{key: existing}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: repo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
require.NoError(t, err)
require.Nil(t, got.APIKey.GroupID, "group_id should be nil after unbind")
require.Nil(t, got.APIKey.Group, "group object should be nil after unbind")
require.NotNil(t, repo.updated, "Update should have been called")
require.Nil(t, repo.updated.GroupID)
require.Equal(t, []string{"sk-test"}, cache.keys, "cache should be invalidated")
}
func TestAdminService_AdminUpdateAPIKeyGroupID_BindActiveGroup(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
require.Equal(t, []string{"sk-test"}, cache.keys)
// M3: verify correct group ID was passed to repo
require.Equal(t, int64(10), groupRepo.lastGetByIDArg)
// C1 fix: verify Group object is populated
require.NotNil(t, got.APIKey.Group)
require.Equal(t, "Pro", got.APIKey.Group.Name)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SameGroup_Idempotent(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Pro"}}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
// Update is still called (current impl doesn't short-circuit on same group)
require.NotNil(t, apiKeyRepo.updated)
require.Equal(t, []string{"sk-test"}, cache.keys)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotFound(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{getErr: ErrGroupNotFound}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(99))
require.ErrorIs(t, err, ErrGroupNotFound)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotActive(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 5, Status: StatusDisabled}}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(5))
require.Error(t, err)
require.Equal(t, "GROUP_NOT_ACTIVE", infraerrors.Reason(err))
}
func TestAdminService_AdminUpdateAPIKeyGroupID_UpdateFails(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(3)}
repo := &apiKeyRepoStubForGroupUpdate{key: existing, updateErr: errors.New("db write error")}
svc := &adminServiceImpl{apiKeyRepo: repo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
require.Error(t, err)
require.Contains(t, err.Error(), "update api key")
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NegativeGroupID(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(-5))
require.Error(t, err)
require.Equal(t, "INVALID_GROUP_ID", infraerrors.Reason(err))
}
func TestAdminService_AdminUpdateAPIKeyGroupID_PointerIsolation(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
inputGID := int64(10)
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, &inputGID)
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
// Mutating the input pointer must NOT affect the stored value
inputGID = 999
require.Equal(t, int64(10), *got.APIKey.GroupID)
require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NilCacheInvalidator(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 7, Status: StatusActive}}
// authCacheInvalidator is nil should not panic
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(7))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(7), *got.APIKey.GroupID)
}
// ---------------------------------------------------------------------------
// Tests: AllowedGroup auto-sync
// ---------------------------------------------------------------------------
func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AddsAllowedGroup(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
userRepo := &userRepoStubForGroupUpdate{}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
// 验证 AddGroupToAllowedGroups 被调用,且参数正确
require.True(t, userRepo.addGroupCalled)
require.Equal(t, int64(42), userRepo.addedUserID)
require.Equal(t, int64(10), userRepo.addedGroupID)
// 验证 result 标记了自动授权
require.True(t, got.AutoGrantedGroupAccess)
require.NotNil(t, got.GrantedGroupID)
require.Equal(t, int64(10), *got.GrantedGroupID)
require.Equal(t, "Exclusive", got.GrantedGroupName)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupUpdate(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Public", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeStandard}}
userRepo := &userRepoStubForGroupUpdate{}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
// 非专属分组不触发 AddGroupToAllowedGroups
require.False(t, userRepo.addGroupCalled)
require.False(t, got.AutoGrantedGroupAccess)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
// 订阅类型分组应被阻止绑定
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err)
require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err))
require.False(t, userRepo.addGroupCalled)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AllowedGroupAddFails_ReturnsError(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
userRepo := &userRepoStubForGroupUpdate{addGroupErr: errors.New("db error")}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
// 严格模式AddGroupToAllowedGroups 失败时,整体操作报错
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err)
require.Contains(t, err.Error(), "add group to user allowed groups")
require.True(t, userRepo.addGroupCalled)
// apiKey 不应被更新
require.Nil(t, apiKeyRepo.updated)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind_NoAllowedGroupUpdate(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Exclusive"}}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
userRepo := &userRepoStubForGroupUpdate{}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, userRepo: userRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
require.NoError(t, err)
require.Nil(t, got.APIKey.GroupID)
// 解绑时不修改 allowed_groups
require.False(t, userRepo.addGroupCalled)
require.False(t, got.AutoGrantedGroupAccess)
}

View File

@@ -100,7 +100,10 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
2: errors.New("bind failed"),
},
}
svc := &adminServiceImpl{accountRepo: repo}
svc := &adminServiceImpl{
accountRepo: repo,
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "g10"}},
}
groupIDs := []int64{10}
schedulable := false
@@ -120,6 +123,22 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
require.Len(t, result.Results, 3)
}
func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}
svc := &adminServiceImpl{accountRepo: repo}
groupIDs := []int64{10}
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1},
GroupIDs: &groupIDs,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "group repository not configured")
}
// TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict verifies
// that the global pre-check detects a conflict with existing group members and returns an
// error before any DB write is performed.

View File

@@ -7,6 +7,7 @@ import (
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
@@ -65,3 +66,32 @@ func TestAdminService_CreateUser_CreateError(t *testing.T) {
require.ErrorIs(t, err, createErr)
require.Empty(t, repo.created)
}
func TestAdminService_CreateUser_AssignsDefaultSubscriptions(t *testing.T) {
repo := &userRepoStub{nextID: 21}
assigner := &defaultSubscriptionAssignerStub{}
cfg := &config.Config{
Default: config.DefaultConfig{
UserBalance: 0,
UserConcurrency: 1,
},
}
settingService := NewSettingService(&settingRepoStub{values: map[string]string{
SettingKeyDefaultSubscriptions: `[{"group_id":5,"validity_days":30}]`,
}}, cfg)
svc := &adminServiceImpl{
userRepo: repo,
settingService: settingService,
defaultSubAssigner: assigner,
}
_, err := svc.CreateUser(context.Background(), &CreateUserInput{
Email: "new-user@test.com",
Password: "password",
})
require.NoError(t, err)
require.Len(t, assigner.calls, 1)
require.Equal(t, int64(21), assigner.calls[0].UserID)
require.Equal(t, int64(5), assigner.calls[0].GroupID)
require.Equal(t, 30, assigner.calls[0].ValidityDays)
}

View File

@@ -93,6 +93,10 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
panic("unexpected AddGroupToAllowedGroups call")
}
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
@@ -344,6 +348,19 @@ func (s *billingCacheStub) InvalidateSubscriptionCache(ctx context.Context, user
return nil
}
func (s *billingCacheStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
panic("unexpected GetAPIKeyRateLimit call")
}
func (s *billingCacheStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
panic("unexpected SetAPIKeyRateLimit call")
}
func (s *billingCacheStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
panic("unexpected UpdateAPIKeyRateLimitUsage call")
}
func (s *billingCacheStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
panic("unexpected InvalidateAPIKeyRateLimit call")
}
func waitForInvalidations(t *testing.T, ch <-chan subscriptionInvalidateCall, expected int) []subscriptionInvalidateCall {
t.Helper()
calls := make([]subscriptionInvalidateCall, 0, expected)

View File

@@ -0,0 +1,106 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type userRepoStubForListUsers struct {
userRepoStub
users []User
err error
}
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
if s.err != nil {
return nil, nil, s.err
}
out := make([]User, len(s.users))
copy(out, s.users)
return out, &pagination.PaginationResult{
Total: int64(len(out)),
Page: params.Page,
PageSize: params.PageSize,
}, nil
}
type userGroupRateRepoStubForListUsers struct {
batchCalls int
singleCall []int64
batchErr error
batchData map[int64]map[int64]float64
singleErr map[int64]error
singleData map[int64]map[int64]float64
}
func (s *userGroupRateRepoStubForListUsers) GetByUserIDs(_ context.Context, _ []int64) (map[int64]map[int64]float64, error) {
s.batchCalls++
if s.batchErr != nil {
return nil, s.batchErr
}
return s.batchData, nil
}
func (s *userGroupRateRepoStubForListUsers) GetByUserID(_ context.Context, userID int64) (map[int64]float64, error) {
s.singleCall = append(s.singleCall, userID)
if err, ok := s.singleErr[userID]; ok {
return nil, err
}
if rates, ok := s.singleData[userID]; ok {
return rates, nil
}
return map[int64]float64{}, nil
}
func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, userID, groupID int64) (*float64, error) {
panic("unexpected GetByUserAndGroup call")
}
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
panic("unexpected SyncUserGroupRates call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
panic("unexpected DeleteByGroupID call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByUserID(_ context.Context, userID int64) error {
panic("unexpected DeleteByUserID call")
}
func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
userRepo := &userRepoStubForListUsers{
users: []User{
{ID: 101, Username: "u1"},
{ID: 202, Username: "u2"},
},
}
rateRepo := &userGroupRateRepoStubForListUsers{
batchErr: errors.New("batch unavailable"),
singleData: map[int64]map[int64]float64{
101: {11: 1.1},
202: {22: 2.2},
},
}
svc := &adminServiceImpl{
userRepo: userRepo,
userGroupRateRepo: rateRepo,
}
users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{})
require.NoError(t, err)
require.Equal(t, int64(2), total)
require.Len(t, users, 2)
require.Equal(t, 1, rateRepo.batchCalls)
require.ElementsMatch(t, []int64{101, 202}, rateRepo.singleCall)
require.Equal(t, 1.1, users[0].GroupRates[11])
require.Equal(t, 2.2, users[1].GroupRates[22])
}

View File

@@ -21,7 +21,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -2294,7 +2293,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
// isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记
func isSingleAccountRetry(ctx context.Context) bool {
v, _ := ctx.Value(ctxkey.SingleAccountRetry).(bool)
v, _ := SingleAccountRetryFromContext(ctx)
return v
}

View File

@@ -112,7 +112,10 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
}
}
client := antigravity.NewClient(proxyURL)
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
// 交换 token
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
@@ -167,7 +170,10 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
time.Sleep(backoff)
}
client := antigravity.NewClient(proxyURL)
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
tokenResp, err := client.RefreshToken(ctx, refreshToken)
if err == nil {
now := time.Now()
@@ -209,7 +215,10 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr
}
// 获取用户信息email
client := antigravity.NewClient(proxyURL)
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
@@ -309,7 +318,10 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
time.Sleep(backoff)
}
client := antigravity.NewClient(proxyURL)
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return "", fmt.Errorf("create antigravity client failed: %w", err)
}
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {

View File

@@ -2,6 +2,7 @@ package service
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
@@ -31,7 +32,10 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id")
client := antigravity.NewClient(proxyURL)
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
// 调用 API 获取配额
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)

View File

@@ -1,6 +1,10 @@
package service
import "time"
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
)
// API Key status constants
const (
@@ -19,22 +23,41 @@ type APIKey struct {
Status string
IPWhitelist []string
IPBlacklist []string
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
// 预编译的 IP 规则,用于认证热路径避免重复 ParseIP/ParseCIDR。
CompiledIPWhitelist *ip.CompiledIPRules `json:"-"`
CompiledIPBlacklist *ip.CompiledIPRules `json:"-"`
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
// Quota fields
Quota float64 // Quota limit in USD (0 = unlimited)
QuotaUsed float64 // Used quota amount
ExpiresAt *time.Time // Expiration time (nil = never expires)
// Rate limit fields
RateLimit5h float64 // Rate limit in USD per 5h (0 = unlimited)
RateLimit1d float64 // Rate limit in USD per 1d (0 = unlimited)
RateLimit7d float64 // Rate limit in USD per 7d (0 = unlimited)
Usage5h float64 // Used amount in current 5h window
Usage1d float64 // Used amount in current 1d window
Usage7d float64 // Used amount in current 7d window
Window5hStart *time.Time // Start of current 5h window
Window1dStart *time.Time // Start of current 1d window
Window7dStart *time.Time // Start of current 7d window
}
func (k *APIKey) IsActive() bool {
return k.Status == StatusActive
}
// HasRateLimits returns true if any rate limit window is configured
func (k *APIKey) HasRateLimits() bool {
return k.RateLimit5h > 0 || k.RateLimit1d > 0 || k.RateLimit7d > 0
}
// IsExpired checks if the API key has expired
func (k *APIKey) IsExpired() bool {
if k.ExpiresAt == nil {
@@ -74,3 +97,10 @@ func (k *APIKey) GetDaysUntilExpiry() int {
}
return int(duration.Hours() / 24)
}
// APIKeyListFilters holds optional filtering parameters for listing API keys.
type APIKeyListFilters struct {
Search string
Status string
GroupID *int64 // nil=不筛选, 0=无分组, >0=指定分组
}

View File

@@ -19,6 +19,11 @@ type APIKeyAuthSnapshot struct {
// Expiration field for API Key expiration feature
ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires)
// Rate limit configuration (only limits, not usage - usage read from Redis at check time)
RateLimit5h float64 `json:"rate_limit_5h"`
RateLimit1d float64 `json:"rate_limit_1d"`
RateLimit7d float64 `json:"rate_limit_7d"`
}
// APIKeyAuthUserSnapshot 用户快照

View File

@@ -209,6 +209,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
Quota: apiKey.Quota,
QuotaUsed: apiKey.QuotaUsed,
ExpiresAt: apiKey.ExpiresAt,
RateLimit5h: apiKey.RateLimit5h,
RateLimit1d: apiKey.RateLimit1d,
RateLimit7d: apiKey.RateLimit7d,
User: APIKeyAuthUserSnapshot{
ID: apiKey.User.ID,
Status: apiKey.User.Status,
@@ -263,6 +266,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
Quota: snapshot.Quota,
QuotaUsed: snapshot.QuotaUsed,
ExpiresAt: snapshot.ExpiresAt,
RateLimit5h: snapshot.RateLimit5h,
RateLimit1d: snapshot.RateLimit1d,
RateLimit7d: snapshot.RateLimit7d,
User: &User{
ID: snapshot.User.ID,
Status: snapshot.User.Status,
@@ -300,5 +306,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
}
}
s.compileAPIKeyIPRules(apiKey)
return apiKey
}

View File

@@ -30,6 +30,11 @@ var (
ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期")
// ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted")
ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完")
// Rate limit errors
ErrAPIKeyRateLimit5hExceeded = infraerrors.TooManyRequests("API_KEY_RATE_5H_EXCEEDED", "api key 5小时限额已用完")
ErrAPIKeyRateLimit1dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_1D_EXCEEDED", "api key 日限额已用完")
ErrAPIKeyRateLimit7dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_7D_EXCEEDED", "api key 7天限额已用完")
)
const (
@@ -50,7 +55,7 @@ type APIKeyRepository interface {
Update(ctx context.Context, key *APIKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
@@ -64,6 +69,21 @@ type APIKeyRepository interface {
// Quota methods
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error)
UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error
// Rate limit methods
IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error
ResetRateLimitWindows(ctx context.Context, id int64) error
GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error)
}
// APIKeyRateLimitData holds rate limit usage and window state for an API key.
type APIKeyRateLimitData struct {
Usage5h float64
Usage1d float64
Usage7d float64
Window5hStart *time.Time
Window1dStart *time.Time
Window7dStart *time.Time
}
// APIKeyCache defines cache operations for API key service
@@ -102,6 +122,11 @@ type CreateAPIKeyRequest struct {
// Quota fields
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires)
// Rate limit fields (0 = unlimited)
RateLimit5h float64 `json:"rate_limit_5h"`
RateLimit1d float64 `json:"rate_limit_1d"`
RateLimit7d float64 `json:"rate_limit_7d"`
}
// UpdateAPIKeyRequest 更新API Key请求
@@ -117,22 +142,34 @@ type UpdateAPIKeyRequest struct {
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change)
ClearExpiration bool `json:"-"` // Clear expiration (internal use)
ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0
// Rate limit fields (nil = no change, 0 = unlimited)
RateLimit5h *float64 `json:"rate_limit_5h"`
RateLimit1d *float64 `json:"rate_limit_1d"`
RateLimit7d *float64 `json:"rate_limit_7d"`
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // Reset all usage counters to 0
}
// APIKeyService API Key服务
// RateLimitCacheInvalidator invalidates rate limit cache entries on manual reset.
type RateLimitCacheInvalidator interface {
InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error
}
type APIKeyService struct {
apiKeyRepo APIKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache APIKeyCache
cfg *config.Config
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time)
lastUsedTouchSF singleflight.Group
apiKeyRepo APIKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache APIKeyCache
rateLimitCacheInvalid RateLimitCacheInvalidator // optional: invalidate Redis rate limit cache
cfg *config.Config
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time)
lastUsedTouchSF singleflight.Group
}
// NewAPIKeyService 创建API Key服务实例
@@ -158,6 +195,20 @@ func NewAPIKeyService(
return svc
}
// SetRateLimitCacheInvalidator sets the optional rate limit cache invalidator.
// Called after construction (e.g. in wire) to avoid circular dependencies.
func (s *APIKeyService) SetRateLimitCacheInvalidator(inv RateLimitCacheInvalidator) {
s.rateLimitCacheInvalid = inv
}
func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) {
if apiKey == nil {
return
}
apiKey.CompiledIPWhitelist = ip.CompileIPRules(apiKey.IPWhitelist)
apiKey.CompiledIPBlacklist = ip.CompileIPRules(apiKey.IPBlacklist)
}
// GenerateKey 生成随机API Key
func (s *APIKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
@@ -319,6 +370,9 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
IPBlacklist: req.IPBlacklist,
Quota: req.Quota,
QuotaUsed: 0,
RateLimit5h: req.RateLimit5h,
RateLimit1d: req.RateLimit1d,
RateLimit7d: req.RateLimit7d,
}
// Set expiration time if specified
@@ -332,13 +386,14 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
// List 获取用户的API Key列表
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, filters)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
}
@@ -363,6 +418,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
@@ -375,6 +431,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
}
@@ -391,6 +448,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
} else {
@@ -402,6 +460,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
}
@@ -411,6 +470,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
@@ -505,11 +565,37 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
apiKey.IPWhitelist = req.IPWhitelist
apiKey.IPBlacklist = req.IPBlacklist
// Update rate limit configuration
if req.RateLimit5h != nil {
apiKey.RateLimit5h = *req.RateLimit5h
}
if req.RateLimit1d != nil {
apiKey.RateLimit1d = *req.RateLimit1d
}
if req.RateLimit7d != nil {
apiKey.RateLimit7d = *req.RateLimit7d
}
resetRateLimit := req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage
if resetRateLimit {
apiKey.Usage5h = 0
apiKey.Usage1d = 0
apiKey.Usage7d = 0
apiKey.Window5hStart = nil
apiKey.Window1dStart = nil
apiKey.Window7dStart = nil
}
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
s.compileAPIKeyIPRules(apiKey)
// Invalidate Redis rate limit cache so reset takes effect immediately
if resetRateLimit && s.rateLimitCacheInvalid != nil {
_ = s.rateLimitCacheInvalid.InvalidateAPIKeyRateLimit(ctx, apiKey.ID)
}
return apiKey, nil
}
@@ -731,3 +817,16 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos
return nil
}
// GetRateLimitData returns rate limit usage and window state for an API key.
func (s *APIKeyService) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
return s.apiKeyRepo.GetRateLimitData(ctx, id)
}
// UpdateRateLimitUsage atomically increments rate limit usage counters in the DB.
func (s *APIKeyService) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
if cost <= 0 {
return nil
}
return s.apiKeyRepo.IncrementRateLimitUsage(ctx, apiKeyID, cost)
}

View File

@@ -53,7 +53,7 @@ func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
@@ -106,6 +106,15 @@ func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount
func (s *authRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
panic("unexpected UpdateLastUsed call")
}
func (s *authRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
panic("unexpected IncrementRateLimitUsage call")
}
func (s *authRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error {
panic("unexpected ResetRateLimitWindows call")
}
func (s *authRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
panic("unexpected GetRateLimitData call")
}
type authCacheStub struct {
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)

View File

@@ -81,7 +81,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
// 以下是接口要求实现但本测试不关心的方法
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
@@ -134,6 +134,18 @@ func (s *apiKeyRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt ti
return nil
}
func (s *apiKeyRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
panic("unexpected IncrementRateLimitUsage call")
}
func (s *apiKeyRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error {
panic("unexpected ResetRateLimitWindows call")
}
func (s *apiKeyRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
panic("unexpected GetRateLimitData call")
}
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//

View File

@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"net/mail"
"strconv"
"strings"
"time"
@@ -33,6 +34,7 @@ var (
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
@@ -56,15 +58,20 @@ type JWTClaims struct {
// AuthService 认证服务
type AuthService struct {
userRepo UserRepository
redeemRepo RedeemCodeRepository
refreshTokenCache RefreshTokenCache
cfg *config.Config
settingService *SettingService
emailService *EmailService
turnstileService *TurnstileService
emailQueueService *EmailQueueService
promoService *PromoService
userRepo UserRepository
redeemRepo RedeemCodeRepository
refreshTokenCache RefreshTokenCache
cfg *config.Config
settingService *SettingService
emailService *EmailService
turnstileService *TurnstileService
emailQueueService *EmailQueueService
promoService *PromoService
defaultSubAssigner DefaultSubscriptionAssigner
}
type DefaultSubscriptionAssigner interface {
AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error)
}
// NewAuthService 创建认证服务实例
@@ -78,17 +85,19 @@ func NewAuthService(
turnstileService *TurnstileService,
emailQueueService *EmailQueueService,
promoService *PromoService,
defaultSubAssigner DefaultSubscriptionAssigner,
) *AuthService {
return &AuthService{
userRepo: userRepo,
redeemRepo: redeemRepo,
refreshTokenCache: refreshTokenCache,
cfg: cfg,
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
emailQueueService: emailQueueService,
promoService: promoService,
userRepo: userRepo,
redeemRepo: redeemRepo,
refreshTokenCache: refreshTokenCache,
cfg: cfg,
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
emailQueueService: emailQueueService,
promoService: promoService,
defaultSubAssigner: defaultSubAssigner,
}
}
@@ -108,6 +117,9 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
if isReservedEmail(email) {
return "", nil, ErrEmailReserved
}
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return "", nil, err
}
// 检查是否需要邀请码
var invitationRedeemCode *RedeemCode
@@ -188,6 +200,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable
}
s.assignDefaultSubscriptions(ctx, user.ID)
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
@@ -233,6 +246,9 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
if isReservedEmail(email) {
return ErrEmailReserved
}
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return err
}
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
@@ -271,6 +287,9 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
if isReservedEmail(email) {
return nil, ErrEmailReserved
}
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return nil, err
}
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
@@ -477,6 +496,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@@ -572,6 +592,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@@ -597,6 +618,49 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil
}
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return
}
items := s.settingService.GetDefaultSubscriptions(ctx)
for _, item := range items {
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
Notes: "auto assigned by default user subscriptions setting",
}); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
}
}
}
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
if s.settingService == nil {
return nil
}
whitelist := s.settingService.GetRegistrationEmailSuffixWhitelist(ctx)
if !IsRegistrationEmailSuffixAllowed(email, whitelist) {
return buildEmailSuffixNotAllowedError(whitelist)
}
return nil
}
func buildEmailSuffixNotAllowedError(whitelist []string) error {
if len(whitelist) == 0 {
return ErrEmailSuffixNotAllowed
}
allowed := strings.Join(whitelist, ", ")
return infraerrors.BadRequest(
"EMAIL_SUFFIX_NOT_ALLOWED",
fmt.Sprintf("email suffix is not allowed, allowed suffixes: %s", allowed),
).WithMetadata(map[string]string{
"allowed_suffixes": strings.Join(whitelist, ","),
"allowed_suffix_count": strconv.Itoa(len(whitelist)),
})
}
// ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
// 先做长度校验,尽早拒绝异常超长 token降低 DoS 风险。

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
@@ -56,6 +57,21 @@ type emailCacheStub struct {
err error
}
type defaultSubscriptionAssignerStub struct {
calls []AssignSubscriptionInput
err error
}
func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
if input != nil {
s.calls = append(s.calls, *input)
}
if s.err != nil {
return nil, false, s.err
}
return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
}
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
if s.err != nil {
return nil, s.err
@@ -123,6 +139,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
nil,
nil,
nil, // promoService
nil, // defaultSubAssigner
)
}
@@ -215,6 +232,51 @@ func TestAuthService_Register_ReservedEmail(t *testing.T) {
require.ErrorIs(t, err, ErrEmailReserved)
}
func TestAuthService_Register_EmailSuffixNotAllowed(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
}, nil)
_, _, err := service.Register(context.Background(), "user@other.com", "password")
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
appErr := infraerrors.FromError(err)
require.Contains(t, appErr.Message, "@example.com")
require.Contains(t, appErr.Message, "@company.com")
require.Equal(t, "EMAIL_SUFFIX_NOT_ALLOWED", appErr.Reason)
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
require.Equal(t, "@example.com,@company.com", appErr.Metadata["allowed_suffixes"])
}
func TestAuthService_Register_EmailSuffixAllowed(t *testing.T) {
repo := &userRepoStub{nextID: 8}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["example.com"]`,
}, nil)
_, user, err := service.Register(context.Background(), "user@example.com", "password")
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, int64(8), user.ID)
}
func TestAuthService_SendVerifyCode_EmailSuffixNotAllowed(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
}, nil)
err := service.SendVerifyCode(context.Background(), "user@other.com")
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
appErr := infraerrors.FromError(err)
require.Contains(t, appErr.Message, "@example.com")
require.Contains(t, appErr.Message, "@company.com")
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
}
func TestAuthService_Register_CreateError(t *testing.T) {
repo := &userRepoStub{createErr: errors.New("create failed")}
service := newAuthService(repo, map[string]string{
@@ -381,3 +443,23 @@ func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) {
require.WithinDuration(t, claims.IssuedAt.Time.Add(90*time.Minute), claims.ExpiresAt.Time, 2*time.Second)
}
func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
repo := &userRepoStub{nextID: 42}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
}, nil)
service.defaultSubAssigner = assigner
_, user, err := service.Register(context.Background(), "default-sub@test.com", "password")
require.NoError(t, err)
require.NotNil(t, user)
require.Len(t, assigner.calls, 2)
require.Equal(t, int64(42), assigner.calls[0].UserID)
require.Equal(t, int64(11), assigner.calls[0].GroupID)
require.Equal(t, 30, assigner.calls[0].ValidityDays)
require.Equal(t, int64(12), assigner.calls[1].GroupID)
require.Equal(t, 7, assigner.calls[1].ValidityDays)
}

View File

@@ -0,0 +1,97 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type turnstileVerifierSpy struct {
called int
lastToken string
result *TurnstileVerifyResponse
err error
}
func (s *turnstileVerifierSpy) VerifyToken(_ context.Context, _ string, token, _ string) (*TurnstileVerifyResponse, error) {
s.called++
s.lastToken = token
if s.err != nil {
return nil, s.err
}
if s.result != nil {
return s.result, nil
}
return &TurnstileVerifyResponse{Success: true}, nil
}
func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier TurnstileVerifier) *AuthService {
cfg := &config.Config{
Server: config.ServerConfig{
Mode: "release",
},
Turnstile: config.TurnstileConfig{
Required: true,
},
}
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
turnstileService := NewTurnstileService(settingService, verifier)
return NewAuthService(
&userRepoStub{},
nil, // redeemRepo
nil, // refreshTokenCache
cfg,
settingService,
nil, // emailService
turnstileService,
nil, // emailQueueService
nil, // promoService
nil, // defaultSubAssigner
)
}
func TestAuthService_VerifyTurnstileForRegister_SkipWhenEmailVerifyCodeProvided(t *testing.T) {
verifier := &turnstileVerifierSpy{}
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
SettingKeyEmailVerifyEnabled: "true",
SettingKeyTurnstileEnabled: "true",
SettingKeyTurnstileSecretKey: "secret",
SettingKeyRegistrationEnabled: "true",
}, verifier)
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "123456")
require.NoError(t, err)
require.Equal(t, 0, verifier.called)
}
func TestAuthService_VerifyTurnstileForRegister_RequireWhenVerifyCodeMissing(t *testing.T) {
verifier := &turnstileVerifierSpy{}
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
SettingKeyEmailVerifyEnabled: "true",
SettingKeyTurnstileEnabled: "true",
SettingKeyTurnstileSecretKey: "secret",
}, verifier)
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "")
require.ErrorIs(t, err, ErrTurnstileVerificationFailed)
}
func TestAuthService_VerifyTurnstileForRegister_NoSkipWhenEmailVerifyDisabled(t *testing.T) {
verifier := &turnstileVerifierSpy{}
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
SettingKeyEmailVerifyEnabled: "false",
SettingKeyTurnstileEnabled: "true",
SettingKeyTurnstileSecretKey: "secret",
}, verifier)
err := service.VerifyTurnstileForRegister(context.Background(), "turnstile-token", "127.0.0.1", "123456")
require.NoError(t, err)
require.Equal(t, 1, verifier.called)
require.Equal(t, "turnstile-token", verifier.lastToken)
}

View File

@@ -3,6 +3,7 @@ package service
import (
"context"
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
@@ -10,6 +11,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"golang.org/x/sync/singleflight"
)
// 错误定义
@@ -38,6 +40,7 @@ const (
cacheWriteSetSubscription
cacheWriteUpdateSubscriptionUsage
cacheWriteDeductBalance
cacheWriteUpdateRateLimitUsage
)
// 异步缓存写入工作池配置
@@ -58,6 +61,7 @@ const (
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
balanceLoadTimeout = 3 * time.Second
)
// cacheWriteTask 缓存写入任务
@@ -65,23 +69,33 @@ type cacheWriteTask struct {
kind cacheWriteKind
userID int64
groupID int64
apiKeyID int64
balance float64
amount float64
subscriptionData *subscriptionCacheData
}
// apiKeyRateLimitLoader defines the interface for loading rate limit data from DB.
type apiKeyRateLimitLoader interface {
GetRateLimitData(ctx context.Context, keyID int64) (*APIKeyRateLimitData, error)
}
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct {
cache BillingCache
userRepo UserRepository
subRepo UserSubscriptionRepository
cfg *config.Config
circuitBreaker *billingCircuitBreaker
cache BillingCache
userRepo UserRepository
subRepo UserSubscriptionRepository
apiKeyRateLimitLoader apiKeyRateLimitLoader
cfg *config.Config
circuitBreaker *billingCircuitBreaker
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
cacheWriteStopOnce sync.Once
cacheWriteMu sync.RWMutex
stopped atomic.Bool
balanceLoadSF singleflight.Group
// 丢弃日志节流计数器(减少高负载下日志噪音)
cacheWriteDropFullCount uint64
cacheWriteDropFullLastLog int64
@@ -90,12 +104,13 @@ type BillingCacheService struct {
}
// NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService {
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService {
svc := &BillingCacheService{
cache: cache,
userRepo: userRepo,
subRepo: subRepo,
cfg: cfg,
cache: cache,
userRepo: userRepo,
subRepo: subRepo,
apiKeyRateLimitLoader: apiKeyRepo,
cfg: cfg,
}
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
svc.startCacheWriteWorkers()
@@ -105,35 +120,52 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
// Stop 关闭缓存写入工作池
func (s *BillingCacheService) Stop() {
s.cacheWriteStopOnce.Do(func() {
if s.cacheWriteChan == nil {
s.stopped.Store(true)
s.cacheWriteMu.Lock()
ch := s.cacheWriteChan
if ch != nil {
close(ch)
}
s.cacheWriteMu.Unlock()
if ch == nil {
return
}
close(s.cacheWriteChan)
s.cacheWriteWg.Wait()
s.cacheWriteChan = nil
s.cacheWriteMu.Lock()
if s.cacheWriteChan == ch {
s.cacheWriteChan = nil
}
s.cacheWriteMu.Unlock()
})
}
func (s *BillingCacheService) startCacheWriteWorkers() {
s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize)
ch := make(chan cacheWriteTask, cacheWriteBufferSize)
s.cacheWriteChan = ch
for i := 0; i < cacheWriteWorkerCount; i++ {
s.cacheWriteWg.Add(1)
go s.cacheWriteWorker()
go s.cacheWriteWorker(ch)
}
}
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false并记录告警
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
if s.cacheWriteChan == nil {
if s.stopped.Load() {
s.logCacheWriteDrop(task, "closed")
return false
}
defer func() {
if recovered := recover(); recovered != nil {
// 队列已关闭时可能触发 panic记录后静默失败。
s.logCacheWriteDrop(task, "closed")
enqueued = false
}
}()
s.cacheWriteMu.RLock()
defer s.cacheWriteMu.RUnlock()
if s.cacheWriteChan == nil {
s.logCacheWriteDrop(task, "closed")
return false
}
select {
case s.cacheWriteChan <- task:
return true
@@ -144,9 +176,9 @@ func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued b
}
}
func (s *BillingCacheService) cacheWriteWorker() {
func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) {
defer s.cacheWriteWg.Done()
for task := range s.cacheWriteChan {
for task := range ch {
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
switch task.kind {
case cacheWriteSetBalance:
@@ -165,6 +197,12 @@ func (s *BillingCacheService) cacheWriteWorker() {
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err)
}
}
case cacheWriteUpdateRateLimitUsage:
if s.cache != nil {
if err := s.cache.UpdateAPIKeyRateLimitUsage(ctx, task.apiKeyID, task.amount); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: update rate limit usage cache failed for api key %d: %v", task.apiKeyID, err)
}
}
}
cancel()
}
@@ -181,6 +219,8 @@ func cacheWriteKindName(kind cacheWriteKind) string {
return "update_subscription_usage"
case cacheWriteDeductBalance:
return "deduct_balance"
case cacheWriteUpdateRateLimitUsage:
return "update_rate_limit_usage"
default:
return "unknown"
}
@@ -243,19 +283,31 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
return balance, nil
}
// 缓存未命中,从数据库读取
balance, err = s.getUserBalanceFromDB(ctx, userID)
// 缓存未命中singleflight 合并同一 userID 的并发回源请求。
value, err, _ := s.balanceLoadSF.Do(strconv.FormatInt(userID, 10), func() (any, error) {
loadCtx, cancel := context.WithTimeout(context.Background(), balanceLoadTimeout)
defer cancel()
balance, err := s.getUserBalanceFromDB(loadCtx, userID)
if err != nil {
return nil, err
}
// 异步建立缓存
_ = s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteSetBalance,
userID: userID,
balance: balance,
})
return balance, nil
})
if err != nil {
return 0, err
}
// 异步建立缓存
_ = s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteSetBalance,
userID: userID,
balance: balance,
})
balance, ok := value.(float64)
if !ok {
return 0, fmt.Errorf("unexpected balance type: %T", value)
}
return balance, nil
}
@@ -441,6 +493,137 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
return nil
}
// ============================================
// API Key 限速缓存方法
// ============================================
// checkAPIKeyRateLimits checks rate limit windows for an API key.
// It loads usage from Redis cache (falling back to DB on cache miss),
// resets expired windows in-memory and triggers async DB reset,
// and returns an error if any window limit is exceeded.
func (s *BillingCacheService) checkAPIKeyRateLimits(ctx context.Context, apiKey *APIKey) error {
if s.cache == nil {
// No cache: fall back to reading from DB directly
if s.apiKeyRateLimitLoader == nil {
return nil
}
data, err := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID)
if err != nil {
return nil // Don't block requests on DB errors
}
return s.evaluateRateLimits(ctx, apiKey, data.Usage5h, data.Usage1d, data.Usage7d,
data.Window5hStart, data.Window1dStart, data.Window7dStart)
}
cacheData, err := s.cache.GetAPIKeyRateLimit(ctx, apiKey.ID)
if err != nil {
// Cache miss: load from DB and populate cache
if s.apiKeyRateLimitLoader == nil {
return nil
}
dbData, dbErr := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID)
if dbErr != nil {
return nil // Don't block requests on DB errors
}
// Build cache entry from DB data
cacheEntry := &APIKeyRateLimitCacheData{
Usage5h: dbData.Usage5h,
Usage1d: dbData.Usage1d,
Usage7d: dbData.Usage7d,
}
if dbData.Window5hStart != nil {
cacheEntry.Window5h = dbData.Window5hStart.Unix()
}
if dbData.Window1dStart != nil {
cacheEntry.Window1d = dbData.Window1dStart.Unix()
}
if dbData.Window7dStart != nil {
cacheEntry.Window7d = dbData.Window7dStart.Unix()
}
_ = s.cache.SetAPIKeyRateLimit(ctx, apiKey.ID, cacheEntry)
cacheData = cacheEntry
}
var w5h, w1d, w7d *time.Time
if cacheData.Window5h > 0 {
t := time.Unix(cacheData.Window5h, 0)
w5h = &t
}
if cacheData.Window1d > 0 {
t := time.Unix(cacheData.Window1d, 0)
w1d = &t
}
if cacheData.Window7d > 0 {
t := time.Unix(cacheData.Window7d, 0)
w7d = &t
}
return s.evaluateRateLimits(ctx, apiKey, cacheData.Usage5h, cacheData.Usage1d, cacheData.Usage7d, w5h, w1d, w7d)
}
// evaluateRateLimits checks usage against limits, triggering async resets for expired windows.
func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *APIKey, usage5h, usage1d, usage7d float64, w5h, w1d, w7d *time.Time) error {
needsReset := false
// Reset expired windows in-memory for check purposes
if w5h != nil && time.Since(*w5h) >= 5*time.Hour {
usage5h = 0
needsReset = true
}
if w1d != nil && time.Since(*w1d) >= 24*time.Hour {
usage1d = 0
needsReset = true
}
if w7d != nil && time.Since(*w7d) >= 7*24*time.Hour {
usage7d = 0
needsReset = true
}
// Trigger async DB reset if any window expired
if needsReset {
keyID := apiKey.ID
go func() {
resetCtx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if s.apiKeyRateLimitLoader != nil {
// Use the repo directly - reset then reload cache
if loader, ok := s.apiKeyRateLimitLoader.(interface {
ResetRateLimitWindows(ctx context.Context, id int64) error
}); ok {
_ = loader.ResetRateLimitWindows(resetCtx, keyID)
}
}
// Invalidate cache so next request loads fresh data
if s.cache != nil {
_ = s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID)
}
}()
}
// Check limits
if apiKey.RateLimit5h > 0 && usage5h >= apiKey.RateLimit5h {
return ErrAPIKeyRateLimit5hExceeded
}
if apiKey.RateLimit1d > 0 && usage1d >= apiKey.RateLimit1d {
return ErrAPIKeyRateLimit1dExceeded
}
if apiKey.RateLimit7d > 0 && usage7d >= apiKey.RateLimit7d {
return ErrAPIKeyRateLimit7dExceeded
}
return nil
}
// QueueUpdateAPIKeyRateLimitUsage asynchronously updates rate limit usage in the cache.
func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, cost float64) {
if s.cache == nil {
return
}
s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteUpdateRateLimitUsage,
apiKeyID: apiKeyID,
amount: cost,
})
}
// ============================================
// 统一检查方法
// ============================================
@@ -461,10 +644,23 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
if isSubscriptionMode {
return s.checkSubscriptionEligibility(ctx, user.ID, group, subscription)
if err := s.checkSubscriptionEligibility(ctx, user.ID, group, subscription); err != nil {
return err
}
} else {
if err := s.checkBalanceEligibility(ctx, user.ID); err != nil {
return err
}
}
return s.checkBalanceEligibility(ctx, user.ID)
// Check API Key rate limits (applies to both billing modes)
if apiKey != nil && apiKey.HasRateLimits() {
if err := s.checkAPIKeyRateLimits(ctx, apiKey); err != nil {
return err
}
}
return nil
}
// checkBalanceEligibility 检查余额模式资格

View File

@@ -0,0 +1,131 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type billingCacheMissStub struct {
setBalanceCalls atomic.Int64
}
func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
return 0, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
s.setBalanceCalls.Add(1)
return nil
}
func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
return nil
}
func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
return nil, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
return nil
}
func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
return nil
}
func (s *billingCacheMissStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
return nil, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
return nil
}
func (s *billingCacheMissStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
return nil
}
type balanceLoadUserRepoStub struct {
mockUserRepo
calls atomic.Int64
delay time.Duration
balance float64
}
func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
s.calls.Add(1)
if s.delay > 0 {
select {
case <-time.After(s.delay):
case <-ctx.Done():
return nil, ctx.Err()
}
}
return &User{ID: id, Balance: s.balance}, nil
}
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
cache := &billingCacheMissStub{}
userRepo := &balanceLoadUserRepoStub{
delay: 80 * time.Millisecond,
balance: 12.34,
}
svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
const goroutines = 16
start := make(chan struct{})
var wg sync.WaitGroup
errCh := make(chan error, goroutines)
balCh := make(chan float64, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
bal, err := svc.GetUserBalance(context.Background(), 99)
errCh <- err
balCh <- bal
}()
}
close(start)
wg.Wait()
close(errCh)
close(balCh)
for err := range errCh {
require.NoError(t, err)
}
for bal := range balCh {
require.Equal(t, 12.34, bal)
}
require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并")
require.Eventually(t, func() bool {
return cache.setBalanceCalls.Load() >= 1
}, time.Second, 10*time.Millisecond)
}

View File

@@ -52,9 +52,25 @@ func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context
return nil
}
func (b *billingCacheWorkerStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
return nil, errors.New("not implemented")
}
func (b *billingCacheWorkerStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
return nil
}
func (b *billingCacheWorkerStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
return nil
}
func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
return nil
}
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
start := time.Now()
@@ -73,3 +89,16 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
}, 2*time.Second, 10*time.Millisecond)
}
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
svc.Stop()
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteDeductBalance,
userID: 1,
amount: 1,
})
require.False(t, enqueued)
}

View File

@@ -10,6 +10,16 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
)
// APIKeyRateLimitCacheData holds rate limit usage data cached in Redis.
type APIKeyRateLimitCacheData struct {
Usage5h float64 `json:"usage_5h"`
Usage1d float64 `json:"usage_1d"`
Usage7d float64 `json:"usage_7d"`
Window5h int64 `json:"window_5h"` // unix timestamp, 0 = not started
Window1d int64 `json:"window_1d"`
Window7d int64 `json:"window_7d"`
}
// BillingCache defines cache operations for billing service
type BillingCache interface {
// Balance operations
@@ -23,6 +33,12 @@ type BillingCache interface {
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
// API Key rate limit operations
GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error)
SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error
UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error
InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error
}
// ModelPricing 模型价格配置per-token价格与LiteLLM格式一致

View File

@@ -63,7 +63,7 @@ func TestCalculateImageCost_RateMultiplier(t *testing.T) {
// 费率倍数 1.5x
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5)
require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5
require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5
require.InDelta(t, 0.3015, cost.ActualCost, 0.0001) // ActualCost = 0.201 * 1.5
// 费率倍数 2.0x

View File

@@ -4,6 +4,7 @@ import (
"context"
"net/http"
"regexp"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
@@ -17,6 +18,9 @@ var (
// User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI大小写不敏感)
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
// 带捕获组的版本提取正则
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
@@ -78,7 +82,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过
// 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt
if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku {
if isMaxTokensOneHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(r.Context()); ok && isMaxTokensOneHaiku {
return true // 绕过 system prompt 检查UA 已在 Step 1 验证
}
@@ -270,3 +274,55 @@ func IsClaudeCodeClient(ctx context.Context) bool {
func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context {
return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode)
}
// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
if len(matches) >= 2 {
return matches[1]
}
return ""
}
// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中
func SetClaudeCodeVersion(ctx context.Context, version string) context.Context {
return context.WithValue(ctx, ctxkey.ClaudeCodeVersion, version)
}
// GetClaudeCodeVersion 从 context 中获取 Claude Code 版本号
func GetClaudeCodeVersion(ctx context.Context) string {
if v, ok := ctx.Value(ctxkey.ClaudeCodeVersion).(string); ok {
return v
}
return ""
}
// CompareVersions 比较两个 semver 版本号
// 返回: -1 (a < b), 0 (a == b), 1 (a > b)
func CompareVersions(a, b string) int {
aParts := parseSemver(a)
bParts := parseSemver(b)
for i := 0; i < 3; i++ {
if aParts[i] < bParts[i] {
return -1
}
if aParts[i] > bParts[i] {
return 1
}
}
return 0
}
// parseSemver 解析 semver 版本号为 [major, minor, patch]
func parseSemver(v string) [3]int {
v = strings.TrimPrefix(v, "v")
parts := strings.Split(v, ".")
result := [3]int{0, 0, 0}
for i := 0; i < len(parts) && i < 3; i++ {
if parsed, err := strconv.Atoi(parts[i]); err == nil {
result[i] = parsed
}
}
return result
}

View File

@@ -56,3 +56,51 @@ func TestClaudeCodeValidator_NonMessagesPathUAOnly(t *testing.T) {
ok := validator.Validate(req, nil)
require.True(t, ok)
}
func TestExtractVersion(t *testing.T) {
v := NewClaudeCodeValidator()
tests := []struct {
ua string
want string
}{
{"claude-cli/2.1.22 (darwin; arm64)", "2.1.22"},
{"claude-cli/1.0.0", "1.0.0"},
{"Claude-CLI/3.10.5 (linux; x86_64)", "3.10.5"}, // 大小写不敏感
{"curl/8.0.0", ""}, // 非 Claude CLI
{"", ""}, // 空字符串
{"claude-cli/", ""}, // 无版本号
{"claude-cli/2.1.22-beta", "2.1.22"}, // 带后缀仍提取主版本号
}
for _, tt := range tests {
got := v.ExtractVersion(tt.ua)
require.Equal(t, tt.want, got, "ExtractVersion(%q)", tt.ua)
}
}
func TestCompareVersions(t *testing.T) {
tests := []struct {
a, b string
want int
}{
{"2.1.0", "2.1.0", 0}, // 相等
{"2.1.1", "2.1.0", 1}, // patch 更大
{"2.0.0", "2.1.0", -1}, // minor 更小
{"3.0.0", "2.99.99", 1}, // major 更大
{"1.0.0", "2.0.0", -1}, // major 更小
{"0.0.1", "0.0.0", 1}, // patch 差异
{"", "1.0.0", -1}, // 空字符串 vs 正常版本
{"v2.1.0", "2.1.0", 0}, // v 前缀处理
}
for _, tt := range tests {
got := CompareVersions(tt.a, tt.b)
require.Equal(t, tt.want, got, "CompareVersions(%q, %q)", tt.a, tt.b)
}
}
func TestSetGetClaudeCodeVersion(t *testing.T) {
ctx := context.Background()
require.Equal(t, "", GetClaudeCodeVersion(ctx), "empty context should return empty string")
ctx = SetClaudeCodeVersion(ctx, "2.1.63")
require.Equal(t, "2.1.63", GetClaudeCodeVersion(ctx))
}

View File

@@ -3,8 +3,10 @@ package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"encoding/binary"
"os"
"strconv"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -18,6 +20,7 @@ type ConcurrencyCache interface {
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
// 账号等待队列(账号级)
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
@@ -42,15 +45,25 @@ type ConcurrencyCache interface {
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
}
// generateRequestID generates a unique request ID for concurrency slot tracking
// Uses 8 random bytes (16 hex chars) for uniqueness
func generateRequestID() string {
var (
requestIDPrefix = initRequestIDPrefix()
requestIDCounter atomic.Uint64
)
func initRequestIDPrefix() string {
b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
// Fallback to nanosecond timestamp (extremely rare case)
return fmt.Sprintf("%x", time.Now().UnixNano())
if _, err := rand.Read(b); err == nil {
return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36)
}
return hex.EncodeToString(b)
fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16)
return "r" + strconv.FormatUint(fallback, 36)
}
// generateRequestID generates a unique request ID for concurrency slot tracking.
// Format: {process_random_prefix}-{base36_counter}
func generateRequestID() string {
seq := requestIDCounter.Add(1)
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
}
const (
@@ -321,16 +334,15 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int)
for _, accountID := range accountIDs {
count, err := s.cache.GetAccountConcurrency(ctx, accountID)
if err != nil {
// If key doesn't exist in Redis, count is 0
count = 0
}
result[accountID] = count
if len(accountIDs) == 0 {
return map[int64]int{}, nil
}
return result, nil
if s.cache == nil {
result := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
result[accountID] = 0
}
return result, nil
}
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
}

View File

@@ -5,6 +5,8 @@ package service
import (
"context"
"errors"
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/require"
@@ -12,20 +14,20 @@ import (
// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩
type stubConcurrencyCacheForTest struct {
acquireResult bool
acquireErr error
releaseErr error
concurrency int
acquireResult bool
acquireErr error
releaseErr error
concurrency int
concurrencyErr error
waitAllowed bool
waitErr error
waitCount int
waitCountErr error
loadBatch map[int64]*AccountLoadInfo
loadBatchErr error
waitAllowed bool
waitErr error
waitCount int
waitCountErr error
loadBatch map[int64]*AccountLoadInfo
loadBatchErr error
usersLoadBatch map[int64]*UserLoadInfo
usersLoadErr error
cleanupErr error
cleanupErr error
// 记录调用
releasedAccountIDs []int64
@@ -45,6 +47,16 @@ func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, acco
func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) {
return c.concurrency, c.concurrencyErr
}
func (c *stubConcurrencyCacheForTest) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
if c.concurrencyErr != nil {
return nil, c.concurrencyErr
}
result[accountID] = c.concurrency
}
return result, nil
}
func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) {
return c.waitAllowed, c.waitErr
}
@@ -155,6 +167,25 @@ func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) {
require.True(t, result.Acquired)
}
func TestGenerateRequestID_UsesStablePrefixAndMonotonicCounter(t *testing.T) {
id1 := generateRequestID()
id2 := generateRequestID()
require.NotEmpty(t, id1)
require.NotEmpty(t, id2)
p1 := strings.Split(id1, "-")
p2 := strings.Split(id2, "-")
require.Len(t, p1, 2)
require.Len(t, p2, 2)
require.Equal(t, p1[0], p2[0], "同一进程前缀应保持一致")
n1, err := strconv.ParseUint(p1[1], 36, 64)
require.NoError(t, err)
n2, err := strconv.ParseUint(p2[1], 36, 64)
require.NoError(t, err)
require.Equal(t, n1+1, n2, "计数器应单调递增")
}
func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) {
expected := map[int64]*AccountLoadInfo{
1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60},

View File

@@ -221,7 +221,7 @@ func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username,
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 20 * time.Second}
return nil, fmt.Errorf("create http client failed: %w", err)
}
adminToken, err := crsLogin(ctx, client, normalizedURL, username, password)

View File

@@ -124,24 +124,24 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return stats, nil
}
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get group stats with filters: %w", err)
}

View File

@@ -0,0 +1,252 @@
package service
import "context"
type DataManagementPostgresConfig struct {
Host string `json:"host"`
Port int32 `json:"port"`
User string `json:"user"`
Password string `json:"password,omitempty"`
PasswordConfigured bool `json:"password_configured"`
Database string `json:"database"`
SSLMode string `json:"ssl_mode"`
ContainerName string `json:"container_name"`
}
type DataManagementRedisConfig struct {
Addr string `json:"addr"`
Username string `json:"username"`
Password string `json:"password,omitempty"`
PasswordConfigured bool `json:"password_configured"`
DB int32 `json:"db"`
ContainerName string `json:"container_name"`
}
type DataManagementS3Config struct {
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key,omitempty"`
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
type DataManagementConfig struct {
SourceMode string `json:"source_mode"`
BackupRoot string `json:"backup_root"`
SQLitePath string `json:"sqlite_path,omitempty"`
RetentionDays int32 `json:"retention_days"`
KeepLast int32 `json:"keep_last"`
ActivePostgresID string `json:"active_postgres_profile_id"`
ActiveRedisID string `json:"active_redis_profile_id"`
Postgres DataManagementPostgresConfig `json:"postgres"`
Redis DataManagementRedisConfig `json:"redis"`
S3 DataManagementS3Config `json:"s3"`
ActiveS3ProfileID string `json:"active_s3_profile_id"`
}
type DataManagementTestS3Result struct {
OK bool `json:"ok"`
Message string `json:"message"`
}
type DataManagementCreateBackupJobInput struct {
BackupType string
UploadToS3 bool
TriggeredBy string
IdempotencyKey string
S3ProfileID string
PostgresID string
RedisID string
}
type DataManagementListBackupJobsInput struct {
PageSize int32
PageToken string
Status string
BackupType string
}
type DataManagementArtifactInfo struct {
LocalPath string `json:"local_path"`
SizeBytes int64 `json:"size_bytes"`
SHA256 string `json:"sha256"`
}
type DataManagementS3ObjectInfo struct {
Bucket string `json:"bucket"`
Key string `json:"key"`
ETag string `json:"etag"`
}
type DataManagementBackupJob struct {
JobID string `json:"job_id"`
BackupType string `json:"backup_type"`
Status string `json:"status"`
TriggeredBy string `json:"triggered_by"`
IdempotencyKey string `json:"idempotency_key,omitempty"`
UploadToS3 bool `json:"upload_to_s3"`
S3ProfileID string `json:"s3_profile_id,omitempty"`
PostgresID string `json:"postgres_profile_id,omitempty"`
RedisID string `json:"redis_profile_id,omitempty"`
StartedAt string `json:"started_at,omitempty"`
FinishedAt string `json:"finished_at,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
Artifact DataManagementArtifactInfo `json:"artifact"`
S3Object DataManagementS3ObjectInfo `json:"s3"`
}
type DataManagementSourceProfile struct {
SourceType string `json:"source_type"`
ProfileID string `json:"profile_id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
Config DataManagementSourceConfig `json:"config"`
PasswordConfigured bool `json:"password_configured"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
type DataManagementSourceConfig struct {
Host string `json:"host"`
Port int32 `json:"port"`
User string `json:"user"`
Password string `json:"password,omitempty"`
Database string `json:"database"`
SSLMode string `json:"ssl_mode"`
Addr string `json:"addr"`
Username string `json:"username"`
DB int32 `json:"db"`
ContainerName string `json:"container_name"`
}
type DataManagementCreateSourceProfileInput struct {
SourceType string
ProfileID string
Name string
Config DataManagementSourceConfig
SetActive bool
}
type DataManagementUpdateSourceProfileInput struct {
SourceType string
ProfileID string
Name string
Config DataManagementSourceConfig
}
type DataManagementS3Profile struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
S3 DataManagementS3Config `json:"s3"`
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
type DataManagementCreateS3ProfileInput struct {
ProfileID string
Name string
S3 DataManagementS3Config
SetActive bool
}
type DataManagementUpdateS3ProfileInput struct {
ProfileID string
Name string
S3 DataManagementS3Config
}
type DataManagementListBackupJobsResult struct {
Items []DataManagementBackupJob `json:"items"`
NextPageToken string `json:"next_page_token,omitempty"`
}
func (s *DataManagementService) GetConfig(ctx context.Context) (DataManagementConfig, error) {
_ = ctx
return DataManagementConfig{}, s.deprecatedError()
}
func (s *DataManagementService) UpdateConfig(ctx context.Context, cfg DataManagementConfig) (DataManagementConfig, error) {
_, _ = ctx, cfg
return DataManagementConfig{}, s.deprecatedError()
}
func (s *DataManagementService) ListSourceProfiles(ctx context.Context, sourceType string) ([]DataManagementSourceProfile, error) {
_, _ = ctx, sourceType
return nil, s.deprecatedError()
}
func (s *DataManagementService) CreateSourceProfile(ctx context.Context, input DataManagementCreateSourceProfileInput) (DataManagementSourceProfile, error) {
_, _ = ctx, input
return DataManagementSourceProfile{}, s.deprecatedError()
}
func (s *DataManagementService) UpdateSourceProfile(ctx context.Context, input DataManagementUpdateSourceProfileInput) (DataManagementSourceProfile, error) {
_, _ = ctx, input
return DataManagementSourceProfile{}, s.deprecatedError()
}
func (s *DataManagementService) DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error {
_, _, _ = ctx, sourceType, profileID
return s.deprecatedError()
}
func (s *DataManagementService) SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (DataManagementSourceProfile, error) {
_, _, _ = ctx, sourceType, profileID
return DataManagementSourceProfile{}, s.deprecatedError()
}
func (s *DataManagementService) ValidateS3(ctx context.Context, cfg DataManagementS3Config) (DataManagementTestS3Result, error) {
_, _ = ctx, cfg
return DataManagementTestS3Result{}, s.deprecatedError()
}
func (s *DataManagementService) ListS3Profiles(ctx context.Context) ([]DataManagementS3Profile, error) {
_ = ctx
return nil, s.deprecatedError()
}
func (s *DataManagementService) CreateS3Profile(ctx context.Context, input DataManagementCreateS3ProfileInput) (DataManagementS3Profile, error) {
_, _ = ctx, input
return DataManagementS3Profile{}, s.deprecatedError()
}
func (s *DataManagementService) UpdateS3Profile(ctx context.Context, input DataManagementUpdateS3ProfileInput) (DataManagementS3Profile, error) {
_, _ = ctx, input
return DataManagementS3Profile{}, s.deprecatedError()
}
func (s *DataManagementService) DeleteS3Profile(ctx context.Context, profileID string) error {
_, _ = ctx, profileID
return s.deprecatedError()
}
func (s *DataManagementService) SetActiveS3Profile(ctx context.Context, profileID string) (DataManagementS3Profile, error) {
_, _ = ctx, profileID
return DataManagementS3Profile{}, s.deprecatedError()
}
func (s *DataManagementService) CreateBackupJob(ctx context.Context, input DataManagementCreateBackupJobInput) (DataManagementBackupJob, error) {
_, _ = ctx, input
return DataManagementBackupJob{}, s.deprecatedError()
}
func (s *DataManagementService) ListBackupJobs(ctx context.Context, input DataManagementListBackupJobsInput) (DataManagementListBackupJobsResult, error) {
_, _ = ctx, input
return DataManagementListBackupJobsResult{}, s.deprecatedError()
}
func (s *DataManagementService) GetBackupJob(ctx context.Context, jobID string) (DataManagementBackupJob, error) {
_, _ = ctx, jobID
return DataManagementBackupJob{}, s.deprecatedError()
}
func (s *DataManagementService) deprecatedError() error {
return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
}

View File

@@ -0,0 +1,36 @@
package service
import (
"context"
"path/filepath"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
func TestDataManagementService_DeprecatedRPCMethods(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "datamanagement.sock")
svc := NewDataManagementServiceWithOptions(socketPath, 0)
_, err := svc.GetConfig(context.Background())
assertDeprecatedDataManagementError(t, err, socketPath)
_, err = svc.CreateBackupJob(context.Background(), DataManagementCreateBackupJobInput{BackupType: "full"})
assertDeprecatedDataManagementError(t, err, socketPath)
err = svc.DeleteS3Profile(context.Background(), "s3-default")
assertDeprecatedDataManagementError(t, err, socketPath)
}
func assertDeprecatedDataManagementError(t *testing.T, err error, socketPath string) {
t.Helper()
require.Error(t, err)
statusCode, status := infraerrors.ToHTTP(err)
require.Equal(t, 503, statusCode)
require.Equal(t, DataManagementDeprecatedReason, status.Reason)
require.Equal(t, socketPath, status.Metadata["socket_path"])
}

View File

@@ -0,0 +1,95 @@
package service
import (
"context"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const (
DefaultDataManagementAgentSocketPath = "/tmp/sub2api-datamanagement.sock"
LegacyBackupAgentSocketPath = "/tmp/sub2api-backup.sock"
DataManagementDeprecatedReason = "DATA_MANAGEMENT_DEPRECATED"
DataManagementAgentSocketMissingReason = "DATA_MANAGEMENT_AGENT_SOCKET_MISSING"
DataManagementAgentUnavailableReason = "DATA_MANAGEMENT_AGENT_UNAVAILABLE"
// Deprecated: keep old names for compatibility.
DefaultBackupAgentSocketPath = DefaultDataManagementAgentSocketPath
BackupAgentSocketMissingReason = DataManagementAgentSocketMissingReason
BackupAgentUnavailableReason = DataManagementAgentUnavailableReason
)
var (
ErrDataManagementDeprecated = infraerrors.ServiceUnavailable(
DataManagementDeprecatedReason,
"data management feature is deprecated",
)
ErrDataManagementAgentSocketMissing = infraerrors.ServiceUnavailable(
DataManagementAgentSocketMissingReason,
"data management agent socket is missing",
)
ErrDataManagementAgentUnavailable = infraerrors.ServiceUnavailable(
DataManagementAgentUnavailableReason,
"data management agent is unavailable",
)
// Deprecated: keep old names for compatibility.
ErrBackupAgentSocketMissing = ErrDataManagementAgentSocketMissing
ErrBackupAgentUnavailable = ErrDataManagementAgentUnavailable
)
type DataManagementAgentHealth struct {
Enabled bool
Reason string
SocketPath string
Agent *DataManagementAgentInfo
}
type DataManagementAgentInfo struct {
Status string
Version string
UptimeSeconds int64
}
type DataManagementService struct {
socketPath string
}
func NewDataManagementService() *DataManagementService {
return NewDataManagementServiceWithOptions(DefaultDataManagementAgentSocketPath, 500*time.Millisecond)
}
func NewDataManagementServiceWithOptions(socketPath string, dialTimeout time.Duration) *DataManagementService {
_ = dialTimeout
path := strings.TrimSpace(socketPath)
if path == "" {
path = DefaultDataManagementAgentSocketPath
}
return &DataManagementService{
socketPath: path,
}
}
func (s *DataManagementService) SocketPath() string {
if s == nil || strings.TrimSpace(s.socketPath) == "" {
return DefaultDataManagementAgentSocketPath
}
return s.socketPath
}
func (s *DataManagementService) GetAgentHealth(ctx context.Context) DataManagementAgentHealth {
_ = ctx
return DataManagementAgentHealth{
Enabled: false,
Reason: DataManagementDeprecatedReason,
SocketPath: s.SocketPath(),
}
}
func (s *DataManagementService) EnsureAgentEnabled(ctx context.Context) error {
_ = ctx
return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
}

View File

@@ -0,0 +1,37 @@
package service
import (
"context"
"path/filepath"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
func TestDataManagementService_GetAgentHealth_Deprecated(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "unused.sock")
svc := NewDataManagementServiceWithOptions(socketPath, 0)
health := svc.GetAgentHealth(context.Background())
require.False(t, health.Enabled)
require.Equal(t, DataManagementDeprecatedReason, health.Reason)
require.Equal(t, socketPath, health.SocketPath)
require.Nil(t, health.Agent)
}
func TestDataManagementService_EnsureAgentEnabled_Deprecated(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "unused.sock")
svc := NewDataManagementServiceWithOptions(socketPath, 100)
err := svc.EnsureAgentEnabled(context.Background())
require.Error(t, err)
statusCode, status := infraerrors.ToHTTP(err)
require.Equal(t, 503, statusCode)
require.Equal(t, DataManagementDeprecatedReason, status.Reason)
require.Equal(t, socketPath, status.Metadata["socket_path"])
}

View File

@@ -74,11 +74,12 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyRegistrationEmailSuffixWhitelist = "registration_email_suffix_whitelist" // 注册邮箱后缀白名单JSON 数组)
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
@@ -104,6 +105,7 @@ const (
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// OEM设置
SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制)
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
@@ -112,12 +114,14 @@ const (
SettingKeyDocURL = "doc_url" // 文档链接
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML或 URL 作为 iframe src
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示购买订阅页面入口
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // 购买订阅页面 URL作为 iframe src
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL作为 iframe src
SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项JSON 数组)
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表JSON
// 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
@@ -170,6 +174,37 @@ const (
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
// =========================
// Sora S3 存储配置
// =========================
SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储
SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址
SettingKeySoraS3Region = "sora_s3_region" // S3 区域
SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称
SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID
SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key加密存储
SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀
SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style兼容 MinIO 等)
SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL可选
SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置JSON
// =========================
// Sora 用户存储配额
// =========================
SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节)
// =========================
// Claude Code Version Check
// =========================
// SettingKeyMinClaudeCodeVersion 最低 Claude Code 版本号要求 (semver, 如 "2.1.0",空值=不检查)
SettingKeyMinClaudeCodeVersion = "min_claude_code_version"
// SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false未分组 Key 返回 403
SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).

View File

@@ -279,10 +279,10 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotE
wantPassthrough: true,
},
{
name: "404 generic not found passes through as 404",
name: "404 generic not found does not passthrough",
statusCode: http.StatusNotFound,
respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
wantPassthrough: true,
wantPassthrough: false,
},
{
name: "400 Invalid URL does not passthrough",

View File

@@ -136,3 +136,67 @@ func TestDroppedBetaSet(t *testing.T) {
require.Contains(t, extended, claude.BetaClaudeCode)
require.Len(t, extended, len(claude.DroppedBetas)+1)
}
func TestBuildBetaTokenSet(t *testing.T) {
got := buildBetaTokenSet([]string{"foo", "", "bar", "foo"})
require.Len(t, got, 2)
require.Contains(t, got, "foo")
require.Contains(t, got, "bar")
require.NotContains(t, got, "")
empty := buildBetaTokenSet(nil)
require.Empty(t, empty)
}
func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) {
header := "oauth-2025-04-20,interleaved-thinking-2025-05-14"
got := stripBetaTokensWithSet(header, map[string]struct{}{})
require.Equal(t, header, got)
}
func TestIsCountTokensUnsupported404(t *testing.T) {
tests := []struct {
name string
statusCode int
body string
want bool
}{
{
name: "exact endpoint not found",
statusCode: 404,
body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`,
want: true,
},
{
name: "contains count_tokens and not found",
statusCode: 404,
body: `{"error":{"message":"count_tokens route not found","type":"not_found_error"}}`,
want: true,
},
{
name: "generic 404",
statusCode: 404,
body: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
want: false,
},
{
name: "404 with empty error message",
statusCode: 404,
body: `{"error":{"message":"","type":"not_found_error"}}`,
want: false,
},
{
name: "non-404 status",
statusCode: 400,
body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"invalid_request_error"}}`,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isCountTokensUnsupported404(tt.statusCode, []byte(tt.body))
require.Equal(t, tt.want, got)
})
}
}

View File

@@ -0,0 +1,363 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ============================================================================
// Part 1: isAccountInGroup 单元测试
// ============================================================================
func TestIsAccountInGroup(t *testing.T) {
svc := &GatewayService{}
groupID100 := int64(100)
groupID200 := int64(200)
tests := []struct {
name string
account *Account
groupID *int64
expected bool
}{
// groupID == nil无分组 API Key
{
"nil_groupID_ungrouped_account_nil_groups",
&Account{ID: 1, AccountGroups: nil},
nil, true,
},
{
"nil_groupID_ungrouped_account_empty_slice",
&Account{ID: 2, AccountGroups: []AccountGroup{}},
nil, true,
},
{
"nil_groupID_grouped_account_single",
&Account{ID: 3, AccountGroups: []AccountGroup{{GroupID: 100}}},
nil, false,
},
{
"nil_groupID_grouped_account_multiple",
&Account{ID: 4, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}},
nil, false,
},
// groupID != nil有分组 API Key
{
"with_groupID_account_in_group",
&Account{ID: 5, AccountGroups: []AccountGroup{{GroupID: 100}}},
&groupID100, true,
},
{
"with_groupID_account_not_in_group",
&Account{ID: 6, AccountGroups: []AccountGroup{{GroupID: 200}}},
&groupID100, false,
},
{
"with_groupID_ungrouped_account",
&Account{ID: 7, AccountGroups: nil},
&groupID100, false,
},
{
"with_groupID_multi_group_account_match_one",
&Account{ID: 8, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}},
&groupID200, true,
},
{
"with_groupID_multi_group_account_no_match",
&Account{ID: 9, AccountGroups: []AccountGroup{{GroupID: 300}, {GroupID: 400}}},
&groupID100, false,
},
// 防御性边界
{
"nil_account_nil_groupID",
nil,
nil, false,
},
{
"nil_account_with_groupID",
nil,
&groupID100, false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.isAccountInGroup(tt.account, tt.groupID)
require.Equal(t, tt.expected, got, "isAccountInGroup 结果不符预期")
})
}
}
// ============================================================================
// Part 2: 分组隔离端到端调度测试
// ============================================================================
// groupAwareMockAccountRepo 嵌入 mockAccountRepoForPlatform覆写分组隔离相关方法。
// allAccounts 存储所有账号,分组查询方法按 AccountGroups 字段进行真实过滤。
type groupAwareMockAccountRepo struct {
*mockAccountRepoForPlatform
allAccounts []Account
}
// ListSchedulableUngroupedByPlatform 仅返回未分组账号AccountGroups 为空)
func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
var result []Account
for _, acc := range m.allAccounts {
if acc.Platform == platform && acc.IsSchedulable() && len(acc.AccountGroups) == 0 {
result = append(result, acc)
}
}
return result, nil
}
// ListSchedulableUngroupedByPlatforms 仅返回未分组账号(多平台版本)
func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
platformSet := make(map[string]bool, len(platforms))
for _, p := range platforms {
platformSet[p] = true
}
var result []Account
for _, acc := range m.allAccounts {
if platformSet[acc.Platform] && acc.IsSchedulable() && len(acc.AccountGroups) == 0 {
result = append(result, acc)
}
}
return result, nil
}
// ListSchedulableByGroupIDAndPlatform 返回属于指定分组的账号
func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
var result []Account
for _, acc := range m.allAccounts {
if acc.Platform == platform && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) {
result = append(result, acc)
}
}
return result, nil
}
// ListSchedulableByGroupIDAndPlatforms 返回属于指定分组的账号(多平台版本)
func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
platformSet := make(map[string]bool, len(platforms))
for _, p := range platforms {
platformSet[p] = true
}
var result []Account
for _, acc := range m.allAccounts {
if platformSet[acc.Platform] && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) {
result = append(result, acc)
}
}
return result, nil
}
// accountBelongsToGroup 检查账号是否属于指定分组
func accountBelongsToGroup(acc Account, groupID int64) bool {
for _, ag := range acc.AccountGroups {
if ag.GroupID == groupID {
return true
}
}
return false
}
// Verify interface implementation
var _ AccountRepository = (*groupAwareMockAccountRepo)(nil)
// newGroupAwareMockRepo 创建分组感知的 mock repo
func newGroupAwareMockRepo(accounts []Account) *groupAwareMockAccountRepo {
byID := make(map[int64]*Account, len(accounts))
for i := range accounts {
byID[accounts[i].ID] = &accounts[i]
}
return &groupAwareMockAccountRepo{
mockAccountRepoForPlatform: &mockAccountRepoForPlatform{
accounts: accounts,
accountsByID: byID,
},
allAccounts: accounts,
}
}
func TestGroupIsolation_UngroupedKey_ShouldNotScheduleGroupedAccounts(t *testing.T) {
// 场景:无分组 API KeygroupID=nil池中只有已分组账号 → 应返回错误
ctx := context.Background()
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 100}}},
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 200}}},
}
repo := newGroupAwareMockRepo(accounts)
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
require.Error(t, err, "无分组 Key 不应调度到已分组账号")
require.Nil(t, acc)
}
func TestGroupIsolation_GroupedKey_ShouldNotScheduleUngroupedAccounts(t *testing.T) {
// 场景:有分组 API KeygroupID=100池中只有未分组账号 → 应返回错误
ctx := context.Background()
groupID := int64(100)
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: nil},
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{}},
}
repo := newGroupAwareMockRepo(accounts)
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI)
require.Error(t, err, "有分组 Key 不应调度到未分组账号")
require.Nil(t, acc)
}
func TestGroupIsolation_UngroupedKey_ShouldOnlyScheduleUngroupedAccounts(t *testing.T) {
// 场景:无分组 API KeygroupID=nil池中有未分组和已分组账号 → 应只选中未分组的
ctx := context.Background()
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组,不应被选中
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
AccountGroups: nil}, // 未分组,应被选中
{ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 200}}}, // 已分组,不应被选中
}
repo := newGroupAwareMockRepo(accounts)
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
require.NoError(t, err, "应成功调度未分组账号")
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应选中未分组的账号 ID=2")
}
func TestGroupIsolation_GroupedKey_ShouldOnlyScheduleMatchingGroupAccounts(t *testing.T) {
// 场景:有分组 API KeygroupID=100池中有未分组和多个分组账号 → 应只选中分组 100 内的
ctx := context.Background()
groupID := int64(100)
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: nil}, // 未分组,不应被选中
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 200}}}, // 属于分组 200不应被选中
{ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 属于分组 100应被选中
}
repo := newGroupAwareMockRepo(accounts)
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI)
require.NoError(t, err, "应成功调度分组内账号")
require.NotNil(t, acc)
require.Equal(t, int64(3), acc.ID, "应选中分组 100 内的账号 ID=3")
}
// ============================================================================
// Part 3: SimpleMode 旁路测试
// ============================================================================
func TestGroupIsolation_SimpleMode_SkipsGroupIsolation(t *testing.T) {
// SimpleMode 应跳过分组隔离,使用 ListSchedulableByPlatform 返回所有账号。
// 测试非 useMixed 路径platform=openai不会触发 mixed 调度逻辑)。
ctx := context.Background()
// 混合未分组和已分组账号SimpleMode 下应全部可调度
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组
{ID: 2, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: nil}, // 未分组
}
// 使用基础 mockListSchedulableByPlatform 返回所有匹配平台的账号,不做分组过滤)
byID := make(map[int64]*Account, len(accounts))
for i := range accounts {
byID[accounts[i].ID] = &accounts[i]
}
repo := &mockAccountRepoForPlatform{
accounts: accounts,
accountsByID: byID,
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: &config.Config{RunMode: config.RunModeSimple},
}
// groupID=nil 时SimpleMode 应使用 ListSchedulableByPlatform不过滤分组
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
require.NoError(t, err, "SimpleMode 应跳过分组隔离直接返回账号")
require.NotNil(t, acc)
// 应选择优先级最高的账号Priority=1, ID=2即使它未分组
require.Equal(t, int64(2), acc.ID, "SimpleMode 应按优先级选择,不考虑分组")
}
func TestGroupIsolation_SimpleMode_GroupedAccountAlsoSchedulable(t *testing.T) {
// SimpleMode + groupID=nil 时,已分组账号也应该可被调度
ctx := context.Background()
// 只有已分组账号,在 standard 模式下 groupID=nil 会报错,但 simple 模式应正常
accounts := []Account{
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
AccountGroups: []AccountGroup{{GroupID: 100}}},
}
byID := make(map[int64]*Account, len(accounts))
for i := range accounts {
byID[accounts[i].ID] = &accounts[i]
}
repo := &mockAccountRepoForPlatform{
accounts: accounts,
accountsByID: byID,
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: &config.Config{RunMode: config.RunModeSimple},
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
require.NoError(t, err, "SimpleMode 下已分组账号也应可调度")
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "SimpleMode 应能调度已分组账号")
}

View File

@@ -147,6 +147,12 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByPlatforms(ctx context.Cont
func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
return m.ListSchedulableByPlatforms(ctx, platforms)
}
func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
return m.ListSchedulableByPlatform(ctx, platform)
}
func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
return m.ListSchedulableByPlatforms(ctx, platforms)
}
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
@@ -1892,6 +1898,14 @@ func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accoun
return 0, nil
}
func (m *mockConcurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
result[accountID] = 0
}
return result, nil
}
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}

View File

@@ -61,6 +61,10 @@ type ParsedRequest struct {
ThinkingEnabled bool // 是否开启 thinking部分平台会影响最终模型名
MaxTokens int // max_tokens 值(用于探测请求拦截)
SessionContext *SessionContext // 可选请求上下文区分因子nil 时行为不变)
// OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁)
// 流式请求在收到 2xx 响应头后调用,避免持锁等流完成
OnUpstreamAccepted func()
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果。

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,141 @@
package service
import (
"context"
"strings"
"testing"
"time"
)
func TestCollectSelectionFailureStats(t *testing.T) {
svc := &GatewayService{}
model := "sora2-landscape-10s"
resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339)
accounts := []Account{
// excluded
{
ID: 1,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
},
// unschedulable
{
ID: 2,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: false,
},
// platform filtered
{
ID: 3,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
},
// model unsupported
{
ID: 4,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-image": "gpt-image",
},
},
},
// model rate limited
{
ID: 5,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"model_rate_limits": map[string]any{
model: map[string]any{
"rate_limit_reset_at": resetAt,
},
},
},
},
// eligible
{
ID: 6,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
},
}
excluded := map[int64]struct{}{1: {}}
stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false)
if stats.Total != 6 {
t.Fatalf("total=%d want=6", stats.Total)
}
if stats.Excluded != 1 {
t.Fatalf("excluded=%d want=1", stats.Excluded)
}
if stats.Unschedulable != 1 {
t.Fatalf("unschedulable=%d want=1", stats.Unschedulable)
}
if stats.PlatformFiltered != 1 {
t.Fatalf("platform_filtered=%d want=1", stats.PlatformFiltered)
}
if stats.ModelUnsupported != 1 {
t.Fatalf("model_unsupported=%d want=1", stats.ModelUnsupported)
}
if stats.ModelRateLimited != 1 {
t.Fatalf("model_rate_limited=%d want=1", stats.ModelRateLimited)
}
if stats.Eligible != 1 {
t.Fatalf("eligible=%d want=1", stats.Eligible)
}
}
func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) {
svc := &GatewayService{}
acc := &Account{
ID: 7,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: false,
}
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
if diagnosis.Category != "unschedulable" {
t.Fatalf("category=%s want=unschedulable", diagnosis.Category)
}
if diagnosis.Detail != "schedulable=false" {
t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail)
}
}
func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
svc := &GatewayService{}
model := "sora2-landscape-10s"
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
acc := &Account{
ID: 8,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"model_rate_limits": map[string]any{
model: map[string]any{
"rate_limit_reset_at": resetAt,
},
},
},
}
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false)
if diagnosis.Category != "model_rate_limited" {
t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category)
}
if !strings.Contains(diagnosis.Detail, "remaining=") {
t.Fatalf("detail=%s want contains remaining=", diagnosis.Detail)
}
}

View File

@@ -0,0 +1,79 @@
package service
import "testing"
func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected sora model to be supported when model_mapping is empty")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-4o": "gpt-4o",
},
},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected sora model to be supported when mapping has no sora selectors")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"sora2": "sora2",
},
},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") {
t.Fatalf("expected family selector sora2 to support sora2-landscape-15s")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"sy_8": "sy_8",
},
},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-image": "gpt-image",
},
},
}
if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image")
}
}

View File

@@ -0,0 +1,89 @@
package service
import (
"context"
"testing"
"time"
)
func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) {
svc := &GatewayService{}
now := time.Now()
past := now.Add(-1 * time.Minute)
future := now.Add(5 * time.Minute)
acc := &Account{
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
AutoPauseOnExpired: true,
ExpiresAt: &past,
OverloadUntil: &future,
RateLimitResetAt: &future,
}
if !svc.isAccountSchedulableForSelection(acc) {
t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows")
}
}
func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) {
svc := &GatewayService{}
future := time.Now().Add(5 * time.Minute)
acc := &Account{
Platform: PlatformAnthropic,
Status: StatusActive,
Schedulable: true,
RateLimitResetAt: &future,
}
if svc.isAccountSchedulableForSelection(acc) {
t.Fatalf("expected non-sora account to keep generic schedulable checks")
}
}
func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) {
svc := &GatewayService{}
model := "sora2-landscape-10s"
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
globalResetAt := time.Now().Add(2 * time.Minute)
acc := &Account{
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
RateLimitResetAt: &globalResetAt,
Extra: map[string]any{
"model_rate_limits": map[string]any{
model: map[string]any{
"rate_limit_reset_at": resetAt,
},
},
},
}
if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) {
t.Fatalf("expected sora account to be blocked by model scope rate limit")
}
}
func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) {
svc := &GatewayService{}
future := time.Now().Add(3 * time.Minute)
accounts := []Account{
{
ID: 1,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
RateLimitResetAt: &future,
},
}
stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
if stats.Unschedulable != 0 || stats.Eligible != 1 {
t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible)
}
}

View File

@@ -105,12 +105,12 @@ func TestCalculateMaxWait_Scenarios(t *testing.T) {
concurrency int
expected int
}{
{5, 25}, // 5 + 20
{10, 30}, // 10 + 20
{1, 21}, // 1 + 20
{0, 21}, // min(1) + 20
{-1, 21}, // min(1) + 20
{-10, 21}, // min(1) + 20
{5, 25}, // 5 + 20
{10, 30}, // 10 + 20
{1, 21}, // 1 + 20
{0, 21}, // min(1) + 20
{-1, 21}, // min(1) + 20
{-10, 21}, // min(1) + 20
{100, 120}, // 100 + 20
}
for _, tt := range tests {

View File

@@ -53,6 +53,7 @@ type GeminiMessagesCompatService struct {
httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService
cfg *config.Config
responseHeaderFilter *responseheaders.CompiledHeaderFilter
}
func NewGeminiMessagesCompatService(
@@ -76,6 +77,7 @@ func NewGeminiMessagesCompatService(
httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService,
cfg: cfg,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
}
}
@@ -229,6 +231,16 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
account *Account,
requestedModel, platform string,
useMixedScheduling bool,
) bool {
return s.isAccountUsableForRequestWithPrecheck(ctx, account, requestedModel, platform, useMixedScheduling, nil)
}
func (s *GeminiMessagesCompatService) isAccountUsableForRequestWithPrecheck(
ctx context.Context,
account *Account,
requestedModel, platform string,
useMixedScheduling bool,
precheckResult map[int64]bool,
) bool {
// 检查模型调度能力
// Check model scheduling capability
@@ -250,7 +262,7 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
// 速率限制预检
// Rate limit precheck
if !s.passesRateLimitPreCheck(ctx, account, requestedModel) {
if !s.passesRateLimitPreCheckWithCache(ctx, account, requestedModel, precheckResult) {
return false
}
@@ -272,15 +284,17 @@ func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account
return false
}
// passesRateLimitPreCheck 执行速率限制预检。
// 返回 true 表示通过预检或无需预检。
//
// passesRateLimitPreCheck performs rate limit precheck.
// Returns true if passed or precheck not required.
func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool {
func (s *GeminiMessagesCompatService) passesRateLimitPreCheckWithCache(ctx context.Context, account *Account, requestedModel string, precheckResult map[int64]bool) bool {
if s.rateLimitService == nil || requestedModel == "" {
return true
}
if precheckResult != nil {
if ok, exists := precheckResult[account.ID]; exists {
return ok
}
}
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
@@ -302,6 +316,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
useMixedScheduling bool,
) *Account {
var selected *Account
precheckResult := s.buildPreCheckUsageResultMap(ctx, accounts, requestedModel)
for i := range accounts {
acc := &accounts[i]
@@ -312,7 +327,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
}
// 检查账号是否可用于当前请求
if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) {
if !s.isAccountUsableForRequestWithPrecheck(ctx, acc, requestedModel, platform, useMixedScheduling, precheckResult) {
continue
}
@@ -330,6 +345,23 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
return selected
}
func (s *GeminiMessagesCompatService) buildPreCheckUsageResultMap(ctx context.Context, accounts []Account, requestedModel string) map[int64]bool {
if s.rateLimitService == nil || requestedModel == "" || len(accounts) == 0 {
return nil
}
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
candidates = append(candidates, &accounts[i])
}
result, err := s.rateLimitService.PreCheckUsageBatch(ctx, candidates, requestedModel)
if err != nil {
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheckBatch] failed: %v", err)
}
return result
}
// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
// 规则优先级更高数值更小优先同优先级时未使用过的优先OAuth > 非 OAuth其次是最久未使用的。
//
@@ -399,7 +431,10 @@ func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Co
if groupID != nil {
return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
}
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
}
return s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, queryPlatforms)
}
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
@@ -2390,7 +2425,7 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
}
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
@@ -2415,8 +2450,8 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
}
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Status(resp.StatusCode)
@@ -2557,7 +2592,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.responseHeaderFilter)
if wwwAuthenticate != "" {
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
}

View File

@@ -138,6 +138,12 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
}
return m.ListSchedulableByPlatforms(ctx, platforms)
}
func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
return m.ListSchedulableByPlatform(ctx, platform)
}
func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
return m.ListSchedulableByPlatforms(ctx, platforms)
}
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}

View File

@@ -1045,7 +1045,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
ValidateResolvedIP: true,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
return "", fmt.Errorf("create http client failed: %w", err)
}
resp, err := client.Do(req)

View File

@@ -32,6 +32,9 @@ type Group struct {
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
// Sora 存储配额
SoraStorageQuotaBytes int64
// Claude Code 客户端限制
ClaudeCodeOnly bool
FallbackGroupID *int64

View File

@@ -46,6 +46,7 @@ type Fingerprint struct {
StainlessArch string
StainlessRuntime string
StainlessRuntimeVersion string
UpdatedAt int64 `json:",omitempty"` // Unix timestamp用于判断是否需要续期TTL
}
// IdentityCache defines cache operations for identity service
@@ -78,14 +79,26 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
// 尝试从缓存获取指纹
cached, err := s.cache.GetFingerprint(ctx, accountID)
if err == nil && cached != nil {
needWrite := false
// 检查客户端的user-agent是否是更新版本
clientUA := headers.Get("User-Agent")
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
// 更新user-agent
cached.UserAgent = clientUA
// 保存更新后的指纹
_ = s.cache.SetFingerprint(ctx, accountID, cached)
logger.LegacyPrintf("service.identity", "Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
// 版本升级merge 语义 — 仅更新请求中实际携带的字段,保留缓存值
// 避免缺失的头被硬编码默认值覆盖(如新 CLI 版本 + 旧 SDK 默认值的不一致)
mergeHeadersIntoFingerprint(cached, headers)
needWrite = true
logger.LegacyPrintf("service.identity", "Updated fingerprint for account %d: %s (merge update)", accountID, clientUA)
} else if time.Since(time.Unix(cached.UpdatedAt, 0)) > 24*time.Hour {
// 距上次写入超过24小时续期TTL
needWrite = true
}
if needWrite {
cached.UpdatedAt = time.Now().Unix()
if err := s.cache.SetFingerprint(ctx, accountID, cached); err != nil {
logger.LegacyPrintf("service.identity", "Warning: failed to refresh fingerprint for account %d: %v", accountID, err)
}
}
return cached, nil
}
@@ -95,8 +108,9 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
// 生成随机ClientID
fp.ClientID = generateClientID()
fp.UpdatedAt = time.Now().Unix()
// 保存到缓存(永不过期)
// 保存到缓存(7天TTL每24小时自动续期)
if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil {
logger.LegacyPrintf("service.identity", "Warning: failed to cache fingerprint for account %d: %v", accountID, err)
}
@@ -127,6 +141,31 @@ func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fin
return fp
}
// mergeHeadersIntoFingerprint 将请求头中实际存在的字段合并到现有指纹中(用于版本升级场景)
// 关键语义:请求中有的字段 → 用新值覆盖;缺失的头 → 保留缓存中的已有值
// 与 createFingerprintFromHeaders 的区别:后者用于首次创建,缺失头回退到 defaultFingerprint
// 本函数用于升级更新,缺失头保留缓存值,避免将已知的真实值退化为硬编码默认值
func mergeHeadersIntoFingerprint(fp *Fingerprint, headers http.Header) {
// User-Agent版本升级的触发条件一定存在
if ua := headers.Get("User-Agent"); ua != "" {
fp.UserAgent = ua
}
// X-Stainless-* 头:仅在请求中实际携带时才更新,否则保留缓存值
mergeHeader(headers, "X-Stainless-Lang", &fp.StainlessLang)
mergeHeader(headers, "X-Stainless-Package-Version", &fp.StainlessPackageVersion)
mergeHeader(headers, "X-Stainless-OS", &fp.StainlessOS)
mergeHeader(headers, "X-Stainless-Arch", &fp.StainlessArch)
mergeHeader(headers, "X-Stainless-Runtime", &fp.StainlessRuntime)
mergeHeader(headers, "X-Stainless-Runtime-Version", &fp.StainlessRuntimeVersion)
}
// mergeHeader 如果请求头中存在该字段则更新目标值,否则保留原值
func mergeHeader(headers http.Header, key string, target *string) {
if v := headers.Get(key); v != "" {
*target = v
}
}
// getHeaderOrDefault 获取header值如果不存在则返回默认值
func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
if v := headers.Get(key); v != "" {
@@ -371,8 +410,25 @@ func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
return major, minor, patch, true
}
// extractProduct 提取 User-Agent 中 "/" 前的产品名
// 例如claude-cli/2.1.22 (external, cli) -> "claude-cli"
func extractProduct(ua string) string {
if idx := strings.Index(ua, "/"); idx > 0 {
return strings.ToLower(ua[:idx])
}
return ""
}
// isNewerVersion 比较版本号判断newUA是否比cachedUA更新
// 要求产品名一致(防止浏览器 UA 如 Mozilla/5.0 误判为更新版本)
func isNewerVersion(newUA, cachedUA string) bool {
// 校验产品名一致性
newProduct := extractProduct(newUA)
cachedProduct := extractProduct(cachedUA)
if newProduct == "" || cachedProduct == "" || newProduct != cachedProduct {
return false
}
newMajor, newMinor, newPatch, newOk := parseUserAgentVersion(newUA)
cachedMajor, cachedMinor, cachedPatch, cachedOk := parseUserAgentVersion(cachedUA)

View File

@@ -4,8 +4,6 @@ import (
"context"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
)
const modelRateLimitsKey = "model_rate_limits"
@@ -73,7 +71,7 @@ func resolveFinalAntigravityModelKey(ctx context.Context, account *Account, requ
return ""
}
// thinking 会影响 Antigravity 最终模型名(例如 claude-sonnet-4-5 -> claude-sonnet-4-5-thinking
if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
if enabled, ok := ThinkingEnabledFromContext(ctx); ok {
modelKey = applyThinkingModelSuffix(modelKey, enabled)
}
return modelKey

View File

@@ -12,7 +12,7 @@ import (
// OpenAIOAuthClient interface for OpenAI OAuth operations
type OpenAIOAuthClient interface {
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error)
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error)
}

View File

@@ -14,10 +14,10 @@ import (
// --- mock: ClaudeOAuthClient ---
type mockClaudeOAuthClient struct {
getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error)
getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error)
refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error)
getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error)
refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
}
func (m *mockClaudeOAuthClient) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
@@ -437,9 +437,9 @@ func TestOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) {
// 无 refresh_token 的账号
account := &Account{
ID: 1,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
ID: 1,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "some-token",
},
@@ -460,9 +460,9 @@ func TestOAuthService_RefreshAccountToken_EmptyRefreshToken(t *testing.T) {
defer svc.Stop()
account := &Account{
ID: 2,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
ID: 2,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "some-token",
"refresh_token": "",

View File

@@ -0,0 +1,909 @@
package service
import (
"container/heap"
"context"
"errors"
"hash/fnv"
"math"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
const (
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
openAIAccountScheduleLayerSessionSticky = "session_hash"
openAIAccountScheduleLayerLoadBalance = "load_balance"
)
type OpenAIAccountScheduleRequest struct {
GroupID *int64
SessionHash string
StickyAccountID int64
PreviousResponseID string
RequestedModel string
RequiredTransport OpenAIUpstreamTransport
ExcludedIDs map[int64]struct{}
}
type OpenAIAccountScheduleDecision struct {
Layer string
StickyPreviousHit bool
StickySessionHit bool
CandidateCount int
TopK int
LatencyMs int64
LoadSkew float64
SelectedAccountID int64
SelectedAccountType string
}
type OpenAIAccountSchedulerMetricsSnapshot struct {
SelectTotal int64
StickyPreviousHitTotal int64
StickySessionHitTotal int64
LoadBalanceSelectTotal int64
AccountSwitchTotal int64
SchedulerLatencyMsTotal int64
SchedulerLatencyMsAvg float64
StickyHitRatio float64
AccountSwitchRate float64
LoadSkewAvg float64
RuntimeStatsAccountCount int
}
type OpenAIAccountScheduler interface {
Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error)
ReportResult(accountID int64, success bool, firstTokenMs *int)
ReportSwitch()
SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot
}
type openAIAccountSchedulerMetrics struct {
selectTotal atomic.Int64
stickyPreviousHitTotal atomic.Int64
stickySessionHitTotal atomic.Int64
loadBalanceSelectTotal atomic.Int64
accountSwitchTotal atomic.Int64
latencyMsTotal atomic.Int64
loadSkewMilliTotal atomic.Int64
}
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
if m == nil {
return
}
m.selectTotal.Add(1)
m.latencyMsTotal.Add(decision.LatencyMs)
m.loadSkewMilliTotal.Add(int64(math.Round(decision.LoadSkew * 1000)))
if decision.StickyPreviousHit {
m.stickyPreviousHitTotal.Add(1)
}
if decision.StickySessionHit {
m.stickySessionHitTotal.Add(1)
}
if decision.Layer == openAIAccountScheduleLayerLoadBalance {
m.loadBalanceSelectTotal.Add(1)
}
}
func (m *openAIAccountSchedulerMetrics) recordSwitch() {
if m == nil {
return
}
m.accountSwitchTotal.Add(1)
}
type openAIAccountRuntimeStats struct {
accounts sync.Map
accountCount atomic.Int64
}
type openAIAccountRuntimeStat struct {
errorRateEWMABits atomic.Uint64
ttftEWMABits atomic.Uint64
}
func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats {
return &openAIAccountRuntimeStats{}
}
func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat {
if value, ok := s.accounts.Load(accountID); ok {
stat, _ := value.(*openAIAccountRuntimeStat)
if stat != nil {
return stat
}
}
stat := &openAIAccountRuntimeStat{}
stat.ttftEWMABits.Store(math.Float64bits(math.NaN()))
actual, loaded := s.accounts.LoadOrStore(accountID, stat)
if !loaded {
s.accountCount.Add(1)
return stat
}
existing, _ := actual.(*openAIAccountRuntimeStat)
if existing != nil {
return existing
}
return stat
}
func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) {
for {
oldBits := target.Load()
oldValue := math.Float64frombits(oldBits)
newValue := alpha*sample + (1-alpha)*oldValue
if target.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
return
}
}
}
func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) {
if s == nil || accountID <= 0 {
return
}
const alpha = 0.2
stat := s.loadOrCreate(accountID)
errorSample := 1.0
if success {
errorSample = 0.0
}
updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha)
if firstTokenMs != nil && *firstTokenMs > 0 {
ttft := float64(*firstTokenMs)
ttftBits := math.Float64bits(ttft)
for {
oldBits := stat.ttftEWMABits.Load()
oldValue := math.Float64frombits(oldBits)
if math.IsNaN(oldValue) {
if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) {
break
}
continue
}
newValue := alpha*ttft + (1-alpha)*oldValue
if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
break
}
}
}
}
func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) {
if s == nil || accountID <= 0 {
return 0, 0, false
}
value, ok := s.accounts.Load(accountID)
if !ok {
return 0, 0, false
}
stat, _ := value.(*openAIAccountRuntimeStat)
if stat == nil {
return 0, 0, false
}
errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load()))
ttftValue := math.Float64frombits(stat.ttftEWMABits.Load())
if math.IsNaN(ttftValue) {
return errorRate, 0, false
}
return errorRate, ttftValue, true
}
func (s *openAIAccountRuntimeStats) size() int {
if s == nil {
return 0
}
return int(s.accountCount.Load())
}
type defaultOpenAIAccountScheduler struct {
service *OpenAIGatewayService
metrics openAIAccountSchedulerMetrics
stats *openAIAccountRuntimeStats
}
func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler {
if stats == nil {
stats = newOpenAIAccountRuntimeStats()
}
return &defaultOpenAIAccountScheduler{
service: service,
stats: stats,
}
}
func (s *defaultOpenAIAccountScheduler) Select(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
start := time.Now()
defer func() {
decision.LatencyMs = time.Since(start).Milliseconds()
s.metrics.recordSelect(decision)
}()
previousResponseID := strings.TrimSpace(req.PreviousResponseID)
if previousResponseID != "" {
selection, err := s.service.SelectAccountByPreviousResponseID(
ctx,
req.GroupID,
previousResponseID,
req.RequestedModel,
req.ExcludedIDs,
)
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) {
selection = nil
}
}
if selection != nil && selection.Account != nil {
decision.Layer = openAIAccountScheduleLayerPreviousResponse
decision.StickyPreviousHit = true
decision.SelectedAccountID = selection.Account.ID
decision.SelectedAccountType = selection.Account.Type
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, selection.Account.ID)
}
return selection, decision, nil
}
}
selection, err := s.selectBySessionHash(ctx, req)
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
decision.Layer = openAIAccountScheduleLayerSessionSticky
decision.StickySessionHit = true
decision.SelectedAccountID = selection.Account.ID
decision.SelectedAccountType = selection.Account.Type
return selection, decision, nil
}
selection, candidateCount, topK, loadSkew, err := s.selectByLoadBalance(ctx, req)
decision.Layer = openAIAccountScheduleLayerLoadBalance
decision.CandidateCount = candidateCount
decision.TopK = topK
decision.LoadSkew = loadSkew
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
decision.SelectedAccountID = selection.Account.ID
decision.SelectedAccountType = selection.Account.Type
}
return selection, decision, nil
}
func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, error) {
sessionHash := strings.TrimSpace(req.SessionHash)
if sessionHash == "" || s == nil || s.service == nil || s.service.cache == nil {
return nil, nil
}
accountID := req.StickyAccountID
if accountID <= 0 {
var err error
accountID, err = s.service.getStickySessionAccountID(ctx, req.GroupID, sessionHash)
if err != nil || accountID <= 0 {
return nil, nil
}
}
if accountID <= 0 {
return nil, nil
}
if req.ExcludedIDs != nil {
if _, excluded := req.ExcludedIDs[accountID]; excluded {
return nil, nil
}
}
account, err := s.service.getSchedulableAccount(ctx, accountID)
if err != nil || account == nil {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
return nil, nil
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
cfg := s.service.schedulingConfig()
if s.service.concurrencyService != nil {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
return nil, nil
}
type openAIAccountCandidateScore struct {
account *Account
loadInfo *AccountLoadInfo
score float64
errorRate float64
ttft float64
hasTTFT bool
}
type openAIAccountCandidateHeap []openAIAccountCandidateScore
func (h openAIAccountCandidateHeap) Len() int {
return len(h)
}
func (h openAIAccountCandidateHeap) Less(i, j int) bool {
// 最小堆根节点保存“最差”候选,便于 O(log k) 维护 topK。
return isOpenAIAccountCandidateBetter(h[j], h[i])
}
func (h openAIAccountCandidateHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
}
func (h *openAIAccountCandidateHeap) Push(x any) {
candidate, ok := x.(openAIAccountCandidateScore)
if !ok {
panic("openAIAccountCandidateHeap: invalid element type")
}
*h = append(*h, candidate)
}
func (h *openAIAccountCandidateHeap) Pop() any {
old := *h
n := len(old)
last := old[n-1]
*h = old[:n-1]
return last
}
func isOpenAIAccountCandidateBetter(left openAIAccountCandidateScore, right openAIAccountCandidateScore) bool {
if left.score != right.score {
return left.score > right.score
}
if left.account.Priority != right.account.Priority {
return left.account.Priority < right.account.Priority
}
if left.loadInfo.LoadRate != right.loadInfo.LoadRate {
return left.loadInfo.LoadRate < right.loadInfo.LoadRate
}
if left.loadInfo.WaitingCount != right.loadInfo.WaitingCount {
return left.loadInfo.WaitingCount < right.loadInfo.WaitingCount
}
return left.account.ID < right.account.ID
}
func selectTopKOpenAICandidates(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore {
if len(candidates) == 0 {
return nil
}
if topK <= 0 {
topK = 1
}
if topK >= len(candidates) {
ranked := append([]openAIAccountCandidateScore(nil), candidates...)
sort.Slice(ranked, func(i, j int) bool {
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
})
return ranked
}
best := make(openAIAccountCandidateHeap, 0, topK)
for _, candidate := range candidates {
if len(best) < topK {
heap.Push(&best, candidate)
continue
}
if isOpenAIAccountCandidateBetter(candidate, best[0]) {
best[0] = candidate
heap.Fix(&best, 0)
}
}
ranked := make([]openAIAccountCandidateScore, len(best))
copy(ranked, best)
sort.Slice(ranked, func(i, j int) bool {
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
})
return ranked
}
type openAISelectionRNG struct {
state uint64
}
func newOpenAISelectionRNG(seed uint64) openAISelectionRNG {
if seed == 0 {
seed = 0x9e3779b97f4a7c15
}
return openAISelectionRNG{state: seed}
}
func (r *openAISelectionRNG) nextUint64() uint64 {
// xorshift64*
x := r.state
x ^= x >> 12
x ^= x << 25
x ^= x >> 27
r.state = x
return x * 2685821657736338717
}
func (r *openAISelectionRNG) nextFloat64() float64 {
// [0,1)
return float64(r.nextUint64()>>11) / (1 << 53)
}
func deriveOpenAISelectionSeed(req OpenAIAccountScheduleRequest) uint64 {
hasher := fnv.New64a()
writeValue := func(value string) {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return
}
_, _ = hasher.Write([]byte(trimmed))
_, _ = hasher.Write([]byte{0})
}
writeValue(req.SessionHash)
writeValue(req.PreviousResponseID)
writeValue(req.RequestedModel)
if req.GroupID != nil {
_, _ = hasher.Write([]byte(strconv.FormatInt(*req.GroupID, 10)))
}
seed := hasher.Sum64()
// 对“无会话锚点”的纯负载均衡请求引入时间熵,避免固定命中同一账号。
if strings.TrimSpace(req.SessionHash) == "" && strings.TrimSpace(req.PreviousResponseID) == "" {
seed ^= uint64(time.Now().UnixNano())
}
if seed == 0 {
seed = uint64(time.Now().UnixNano()) ^ 0x9e3779b97f4a7c15
}
return seed
}
func buildOpenAIWeightedSelectionOrder(
candidates []openAIAccountCandidateScore,
req OpenAIAccountScheduleRequest,
) []openAIAccountCandidateScore {
if len(candidates) <= 1 {
return append([]openAIAccountCandidateScore(nil), candidates...)
}
pool := append([]openAIAccountCandidateScore(nil), candidates...)
weights := make([]float64, len(pool))
minScore := pool[0].score
for i := 1; i < len(pool); i++ {
if pool[i].score < minScore {
minScore = pool[i].score
}
}
for i := range pool {
// 将 top-K 分值平移到正区间,避免“单一最高分账号”长期垄断。
weight := (pool[i].score - minScore) + 1.0
if math.IsNaN(weight) || math.IsInf(weight, 0) || weight <= 0 {
weight = 1.0
}
weights[i] = weight
}
order := make([]openAIAccountCandidateScore, 0, len(pool))
rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
for len(pool) > 0 {
total := 0.0
for _, w := range weights {
total += w
}
selectedIdx := 0
if total > 0 {
r := rng.nextFloat64() * total
acc := 0.0
for i, w := range weights {
acc += w
if r <= acc {
selectedIdx = i
break
}
}
} else {
selectedIdx = int(rng.nextUint64() % uint64(len(pool)))
}
order = append(order, pool[selectedIdx])
pool = append(pool[:selectedIdx], pool[selectedIdx+1:]...)
weights = append(weights[:selectedIdx], weights[selectedIdx+1:]...)
}
return order
}
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, int, int, float64, error) {
accounts, err := s.service.listSchedulableAccounts(ctx, req.GroupID)
if err != nil {
return nil, 0, 0, 0, err
}
if len(accounts) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
}
filtered := make([]*Account, 0, len(accounts))
loadReq := make([]AccountWithConcurrency, 0, len(accounts))
for i := range accounts {
account := &accounts[i]
if req.ExcludedIDs != nil {
if _, excluded := req.ExcludedIDs[account.ID]; excluded {
continue
}
}
if !account.IsSchedulable() || !account.IsOpenAI() {
continue
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
continue
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
continue
}
filtered = append(filtered, account)
loadReq = append(loadReq, AccountWithConcurrency{
ID: account.ID,
MaxConcurrency: account.Concurrency,
})
}
if len(filtered) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
}
loadMap := map[int64]*AccountLoadInfo{}
if s.service.concurrencyService != nil {
if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil {
loadMap = batchLoad
}
}
minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
maxWaiting := 1
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
hasTTFTSample := false
candidates := make([]openAIAccountCandidateScore, 0, len(filtered))
for _, account := range filtered {
loadInfo := loadMap[account.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: account.ID}
}
if account.Priority < minPriority {
minPriority = account.Priority
}
if account.Priority > maxPriority {
maxPriority = account.Priority
}
if loadInfo.WaitingCount > maxWaiting {
maxWaiting = loadInfo.WaitingCount
}
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
if hasTTFT && ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = ttft, ttft
hasTTFTSample = true
} else {
if ttft < minTTFT {
minTTFT = ttft
}
if ttft > maxTTFT {
maxTTFT = ttft
}
}
}
loadRate := float64(loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
candidates = append(candidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
errorRate: errorRate,
ttft: ttft,
hasTTFT: hasTTFT,
})
}
loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
weights := s.service.openAIWSSchedulerWeights()
for i := range candidates {
item := &candidates[i]
priorityFactor := 1.0
if maxPriority > minPriority {
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
}
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
errorFactor := 1 - clamp01(item.errorRate)
ttftFactor := 0.5
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
}
item.score = weights.Priority*priorityFactor +
weights.Load*loadFactor +
weights.Queue*queueFactor +
weights.ErrorRate*errorFactor +
weights.TTFT*ttftFactor
}
topK := s.service.openAIWSLBTopK()
if topK > len(candidates) {
topK = len(candidates)
}
if topK <= 0 {
topK = 1
}
rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency)
if acquireErr != nil {
return nil, len(candidates), topK, loadSkew, acquireErr
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID)
}
return &AccountSelectionResult{
Account: candidate.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, len(candidates), topK, loadSkew, nil
}
}
cfg := s.service.schedulingConfig()
candidate := selectionOrder[0]
return &AccountSelectionResult{
Account: candidate.account,
WaitPlan: &AccountWaitPlan{
AccountID: candidate.account.ID,
MaxConcurrency: candidate.account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, len(candidates), topK, loadSkew, nil
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
// HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
return true
}
if s == nil || s.service == nil || account == nil {
return false
}
return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
}
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
if s == nil || s.stats == nil {
return
}
s.stats.report(accountID, success, firstTokenMs)
}
func (s *defaultOpenAIAccountScheduler) ReportSwitch() {
if s == nil {
return
}
s.metrics.recordSwitch()
}
func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot {
if s == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}
selectTotal := s.metrics.selectTotal.Load()
prevHit := s.metrics.stickyPreviousHitTotal.Load()
sessionHit := s.metrics.stickySessionHitTotal.Load()
switchTotal := s.metrics.accountSwitchTotal.Load()
latencyTotal := s.metrics.latencyMsTotal.Load()
loadSkewTotal := s.metrics.loadSkewMilliTotal.Load()
snapshot := OpenAIAccountSchedulerMetricsSnapshot{
SelectTotal: selectTotal,
StickyPreviousHitTotal: prevHit,
StickySessionHitTotal: sessionHit,
LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(),
AccountSwitchTotal: switchTotal,
SchedulerLatencyMsTotal: latencyTotal,
RuntimeStatsAccountCount: s.stats.size(),
}
if selectTotal > 0 {
snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal)
snapshot.StickyHitRatio = float64(prevHit+sessionHit) / float64(selectTotal)
snapshot.AccountSwitchRate = float64(switchTotal) / float64(selectTotal)
snapshot.LoadSkewAvg = float64(loadSkewTotal) / 1000 / float64(selectTotal)
}
return snapshot
}
func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
if s == nil {
return nil
}
s.openaiSchedulerOnce.Do(func() {
if s.openaiAccountStats == nil {
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
}
if s.openaiScheduler == nil {
s.openaiScheduler = newDefaultOpenAIAccountScheduler(s, s.openaiAccountStats)
}
})
return s.openaiScheduler
}
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
ctx context.Context,
groupID *int64,
previousResponseID string,
sessionHash string,
requestedModel string,
excludedIDs map[int64]struct{},
requiredTransport OpenAIUpstreamTransport,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance
return selection, decision, err
}
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 {
stickyAccountID = accountID
}
}
return scheduler.Select(ctx, OpenAIAccountScheduleRequest{
GroupID: groupID,
SessionHash: sessionHash,
StickyAccountID: stickyAccountID,
PreviousResponseID: previousResponseID,
RequestedModel: requestedModel,
RequiredTransport: requiredTransport,
ExcludedIDs: excludedIDs,
})
}
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
return
}
scheduler.ReportResult(accountID, success, firstTokenMs)
}
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
return
}
scheduler.ReportSwitch()
}
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}
return scheduler.SnapshotMetrics()
}
func (s *OpenAIGatewayService) openAIWSSessionStickyTTL() time.Duration {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 {
return time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second
}
return openaiStickySessionTTL
}
func (s *OpenAIGatewayService) openAIWSLBTopK() int {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 {
return s.cfg.Gateway.OpenAIWS.LBTopK
}
return 7
}
func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView {
if s != nil && s.cfg != nil {
return GatewayOpenAIWSSchedulerScoreWeightsView{
Priority: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority,
Load: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load,
Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue,
ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate,
TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT,
}
}
return GatewayOpenAIWSSchedulerScoreWeightsView{
Priority: 1.0,
Load: 1.0,
Queue: 0.7,
ErrorRate: 0.8,
TTFT: 0.5,
}
}
type GatewayOpenAIWSSchedulerScoreWeightsView struct {
Priority float64
Load float64
Queue float64
ErrorRate float64
TTFT float64
}
func clamp01(value float64) float64 {
switch {
case value < 0:
return 0
case value > 1:
return 1
default:
return value
}
}
func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 {
if count <= 1 {
return 0
}
mean := sum / float64(count)
variance := sumSquares/float64(count) - mean*mean
if variance < 0 {
variance = 0
}
return math.Sqrt(variance)
}

View File

@@ -0,0 +1,83 @@
package service
import (
"sort"
"testing"
)
func buildOpenAISchedulerBenchmarkCandidates(size int) []openAIAccountCandidateScore {
if size <= 0 {
return nil
}
candidates := make([]openAIAccountCandidateScore, 0, size)
for i := 0; i < size; i++ {
accountID := int64(10_000 + i)
candidates = append(candidates, openAIAccountCandidateScore{
account: &Account{
ID: accountID,
Priority: i % 7,
},
loadInfo: &AccountLoadInfo{
AccountID: accountID,
LoadRate: (i * 17) % 100,
WaitingCount: (i * 11) % 13,
},
score: float64((i*29)%1000) / 100,
errorRate: float64((i * 5) % 100 / 100),
ttft: float64(30 + (i*3)%500),
hasTTFT: i%3 != 0,
})
}
return candidates
}
func selectTopKOpenAICandidatesBySortBenchmark(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore {
if len(candidates) == 0 {
return nil
}
if topK <= 0 {
topK = 1
}
ranked := append([]openAIAccountCandidateScore(nil), candidates...)
sort.Slice(ranked, func(i, j int) bool {
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
})
if topK > len(ranked) {
topK = len(ranked)
}
return ranked[:topK]
}
func BenchmarkOpenAIAccountSchedulerSelectTopK(b *testing.B) {
cases := []struct {
name string
size int
topK int
}{
{name: "n_16_k_3", size: 16, topK: 3},
{name: "n_64_k_3", size: 64, topK: 3},
{name: "n_256_k_5", size: 256, topK: 5},
}
for _, tc := range cases {
candidates := buildOpenAISchedulerBenchmarkCandidates(tc.size)
b.Run(tc.name+"/heap_topk", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
result := selectTopKOpenAICandidates(candidates, tc.topK)
if len(result) == 0 {
b.Fatal("unexpected empty result")
}
}
})
b.Run(tc.name+"/full_sort", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
result := selectTopKOpenAICandidatesBySortBenchmark(candidates, tc.topK)
if len(result) == 0 {
b.Fatal("unexpected empty result")
}
}
})
}
}

View File

@@ -0,0 +1,841 @@
package service
import (
"context"
"fmt"
"math"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
ctx := context.Background()
groupID := int64(9)
account := Account{
ID: 1001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 2,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_001", account.ID, time.Hour))
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"resp_prev_001",
"session_hash_001",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, account.ID, selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer)
require.True(t, decision.StickyPreviousHit)
require.Equal(t, account.ID, cache.sessionBindings["openai:session_hash_001"])
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testing.T) {
ctx := context.Background()
groupID := int64(10)
account := Account{
ID: 2001,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_abc": account.ID,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"session_hash_abc",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, account.ID, selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
require.True(t, decision.StickySessionHit)
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsSticky(t *testing.T) {
ctx := context.Background()
groupID := int64(10100)
accounts := []Account{
{
ID: 21001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
{
ID: 21002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 9,
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_sticky_busy": 21001,
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2
cfg.Gateway.Scheduling.StickySessionWaitTimeout = 45 * time.Second
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
concurrencyCache := stubConcurrencyCache{
acquireResults: map[int64]bool{
21001: false, // sticky 账号已满
21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换)
},
waitCounts: map[int64]int{
21001: 999,
},
loadMap: map[int64]*AccountLoadInfo{
21001: {AccountID: 21001, LoadRate: 90, WaitingCount: 9},
21002: {AccountID: 21002, LoadRate: 1, WaitingCount: 0},
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"session_hash_sticky_busy",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(21001), selection.Account.ID, "busy sticky account should remain selected")
require.False(t, selection.Acquired)
require.NotNil(t, selection.WaitPlan)
require.Equal(t, int64(21001), selection.WaitPlan.AccountID)
require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
require.True(t, decision.StickySessionHit)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP(t *testing.T) {
ctx := context.Background()
groupID := int64(1010)
account := Account{
ID: 2101,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Extra: map[string]any{
"openai_ws_force_http": true,
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_force_http": account.ID,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"session_hash_force_http",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, account.ID, selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
require.True(t, decision.StickySessionHit)
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStickyHTTPAccount(t *testing.T) {
ctx := context.Background()
groupID := int64(1011)
accounts := []Account{
{
ID: 2201,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
{
ID: 2202,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_ws_only": 2201,
},
}
cfg := newOpenAIWSV2TestConfig()
// 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
concurrencyCache := stubConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0},
2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5},
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"session_hash_ws_only",
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(2202), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
require.False(t, decision.StickySessionHit)
require.Equal(t, 1, decision.CandidateCount)
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailableAccount(t *testing.T) {
ctx := context.Background()
groupID := int64(1012)
accounts := []Account{
{
ID: 2301,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{},
cfg: newOpenAIWSV2TestConfig(),
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
)
require.Error(t, err)
require.Nil(t, selection)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
require.Equal(t, 0, decision.CandidateCount)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback(t *testing.T) {
ctx := context.Background()
groupID := int64(11)
accounts := []Account{
{
ID: 3001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
{
ID: 3002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
{
ID: 3003,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.LBTopK = 2
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.4
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1
concurrencyCache := stubConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8},
3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1},
3003: {AccountID: 3003, LoadRate: 10, WaitingCount: 0},
},
acquireResults: map[int64]bool{
3003: false, // top1 失败,必须回退到 top-K 的下一候选
3002: true,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(3002), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
require.Equal(t, 3, decision.CandidateCount)
require.Equal(t, 2, decision.TopK)
require.Greater(t, decision.LoadSkew, 0.0)
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
ctx := context.Background()
groupID := int64(12)
account := Account{
ID: 4001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_metrics": account.ID,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
require.NotNil(t, selection)
svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120))
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
require.GreaterOrEqual(t, snapshot.SelectTotal, int64(1))
require.GreaterOrEqual(t, snapshot.StickySessionHitTotal, int64(1))
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
require.GreaterOrEqual(t, snapshot.SchedulerLatencyMsAvg, float64(0))
require.GreaterOrEqual(t, snapshot.StickyHitRatio, 0.0)
require.GreaterOrEqual(t, snapshot.RuntimeStatsAccountCount, 1)
}
func intPtrForTest(v int) *int {
return &v
}
func TestOpenAIAccountRuntimeStats_ReportAndSnapshot(t *testing.T) {
stats := newOpenAIAccountRuntimeStats()
stats.report(1001, true, nil)
firstTTFT := 100
stats.report(1001, false, &firstTTFT)
secondTTFT := 200
stats.report(1001, false, &secondTTFT)
errorRate, ttft, hasTTFT := stats.snapshot(1001)
require.True(t, hasTTFT)
require.InDelta(t, 0.36, errorRate, 1e-9)
require.InDelta(t, 120.0, ttft, 1e-9)
require.Equal(t, 1, stats.size())
}
func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) {
stats := newOpenAIAccountRuntimeStats()
const (
accountCount = 4
workers = 16
iterations = 800
)
var wg sync.WaitGroup
wg.Add(workers)
for worker := 0; worker < workers; worker++ {
worker := worker
go func() {
defer wg.Done()
for i := 0; i < iterations; i++ {
accountID := int64(i%accountCount + 1)
success := (i+worker)%3 != 0
ttft := 80 + (i+worker)%40
stats.report(accountID, success, &ttft)
}
}()
}
wg.Wait()
require.Equal(t, accountCount, stats.size())
for accountID := int64(1); accountID <= accountCount; accountID++ {
errorRate, ttft, hasTTFT := stats.snapshot(accountID)
require.GreaterOrEqual(t, errorRate, 0.0)
require.LessOrEqual(t, errorRate, 1.0)
require.True(t, hasTTFT)
require.Greater(t, ttft, 0.0)
}
}
func TestSelectTopKOpenAICandidates(t *testing.T) {
candidates := []openAIAccountCandidateScore{
{
account: &Account{ID: 11, Priority: 2},
loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 1},
score: 10.0,
},
{
account: &Account{ID: 12, Priority: 1},
loadInfo: &AccountLoadInfo{LoadRate: 20, WaitingCount: 1},
score: 9.5,
},
{
account: &Account{ID: 13, Priority: 1},
loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 0},
score: 10.0,
},
{
account: &Account{ID: 14, Priority: 0},
loadInfo: &AccountLoadInfo{LoadRate: 40, WaitingCount: 0},
score: 8.0,
},
}
top2 := selectTopKOpenAICandidates(candidates, 2)
require.Len(t, top2, 2)
require.Equal(t, int64(13), top2[0].account.ID)
require.Equal(t, int64(11), top2[1].account.ID)
topAll := selectTopKOpenAICandidates(candidates, 8)
require.Len(t, topAll, len(candidates))
require.Equal(t, int64(13), topAll[0].account.ID)
require.Equal(t, int64(11), topAll[1].account.ID)
require.Equal(t, int64(12), topAll[2].account.ID)
require.Equal(t, int64(14), topAll[3].account.ID)
}
func TestBuildOpenAIWeightedSelectionOrder_DeterministicBySessionSeed(t *testing.T) {
candidates := []openAIAccountCandidateScore{
{
account: &Account{ID: 101},
loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 0},
score: 4.2,
},
{
account: &Account{ID: 102},
loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 1},
score: 3.5,
},
{
account: &Account{ID: 103},
loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 2},
score: 2.1,
},
}
req := OpenAIAccountScheduleRequest{
GroupID: int64PtrForTest(99),
SessionHash: "session_seed_fixed",
RequestedModel: "gpt-5.1",
}
first := buildOpenAIWeightedSelectionOrder(candidates, req)
second := buildOpenAIWeightedSelectionOrder(candidates, req)
require.Len(t, first, len(candidates))
require.Len(t, second, len(candidates))
for i := range first {
require.Equal(t, first[i].account.ID, second[i].account.ID)
}
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesAcrossSessions(t *testing.T) {
ctx := context.Background()
groupID := int64(15)
accounts := []Account{
{
ID: 5101,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 3,
Priority: 0,
},
{
ID: 5102,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 3,
Priority: 0,
},
{
ID: 5103,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 3,
Priority: 0,
},
}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.LBTopK = 3
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
concurrencyCache := stubConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1},
5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1},
5103: {AccountID: 5103, LoadRate: 20, WaitingCount: 1},
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selected := make(map[int64]int, len(accounts))
for i := 0; i < 60; i++ {
sessionHash := fmt.Sprintf("session_hash_lb_%d", i)
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
sessionHash,
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
selected[selection.Account.ID]++
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
// 多 session 应该能打散到多个账号,避免“恒定单账号命中”。
require.GreaterOrEqual(t, len(selected), 2)
}
func TestDeriveOpenAISelectionSeed_NoAffinityAddsEntropy(t *testing.T) {
req := OpenAIAccountScheduleRequest{
RequestedModel: "gpt-5.1",
}
seed1 := deriveOpenAISelectionSeed(req)
time.Sleep(1 * time.Millisecond)
seed2 := deriveOpenAISelectionSeed(req)
require.NotZero(t, seed1)
require.NotZero(t, seed2)
require.NotEqual(t, seed1, seed2)
}
func TestBuildOpenAIWeightedSelectionOrder_HandlesInvalidScores(t *testing.T) {
candidates := []openAIAccountCandidateScore{
{
account: &Account{ID: 901},
loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
score: math.NaN(),
},
{
account: &Account{ID: 902},
loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
score: math.Inf(1),
},
{
account: &Account{ID: 903},
loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
score: -1,
},
}
req := OpenAIAccountScheduleRequest{
SessionHash: "seed_invalid_scores",
}
order := buildOpenAIWeightedSelectionOrder(candidates, req)
require.Len(t, order, len(candidates))
seen := map[int64]struct{}{}
for _, item := range order {
seen[item.account.ID] = struct{}{}
}
require.Len(t, seen, len(candidates))
}
func TestOpenAISelectionRNG_SeedZeroStillWorks(t *testing.T) {
rng := newOpenAISelectionRNG(0)
v1 := rng.nextUint64()
v2 := rng.nextUint64()
require.NotEqual(t, v1, v2)
require.GreaterOrEqual(t, rng.nextFloat64(), 0.0)
require.Less(t, rng.nextFloat64(), 1.0)
}
func TestOpenAIAccountCandidateHeap_PushPopAndInvalidType(t *testing.T) {
h := openAIAccountCandidateHeap{}
h.Push(openAIAccountCandidateScore{
account: &Account{ID: 7001},
loadInfo: &AccountLoadInfo{LoadRate: 0, WaitingCount: 0},
score: 1.0,
})
require.Equal(t, 1, h.Len())
popped, ok := h.Pop().(openAIAccountCandidateScore)
require.True(t, ok)
require.Equal(t, int64(7001), popped.account.ID)
require.Equal(t, 0, h.Len())
require.Panics(t, func() {
h.Push("bad_element_type")
})
}
func TestClamp01_AllBranches(t *testing.T) {
require.Equal(t, 0.0, clamp01(-0.2))
require.Equal(t, 1.0, clamp01(1.3))
require.Equal(t, 0.5, clamp01(0.5))
}
func TestCalcLoadSkewByMoments_Branches(t *testing.T) {
require.Equal(t, 0.0, calcLoadSkewByMoments(1, 1, 1))
// variance < 0 分支sumSquares/count - mean^2 为负值时应钳制为 0。
require.Equal(t, 0.0, calcLoadSkewByMoments(1, 0, 2))
require.GreaterOrEqual(t, calcLoadSkewByMoments(6, 20, 3), 0.0)
}
func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
schedulerAny := newDefaultOpenAIAccountScheduler(&OpenAIGatewayService{}, nil)
scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
require.True(t, ok)
ttft := 100
scheduler.ReportResult(1001, true, &ttft)
scheduler.ReportSwitch()
scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{
Layer: openAIAccountScheduleLayerLoadBalance,
LatencyMs: 8,
LoadSkew: 0.5,
StickyPreviousHit: true,
})
scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{
Layer: openAIAccountScheduleLayerSessionSticky,
LatencyMs: 6,
LoadSkew: 0.2,
StickySessionHit: true,
})
snapshot := scheduler.SnapshotMetrics()
require.Equal(t, int64(2), snapshot.SelectTotal)
require.Equal(t, int64(1), snapshot.StickyPreviousHitTotal)
require.Equal(t, int64(1), snapshot.StickySessionHitTotal)
require.Equal(t, int64(1), snapshot.LoadBalanceSelectTotal)
require.Equal(t, int64(1), snapshot.AccountSwitchTotal)
require.Greater(t, snapshot.SchedulerLatencyMsAvg, 0.0)
require.Greater(t, snapshot.StickyHitRatio, 0.0)
require.Greater(t, snapshot.LoadSkewAvg, 0.0)
}
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
svc := &OpenAIGatewayService{}
ttft := 120
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
require.Equal(t, 7, svc.openAIWSLBTopK())
require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
defaultWeights := svc.openAIWSSchedulerWeights()
require.Equal(t, 1.0, defaultWeights.Priority)
require.Equal(t, 1.0, defaultWeights.Load)
require.Equal(t, 0.7, defaultWeights.Queue)
require.Equal(t, 0.8, defaultWeights.ErrorRate)
require.Equal(t, 0.5, defaultWeights.TTFT)
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.LBTopK = 9
cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 180
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.2
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0.3
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.4
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.5
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.6
svcWithCfg := &OpenAIGatewayService{cfg: cfg}
require.Equal(t, 9, svcWithCfg.openAIWSLBTopK())
require.Equal(t, 180*time.Second, svcWithCfg.openAIWSSessionStickyTTL())
customWeights := svcWithCfg.openAIWSSchedulerWeights()
require.Equal(t, 0.2, customWeights.Priority)
require.Equal(t, 0.3, customWeights.Load)
require.Equal(t, 0.4, customWeights.Queue)
require.Equal(t, 0.5, customWeights.ErrorRate)
require.Equal(t, 0.6, customWeights.TTFT)
}
func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *testing.T) {
scheduler := &defaultOpenAIAccountScheduler{}
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportAny))
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE))
require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2))
cfg := newOpenAIWSV2TestConfig()
scheduler.service = &OpenAIGatewayService{cfg: cfg}
account := &Account{
ID: 8801,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
require.True(t, scheduler.isAccountTransportCompatible(account, OpenAIUpstreamTransportResponsesWebsocketV2))
}
func int64PtrForTest(v int64) *int64 {
return &v
}

View File

@@ -0,0 +1,71 @@
package service
import (
"strings"
"github.com/gin-gonic/gin"
)
// OpenAIClientTransport 表示客户端入站协议类型。
type OpenAIClientTransport string
const (
OpenAIClientTransportUnknown OpenAIClientTransport = ""
OpenAIClientTransportHTTP OpenAIClientTransport = "http"
OpenAIClientTransportWS OpenAIClientTransport = "ws"
)
const openAIClientTransportContextKey = "openai_client_transport"
// SetOpenAIClientTransport 标记当前请求的客户端入站协议。
func SetOpenAIClientTransport(c *gin.Context, transport OpenAIClientTransport) {
if c == nil {
return
}
normalized := normalizeOpenAIClientTransport(transport)
if normalized == OpenAIClientTransportUnknown {
return
}
c.Set(openAIClientTransportContextKey, string(normalized))
}
// GetOpenAIClientTransport 读取当前请求的客户端入站协议。
func GetOpenAIClientTransport(c *gin.Context) OpenAIClientTransport {
if c == nil {
return OpenAIClientTransportUnknown
}
raw, ok := c.Get(openAIClientTransportContextKey)
if !ok || raw == nil {
return OpenAIClientTransportUnknown
}
switch v := raw.(type) {
case OpenAIClientTransport:
return normalizeOpenAIClientTransport(v)
case string:
return normalizeOpenAIClientTransport(OpenAIClientTransport(v))
default:
return OpenAIClientTransportUnknown
}
}
func normalizeOpenAIClientTransport(transport OpenAIClientTransport) OpenAIClientTransport {
switch strings.ToLower(strings.TrimSpace(string(transport))) {
case string(OpenAIClientTransportHTTP), "http_sse", "sse":
return OpenAIClientTransportHTTP
case string(OpenAIClientTransportWS), "websocket":
return OpenAIClientTransportWS
default:
return OpenAIClientTransportUnknown
}
}
func resolveOpenAIWSDecisionByClientTransport(
decision OpenAIWSProtocolDecision,
clientTransport OpenAIClientTransport,
) OpenAIWSProtocolDecision {
if clientTransport == OpenAIClientTransportHTTP {
return openAIWSHTTPDecision("client_protocol_http")
}
return decision
}

View File

@@ -0,0 +1,107 @@
package service
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestOpenAIClientTransport_SetAndGet(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(c))
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c))
SetOpenAIClientTransport(c, OpenAIClientTransportWS)
require.Equal(t, OpenAIClientTransportWS, GetOpenAIClientTransport(c))
}
func TestOpenAIClientTransport_GetNormalizesRawContextValue(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
rawValue any
want OpenAIClientTransport
}{
{
name: "type_value_ws",
rawValue: OpenAIClientTransportWS,
want: OpenAIClientTransportWS,
},
{
name: "http_sse_alias",
rawValue: "http_sse",
want: OpenAIClientTransportHTTP,
},
{
name: "sse_alias",
rawValue: "sSe",
want: OpenAIClientTransportHTTP,
},
{
name: "websocket_alias",
rawValue: "WebSocket",
want: OpenAIClientTransportWS,
},
{
name: "invalid_string",
rawValue: "tcp",
want: OpenAIClientTransportUnknown,
},
{
name: "invalid_type",
rawValue: 123,
want: OpenAIClientTransportUnknown,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Set(openAIClientTransportContextKey, tt.rawValue)
require.Equal(t, tt.want, GetOpenAIClientTransport(c))
})
}
}
func TestOpenAIClientTransport_NilAndUnknownInput(t *testing.T) {
SetOpenAIClientTransport(nil, OpenAIClientTransportHTTP)
require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(nil))
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
SetOpenAIClientTransport(c, OpenAIClientTransportUnknown)
_, exists := c.Get(openAIClientTransportContextKey)
require.False(t, exists)
SetOpenAIClientTransport(c, OpenAIClientTransport(" "))
_, exists = c.Get(openAIClientTransportContextKey)
require.False(t, exists)
}
func TestResolveOpenAIWSDecisionByClientTransport(t *testing.T) {
base := OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
Reason: "ws_v2_enabled",
}
httpDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportHTTP)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, httpDecision.Transport)
require.Equal(t, "client_protocol_http", httpDecision.Reason)
wsDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportWS)
require.Equal(t, base, wsDecision)
unknownDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportUnknown)
require.Equal(t, base, unknownDecision)
}

File diff suppressed because it is too large Load Diff

View File

@@ -123,3 +123,19 @@ func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "parse request")
}
func TestGetOpenAIRequestBodyMap_WriteBackContextCache(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
got, err := getOpenAIRequestBodyMap(c, []byte(`{"model":"gpt-5","stream":true}`))
require.NoError(t, err)
require.Equal(t, "gpt-5", got["model"])
cached, ok := c.Get(OpenAIParsedRequestBodyKey)
require.True(t, ok)
cachedMap, ok := cached.(map[string]any)
require.True(t, ok)
require.Equal(t, got, cachedMap)
}

View File

@@ -5,6 +5,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
@@ -13,6 +14,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/cespare/xxhash/v2"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -55,6 +57,10 @@ func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, pl
return result, nil
}
func (r stubOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
return r.ListSchedulableByPlatform(ctx, platform)
}
type stubConcurrencyCache struct {
ConcurrencyCache
loadBatchErr error
@@ -166,6 +172,54 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
}
}
func TestOpenAIGatewayService_GenerateSessionHash_UsesXXHash64(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("session_id", "sess-fixed-value")
svc := &OpenAIGatewayService{}
got := svc.GenerateSessionHash(c, nil)
want := fmt.Sprintf("%016x", xxhash.Sum64String("sess-fixed-value"))
require.Equal(t, want, got)
}
func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("session_id", "sess-legacy-check")
svc := &OpenAIGatewayService{}
sessionHash := svc.GenerateSessionHash(c, nil)
require.NotEmpty(t, sessionHash)
require.NotNil(t, c.Request)
require.NotNil(t, c.Request.Context())
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
}
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
svc := &OpenAIGatewayService{}
seed := "openai_ws_ingress:9:100:200"
got := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), seed)
want := fmt.Sprintf("%016x", xxhash.Sum64String(seed))
require.Equal(t, want, got)
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
empty := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), " ")
require.Equal(t, "", empty)
}
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if c.waitCounts != nil {
if count, ok := c.waitCounts[accountID]; ok {

View File

@@ -0,0 +1,357 @@
package service
import (
"encoding/json"
"strconv"
"strings"
"testing"
"github.com/tidwall/gjson"
)
var (
benchmarkToolContinuationBoolSink bool
benchmarkWSParseStringSink string
benchmarkWSParseMapSink map[string]any
benchmarkUsageSink OpenAIUsage
)
func BenchmarkToolContinuationValidationLegacy(b *testing.B) {
reqBody := benchmarkToolContinuationRequestBody()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkToolContinuationBoolSink = legacyValidateFunctionCallOutputContext(reqBody)
}
}
func BenchmarkToolContinuationValidationOptimized(b *testing.B) {
reqBody := benchmarkToolContinuationRequestBody()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkToolContinuationBoolSink = optimizedValidateFunctionCallOutputContext(reqBody)
}
}
func BenchmarkWSIngressPayloadParseLegacy(b *testing.B) {
raw := benchmarkWSIngressPayloadBytes()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
eventType, model, promptCacheKey, previousResponseID, payload, err := legacyParseWSIngressPayload(raw)
if err == nil {
benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID
benchmarkWSParseMapSink = payload
}
}
}
func BenchmarkWSIngressPayloadParseOptimized(b *testing.B) {
raw := benchmarkWSIngressPayloadBytes()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
eventType, model, promptCacheKey, previousResponseID, payload, err := optimizedParseWSIngressPayload(raw)
if err == nil {
benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID
benchmarkWSParseMapSink = payload
}
}
}
func BenchmarkOpenAIUsageExtractLegacy(b *testing.B) {
body := benchmarkOpenAIUsageJSONBytes()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
usage, ok := legacyExtractOpenAIUsageFromJSONBytes(body)
if ok {
benchmarkUsageSink = usage
}
}
}
func BenchmarkOpenAIUsageExtractOptimized(b *testing.B) {
body := benchmarkOpenAIUsageJSONBytes()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
usage, ok := extractOpenAIUsageFromJSONBytes(body)
if ok {
benchmarkUsageSink = usage
}
}
}
func benchmarkToolContinuationRequestBody() map[string]any {
input := make([]any, 0, 64)
for i := 0; i < 24; i++ {
input = append(input, map[string]any{
"type": "text",
"text": "benchmark text",
})
}
for i := 0; i < 10; i++ {
callID := "call_" + strconv.Itoa(i)
input = append(input, map[string]any{
"type": "tool_call",
"call_id": callID,
})
input = append(input, map[string]any{
"type": "function_call_output",
"call_id": callID,
})
input = append(input, map[string]any{
"type": "item_reference",
"id": callID,
})
}
return map[string]any{
"model": "gpt-5.3-codex",
"input": input,
}
}
func benchmarkWSIngressPayloadBytes() []byte {
return []byte(`{"type":"response.create","model":"gpt-5.3-codex","prompt_cache_key":"cache_bench","previous_response_id":"resp_prev_bench","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`)
}
func benchmarkOpenAIUsageJSONBytes() []byte {
return []byte(`{"id":"resp_bench","object":"response","model":"gpt-5.3-codex","usage":{"input_tokens":3210,"output_tokens":987,"input_tokens_details":{"cached_tokens":456}}}`)
}
func legacyValidateFunctionCallOutputContext(reqBody map[string]any) bool {
if !legacyHasFunctionCallOutput(reqBody) {
return true
}
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) != "" {
return true
}
if legacyHasToolCallContext(reqBody) {
return true
}
if legacyHasFunctionCallOutputMissingCallID(reqBody) {
return false
}
callIDs := legacyFunctionCallOutputCallIDs(reqBody)
return legacyHasItemReferenceForCallIDs(reqBody, callIDs)
}
func optimizedValidateFunctionCallOutputContext(reqBody map[string]any) bool {
validation := ValidateFunctionCallOutputContext(reqBody)
if !validation.HasFunctionCallOutput {
return true
}
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) != "" {
return true
}
if validation.HasToolCallContext {
return true
}
if validation.HasFunctionCallOutputMissingCallID {
return false
}
return validation.HasItemReferenceForAllCallIDs
}
func legacyHasFunctionCallOutput(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType == "function_call_output" {
return true
}
}
return false
}
func legacyHasToolCallContext(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "tool_call" && itemType != "function_call" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
return true
}
}
return false
}
func legacyFunctionCallOutputCallIDs(reqBody map[string]any) []string {
if reqBody == nil {
return nil
}
input, ok := reqBody["input"].([]any)
if !ok {
return nil
}
ids := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
ids[callID] = struct{}{}
}
}
if len(ids) == 0 {
return nil
}
callIDs := make([]string, 0, len(ids))
for id := range ids {
callIDs = append(callIDs, id)
}
return callIDs
}
func legacyHasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) == "" {
return true
}
}
return false
}
func legacyHasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
if reqBody == nil || len(callIDs) == 0 {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
referenceIDs := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "item_reference" {
continue
}
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
referenceIDs[idValue] = struct{}{}
}
if len(referenceIDs) == 0 {
return false
}
for _, callID := range callIDs {
if _, ok := referenceIDs[callID]; !ok {
return false
}
}
return true
}
func legacyParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) {
values := gjson.GetManyBytes(raw, "type", "model", "prompt_cache_key", "previous_response_id")
eventType = strings.TrimSpace(values[0].String())
if eventType == "" {
eventType = "response.create"
}
model = strings.TrimSpace(values[1].String())
promptCacheKey = strings.TrimSpace(values[2].String())
previousResponseID = strings.TrimSpace(values[3].String())
payload = make(map[string]any)
if err = json.Unmarshal(raw, &payload); err != nil {
return "", "", "", "", nil, err
}
if _, exists := payload["type"]; !exists {
payload["type"] = "response.create"
}
return eventType, model, promptCacheKey, previousResponseID, payload, nil
}
func optimizedParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) {
payload = make(map[string]any)
if err = json.Unmarshal(raw, &payload); err != nil {
return "", "", "", "", nil, err
}
eventType = openAIWSPayloadString(payload, "type")
if eventType == "" {
eventType = "response.create"
payload["type"] = eventType
}
model = openAIWSPayloadString(payload, "model")
promptCacheKey = openAIWSPayloadString(payload, "prompt_cache_key")
previousResponseID = openAIWSPayloadString(payload, "previous_response_id")
return eventType, model, promptCacheKey, previousResponseID, payload, nil
}
func legacyExtractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
var response struct {
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
} `json:"input_tokens_details"`
} `json:"usage"`
}
if err := json.Unmarshal(body, &response); err != nil {
return OpenAIUsage{}, false
}
return OpenAIUsage{
InputTokens: response.Usage.InputTokens,
OutputTokens: response.Usage.OutputTokens,
CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens,
}, true
}

View File

@@ -515,7 +515,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te
require.NoError(t, err)
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent"))
require.Equal(t, "codex_cli_rs/0.104.0", upstream.lastReq.Header.Get("User-Agent"))
}
func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) {

View File

@@ -5,17 +5,28 @@ import (
"crypto/subtle"
"encoding/json"
"io"
"log/slog"
"net/http"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`)
type soraSessionChunk struct {
index int
value string
}
// OpenAIOAuthService handles OpenAI OAuth authentication flows
type OpenAIOAuthService struct {
sessionStore *openai.SessionStore
@@ -39,7 +50,7 @@ type OpenAIAuthURLResult struct {
}
// GenerateAuthURL generates an OpenAI OAuth authorization URL
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*OpenAIAuthURLResult, error) {
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, platform string) (*OpenAIAuthURLResult, error) {
// Generate PKCE values
state, err := openai.GenerateState()
if err != nil {
@@ -75,11 +86,14 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
}
normalizedPlatform := normalizeOpenAIOAuthPlatform(platform)
clientID, _ := openai.OAuthClientConfigByPlatform(normalizedPlatform)
// Store session
session := &openai.OAuthSession{
State: state,
CodeVerifier: codeVerifier,
ClientID: clientID,
RedirectURI: redirectURI,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
@@ -87,7 +101,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
s.sessionStore.Set(sessionID, session)
// Build authorization URL
authURL := openai.BuildAuthorizationURL(state, codeChallenge, redirectURI)
authURL := openai.BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, normalizedPlatform)
return &OpenAIAuthURLResult{
AuthURL: authURL,
@@ -111,6 +125,7 @@ type OpenAITokenInfo struct {
IDToken string `json:"id_token,omitempty"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
ClientID string `json:"client_id,omitempty"`
Email string `json:"email,omitempty"`
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
@@ -148,9 +163,13 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
if input.RedirectURI != "" {
redirectURI = input.RedirectURI
}
clientID := strings.TrimSpace(session.ClientID)
if clientID == "" {
clientID = openai.ClientID
}
// Exchange code for token
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL, clientID)
if err != nil {
return nil, err
}
@@ -158,8 +177,10 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// Parse ID token to get user info
var userInfo *openai.UserInfo
if tokenResp.IDToken != "" {
claims, err := openai.ParseIDToken(tokenResp.IDToken)
if err == nil {
claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
if parseErr != nil {
slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
} else {
userInfo = claims.GetUserInfo()
}
}
@@ -173,6 +194,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
IDToken: tokenResp.IDToken,
ExpiresIn: int64(tokenResp.ExpiresIn),
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
ClientID: clientID,
}
if userInfo != nil {
@@ -200,8 +222,10 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
// Parse ID token to get user info
var userInfo *openai.UserInfo
if tokenResp.IDToken != "" {
claims, err := openai.ParseIDToken(tokenResp.IDToken)
if err == nil {
claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
if parseErr != nil {
slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
} else {
userInfo = claims.GetUserInfo()
}
}
@@ -213,6 +237,9 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
ExpiresIn: int64(tokenResp.ExpiresIn),
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
}
if trimmed := strings.TrimSpace(clientID); trimmed != "" {
tokenInfo.ClientID = trimmed
}
if userInfo != nil {
tokenInfo.Email = userInfo.Email
@@ -226,6 +253,7 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
sessionToken = normalizeSoraSessionTokenInput(sessionToken)
if strings.TrimSpace(sessionToken) == "" {
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
}
@@ -245,7 +273,13 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi
req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
client := newOpenAIOAuthHTTPClient(proxyURL)
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
})
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
@@ -287,10 +321,141 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
ExpiresIn: expiresIn,
ExpiresAt: expiresAt,
ClientID: openai.SoraClientID,
Email: strings.TrimSpace(sessionResp.User.Email),
}, nil
}
func normalizeSoraSessionTokenInput(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return ""
}
matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1)
if len(matches) == 0 {
return sanitizeSessionToken(trimmed)
}
chunkMatches := make([]soraSessionChunk, 0, len(matches))
singleValues := make([]string, 0, len(matches))
for _, match := range matches {
if len(match) < 3 {
continue
}
value := sanitizeSessionToken(match[2])
if value == "" {
continue
}
if strings.TrimSpace(match[1]) == "" {
singleValues = append(singleValues, value)
continue
}
idx, err := strconv.Atoi(strings.TrimSpace(match[1]))
if err != nil || idx < 0 {
continue
}
chunkMatches = append(chunkMatches, soraSessionChunk{
index: idx,
value: value,
})
}
if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" {
return merged
}
if len(singleValues) > 0 {
return singleValues[len(singleValues)-1]
}
return ""
}
func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string {
if len(chunks) == 0 {
return ""
}
byIndex := make(map[int]string, len(chunks))
for _, chunk := range chunks {
byIndex[chunk.index] = chunk.value
}
if _, ok := byIndex[0]; !ok {
return ""
}
if requireComplete {
for idx := 0; idx <= requiredMaxIndex; idx++ {
if _, ok := byIndex[idx]; !ok {
return ""
}
}
}
orderedIndexes := make([]int, 0, len(byIndex))
for idx := range byIndex {
orderedIndexes = append(orderedIndexes, idx)
}
sort.Ints(orderedIndexes)
var builder strings.Builder
for _, idx := range orderedIndexes {
if _, err := builder.WriteString(byIndex[idx]); err != nil {
return ""
}
}
return sanitizeSessionToken(builder.String())
}
func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string {
if len(chunks) == 0 {
return ""
}
requiredMaxIndex := 0
for _, chunk := range chunks {
if chunk.index > requiredMaxIndex {
requiredMaxIndex = chunk.index
}
}
groupStarts := make([]int, 0, len(chunks))
for idx, chunk := range chunks {
if chunk.index == 0 {
groupStarts = append(groupStarts, idx)
}
}
if len(groupStarts) == 0 {
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
}
for i := len(groupStarts) - 1; i >= 0; i-- {
start := groupStarts[i]
end := len(chunks)
if i+1 < len(groupStarts) {
end = groupStarts[i+1]
}
if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" {
return merged
}
}
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
}
func sanitizeSessionToken(raw string) string {
token := strings.TrimSpace(raw)
token = strings.Trim(token, "\"'`")
token = strings.TrimSuffix(token, ";")
return strings.TrimSpace(token)
}
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
@@ -322,9 +487,12 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339)
creds := map[string]any{
"access_token": tokenInfo.AccessToken,
"refresh_token": tokenInfo.RefreshToken,
"expires_at": expiresAt,
"access_token": tokenInfo.AccessToken,
"expires_at": expiresAt,
}
// 仅在刷新响应返回了新的 refresh_token 时才更新,防止用空值覆盖已有令牌
if strings.TrimSpace(tokenInfo.RefreshToken) != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.IDToken != "" {
@@ -342,6 +510,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
if tokenInfo.OrganizationID != "" {
creds["organization_id"] = tokenInfo.OrganizationID
}
if strings.TrimSpace(tokenInfo.ClientID) != "" {
creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID)
}
return creds
}
@@ -365,15 +536,11 @@ func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64
return proxy.URL(), nil
}
func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
transport := &http.Transport{}
if strings.TrimSpace(proxyURL) != "" {
if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" {
transport.Proxy = http.ProxyURL(parsed)
}
}
return &http.Client{
Timeout: 120 * time.Second,
Transport: transport,
func normalizeOpenAIOAuthPlatform(platform string) string {
switch strings.ToLower(strings.TrimSpace(platform)) {
case PlatformSora:
return openai.OAuthPlatformSora
default:
return openai.OAuthPlatformOpenAI
}
}

View File

@@ -0,0 +1,67 @@
package service
import (
"context"
"errors"
"net/url"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type openaiOAuthClientAuthURLStub struct{}
func (s *openaiOAuthClientAuthURLStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientAuthURLStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientAuthURLStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) {
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{})
defer svc.Stop()
result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformOpenAI)
require.NoError(t, err)
require.NotEmpty(t, result.AuthURL)
require.NotEmpty(t, result.SessionID)
parsed, err := url.Parse(result.AuthURL)
require.NoError(t, err)
q := parsed.Query()
require.Equal(t, openai.ClientID, q.Get("client_id"))
require.Equal(t, "true", q.Get("codex_cli_simplified_flow"))
session, ok := svc.sessionStore.Get(result.SessionID)
require.True(t, ok)
require.Equal(t, openai.ClientID, session.ClientID)
}
// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的
// client_id支持 localhost redirect_uri但不启用 codex_cli_simplified_flow。
func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) {
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{})
defer svc.Stop()
result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora)
require.NoError(t, err)
require.NotEmpty(t, result.AuthURL)
require.NotEmpty(t, result.SessionID)
parsed, err := url.Parse(result.AuthURL)
require.NoError(t, err)
q := parsed.Query()
require.Equal(t, openai.ClientID, q.Get("client_id"))
require.Empty(t, q.Get("codex_cli_simplified_flow"))
session, ok := svc.sessionStore.Get(result.SessionID)
require.True(t, ok)
require.Equal(t, openai.ClientID, session.ClientID)
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
@@ -13,7 +14,7 @@ import (
type openaiOAuthClientNoopStub struct{}
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
@@ -67,3 +68,106 @@ func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testi
require.Error(t, err)
require.Contains(t, err.Error(), "missing access token")
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax"
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := strings.Join([]string{
"Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly",
}, "\n")
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := strings.Join([]string{
"Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly",
}, "\n")
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := strings.Join([]string{
"set-cookie",
"__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/",
"set-cookie",
"__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/",
"set-cookie",
"__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/",
}, "\n")
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}

View File

@@ -13,10 +13,12 @@ import (
type openaiOAuthClientStateStub struct {
exchangeCalled int32
lastClientID string
}
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
atomic.AddInt32(&s.exchangeCalled, 1)
s.lastClientID = clientID
return &openai.TokenResponse{
AccessToken: "at",
RefreshToken: "rt",
@@ -95,6 +97,8 @@ func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, info)
require.Equal(t, "at", info.AccessToken)
require.Equal(t, openai.ClientID, info.ClientID)
require.Equal(t, openai.ClientID, client.lastClientID)
require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled))
_, ok := svc.sessionStore.Get("sid")

View File

@@ -0,0 +1,37 @@
package service
import (
"regexp"
"strings"
)
const (
OpenAIPreviousResponseIDKindEmpty = "empty"
OpenAIPreviousResponseIDKindResponseID = "response_id"
OpenAIPreviousResponseIDKindMessageID = "message_id"
OpenAIPreviousResponseIDKindUnknown = "unknown"
)
var (
openAIResponseIDPattern = regexp.MustCompile(`^resp_[A-Za-z0-9_-]{1,256}$`)
openAIMessageIDPattern = regexp.MustCompile(`^(msg|message|item|chatcmpl)_[A-Za-z0-9_-]{1,256}$`)
)
// ClassifyOpenAIPreviousResponseIDKind classifies previous_response_id to improve diagnostics.
func ClassifyOpenAIPreviousResponseIDKind(id string) string {
trimmed := strings.TrimSpace(id)
if trimmed == "" {
return OpenAIPreviousResponseIDKindEmpty
}
if openAIResponseIDPattern.MatchString(trimmed) {
return OpenAIPreviousResponseIDKindResponseID
}
if openAIMessageIDPattern.MatchString(strings.ToLower(trimmed)) {
return OpenAIPreviousResponseIDKindMessageID
}
return OpenAIPreviousResponseIDKindUnknown
}
func IsOpenAIPreviousResponseIDLikelyMessageID(id string) bool {
return ClassifyOpenAIPreviousResponseIDKind(id) == OpenAIPreviousResponseIDKindMessageID
}

View File

@@ -0,0 +1,34 @@
package service
import "testing"
func TestClassifyOpenAIPreviousResponseIDKind(t *testing.T) {
tests := []struct {
name string
id string
want string
}{
{name: "empty", id: " ", want: OpenAIPreviousResponseIDKindEmpty},
{name: "response_id", id: "resp_0906a621bc423a8d0169a108637ef88197b74b0e2f37ba358f", want: OpenAIPreviousResponseIDKindResponseID},
{name: "message_id", id: "msg_123456", want: OpenAIPreviousResponseIDKindMessageID},
{name: "item_id", id: "item_abcdef", want: OpenAIPreviousResponseIDKindMessageID},
{name: "unknown", id: "foo_123456", want: OpenAIPreviousResponseIDKindUnknown},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := ClassifyOpenAIPreviousResponseIDKind(tc.id); got != tc.want {
t.Fatalf("ClassifyOpenAIPreviousResponseIDKind(%q)=%q want=%q", tc.id, got, tc.want)
}
})
}
}
func TestIsOpenAIPreviousResponseIDLikelyMessageID(t *testing.T) {
if !IsOpenAIPreviousResponseIDLikelyMessageID("msg_123") {
t.Fatal("expected msg_123 to be identified as message id")
}
if IsOpenAIPreviousResponseIDLikelyMessageID("resp_123") {
t.Fatal("expected resp_123 not to be identified as message id")
}
}

View File

@@ -0,0 +1,214 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/cespare/xxhash/v2"
"github.com/gin-gonic/gin"
)
type openAILegacySessionHashContextKey struct{}
var openAILegacySessionHashKey = openAILegacySessionHashContextKey{}
var (
openAIStickyLegacyReadFallbackTotal atomic.Int64
openAIStickyLegacyReadFallbackHit atomic.Int64
openAIStickyLegacyDualWriteTotal atomic.Int64
)
func openAIStickyCompatStats() (legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal int64) {
return openAIStickyLegacyReadFallbackTotal.Load(),
openAIStickyLegacyReadFallbackHit.Load(),
openAIStickyLegacyDualWriteTotal.Load()
}
func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash string) {
normalized := strings.TrimSpace(sessionID)
if normalized == "" {
return "", ""
}
currentHash = fmt.Sprintf("%016x", xxhash.Sum64String(normalized))
sum := sha256.Sum256([]byte(normalized))
legacyHash = hex.EncodeToString(sum[:])
return currentHash, legacyHash
}
func withOpenAILegacySessionHash(ctx context.Context, legacyHash string) context.Context {
if ctx == nil {
return nil
}
trimmed := strings.TrimSpace(legacyHash)
if trimmed == "" {
return ctx
}
return context.WithValue(ctx, openAILegacySessionHashKey, trimmed)
}
func openAILegacySessionHashFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
value, _ := ctx.Value(openAILegacySessionHashKey).(string)
return strings.TrimSpace(value)
}
func attachOpenAILegacySessionHashToGin(c *gin.Context, legacyHash string) {
if c == nil || c.Request == nil {
return
}
c.Request = c.Request.WithContext(withOpenAILegacySessionHash(c.Request.Context(), legacyHash))
}
func (s *OpenAIGatewayService) openAISessionHashReadOldFallbackEnabled() bool {
if s == nil || s.cfg == nil {
return true
}
return s.cfg.Gateway.OpenAIWS.SessionHashReadOldFallback
}
func (s *OpenAIGatewayService) openAISessionHashDualWriteOldEnabled() bool {
if s == nil || s.cfg == nil {
return true
}
return s.cfg.Gateway.OpenAIWS.SessionHashDualWriteOld
}
func (s *OpenAIGatewayService) openAISessionCacheKey(sessionHash string) string {
normalized := strings.TrimSpace(sessionHash)
if normalized == "" {
return ""
}
return "openai:" + normalized
}
func (s *OpenAIGatewayService) openAILegacySessionCacheKey(ctx context.Context, sessionHash string) string {
legacyHash := openAILegacySessionHashFromContext(ctx)
if legacyHash == "" {
return ""
}
legacyKey := "openai:" + legacyHash
if legacyKey == s.openAISessionCacheKey(sessionHash) {
return ""
}
return legacyKey
}
func (s *OpenAIGatewayService) openAIStickyLegacyTTL(ttl time.Duration) time.Duration {
legacyTTL := ttl
if legacyTTL <= 0 {
legacyTTL = openaiStickySessionTTL
}
if legacyTTL > 10*time.Minute {
return 10 * time.Minute
}
return legacyTTL
}
func (s *OpenAIGatewayService) getStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) {
if s == nil || s.cache == nil {
return 0, nil
}
primaryKey := s.openAISessionCacheKey(sessionHash)
if primaryKey == "" {
return 0, nil
}
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), primaryKey)
if err == nil && accountID > 0 {
return accountID, nil
}
if !s.openAISessionHashReadOldFallbackEnabled() {
return accountID, err
}
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
if legacyKey == "" {
return accountID, err
}
openAIStickyLegacyReadFallbackTotal.Add(1)
legacyAccountID, legacyErr := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), legacyKey)
if legacyErr == nil && legacyAccountID > 0 {
openAIStickyLegacyReadFallbackHit.Add(1)
return legacyAccountID, nil
}
return accountID, err
}
func (s *OpenAIGatewayService) setStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string, accountID int64, ttl time.Duration) error {
if s == nil || s.cache == nil || accountID <= 0 {
return nil
}
primaryKey := s.openAISessionCacheKey(sessionHash)
if primaryKey == "" {
return nil
}
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), primaryKey, accountID, ttl); err != nil {
return err
}
if !s.openAISessionHashDualWriteOldEnabled() {
return nil
}
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
if legacyKey == "" {
return nil
}
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), legacyKey, accountID, s.openAIStickyLegacyTTL(ttl)); err != nil {
return err
}
openAIStickyLegacyDualWriteTotal.Add(1)
return nil
}
func (s *OpenAIGatewayService) refreshStickySessionTTL(ctx context.Context, groupID *int64, sessionHash string, ttl time.Duration) error {
if s == nil || s.cache == nil {
return nil
}
primaryKey := s.openAISessionCacheKey(sessionHash)
if primaryKey == "" {
return nil
}
err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), primaryKey, ttl)
if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() {
return err
}
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
if legacyKey != "" {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), legacyKey, s.openAIStickyLegacyTTL(ttl))
}
return err
}
func (s *OpenAIGatewayService) deleteStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) error {
if s == nil || s.cache == nil {
return nil
}
primaryKey := s.openAISessionCacheKey(sessionHash)
if primaryKey == "" {
return nil
}
err := s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), primaryKey)
if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() {
return err
}
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
if legacyKey != "" {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), legacyKey)
}
return err
}

View File

@@ -0,0 +1,96 @@
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func TestGetStickySessionAccountID_FallbackToLegacyKey(t *testing.T) {
beforeFallbackTotal, beforeFallbackHit, _ := openAIStickyCompatStats()
cache := &stubGatewayCache{
sessionBindings: map[string]int64{
"openai:legacy-hash": 42,
},
}
svc := &OpenAIGatewayService{
cache: cache,
cfg: &config.Config{
Gateway: config.GatewayConfig{
OpenAIWS: config.GatewayOpenAIWSConfig{
SessionHashReadOldFallback: true,
},
},
},
}
ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash")
accountID, err := svc.getStickySessionAccountID(ctx, nil, "new-hash")
require.NoError(t, err)
require.Equal(t, int64(42), accountID)
afterFallbackTotal, afterFallbackHit, _ := openAIStickyCompatStats()
require.Equal(t, beforeFallbackTotal+1, afterFallbackTotal)
require.Equal(t, beforeFallbackHit+1, afterFallbackHit)
}
func TestSetStickySessionAccountID_DualWriteOldEnabled(t *testing.T) {
_, _, beforeDualWriteTotal := openAIStickyCompatStats()
cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
svc := &OpenAIGatewayService{
cache: cache,
cfg: &config.Config{
Gateway: config.GatewayConfig{
OpenAIWS: config.GatewayOpenAIWSConfig{
SessionHashDualWriteOld: true,
},
},
},
}
ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash")
err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL)
require.NoError(t, err)
require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"])
require.Equal(t, int64(9), cache.sessionBindings["openai:legacy-hash"])
_, _, afterDualWriteTotal := openAIStickyCompatStats()
require.Equal(t, beforeDualWriteTotal+1, afterDualWriteTotal)
}
func TestSetStickySessionAccountID_DualWriteOldDisabled(t *testing.T) {
cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
svc := &OpenAIGatewayService{
cache: cache,
cfg: &config.Config{
Gateway: config.GatewayConfig{
OpenAIWS: config.GatewayOpenAIWSConfig{
SessionHashDualWriteOld: false,
},
},
},
}
ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash")
err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL)
require.NoError(t, err)
require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"])
_, exists := cache.sessionBindings["openai:legacy-hash"]
require.False(t, exists)
}
func TestSnapshotOpenAICompatibilityFallbackMetrics(t *testing.T) {
before := SnapshotOpenAICompatibilityFallbackMetrics()
ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true)
_, _ = ThinkingEnabledFromContext(ctx)
after := SnapshotOpenAICompatibilityFallbackMetrics()
require.GreaterOrEqual(t, after.MetadataLegacyFallbackTotal, before.MetadataLegacyFallbackTotal+1)
require.GreaterOrEqual(t, after.MetadataLegacyFallbackThinkingEnabledTotal, before.MetadataLegacyFallbackThinkingEnabledTotal+1)
}

View File

@@ -2,6 +2,24 @@ package service
import "strings"
// ToolContinuationSignals 聚合工具续链相关信号,避免重复遍历 input。
type ToolContinuationSignals struct {
HasFunctionCallOutput bool
HasFunctionCallOutputMissingCallID bool
HasToolCallContext bool
HasItemReference bool
HasItemReferenceForAllCallIDs bool
FunctionCallOutputCallIDs []string
}
// FunctionCallOutputValidation 汇总 function_call_output 关联性校验结果。
type FunctionCallOutputValidation struct {
HasFunctionCallOutput bool
HasToolCallContext bool
HasFunctionCallOutputMissingCallID bool
HasItemReferenceForAllCallIDs bool
}
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
// 满足以下任一信号即视为续链previous_response_id、input 内包含 function_call_output/item_reference、
// 或显式声明 tools/tool_choice。
@@ -18,107 +36,191 @@ func NeedsToolContinuation(reqBody map[string]any) bool {
if hasToolChoiceSignal(reqBody) {
return true
}
if inputHasType(reqBody, "function_call_output") {
return true
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
if inputHasType(reqBody, "item_reference") {
return true
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType == "function_call_output" || itemType == "item_reference" {
return true
}
}
return false
}
// AnalyzeToolContinuationSignals 单次遍历 input提取 function_call_output/tool_call/item_reference 相关信号。
func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSignals {
signals := ToolContinuationSignals{}
if reqBody == nil {
return signals
}
input, ok := reqBody["input"].([]any)
if !ok {
return signals
}
var callIDs map[string]struct{}
var referenceIDs map[string]struct{}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
switch itemType {
case "tool_call", "function_call":
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) != "" {
signals.HasToolCallContext = true
}
case "function_call_output":
signals.HasFunctionCallOutput = true
callID, _ := itemMap["call_id"].(string)
callID = strings.TrimSpace(callID)
if callID == "" {
signals.HasFunctionCallOutputMissingCallID = true
continue
}
if callIDs == nil {
callIDs = make(map[string]struct{})
}
callIDs[callID] = struct{}{}
case "item_reference":
signals.HasItemReference = true
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
if referenceIDs == nil {
referenceIDs = make(map[string]struct{})
}
referenceIDs[idValue] = struct{}{}
}
}
if len(callIDs) == 0 {
return signals
}
signals.FunctionCallOutputCallIDs = make([]string, 0, len(callIDs))
allReferenced := len(referenceIDs) > 0
for callID := range callIDs {
signals.FunctionCallOutputCallIDs = append(signals.FunctionCallOutputCallIDs, callID)
if allReferenced {
if _, ok := referenceIDs[callID]; !ok {
allReferenced = false
}
}
}
signals.HasItemReferenceForAllCallIDs = allReferenced
return signals
}
// ValidateFunctionCallOutputContext 为 handler 提供低开销校验结果:
// 1) 无 function_call_output 直接返回
// 2) 若已存在 tool_call/function_call 上下文则提前返回
// 3) 仅在无工具上下文时才构建 call_id / item_reference 集合
func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutputValidation {
result := FunctionCallOutputValidation{}
if reqBody == nil {
return result
}
input, ok := reqBody["input"].([]any)
if !ok {
return result
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
switch itemType {
case "function_call_output":
result.HasFunctionCallOutput = true
case "tool_call", "function_call":
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) != "" {
result.HasToolCallContext = true
}
}
if result.HasFunctionCallOutput && result.HasToolCallContext {
return result
}
}
if !result.HasFunctionCallOutput || result.HasToolCallContext {
return result
}
callIDs := make(map[string]struct{})
referenceIDs := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
switch itemType {
case "function_call_output":
callID, _ := itemMap["call_id"].(string)
callID = strings.TrimSpace(callID)
if callID == "" {
result.HasFunctionCallOutputMissingCallID = true
continue
}
callIDs[callID] = struct{}{}
case "item_reference":
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
referenceIDs[idValue] = struct{}{}
}
}
if len(callIDs) == 0 || len(referenceIDs) == 0 {
return result
}
allReferenced := true
for callID := range callIDs {
if _, ok := referenceIDs[callID]; !ok {
allReferenced = false
break
}
}
result.HasItemReferenceForAllCallIDs = allReferenced
return result
}
// HasFunctionCallOutput 判断 input 是否包含 function_call_output用于触发续链校验。
func HasFunctionCallOutput(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
return inputHasType(reqBody, "function_call_output")
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutput
}
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call
// 用于判断 function_call_output 是否具备可关联的上下文。
func HasToolCallContext(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "tool_call" && itemType != "function_call" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
return true
}
}
return false
return AnalyzeToolContinuationSignals(reqBody).HasToolCallContext
}
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
// 仅返回非空 call_id用于与 item_reference.id 做匹配校验。
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
if reqBody == nil {
return nil
}
input, ok := reqBody["input"].([]any)
if !ok {
return nil
}
ids := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
ids[callID] = struct{}{}
}
}
if len(ids) == 0 {
return nil
}
result := make([]string, 0, len(ids))
for id := range ids {
result = append(result, id)
}
return result
return AnalyzeToolContinuationSignals(reqBody).FunctionCallOutputCallIDs
}
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) == "" {
return true
}
}
return false
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutputMissingCallID
}
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
@@ -152,32 +254,13 @@ func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
return false
}
for _, callID := range callIDs {
if _, ok := referenceIDs[callID]; !ok {
if _, ok := referenceIDs[strings.TrimSpace(callID)]; !ok {
return false
}
}
return true
}
// inputHasType 判断 input 中是否存在指定类型的 item。
func inputHasType(reqBody map[string]any, want string) bool {
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType == want {
return true
}
}
return false
}
// hasNonEmptyString 判断字段是否为非空字符串。
func hasNonEmptyString(value any) bool {
stringValue, ok := value.(string)

View File

@@ -1,11 +1,15 @@
package service
import (
"encoding/json"
"bytes"
"fmt"
"strconv"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
@@ -62,169 +66,201 @@ func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, boo
if data == "" || data == "\n" {
return data, false
}
correctedBytes, corrected := c.CorrectToolCallsInSSEBytes([]byte(data))
if !corrected {
return data, false
}
return string(correctedBytes), true
}
// 尝试解析 JSON
var payload map[string]any
if err := json.Unmarshal([]byte(data), &payload); err != nil {
// 不是有效的 JSON直接返回原数据
// CorrectToolCallsInSSEBytes 修正 SSE JSON 数据中的工具调用(字节路径)。
// 返回修正后的数据和是否进行了修正。
func (c *CodexToolCorrector) CorrectToolCallsInSSEBytes(data []byte) ([]byte, bool) {
if len(bytes.TrimSpace(data)) == 0 {
return data, false
}
if !mayContainToolCallPayload(data) {
return data, false
}
if !gjson.ValidBytes(data) {
// 不是有效 JSON直接返回原数据
return data, false
}
updated := data
corrected := false
// 处理 tool_calls 数组
if toolCalls, ok := payload["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
collect := func(changed bool, next []byte) {
if changed {
corrected = true
updated = next
}
}
// 处理 function_call 对象
if functionCall, ok := payload["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
if next, changed := c.correctToolCallsArrayAtPath(updated, "tool_calls"); changed {
collect(changed, next)
}
if next, changed := c.correctFunctionAtPath(updated, "function_call"); changed {
collect(changed, next)
}
if next, changed := c.correctToolCallsArrayAtPath(updated, "delta.tool_calls"); changed {
collect(changed, next)
}
if next, changed := c.correctFunctionAtPath(updated, "delta.function_call"); changed {
collect(changed, next)
}
// 处理 delta.tool_calls
if delta, ok := payload["delta"].(map[string]any); ok {
if toolCalls, ok := delta["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
choicesCount := int(gjson.GetBytes(updated, "choices.#").Int())
for i := 0; i < choicesCount; i++ {
prefix := "choices." + strconv.Itoa(i)
if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".message.tool_calls"); changed {
collect(changed, next)
}
if functionCall, ok := delta["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
if next, changed := c.correctFunctionAtPath(updated, prefix+".message.function_call"); changed {
collect(changed, next)
}
}
// 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
if choices, ok := payload["choices"].([]any); ok {
for _, choice := range choices {
if choiceMap, ok := choice.(map[string]any); ok {
// 处理 message 中的工具调用
if message, ok := choiceMap["message"].(map[string]any); ok {
if toolCalls, ok := message["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
if functionCall, ok := message["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
}
// 处理 delta 中的工具调用
if delta, ok := choiceMap["delta"].(map[string]any); ok {
if toolCalls, ok := delta["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
if functionCall, ok := delta["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
}
}
if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".delta.tool_calls"); changed {
collect(changed, next)
}
if next, changed := c.correctFunctionAtPath(updated, prefix+".delta.function_call"); changed {
collect(changed, next)
}
}
if !corrected {
return data, false
}
return updated, true
}
// 序列化回 JSON
correctedBytes, err := json.Marshal(payload)
if err != nil {
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Failed to marshal corrected data: %v", err)
func mayContainToolCallPayload(data []byte) bool {
// 快速路径:多数 token / 文本事件不包含工具字段,避免进入 JSON 解析热路径。
return bytes.Contains(data, []byte(`"tool_calls"`)) ||
bytes.Contains(data, []byte(`"function_call"`)) ||
bytes.Contains(data, []byte(`"function":{"name"`))
}
// correctToolCallsArrayAtPath 修正指定路径下 tool_calls 数组中的工具名称。
func (c *CodexToolCorrector) correctToolCallsArrayAtPath(data []byte, toolCallsPath string) ([]byte, bool) {
count := int(gjson.GetBytes(data, toolCallsPath+".#").Int())
if count <= 0 {
return data, false
}
return string(correctedBytes), true
}
// correctToolCallsArray 修正工具调用数组中的工具名称
func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool {
updated := data
corrected := false
for _, toolCall := range toolCalls {
if toolCallMap, ok := toolCall.(map[string]any); ok {
if function, ok := toolCallMap["function"].(map[string]any); ok {
if c.correctFunctionCall(function) {
corrected = true
}
}
for i := 0; i < count; i++ {
functionPath := toolCallsPath + "." + strconv.Itoa(i) + ".function"
if next, changed := c.correctFunctionAtPath(updated, functionPath); changed {
updated = next
corrected = true
}
}
return corrected
return updated, corrected
}
// correctFunctionCall 修正单个函数调用的工具名称和参数
func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool {
name, ok := functionCall["name"].(string)
if !ok || name == "" {
return false
// correctFunctionAtPath 修正指定路径下单个函数调用的工具名称和参数
func (c *CodexToolCorrector) correctFunctionAtPath(data []byte, functionPath string) ([]byte, bool) {
namePath := functionPath + ".name"
nameResult := gjson.GetBytes(data, namePath)
if !nameResult.Exists() || nameResult.Type != gjson.String {
return data, false
}
name := strings.TrimSpace(nameResult.Str)
if name == "" {
return data, false
}
updated := data
corrected := false
// 查找并修正工具名称
if correctName, found := codexToolNameMapping[name]; found {
functionCall["name"] = correctName
c.recordCorrection(name, correctName)
corrected = true
name = correctName // 使用修正后的名称进行参数修正
if next, err := sjson.SetBytes(updated, namePath, correctName); err == nil {
updated = next
c.recordCorrection(name, correctName)
corrected = true
name = correctName // 使用修正后的名称进行参数修正
}
}
// 修正工具参数(基于工具名称)
if c.correctToolParameters(name, functionCall) {
if next, changed := c.correctToolParametersAtPath(updated, functionPath+".arguments", name); changed {
updated = next
corrected = true
}
return corrected
return updated, corrected
}
// correctToolParameters 修正工具参数以符合 OpenCode 规范
func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool {
arguments, ok := functionCall["arguments"]
if !ok {
return false
// correctToolParametersAtPath 修正指定路径下 arguments 参数。
func (c *CodexToolCorrector) correctToolParametersAtPath(data []byte, argumentsPath, toolName string) ([]byte, bool) {
if toolName != "bash" && toolName != "edit" {
return data, false
}
// arguments 可能是字符串JSON或已解析的 map
var argsMap map[string]any
switch v := arguments.(type) {
case string:
// 解析 JSON 字符串
if err := json.Unmarshal([]byte(v), &argsMap); err != nil {
return false
args := gjson.GetBytes(data, argumentsPath)
if !args.Exists() {
return data, false
}
switch args.Type {
case gjson.String:
argsJSON := strings.TrimSpace(args.Str)
if !gjson.Valid(argsJSON) {
return data, false
}
case map[string]any:
argsMap = v
if !gjson.Parse(argsJSON).IsObject() {
return data, false
}
nextArgsJSON, corrected := c.correctToolArgumentsJSON(argsJSON, toolName)
if !corrected {
return data, false
}
next, err := sjson.SetBytes(data, argumentsPath, nextArgsJSON)
if err != nil {
return data, false
}
return next, true
case gjson.JSON:
if !args.IsObject() || !gjson.Valid(args.Raw) {
return data, false
}
nextArgsJSON, corrected := c.correctToolArgumentsJSON(args.Raw, toolName)
if !corrected {
return data, false
}
next, err := sjson.SetRawBytes(data, argumentsPath, []byte(nextArgsJSON))
if err != nil {
return data, false
}
return next, true
default:
return false
return data, false
}
}
// correctToolArgumentsJSON 修正工具参数 JSON对象字符串返回修正后的 JSON 与是否变更。
func (c *CodexToolCorrector) correctToolArgumentsJSON(argsJSON, toolName string) (string, bool) {
if !gjson.Valid(argsJSON) {
return argsJSON, false
}
if !gjson.Parse(argsJSON).IsObject() {
return argsJSON, false
}
updated := argsJSON
corrected := false
// 根据工具名称应用特定的参数修正规则
switch toolName {
case "bash":
// OpenCode bash 支持 workdir有些来源会输出 work_dir。
if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir {
if workDir, exists := argsMap["work_dir"]; exists {
argsMap["workdir"] = workDir
delete(argsMap, "work_dir")
if !gjson.Get(updated, "workdir").Exists() {
if next, changed := moveJSONField(updated, "work_dir", "workdir"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool")
}
} else {
if _, exists := argsMap["work_dir"]; exists {
delete(argsMap, "work_dir")
if next, changed := deleteJSONField(updated, "work_dir"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool")
}
@@ -232,67 +268,71 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
case "edit":
// OpenCode edit 参数为 filePath/oldString/newStringcamelCase
if _, exists := argsMap["filePath"]; !exists {
if filePath, exists := argsMap["file_path"]; exists {
argsMap["filePath"] = filePath
delete(argsMap, "file_path")
if !gjson.Get(updated, "filePath").Exists() {
if next, changed := moveJSONField(updated, "file_path", "filePath"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool")
} else if filePath, exists := argsMap["path"]; exists {
argsMap["filePath"] = filePath
delete(argsMap, "path")
} else if next, changed := moveJSONField(updated, "path", "filePath"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool")
} else if filePath, exists := argsMap["file"]; exists {
argsMap["filePath"] = filePath
delete(argsMap, "file")
} else if next, changed := moveJSONField(updated, "file", "filePath"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool")
}
}
if _, exists := argsMap["oldString"]; !exists {
if oldString, exists := argsMap["old_string"]; exists {
argsMap["oldString"] = oldString
delete(argsMap, "old_string")
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
}
if next, changed := moveJSONField(updated, "old_string", "oldString"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
}
if _, exists := argsMap["newString"]; !exists {
if newString, exists := argsMap["new_string"]; exists {
argsMap["newString"] = newString
delete(argsMap, "new_string")
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
}
if next, changed := moveJSONField(updated, "new_string", "newString"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
}
if _, exists := argsMap["replaceAll"]; !exists {
if replaceAll, exists := argsMap["replace_all"]; exists {
argsMap["replaceAll"] = replaceAll
delete(argsMap, "replace_all")
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
}
if next, changed := moveJSONField(updated, "replace_all", "replaceAll"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
}
}
return updated, corrected
}
// 如果修正了参数,需要重新序列化
if corrected {
if _, wasString := arguments.(string); wasString {
// 原本是字符串,序列化回字符串
if newArgsJSON, err := json.Marshal(argsMap); err == nil {
functionCall["arguments"] = string(newArgsJSON)
}
} else {
// 原本是 map直接赋值
functionCall["arguments"] = argsMap
}
func moveJSONField(input, from, to string) (string, bool) {
if gjson.Get(input, to).Exists() {
return input, false
}
src := gjson.Get(input, from)
if !src.Exists() {
return input, false
}
next, err := sjson.SetRaw(input, to, src.Raw)
if err != nil {
return input, false
}
next, err = sjson.Delete(next, from)
if err != nil {
return input, false
}
return next, true
}
return corrected
func deleteJSONField(input, path string) (string, bool) {
if !gjson.Get(input, path).Exists() {
return input, false
}
next, err := sjson.Delete(input, path)
if err != nil {
return input, false
}
return next, true
}
// recordCorrection 记录一次工具名称修正

View File

@@ -5,6 +5,15 @@ import (
"testing"
)
func TestMayContainToolCallPayload(t *testing.T) {
if mayContainToolCallPayload([]byte(`{"type":"response.output_text.delta","delta":"hello"}`)) {
t.Fatalf("plain text event should not trigger tool-call parsing")
}
if !mayContainToolCallPayload([]byte(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)) {
t.Fatalf("tool_calls event should trigger tool-call parsing")
}
}
func TestCorrectToolCallsInSSEData(t *testing.T) {
corrector := NewCodexToolCorrector()

View File

@@ -0,0 +1,190 @@
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T) {
ctx := context.Background()
groupID := int64(23)
account := Account{
ID: 2,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 2,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
cfg := newOpenAIWSV2TestConfig()
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
openaiWSStateStore: store,
}
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_1", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, account.ID, selection.Account.ID)
require.True(t, selection.Acquired)
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) {
ctx := context.Background()
groupID := int64(23)
account := Account{
ID: 8,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
cfg := newOpenAIWSV2TestConfig()
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
openaiWSStateStore: store,
}
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_2", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}})
require.NoError(t, err)
require.Nil(t, selection)
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_ForceHTTPIgnored(t *testing.T) {
ctx := context.Background()
groupID := int64(23)
account := Account{
ID: 11,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Extra: map[string]any{
"openai_ws_force_http": true,
"responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
cfg := newOpenAIWSV2TestConfig()
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
openaiWSStateStore: store,
}
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_force_http", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil)
require.NoError(t, err)
require.Nil(t, selection, "force_http 场景应忽略 previous_response_id 粘连")
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(t *testing.T) {
ctx := context.Background()
groupID := int64(23)
accounts := []Account{
{
ID: 21,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
},
{
ID: 22,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 9,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
},
}
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
cfg := newOpenAIWSV2TestConfig()
cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2
cfg.Gateway.Scheduling.StickySessionWaitTimeout = 30 * time.Second
concurrencyCache := stubConcurrencyCache{
acquireResults: map[int64]bool{
21: false, // previous_response 命中的账号繁忙
22: true, // 次优账号可用(若回退会命中)
},
waitCounts: map[int64]int{
21: 999,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
openaiWSStateStore: store,
}
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_busy", 21, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(21), selection.Account.ID, "busy previous_response sticky account should remain selected")
require.False(t, selection.Acquired)
require.NotNil(t, selection.WaitPlan)
require.Equal(t, int64(21), selection.WaitPlan.AccountID)
}
func newOpenAIWSV2TestConfig() *config.Config {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
return cfg
}

View File

@@ -0,0 +1,285 @@
package service
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
coderws "github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
const openAIWSMessageReadLimitBytes int64 = 16 * 1024 * 1024
const (
openAIWSProxyTransportMaxIdleConns = 128
openAIWSProxyTransportMaxIdleConnsPerHost = 64
openAIWSProxyTransportIdleConnTimeout = 90 * time.Second
openAIWSProxyClientCacheMaxEntries = 256
openAIWSProxyClientCacheIdleTTL = 15 * time.Minute
)
type OpenAIWSTransportMetricsSnapshot struct {
ProxyClientCacheHits int64 `json:"proxy_client_cache_hits"`
ProxyClientCacheMisses int64 `json:"proxy_client_cache_misses"`
TransportReuseRatio float64 `json:"transport_reuse_ratio"`
}
// openAIWSClientConn 抽象 WS 客户端连接,便于替换底层实现。
type openAIWSClientConn interface {
WriteJSON(ctx context.Context, value any) error
ReadMessage(ctx context.Context) ([]byte, error)
Ping(ctx context.Context) error
Close() error
}
// openAIWSClientDialer 抽象 WS 建连器。
type openAIWSClientDialer interface {
Dial(ctx context.Context, wsURL string, headers http.Header, proxyURL string) (openAIWSClientConn, int, http.Header, error)
}
type openAIWSTransportMetricsDialer interface {
SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot
}
func newDefaultOpenAIWSClientDialer() openAIWSClientDialer {
return &coderOpenAIWSClientDialer{
proxyClients: make(map[string]*openAIWSProxyClientEntry),
}
}
type coderOpenAIWSClientDialer struct {
proxyMu sync.Mutex
proxyClients map[string]*openAIWSProxyClientEntry
proxyHits atomic.Int64
proxyMisses atomic.Int64
}
type openAIWSProxyClientEntry struct {
client *http.Client
lastUsedUnixNano int64
}
func (d *coderOpenAIWSClientDialer) Dial(
ctx context.Context,
wsURL string,
headers http.Header,
proxyURL string,
) (openAIWSClientConn, int, http.Header, error) {
targetURL := strings.TrimSpace(wsURL)
if targetURL == "" {
return nil, 0, nil, errors.New("ws url is empty")
}
opts := &coderws.DialOptions{
HTTPHeader: cloneHeader(headers),
CompressionMode: coderws.CompressionContextTakeover,
}
if proxy := strings.TrimSpace(proxyURL); proxy != "" {
proxyClient, err := d.proxyHTTPClient(proxy)
if err != nil {
return nil, 0, nil, err
}
opts.HTTPClient = proxyClient
}
conn, resp, err := coderws.Dial(ctx, targetURL, opts)
if err != nil {
status := 0
respHeaders := http.Header(nil)
if resp != nil {
status = resp.StatusCode
respHeaders = cloneHeader(resp.Header)
}
return nil, status, respHeaders, err
}
// coder/websocket 默认单消息读取上限为 32KBCodex WS 事件(如 rate_limits/大 delta
// 可能超过该阈值,需显式提高上限,避免本地 read_fail(message too big)。
conn.SetReadLimit(openAIWSMessageReadLimitBytes)
respHeaders := http.Header(nil)
if resp != nil {
respHeaders = cloneHeader(resp.Header)
}
return &coderOpenAIWSClientConn{conn: conn}, 0, respHeaders, nil
}
func (d *coderOpenAIWSClientDialer) proxyHTTPClient(proxy string) (*http.Client, error) {
if d == nil {
return nil, errors.New("openai ws dialer is nil")
}
normalizedProxy := strings.TrimSpace(proxy)
if normalizedProxy == "" {
return nil, errors.New("proxy url is empty")
}
parsedProxyURL, err := url.Parse(normalizedProxy)
if err != nil {
return nil, fmt.Errorf("invalid proxy url: %w", err)
}
now := time.Now().UnixNano()
d.proxyMu.Lock()
defer d.proxyMu.Unlock()
if entry, ok := d.proxyClients[normalizedProxy]; ok && entry != nil && entry.client != nil {
entry.lastUsedUnixNano = now
d.proxyHits.Add(1)
return entry.client, nil
}
d.cleanupProxyClientsLocked(now)
transport := &http.Transport{
Proxy: http.ProxyURL(parsedProxyURL),
MaxIdleConns: openAIWSProxyTransportMaxIdleConns,
MaxIdleConnsPerHost: openAIWSProxyTransportMaxIdleConnsPerHost,
IdleConnTimeout: openAIWSProxyTransportIdleConnTimeout,
TLSHandshakeTimeout: 10 * time.Second,
ForceAttemptHTTP2: true,
}
client := &http.Client{Transport: transport}
d.proxyClients[normalizedProxy] = &openAIWSProxyClientEntry{
client: client,
lastUsedUnixNano: now,
}
d.ensureProxyClientCapacityLocked()
d.proxyMisses.Add(1)
return client, nil
}
func (d *coderOpenAIWSClientDialer) cleanupProxyClientsLocked(nowUnixNano int64) {
if d == nil || len(d.proxyClients) == 0 {
return
}
idleTTL := openAIWSProxyClientCacheIdleTTL
if idleTTL <= 0 {
return
}
now := time.Unix(0, nowUnixNano)
for key, entry := range d.proxyClients {
if entry == nil || entry.client == nil {
delete(d.proxyClients, key)
continue
}
lastUsed := time.Unix(0, entry.lastUsedUnixNano)
if now.Sub(lastUsed) > idleTTL {
closeOpenAIWSProxyClient(entry.client)
delete(d.proxyClients, key)
}
}
}
func (d *coderOpenAIWSClientDialer) ensureProxyClientCapacityLocked() {
if d == nil {
return
}
maxEntries := openAIWSProxyClientCacheMaxEntries
if maxEntries <= 0 {
return
}
for len(d.proxyClients) > maxEntries {
var oldestKey string
var oldestLastUsed int64
hasOldest := false
for key, entry := range d.proxyClients {
lastUsed := int64(0)
if entry != nil {
lastUsed = entry.lastUsedUnixNano
}
if !hasOldest || lastUsed < oldestLastUsed {
hasOldest = true
oldestKey = key
oldestLastUsed = lastUsed
}
}
if !hasOldest {
return
}
if entry := d.proxyClients[oldestKey]; entry != nil {
closeOpenAIWSProxyClient(entry.client)
}
delete(d.proxyClients, oldestKey)
}
}
func closeOpenAIWSProxyClient(client *http.Client) {
if client == nil || client.Transport == nil {
return
}
if transport, ok := client.Transport.(*http.Transport); ok && transport != nil {
transport.CloseIdleConnections()
}
}
func (d *coderOpenAIWSClientDialer) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
if d == nil {
return OpenAIWSTransportMetricsSnapshot{}
}
hits := d.proxyHits.Load()
misses := d.proxyMisses.Load()
total := hits + misses
reuseRatio := 0.0
if total > 0 {
reuseRatio = float64(hits) / float64(total)
}
return OpenAIWSTransportMetricsSnapshot{
ProxyClientCacheHits: hits,
ProxyClientCacheMisses: misses,
TransportReuseRatio: reuseRatio,
}
}
type coderOpenAIWSClientConn struct {
conn *coderws.Conn
}
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return wsjson.Write(ctx, c.conn, value)
}
func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, error) {
if c == nil || c.conn == nil {
return nil, errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
msgType, payload, err := c.conn.Read(ctx)
if err != nil {
return nil, err
}
switch msgType {
case coderws.MessageText, coderws.MessageBinary:
return payload, nil
default:
return nil, errOpenAIWSConnClosed
}
}
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Ping(ctx)
}
func (c *coderOpenAIWSClientConn) Close() error {
if c == nil || c.conn == nil {
return nil
}
// Close 为幂等,忽略重复关闭错误。
_ = c.conn.Close(coderws.StatusNormalClosure, "")
_ = c.conn.CloseNow()
return nil
}

View File

@@ -0,0 +1,112 @@
package service
import (
"fmt"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestCoderOpenAIWSClientDialer_ProxyHTTPClientReuse(t *testing.T) {
dialer := newDefaultOpenAIWSClientDialer()
impl, ok := dialer.(*coderOpenAIWSClientDialer)
require.True(t, ok)
c1, err := impl.proxyHTTPClient("http://127.0.0.1:8080")
require.NoError(t, err)
c2, err := impl.proxyHTTPClient("http://127.0.0.1:8080")
require.NoError(t, err)
require.Same(t, c1, c2, "同一代理地址应复用同一个 HTTP 客户端")
c3, err := impl.proxyHTTPClient("http://127.0.0.1:8081")
require.NoError(t, err)
require.NotSame(t, c1, c3, "不同代理地址应分离客户端")
}
func TestCoderOpenAIWSClientDialer_ProxyHTTPClientInvalidURL(t *testing.T) {
dialer := newDefaultOpenAIWSClientDialer()
impl, ok := dialer.(*coderOpenAIWSClientDialer)
require.True(t, ok)
_, err := impl.proxyHTTPClient("://bad")
require.Error(t, err)
}
func TestCoderOpenAIWSClientDialer_TransportMetricsSnapshot(t *testing.T) {
dialer := newDefaultOpenAIWSClientDialer()
impl, ok := dialer.(*coderOpenAIWSClientDialer)
require.True(t, ok)
_, err := impl.proxyHTTPClient("http://127.0.0.1:18080")
require.NoError(t, err)
_, err = impl.proxyHTTPClient("http://127.0.0.1:18080")
require.NoError(t, err)
_, err = impl.proxyHTTPClient("http://127.0.0.1:18081")
require.NoError(t, err)
snapshot := impl.SnapshotTransportMetrics()
require.Equal(t, int64(1), snapshot.ProxyClientCacheHits)
require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses)
require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001)
}
func TestCoderOpenAIWSClientDialer_ProxyClientCacheCapacity(t *testing.T) {
dialer := newDefaultOpenAIWSClientDialer()
impl, ok := dialer.(*coderOpenAIWSClientDialer)
require.True(t, ok)
total := openAIWSProxyClientCacheMaxEntries + 32
for i := 0; i < total; i++ {
_, err := impl.proxyHTTPClient(fmt.Sprintf("http://127.0.0.1:%d", 20000+i))
require.NoError(t, err)
}
impl.proxyMu.Lock()
cacheSize := len(impl.proxyClients)
impl.proxyMu.Unlock()
require.LessOrEqual(t, cacheSize, openAIWSProxyClientCacheMaxEntries, "代理客户端缓存应受容量上限约束")
}
func TestCoderOpenAIWSClientDialer_ProxyClientCacheIdleTTL(t *testing.T) {
dialer := newDefaultOpenAIWSClientDialer()
impl, ok := dialer.(*coderOpenAIWSClientDialer)
require.True(t, ok)
oldProxy := "http://127.0.0.1:28080"
_, err := impl.proxyHTTPClient(oldProxy)
require.NoError(t, err)
impl.proxyMu.Lock()
oldEntry := impl.proxyClients[oldProxy]
require.NotNil(t, oldEntry)
oldEntry.lastUsedUnixNano = time.Now().Add(-openAIWSProxyClientCacheIdleTTL - time.Minute).UnixNano()
impl.proxyMu.Unlock()
// 触发一次新的代理获取,驱动 TTL 清理。
_, err = impl.proxyHTTPClient("http://127.0.0.1:28081")
require.NoError(t, err)
impl.proxyMu.Lock()
_, exists := impl.proxyClients[oldProxy]
impl.proxyMu.Unlock()
require.False(t, exists, "超过空闲 TTL 的代理客户端应被回收")
}
func TestCoderOpenAIWSClientDialer_ProxyTransportTLSHandshakeTimeout(t *testing.T) {
dialer := newDefaultOpenAIWSClientDialer()
impl, ok := dialer.(*coderOpenAIWSClientDialer)
require.True(t, ok)
client, err := impl.proxyHTTPClient("http://127.0.0.1:38080")
require.NoError(t, err)
require.NotNil(t, client)
transport, ok := client.Transport.(*http.Transport)
require.True(t, ok)
require.NotNil(t, transport)
require.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout)
}

View File

@@ -0,0 +1,251 @@
package service
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
coderws "github.com/coder/websocket"
"github.com/stretchr/testify/require"
)
func TestClassifyOpenAIWSAcquireError(t *testing.T) {
t.Run("dial_426_upgrade_required", func(t *testing.T) {
err := &openAIWSDialError{StatusCode: 426, Err: errors.New("upgrade required")}
require.Equal(t, "upgrade_required", classifyOpenAIWSAcquireError(err))
})
t.Run("queue_full", func(t *testing.T) {
require.Equal(t, "conn_queue_full", classifyOpenAIWSAcquireError(errOpenAIWSConnQueueFull))
})
t.Run("preferred_conn_unavailable", func(t *testing.T) {
require.Equal(t, "preferred_conn_unavailable", classifyOpenAIWSAcquireError(errOpenAIWSPreferredConnUnavailable))
})
t.Run("acquire_timeout", func(t *testing.T) {
require.Equal(t, "acquire_timeout", classifyOpenAIWSAcquireError(context.DeadlineExceeded))
})
t.Run("auth_failed_401", func(t *testing.T) {
err := &openAIWSDialError{StatusCode: 401, Err: errors.New("unauthorized")}
require.Equal(t, "auth_failed", classifyOpenAIWSAcquireError(err))
})
t.Run("upstream_rate_limited", func(t *testing.T) {
err := &openAIWSDialError{StatusCode: 429, Err: errors.New("rate limited")}
require.Equal(t, "upstream_rate_limited", classifyOpenAIWSAcquireError(err))
})
t.Run("upstream_5xx", func(t *testing.T) {
err := &openAIWSDialError{StatusCode: 502, Err: errors.New("bad gateway")}
require.Equal(t, "upstream_5xx", classifyOpenAIWSAcquireError(err))
})
t.Run("dial_failed_other_status", func(t *testing.T) {
err := &openAIWSDialError{StatusCode: 418, Err: errors.New("teapot")}
require.Equal(t, "dial_failed", classifyOpenAIWSAcquireError(err))
})
t.Run("other", func(t *testing.T) {
require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(errors.New("x")))
})
t.Run("nil", func(t *testing.T) {
require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(nil))
})
}
func TestClassifyOpenAIWSDialError(t *testing.T) {
t.Run("handshake_not_finished", func(t *testing.T) {
err := &openAIWSDialError{
StatusCode: http.StatusBadGateway,
Err: errors.New("WebSocket protocol error: Handshake not finished"),
}
require.Equal(t, "handshake_not_finished", classifyOpenAIWSDialError(err))
})
t.Run("context_deadline", func(t *testing.T) {
err := &openAIWSDialError{
StatusCode: 0,
Err: context.DeadlineExceeded,
}
require.Equal(t, "ctx_deadline_exceeded", classifyOpenAIWSDialError(err))
})
}
func TestSummarizeOpenAIWSDialError(t *testing.T) {
err := &openAIWSDialError{
StatusCode: http.StatusBadGateway,
ResponseHeaders: http.Header{
"Server": []string{"cloudflare"},
"Via": []string{"1.1 example"},
"Cf-Ray": []string{"abcd1234"},
"X-Request-Id": []string{"req_123"},
},
Err: errors.New("WebSocket protocol error: Handshake not finished"),
}
status, class, closeStatus, closeReason, server, via, cfRay, reqID := summarizeOpenAIWSDialError(err)
require.Equal(t, http.StatusBadGateway, status)
require.Equal(t, "handshake_not_finished", class)
require.Equal(t, "-", closeStatus)
require.Equal(t, "-", closeReason)
require.Equal(t, "cloudflare", server)
require.Equal(t, "1.1 example", via)
require.Equal(t, "abcd1234", cfRay)
require.Equal(t, "req_123", reqID)
}
func TestClassifyOpenAIWSErrorEvent(t *testing.T) {
reason, recoverable := classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"upgrade_required","message":"Upgrade required"}}`))
require.Equal(t, "upgrade_required", reason)
require.True(t, recoverable)
reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"previous_response_not_found","message":"not found"}}`))
require.Equal(t, "previous_response_not_found", reason)
require.True(t, recoverable)
}
func TestClassifyOpenAIWSReconnectReason(t *testing.T) {
reason, retryable := classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("policy_violation", errors.New("policy")))
require.Equal(t, "policy_violation", reason)
require.False(t, retryable)
reason, retryable = classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("read_event", errors.New("io")))
require.Equal(t, "read_event", reason)
require.True(t, retryable)
}
func TestOpenAIWSErrorHTTPStatus(t *testing.T) {
require.Equal(t, http.StatusBadRequest, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)))
require.Equal(t, http.StatusUnauthorized, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"authentication_error","code":"invalid_api_key","message":"auth failed"}}`)))
require.Equal(t, http.StatusForbidden, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"permission_error","code":"forbidden","message":"forbidden"}}`)))
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"rate_limit_error","code":"rate_limit_exceeded","message":"rate limited"}}`)))
require.Equal(t, http.StatusBadGateway, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"server_error","code":"server_error","message":"server"}}`)))
}
func TestResolveOpenAIWSFallbackErrorResponse(t *testing.T) {
t.Run("previous_response_not_found", func(t *testing.T) {
statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(
wrapOpenAIWSFallback("previous_response_not_found", errors.New("previous response not found")),
)
require.True(t, ok)
require.Equal(t, http.StatusBadRequest, statusCode)
require.Equal(t, "invalid_request_error", errType)
require.Equal(t, "previous response not found", clientMessage)
require.Equal(t, "previous response not found", upstreamMessage)
})
t.Run("auth_failed_uses_dial_status", func(t *testing.T) {
statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(
wrapOpenAIWSFallback("auth_failed", &openAIWSDialError{
StatusCode: http.StatusForbidden,
Err: errors.New("forbidden"),
}),
)
require.True(t, ok)
require.Equal(t, http.StatusForbidden, statusCode)
require.Equal(t, "upstream_error", errType)
require.Equal(t, "forbidden", clientMessage)
require.Equal(t, "forbidden", upstreamMessage)
})
t.Run("non_fallback_error_not_resolved", func(t *testing.T) {
_, _, _, _, ok := resolveOpenAIWSFallbackErrorResponse(errors.New("plain error"))
require.False(t, ok)
})
}
func TestOpenAIWSFallbackCooling(t *testing.T) {
svc := &OpenAIGatewayService{cfg: &config.Config{}}
svc.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
require.False(t, svc.isOpenAIWSFallbackCooling(1))
svc.markOpenAIWSFallbackCooling(1, "upgrade_required")
require.True(t, svc.isOpenAIWSFallbackCooling(1))
svc.clearOpenAIWSFallbackCooling(1)
require.False(t, svc.isOpenAIWSFallbackCooling(1))
svc.markOpenAIWSFallbackCooling(2, "x")
time.Sleep(1200 * time.Millisecond)
require.False(t, svc.isOpenAIWSFallbackCooling(2))
}
func TestOpenAIWSRetryBackoff(t *testing.T) {
svc := &OpenAIGatewayService{cfg: &config.Config{}}
svc.cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 100
svc.cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 400
svc.cfg.Gateway.OpenAIWS.RetryJitterRatio = 0
require.Equal(t, time.Duration(100)*time.Millisecond, svc.openAIWSRetryBackoff(1))
require.Equal(t, time.Duration(200)*time.Millisecond, svc.openAIWSRetryBackoff(2))
require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(3))
require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(4))
}
func TestOpenAIWSRetryTotalBudget(t *testing.T) {
svc := &OpenAIGatewayService{cfg: &config.Config{}}
svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 1200
require.Equal(t, 1200*time.Millisecond, svc.openAIWSRetryTotalBudget())
svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 0
require.Equal(t, time.Duration(0), svc.openAIWSRetryTotalBudget())
}
func TestClassifyOpenAIWSReadFallbackReason(t *testing.T) {
require.Equal(t, "policy_violation", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusPolicyViolation}))
require.Equal(t, "message_too_big", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusMessageTooBig}))
require.Equal(t, "read_event", classifyOpenAIWSReadFallbackReason(errors.New("io")))
}
func TestOpenAIWSStoreDisabledConnMode(t *testing.T) {
svc := &OpenAIGatewayService{cfg: &config.Config{}}
svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true
require.Equal(t, openAIWSStoreDisabledConnModeStrict, svc.openAIWSStoreDisabledConnMode())
svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = "adaptive"
require.Equal(t, openAIWSStoreDisabledConnModeAdaptive, svc.openAIWSStoreDisabledConnMode())
svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = ""
svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false
require.Equal(t, openAIWSStoreDisabledConnModeOff, svc.openAIWSStoreDisabledConnMode())
}
func TestShouldForceNewConnOnStoreDisabled(t *testing.T) {
require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeStrict, ""))
require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeOff, "policy_violation"))
require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "policy_violation"))
require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "prewarm_message_too_big"))
require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "read_event"))
}
func TestOpenAIWSRetryMetricsSnapshot(t *testing.T) {
svc := &OpenAIGatewayService{}
svc.recordOpenAIWSRetryAttempt(150 * time.Millisecond)
svc.recordOpenAIWSRetryAttempt(0)
svc.recordOpenAIWSRetryExhausted()
svc.recordOpenAIWSNonRetryableFastFallback()
snapshot := svc.SnapshotOpenAIWSRetryMetrics()
require.Equal(t, int64(2), snapshot.RetryAttemptsTotal)
require.Equal(t, int64(150), snapshot.RetryBackoffMsTotal)
require.Equal(t, int64(1), snapshot.RetryExhaustedTotal)
require.Equal(t, int64(1), snapshot.NonRetryableFastFallbackTotal)
}
func TestShouldLogOpenAIWSPayloadSchema(t *testing.T) {
svc := &OpenAIGatewayService{cfg: &config.Config{}}
svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 0
require.True(t, svc.shouldLogOpenAIWSPayloadSchema(1), "首次尝试应始终记录 payload_schema")
require.False(t, svc.shouldLogOpenAIWSPayloadSchema(2))
svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 1
require.True(t, svc.shouldLogOpenAIWSPayloadSchema(2))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,127 @@
package service
import (
"fmt"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
var (
benchmarkOpenAIWSPayloadJSONSink string
benchmarkOpenAIWSStringSink string
benchmarkOpenAIWSBoolSink bool
benchmarkOpenAIWSBytesSink []byte
)
func BenchmarkOpenAIWSForwarderHotPath(b *testing.B) {
cfg := &config.Config{}
svc := &OpenAIGatewayService{cfg: cfg}
account := &Account{ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
reqBody := benchmarkOpenAIWSHotPathRequest()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
payload := svc.buildOpenAIWSCreatePayload(reqBody, account)
_, _ = applyOpenAIWSRetryPayloadStrategy(payload, 2)
setOpenAIWSTurnMetadata(payload, `{"trace":"bench","turn":"1"}`)
benchmarkOpenAIWSStringSink = openAIWSPayloadString(payload, "previous_response_id")
benchmarkOpenAIWSBoolSink = payload["tools"] != nil
benchmarkOpenAIWSStringSink = summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN)
benchmarkOpenAIWSStringSink = summarizeOpenAIWSInput(payload["input"])
benchmarkOpenAIWSPayloadJSONSink = payloadAsJSON(payload)
}
}
func benchmarkOpenAIWSHotPathRequest() map[string]any {
tools := make([]map[string]any, 0, 24)
for i := 0; i < 24; i++ {
tools = append(tools, map[string]any{
"type": "function",
"name": fmt.Sprintf("tool_%02d", i),
"description": "benchmark tool schema",
"parameters": map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string"},
"limit": map[string]any{"type": "number"},
},
"required": []string{"query"},
},
})
}
input := make([]map[string]any, 0, 16)
for i := 0; i < 16; i++ {
input = append(input, map[string]any{
"role": "user",
"type": "message",
"content": fmt.Sprintf("benchmark message %d", i),
})
}
return map[string]any{
"type": "response.create",
"model": "gpt-5.3-codex",
"input": input,
"tools": tools,
"parallel_tool_calls": true,
"previous_response_id": "resp_benchmark_prev",
"prompt_cache_key": "bench-cache-key",
"reasoning": map[string]any{"effort": "medium"},
"instructions": "benchmark instructions",
"store": false,
}
}
func BenchmarkOpenAIWSEventEnvelopeParse(b *testing.B) {
event := []byte(`{"type":"response.completed","response":{"id":"resp_bench_1","model":"gpt-5.1","usage":{"input_tokens":12,"output_tokens":8}}}`)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
eventType, responseID, response := parseOpenAIWSEventEnvelope(event)
benchmarkOpenAIWSStringSink = eventType
benchmarkOpenAIWSStringSink = responseID
benchmarkOpenAIWSBoolSink = response.Exists()
}
}
func BenchmarkOpenAIWSErrorEventFieldReuse(b *testing.B) {
event := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(event)
benchmarkOpenAIWSStringSink, benchmarkOpenAIWSBoolSink = classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw)
code, errType, errMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw)
benchmarkOpenAIWSStringSink = code
benchmarkOpenAIWSStringSink = errType
benchmarkOpenAIWSStringSink = errMsg
benchmarkOpenAIWSBoolSink = openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) > 0
}
}
func BenchmarkReplaceOpenAIWSMessageModel_NoMatchFastPath(b *testing.B) {
event := []byte(`{"type":"response.output_text.delta","delta":"hello world"}`)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model")
}
}
func BenchmarkReplaceOpenAIWSMessageModel_DualReplace(b *testing.B) {
event := []byte(`{"type":"response.completed","model":"gpt-5.1","response":{"id":"resp_1","model":"gpt-5.1"}}`)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model")
}
}

View File

@@ -0,0 +1,73 @@
package service
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestParseOpenAIWSEventEnvelope(t *testing.T) {
eventType, responseID, response := parseOpenAIWSEventEnvelope([]byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.1"}}`))
require.Equal(t, "response.completed", eventType)
require.Equal(t, "resp_1", responseID)
require.True(t, response.Exists())
require.Equal(t, `{"id":"resp_1","model":"gpt-5.1"}`, response.Raw)
eventType, responseID, response = parseOpenAIWSEventEnvelope([]byte(`{"type":"response.delta","id":"evt_1"}`))
require.Equal(t, "response.delta", eventType)
require.Equal(t, "evt_1", responseID)
require.False(t, response.Exists())
}
func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) {
usage := &OpenAIUsage{}
parseOpenAIWSResponseUsageFromCompletedEvent(
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`),
usage,
)
require.Equal(t, 11, usage.InputTokens)
require.Equal(t, 7, usage.OutputTokens)
require.Equal(t, 3, usage.CacheReadInputTokens)
}
func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) {
message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
wrappedReason, wrappedRecoverable := classifyOpenAIWSErrorEvent(message)
rawReason, rawRecoverable := classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw)
require.Equal(t, wrappedReason, rawReason)
require.Equal(t, wrappedRecoverable, rawRecoverable)
wrappedStatus := openAIWSErrorHTTPStatus(message)
rawStatus := openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw)
require.Equal(t, wrappedStatus, rawStatus)
require.Equal(t, http.StatusBadRequest, rawStatus)
wrappedCode, wrappedType, wrappedMsg := summarizeOpenAIWSErrorEventFields(message)
rawCode, rawType, rawMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw)
require.Equal(t, wrappedCode, rawCode)
require.Equal(t, wrappedType, rawType)
require.Equal(t, wrappedMsg, rawMsg)
}
func TestOpenAIWSMessageLikelyContainsToolCalls(t *testing.T) {
require.False(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_text.delta","delta":"hello"}`)))
require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"tool_calls":[{"id":"tc1"}]}}`)))
require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"type":"function_call"}}`)))
}
func TestReplaceOpenAIWSMessageModel_OptimizedStillCorrect(t *testing.T) {
noModel := []byte(`{"type":"response.output_text.delta","delta":"hello"}`)
require.Equal(t, string(noModel), string(replaceOpenAIWSMessageModel(noModel, "gpt-5.1", "custom-model")))
rootOnly := []byte(`{"type":"response.created","model":"gpt-5.1"}`)
require.Equal(t, `{"type":"response.created","model":"custom-model"}`, string(replaceOpenAIWSMessageModel(rootOnly, "gpt-5.1", "custom-model")))
responseOnly := []byte(`{"type":"response.completed","response":{"model":"gpt-5.1"}}`)
require.Equal(t, `{"type":"response.completed","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(responseOnly, "gpt-5.1", "custom-model")))
both := []byte(`{"model":"gpt-5.1","response":{"model":"gpt-5.1"}}`)
require.Equal(t, `{"model":"custom-model","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(both, "gpt-5.1", "custom-model")))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,714 @@
package service
import (
"context"
"encoding/json"
"errors"
"io"
"net"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
coderws "github.com/coder/websocket"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestIsOpenAIWSClientDisconnectError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
err error
want bool
}{
{name: "nil", err: nil, want: false},
{name: "io_eof", err: io.EOF, want: true},
{name: "net_closed", err: net.ErrClosed, want: true},
{name: "context_canceled", err: context.Canceled, want: true},
{name: "ws_normal_closure", err: coderws.CloseError{Code: coderws.StatusNormalClosure}, want: true},
{name: "ws_going_away", err: coderws.CloseError{Code: coderws.StatusGoingAway}, want: true},
{name: "ws_no_status", err: coderws.CloseError{Code: coderws.StatusNoStatusRcvd}, want: true},
{name: "ws_abnormal_1006", err: coderws.CloseError{Code: coderws.StatusAbnormalClosure}, want: true},
{name: "ws_policy_violation", err: coderws.CloseError{Code: coderws.StatusPolicyViolation}, want: false},
{name: "wrapped_eof_message", err: errors.New("failed to get reader: failed to read frame header: EOF"), want: true},
{name: "connection_reset_by_peer", err: errors.New("failed to read frame header: read tcp 127.0.0.1:1234->127.0.0.1:5678: read: connection reset by peer"), want: true},
{name: "broken_pipe", err: errors.New("write tcp 127.0.0.1:1234->127.0.0.1:5678: write: broken pipe"), want: true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, isOpenAIWSClientDisconnectError(tt.err))
})
}
}
func TestIsOpenAIWSIngressPreviousResponseNotFound(t *testing.T) {
t.Parallel()
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(nil))
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(errors.New("plain error")))
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(
wrapOpenAIWSIngressTurnError("read_upstream", errors.New("upstream read failed"), false),
))
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(
wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), true),
))
require.True(t, isOpenAIWSIngressPreviousResponseNotFound(
wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), false),
))
}
func TestOpenAIWSIngressPreviousResponseRecoveryEnabled(t *testing.T) {
t.Parallel()
var nilService *OpenAIGatewayService
require.True(t, nilService.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil service should default to enabled")
svcWithNilCfg := &OpenAIGatewayService{}
require.True(t, svcWithNilCfg.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil config should default to enabled")
svc := &OpenAIGatewayService{
cfg: &config.Config{},
}
require.False(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled(), "explicit config default should be false")
svc.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true
require.True(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled())
}
func TestDropPreviousResponseIDFromRawPayload(t *testing.T) {
t.Parallel()
t.Run("empty_payload", func(t *testing.T) {
updated, removed, err := dropPreviousResponseIDFromRawPayload(nil)
require.NoError(t, err)
require.False(t, removed)
require.Empty(t, updated)
})
t.Run("payload_without_previous_response_id", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
require.NoError(t, err)
require.False(t, removed)
require.Equal(t, string(payload), string(updated))
})
t.Run("normal_delete_success", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`)
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
require.NoError(t, err)
require.True(t, removed)
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
})
t.Run("duplicate_keys_are_removed", func(t *testing.T) {
payload := []byte(`{"type":"response.create","previous_response_id":"resp_a","input":[],"previous_response_id":"resp_b"}`)
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
require.NoError(t, err)
require.True(t, removed)
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
})
t.Run("nil_delete_fn_uses_default_delete_logic", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`)
updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, nil)
require.NoError(t, err)
require.True(t, removed)
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
})
t.Run("delete_error", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`)
updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, func(_ []byte, _ string) ([]byte, error) {
return nil, errors.New("delete failed")
})
require.Error(t, err)
require.False(t, removed)
require.Equal(t, string(payload), string(updated))
})
t.Run("malformed_json_is_still_best_effort_deleted", func(t *testing.T) {
payload := []byte(`{"type":"response.create","previous_response_id":"resp_abc"`)
require.True(t, gjson.GetBytes(payload, "previous_response_id").Exists())
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
require.NoError(t, err)
require.True(t, removed)
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
})
}
func TestAlignStoreDisabledPreviousResponseID(t *testing.T) {
t.Parallel()
t.Run("empty_payload", func(t *testing.T) {
updated, changed, err := alignStoreDisabledPreviousResponseID(nil, "resp_target")
require.NoError(t, err)
require.False(t, changed)
require.Empty(t, updated)
})
t.Run("empty_expected", func(t *testing.T) {
payload := []byte(`{"type":"response.create","previous_response_id":"resp_old"}`)
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "")
require.NoError(t, err)
require.False(t, changed)
require.Equal(t, string(payload), string(updated))
})
t.Run("missing_previous_response_id", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
require.NoError(t, err)
require.False(t, changed)
require.Equal(t, string(payload), string(updated))
})
t.Run("already_aligned", func(t *testing.T) {
payload := []byte(`{"type":"response.create","previous_response_id":"resp_target"}`)
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
require.NoError(t, err)
require.False(t, changed)
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
})
t.Run("mismatch_rewrites_to_expected", func(t *testing.T) {
payload := []byte(`{"type":"response.create","previous_response_id":"resp_old","input":[]}`)
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
require.NoError(t, err)
require.True(t, changed)
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
})
t.Run("duplicate_keys_rewrites_to_single_expected", func(t *testing.T) {
payload := []byte(`{"type":"response.create","previous_response_id":"resp_old_1","input":[],"previous_response_id":"resp_old_2"}`)
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
require.NoError(t, err)
require.True(t, changed)
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
})
}
func TestSetPreviousResponseIDToRawPayload(t *testing.T) {
t.Parallel()
t.Run("empty_payload", func(t *testing.T) {
updated, err := setPreviousResponseIDToRawPayload(nil, "resp_target")
require.NoError(t, err)
require.Empty(t, updated)
})
t.Run("empty_previous_response_id", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
updated, err := setPreviousResponseIDToRawPayload(payload, "")
require.NoError(t, err)
require.Equal(t, string(payload), string(updated))
})
t.Run("set_previous_response_id_when_missing", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
updated, err := setPreviousResponseIDToRawPayload(payload, "resp_target")
require.NoError(t, err)
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
require.Equal(t, "gpt-5.1", gjson.GetBytes(updated, "model").String())
})
t.Run("overwrite_existing_previous_response_id", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_old"}`)
updated, err := setPreviousResponseIDToRawPayload(payload, "resp_new")
require.NoError(t, err)
require.Equal(t, "resp_new", gjson.GetBytes(updated, "previous_response_id").String())
})
}
func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
storeDisabled bool
turn int
hasFunctionCallOutput bool
currentPreviousResponse string
expectedPrevious string
want bool
}{
{
name: "infer_when_all_conditions_match",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: "resp_1",
want: true,
},
{
name: "skip_when_store_enabled",
storeDisabled: false,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_on_first_turn",
storeDisabled: true,
turn: 1,
hasFunctionCallOutput: true,
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_without_function_call_output",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: false,
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_when_request_already_has_previous_response_id",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
currentPreviousResponse: "resp_client",
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_when_last_turn_response_id_missing",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: "",
want: false,
},
{
name: "trim_whitespace_before_judgement",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: " resp_2 ",
want: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := shouldInferIngressFunctionCallOutputPreviousResponseID(
tt.storeDisabled,
tt.turn,
tt.hasFunctionCallOutput,
tt.currentPreviousResponse,
tt.expectedPrevious,
)
require.Equal(t, tt.want, got)
})
}
}
func TestOpenAIWSInputIsPrefixExtended(t *testing.T) {
t.Parallel()
tests := []struct {
name string
previous []byte
current []byte
want bool
expectErr bool
}{
{
name: "both_missing_input",
previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`),
current: []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_1"}`),
want: true,
},
{
name: "previous_missing_current_empty_array",
previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`),
current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`),
want: true,
},
{
name: "previous_missing_current_non_empty_array",
previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`),
current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[{"type":"input_text","text":"hello"}]}`),
want: false,
},
{
name: "array_prefix_match",
previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`),
current: []byte(`{"input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}]}`),
want: true,
},
{
name: "array_prefix_mismatch",
previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`),
current: []byte(`{"input":[{"type":"input_text","text":"different"}]}`),
want: false,
},
{
name: "current_shorter_than_previous",
previous: []byte(`{"input":[{"type":"input_text","text":"a"},{"type":"input_text","text":"b"}]}`),
current: []byte(`{"input":[{"type":"input_text","text":"a"}]}`),
want: false,
},
{
name: "previous_has_input_current_missing",
previous: []byte(`{"input":[{"type":"input_text","text":"a"}]}`),
current: []byte(`{"model":"gpt-5.1"}`),
want: false,
},
{
name: "input_string_treated_as_single_item",
previous: []byte(`{"input":"hello"}`),
current: []byte(`{"input":"hello"}`),
want: true,
},
{
name: "current_invalid_input_json",
previous: []byte(`{"input":[]}`),
current: []byte(`{"input":[}`),
expectErr: true,
},
{
name: "invalid_input_json",
previous: []byte(`{"input":[}`),
current: []byte(`{"input":[]}`),
expectErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := openAIWSInputIsPrefixExtended(tt.previous, tt.current)
if tt.expectErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.want, got)
})
}
}
func TestNormalizeOpenAIWSJSONForCompare(t *testing.T) {
t.Parallel()
normalized, err := normalizeOpenAIWSJSONForCompare([]byte(`{"b":2,"a":1}`))
require.NoError(t, err)
require.Equal(t, `{"a":1,"b":2}`, string(normalized))
_, err = normalizeOpenAIWSJSONForCompare([]byte(" "))
require.Error(t, err)
_, err = normalizeOpenAIWSJSONForCompare([]byte(`{"a":`))
require.Error(t, err)
}
func TestNormalizeOpenAIWSJSONForCompareOrRaw(t *testing.T) {
t.Parallel()
require.Equal(t, `{"a":1,"b":2}`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"b":2,"a":1}`))))
require.Equal(t, `{"a":`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"a":`))))
}
func TestNormalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(t *testing.T) {
t.Parallel()
normalized, err := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(
[]byte(`{"model":"gpt-5.1","input":[1],"previous_response_id":"resp_x","metadata":{"b":2,"a":1}}`),
)
require.NoError(t, err)
require.False(t, gjson.GetBytes(normalized, "input").Exists())
require.False(t, gjson.GetBytes(normalized, "previous_response_id").Exists())
require.Equal(t, float64(1), gjson.GetBytes(normalized, "metadata.a").Float())
_, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(nil)
require.Error(t, err)
_, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID([]byte(`[]`))
require.Error(t, err)
}
func TestOpenAIWSExtractNormalizedInputSequence(t *testing.T) {
t.Parallel()
t.Run("empty_payload", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence(nil)
require.NoError(t, err)
require.False(t, exists)
require.Nil(t, items)
})
t.Run("input_missing", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"type":"response.create"}`))
require.NoError(t, err)
require.False(t, exists)
require.Nil(t, items)
})
t.Run("input_array", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[{"type":"input_text","text":"hello"}]}`))
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 1)
})
t.Run("input_object", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":{"type":"input_text","text":"hello"}}`))
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 1)
})
t.Run("input_string", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":"hello"}`))
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 1)
require.Equal(t, `"hello"`, string(items[0]))
})
t.Run("input_number", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":42}`))
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 1)
require.Equal(t, "42", string(items[0]))
})
t.Run("input_bool", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":true}`))
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 1)
require.Equal(t, "true", string(items[0]))
})
t.Run("input_null", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":null}`))
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 1)
require.Equal(t, "null", string(items[0]))
})
t.Run("input_invalid_array_json", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[}`))
require.Error(t, err)
require.True(t, exists)
require.Nil(t, items)
})
}
func TestShouldKeepIngressPreviousResponseID(t *testing.T) {
t.Parallel()
previousPayload := []byte(`{
"type":"response.create",
"model":"gpt-5.1",
"store":false,
"tools":[{"type":"function","name":"tool_a"}],
"input":[{"type":"input_text","text":"hello"}]
}`)
currentStrictPayload := []byte(`{
"type":"response.create",
"model":"gpt-5.1",
"store":false,
"tools":[{"name":"tool_a","type":"function"}],
"previous_response_id":"resp_turn_1",
"input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}]
}`)
t.Run("strict_incremental_keep", func(t *testing.T) {
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_1", false)
require.NoError(t, err)
require.True(t, keep)
require.Equal(t, "strict_incremental_ok", reason)
})
t.Run("missing_previous_response_id", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`)
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "missing_previous_response_id", reason)
})
t.Run("missing_last_turn_response_id", func(t *testing.T) {
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "", false)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "missing_last_turn_response_id", reason)
})
t.Run("previous_response_id_mismatch", func(t *testing.T) {
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_other", false)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "previous_response_id_mismatch", reason)
})
t.Run("missing_previous_turn_payload", func(t *testing.T) {
keep, reason, err := shouldKeepIngressPreviousResponseID(nil, currentStrictPayload, "resp_turn_1", false)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "missing_previous_turn_payload", reason)
})
t.Run("non_input_changed", func(t *testing.T) {
payload := []byte(`{
"type":"response.create",
"model":"gpt-5.1-mini",
"store":false,
"tools":[{"type":"function","name":"tool_a"}],
"previous_response_id":"resp_turn_1",
"input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]
}`)
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "non_input_changed", reason)
})
t.Run("delta_input_keeps_previous_response_id", func(t *testing.T) {
payload := []byte(`{
"type":"response.create",
"model":"gpt-5.1",
"store":false,
"tools":[{"type":"function","name":"tool_a"}],
"previous_response_id":"resp_turn_1",
"input":[{"type":"input_text","text":"different"}]
}`)
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
require.NoError(t, err)
require.True(t, keep)
require.Equal(t, "strict_incremental_ok", reason)
})
t.Run("function_call_output_keeps_previous_response_id", func(t *testing.T) {
payload := []byte(`{
"type":"response.create",
"model":"gpt-5.1",
"store":false,
"previous_response_id":"resp_external",
"input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
}`)
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", true)
require.NoError(t, err)
require.True(t, keep)
require.Equal(t, "has_function_call_output", reason)
})
t.Run("non_input_compare_error", func(t *testing.T) {
keep, reason, err := shouldKeepIngressPreviousResponseID([]byte(`[]`), currentStrictPayload, "resp_turn_1", false)
require.Error(t, err)
require.False(t, keep)
require.Equal(t, "non_input_compare_error", reason)
})
t.Run("current_payload_compare_error", func(t *testing.T) {
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, []byte(`{"previous_response_id":"resp_turn_1","input":[}`), "resp_turn_1", false)
require.Error(t, err)
require.False(t, keep)
require.Equal(t, "non_input_compare_error", reason)
})
}
func TestBuildOpenAIWSReplayInputSequence(t *testing.T) {
t.Parallel()
lastFull := []json.RawMessage{
json.RawMessage(`{"type":"input_text","text":"hello"}`),
}
t.Run("no_previous_response_id_use_current", func(t *testing.T) {
items, exists, err := buildOpenAIWSReplayInputSequence(
lastFull,
true,
[]byte(`{"input":[{"type":"input_text","text":"new"}]}`),
false,
)
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 1)
require.Equal(t, "new", gjson.GetBytes(items[0], "text").String())
})
t.Run("previous_response_id_delta_append", func(t *testing.T) {
items, exists, err := buildOpenAIWSReplayInputSequence(
lastFull,
true,
[]byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"world"}]}`),
true,
)
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 2)
require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String())
require.Equal(t, "world", gjson.GetBytes(items[1], "text").String())
})
t.Run("previous_response_id_full_input_replace", func(t *testing.T) {
items, exists, err := buildOpenAIWSReplayInputSequence(
lastFull,
true,
[]byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]}`),
true,
)
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 2)
require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String())
require.Equal(t, "world", gjson.GetBytes(items[1], "text").String())
})
}
func TestSetOpenAIWSPayloadInputSequence(t *testing.T) {
t.Parallel()
t.Run("set_items", func(t *testing.T) {
original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`)
items := []json.RawMessage{
json.RawMessage(`{"type":"input_text","text":"hello"}`),
json.RawMessage(`{"type":"input_text","text":"world"}`),
}
updated, err := setOpenAIWSPayloadInputSequence(original, items, true)
require.NoError(t, err)
require.Equal(t, "hello", gjson.GetBytes(updated, "input.0.text").String())
require.Equal(t, "world", gjson.GetBytes(updated, "input.1.text").String())
})
t.Run("preserve_empty_array_not_null", func(t *testing.T) {
original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`)
updated, err := setOpenAIWSPayloadInputSequence(original, nil, true)
require.NoError(t, err)
require.True(t, gjson.GetBytes(updated, "input").IsArray())
require.Len(t, gjson.GetBytes(updated, "input").Array(), 0)
require.False(t, gjson.GetBytes(updated, "input").Type == gjson.Null)
})
}
func TestCloneOpenAIWSRawMessages(t *testing.T) {
t.Parallel()
t.Run("nil_slice", func(t *testing.T) {
cloned := cloneOpenAIWSRawMessages(nil)
require.Nil(t, cloned)
})
t.Run("empty_slice", func(t *testing.T) {
items := make([]json.RawMessage, 0)
cloned := cloneOpenAIWSRawMessages(items)
require.NotNil(t, cloned)
require.Len(t, cloned, 0)
})
}

View File

@@ -0,0 +1,50 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestApplyOpenAIWSRetryPayloadStrategy_KeepPromptCacheKey(t *testing.T) {
payload := map[string]any{
"model": "gpt-5.3-codex",
"prompt_cache_key": "pcache_123",
"include": []any{"reasoning.encrypted_content"},
"text": map[string]any{
"verbosity": "low",
},
"tools": []any{map[string]any{"type": "function"}},
}
strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 3)
require.Equal(t, "trim_optional_fields", strategy)
require.Contains(t, removed, "include")
require.NotContains(t, removed, "prompt_cache_key")
require.Equal(t, "pcache_123", payload["prompt_cache_key"])
require.NotContains(t, payload, "include")
require.Contains(t, payload, "text")
}
func TestApplyOpenAIWSRetryPayloadStrategy_AttemptSixKeepsSemanticFields(t *testing.T) {
payload := map[string]any{
"prompt_cache_key": "pcache_456",
"instructions": "long instructions",
"tools": []any{map[string]any{"type": "function"}},
"parallel_tool_calls": true,
"tool_choice": "auto",
"include": []any{"reasoning.encrypted_content"},
"text": map[string]any{"verbosity": "high"},
}
strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 6)
require.Equal(t, "trim_optional_fields", strategy)
require.Contains(t, removed, "include")
require.NotContains(t, removed, "prompt_cache_key")
require.Equal(t, "pcache_456", payload["prompt_cache_key"])
require.Contains(t, payload, "instructions")
require.Contains(t, payload, "tools")
require.Contains(t, payload, "tool_choice")
require.Contains(t, payload, "parallel_tool_calls")
require.Contains(t, payload, "text")
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,58 @@
package service
import (
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
func BenchmarkOpenAIWSPoolAcquire(b *testing.B) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 256
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(&openAIWSCountingDialer{})
account := &Account{ID: 1001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
req := openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
}
ctx := context.Background()
lease, err := pool.Acquire(ctx, req)
if err != nil {
b.Fatalf("warm acquire failed: %v", err)
}
lease.Release()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var (
got *openAIWSConnLease
acquireErr error
)
for retry := 0; retry < 3; retry++ {
got, acquireErr = pool.Acquire(ctx, req)
if acquireErr == nil {
break
}
if !errors.Is(acquireErr, errOpenAIWSConnClosed) {
break
}
}
if acquireErr != nil {
b.Fatalf("acquire failed: %v", acquireErr)
}
got.Release()
}
})
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,117 @@
package service
import "github.com/Wei-Shaw/sub2api/internal/config"
// OpenAIUpstreamTransport 表示 OpenAI 上游传输协议。
type OpenAIUpstreamTransport string
const (
OpenAIUpstreamTransportAny OpenAIUpstreamTransport = ""
OpenAIUpstreamTransportHTTPSSE OpenAIUpstreamTransport = "http_sse"
OpenAIUpstreamTransportResponsesWebsocket OpenAIUpstreamTransport = "responses_websockets"
OpenAIUpstreamTransportResponsesWebsocketV2 OpenAIUpstreamTransport = "responses_websockets_v2"
)
// OpenAIWSProtocolDecision 表示协议决策结果。
type OpenAIWSProtocolDecision struct {
Transport OpenAIUpstreamTransport
Reason string
}
// OpenAIWSProtocolResolver 定义 OpenAI 上游协议决策。
type OpenAIWSProtocolResolver interface {
Resolve(account *Account) OpenAIWSProtocolDecision
}
type defaultOpenAIWSProtocolResolver struct {
cfg *config.Config
}
// NewOpenAIWSProtocolResolver 创建默认协议决策器。
func NewOpenAIWSProtocolResolver(cfg *config.Config) OpenAIWSProtocolResolver {
return &defaultOpenAIWSProtocolResolver{cfg: cfg}
}
func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProtocolDecision {
if account == nil {
return openAIWSHTTPDecision("account_missing")
}
if !account.IsOpenAI() {
return openAIWSHTTPDecision("platform_not_openai")
}
if account.IsOpenAIWSForceHTTPEnabled() {
return openAIWSHTTPDecision("account_force_http")
}
if r == nil || r.cfg == nil {
return openAIWSHTTPDecision("config_missing")
}
wsCfg := r.cfg.Gateway.OpenAIWS
if wsCfg.ForceHTTP {
return openAIWSHTTPDecision("global_force_http")
}
if !wsCfg.Enabled {
return openAIWSHTTPDecision("global_disabled")
}
if account.IsOpenAIOAuth() {
if !wsCfg.OAuthEnabled {
return openAIWSHTTPDecision("oauth_disabled")
}
} else if account.IsOpenAIApiKey() {
if !wsCfg.APIKeyEnabled {
return openAIWSHTTPDecision("apikey_disabled")
}
} else {
return openAIWSHTTPDecision("unknown_auth_type")
}
if wsCfg.ModeRouterV2Enabled {
mode := account.ResolveOpenAIResponsesWebSocketV2Mode(wsCfg.IngressModeDefault)
switch mode {
case OpenAIWSIngressModeOff:
return openAIWSHTTPDecision("account_mode_off")
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
// continue
default:
return openAIWSHTTPDecision("account_mode_off")
}
if account.Concurrency <= 0 {
return openAIWSHTTPDecision("account_concurrency_invalid")
}
if wsCfg.ResponsesWebsocketsV2 {
return OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
Reason: "ws_v2_mode_" + mode,
}
}
if wsCfg.ResponsesWebsockets {
return OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportResponsesWebsocket,
Reason: "ws_v1_mode_" + mode,
}
}
return openAIWSHTTPDecision("feature_disabled")
}
if !account.IsOpenAIResponsesWebSocketV2Enabled() {
return openAIWSHTTPDecision("account_disabled")
}
if wsCfg.ResponsesWebsocketsV2 {
return OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
Reason: "ws_v2_enabled",
}
}
if wsCfg.ResponsesWebsockets {
return OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportResponsesWebsocket,
Reason: "ws_v1_enabled",
}
}
return openAIWSHTTPDecision("feature_disabled")
}
func openAIWSHTTPDecision(reason string) OpenAIWSProtocolDecision {
return OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportHTTPSSE,
Reason: reason,
}
}

View File

@@ -0,0 +1,203 @@
package service
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestOpenAIWSProtocolResolver_Resolve(t *testing.T) {
baseCfg := &config.Config{}
baseCfg.Gateway.OpenAIWS.Enabled = true
baseCfg.Gateway.OpenAIWS.OAuthEnabled = true
baseCfg.Gateway.OpenAIWS.APIKeyEnabled = true
baseCfg.Gateway.OpenAIWS.ResponsesWebsockets = false
baseCfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
openAIOAuthEnabled := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
},
}
t.Run("v2优先", func(t *testing.T) {
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_enabled", decision.Reason)
})
t.Run("v2关闭时回退v1", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocket, decision.Transport)
require.Equal(t, "ws_v1_enabled", decision.Reason)
})
t.Run("透传开关不影响WS协议判定", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
"openai_passthrough": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_enabled", decision.Reason)
})
t.Run("账号级强制HTTP", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
"openai_ws_force_http": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_force_http", decision.Reason)
})
t.Run("全局关闭保持HTTP", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.Enabled = false
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "global_disabled", decision.Reason)
})
t.Run("账号开关关闭保持HTTP", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_oauth_responses_websockets_v2_enabled": false,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_disabled", decision.Reason)
})
t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_disabled", decision.Reason)
})
t.Run("兼容旧键openai_ws_enabled", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_ws_enabled": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_enabled", decision.Reason)
})
t.Run("按账号类型开关控制", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.OAuthEnabled = false
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "oauth_disabled", decision.Reason)
})
t.Run("API Key 账号关闭开关时回退HTTP", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.APIKeyEnabled = false
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "apikey_disabled", decision.Reason)
})
t.Run("未知认证类型回退HTTP", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: "unknown_type",
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "unknown_auth_type", decision.Reason)
})
}
func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
t.Run("dedicated mode routes to ws v2", func(t *testing.T) {
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_dedicated", decision.Reason)
})
t.Run("off mode routes to http", func(t *testing.T) {
offAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(offAccount)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_mode_off", decision.Reason)
})
t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) {
legacyAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_shared", decision.Reason)
})
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
invalidConcurrency := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_concurrency_invalid", decision.Reason)
})
}

View File

@@ -0,0 +1,440 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
)
const (
openAIWSResponseAccountCachePrefix = "openai:response:"
openAIWSStateStoreCleanupInterval = time.Minute
openAIWSStateStoreCleanupMaxPerMap = 512
openAIWSStateStoreMaxEntriesPerMap = 65536
openAIWSStateStoreRedisTimeout = 3 * time.Second
)
type openAIWSAccountBinding struct {
accountID int64
expiresAt time.Time
}
type openAIWSConnBinding struct {
connID string
expiresAt time.Time
}
type openAIWSTurnStateBinding struct {
turnState string
expiresAt time.Time
}
type openAIWSSessionConnBinding struct {
connID string
expiresAt time.Time
}
// OpenAIWSStateStore 管理 WSv2 的粘连状态。
// - response_id -> account_id 用于续链路由
// - response_id -> conn_id 用于连接内上下文复用
//
// response_id -> account_id 优先走 GatewayCacheRedis同时维护本地热缓存。
// response_id -> conn_id 仅在本进程内有效。
type OpenAIWSStateStore interface {
BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error
GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error)
DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error
BindResponseConn(responseID, connID string, ttl time.Duration)
GetResponseConn(responseID string) (string, bool)
DeleteResponseConn(responseID string)
BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration)
GetSessionTurnState(groupID int64, sessionHash string) (string, bool)
DeleteSessionTurnState(groupID int64, sessionHash string)
BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration)
GetSessionConn(groupID int64, sessionHash string) (string, bool)
DeleteSessionConn(groupID int64, sessionHash string)
}
type defaultOpenAIWSStateStore struct {
cache GatewayCache
responseToAccountMu sync.RWMutex
responseToAccount map[string]openAIWSAccountBinding
responseToConnMu sync.RWMutex
responseToConn map[string]openAIWSConnBinding
sessionToTurnStateMu sync.RWMutex
sessionToTurnState map[string]openAIWSTurnStateBinding
sessionToConnMu sync.RWMutex
sessionToConn map[string]openAIWSSessionConnBinding
lastCleanupUnixNano atomic.Int64
}
// NewOpenAIWSStateStore 创建默认 WS 状态存储。
func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore {
store := &defaultOpenAIWSStateStore{
cache: cache,
responseToAccount: make(map[string]openAIWSAccountBinding, 256),
responseToConn: make(map[string]openAIWSConnBinding, 256),
sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256),
sessionToConn: make(map[string]openAIWSSessionConnBinding, 256),
}
store.lastCleanupUnixNano.Store(time.Now().UnixNano())
return store
}
func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" || accountID <= 0 {
return nil
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
expiresAt := time.Now().Add(ttl)
s.responseToAccountMu.Lock()
ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap)
s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt}
s.responseToAccountMu.Unlock()
if s.cache == nil {
return nil
}
cacheKey := openAIWSResponseAccountCacheKey(id)
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
defer cancel()
return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl)
}
func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return 0, nil
}
s.maybeCleanup()
now := time.Now()
s.responseToAccountMu.RLock()
if binding, ok := s.responseToAccount[id]; ok {
if now.Before(binding.expiresAt) {
accountID := binding.accountID
s.responseToAccountMu.RUnlock()
return accountID, nil
}
}
s.responseToAccountMu.RUnlock()
if s.cache == nil {
return 0, nil
}
cacheKey := openAIWSResponseAccountCacheKey(id)
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
defer cancel()
accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey)
if err != nil || accountID <= 0 {
// 缓存读取失败不阻断主流程,按未命中降级。
return 0, nil
}
return accountID, nil
}
func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return nil
}
s.responseToAccountMu.Lock()
delete(s.responseToAccount, id)
s.responseToAccountMu.Unlock()
if s.cache == nil {
return nil
}
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
defer cancel()
return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id))
}
func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) {
id := normalizeOpenAIWSResponseID(responseID)
conn := strings.TrimSpace(connID)
if id == "" || conn == "" {
return
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
s.responseToConnMu.Lock()
ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap)
s.responseToConn[id] = openAIWSConnBinding{
connID: conn,
expiresAt: time.Now().Add(ttl),
}
s.responseToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return "", false
}
s.maybeCleanup()
now := time.Now()
s.responseToConnMu.RLock()
binding, ok := s.responseToConn[id]
s.responseToConnMu.RUnlock()
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
return "", false
}
return binding.connID, true
}
func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return
}
s.responseToConnMu.Lock()
delete(s.responseToConn, id)
s.responseToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
state := strings.TrimSpace(turnState)
if key == "" || state == "" {
return
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
s.sessionToTurnStateMu.Lock()
ensureBindingCapacity(s.sessionToTurnState, key, openAIWSStateStoreMaxEntriesPerMap)
s.sessionToTurnState[key] = openAIWSTurnStateBinding{
turnState: state,
expiresAt: time.Now().Add(ttl),
}
s.sessionToTurnStateMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return "", false
}
s.maybeCleanup()
now := time.Now()
s.sessionToTurnStateMu.RLock()
binding, ok := s.sessionToTurnState[key]
s.sessionToTurnStateMu.RUnlock()
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.turnState) == "" {
return "", false
}
return binding.turnState, true
}
func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessionHash string) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return
}
s.sessionToTurnStateMu.Lock()
delete(s.sessionToTurnState, key)
s.sessionToTurnStateMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
conn := strings.TrimSpace(connID)
if key == "" || conn == "" {
return
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
s.sessionToConnMu.Lock()
ensureBindingCapacity(s.sessionToConn, key, openAIWSStateStoreMaxEntriesPerMap)
s.sessionToConn[key] = openAIWSSessionConnBinding{
connID: conn,
expiresAt: time.Now().Add(ttl),
}
s.sessionToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash string) (string, bool) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return "", false
}
s.maybeCleanup()
now := time.Now()
s.sessionToConnMu.RLock()
binding, ok := s.sessionToConn[key]
s.sessionToConnMu.RUnlock()
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
return "", false
}
return binding.connID, true
}
func (s *defaultOpenAIWSStateStore) DeleteSessionConn(groupID int64, sessionHash string) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return
}
s.sessionToConnMu.Lock()
delete(s.sessionToConn, key)
s.sessionToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) maybeCleanup() {
if s == nil {
return
}
now := time.Now()
last := time.Unix(0, s.lastCleanupUnixNano.Load())
if now.Sub(last) < openAIWSStateStoreCleanupInterval {
return
}
if !s.lastCleanupUnixNano.CompareAndSwap(last.UnixNano(), now.UnixNano()) {
return
}
// 增量限额清理,避免高规模下一次性全量扫描导致长时间阻塞。
s.responseToAccountMu.Lock()
cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap)
s.responseToAccountMu.Unlock()
s.responseToConnMu.Lock()
cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap)
s.responseToConnMu.Unlock()
s.sessionToTurnStateMu.Lock()
cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap)
s.sessionToTurnStateMu.Unlock()
s.sessionToConnMu.Lock()
cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap)
s.sessionToConnMu.Unlock()
}
func cleanupExpiredAccountBindings(bindings map[string]openAIWSAccountBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) {
if len(bindings) < maxEntries || maxEntries <= 0 {
return
}
if _, exists := bindings[incomingKey]; exists {
return
}
// 固定上限保护:淘汰任意一项,优先保证内存有界。
for key := range bindings {
delete(bindings, key)
return
}
}
func normalizeOpenAIWSResponseID(responseID string) string {
return strings.TrimSpace(responseID)
}
func openAIWSResponseAccountCacheKey(responseID string) string {
sum := sha256.Sum256([]byte(responseID))
return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:])
}
func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration {
if ttl <= 0 {
return time.Hour
}
return ttl
}
func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string {
hash := strings.TrimSpace(sessionHash)
if hash == "" {
return ""
}
return fmt.Sprintf("%d:%s", groupID, hash)
}
func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = context.Background()
}
return context.WithTimeout(ctx, openAIWSStateStoreRedisTimeout)
}

View File

@@ -0,0 +1,235 @@
package service
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestOpenAIWSStateStore_BindGetDeleteResponseAccount(t *testing.T) {
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
ctx := context.Background()
groupID := int64(7)
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_abc", 101, time.Minute))
accountID, err := store.GetResponseAccount(ctx, groupID, "resp_abc")
require.NoError(t, err)
require.Equal(t, int64(101), accountID)
require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_abc"))
accountID, err = store.GetResponseAccount(ctx, groupID, "resp_abc")
require.NoError(t, err)
require.Zero(t, accountID)
}
func TestOpenAIWSStateStore_ResponseConnTTL(t *testing.T) {
store := NewOpenAIWSStateStore(nil)
store.BindResponseConn("resp_conn", "conn_1", 30*time.Millisecond)
connID, ok := store.GetResponseConn("resp_conn")
require.True(t, ok)
require.Equal(t, "conn_1", connID)
time.Sleep(60 * time.Millisecond)
_, ok = store.GetResponseConn("resp_conn")
require.False(t, ok)
}
func TestOpenAIWSStateStore_SessionTurnStateTTL(t *testing.T) {
store := NewOpenAIWSStateStore(nil)
store.BindSessionTurnState(9, "session_hash_1", "turn_state_1", 30*time.Millisecond)
state, ok := store.GetSessionTurnState(9, "session_hash_1")
require.True(t, ok)
require.Equal(t, "turn_state_1", state)
// group 隔离
_, ok = store.GetSessionTurnState(10, "session_hash_1")
require.False(t, ok)
time.Sleep(60 * time.Millisecond)
_, ok = store.GetSessionTurnState(9, "session_hash_1")
require.False(t, ok)
}
func TestOpenAIWSStateStore_SessionConnTTL(t *testing.T) {
store := NewOpenAIWSStateStore(nil)
store.BindSessionConn(9, "session_hash_conn_1", "conn_1", 30*time.Millisecond)
connID, ok := store.GetSessionConn(9, "session_hash_conn_1")
require.True(t, ok)
require.Equal(t, "conn_1", connID)
// group 隔离
_, ok = store.GetSessionConn(10, "session_hash_conn_1")
require.False(t, ok)
time.Sleep(60 * time.Millisecond)
_, ok = store.GetSessionConn(9, "session_hash_conn_1")
require.False(t, ok)
}
func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing.T) {
cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
store := NewOpenAIWSStateStore(cache)
ctx := context.Background()
groupID := int64(17)
responseID := "resp_cache_stale"
cacheKey := openAIWSResponseAccountCacheKey(responseID)
cache.sessionBindings[cacheKey] = 501
accountID, err := store.GetResponseAccount(ctx, groupID, responseID)
require.NoError(t, err)
require.Equal(t, int64(501), accountID)
delete(cache.sessionBindings, cacheKey)
accountID, err = store.GetResponseAccount(ctx, groupID, responseID)
require.NoError(t, err)
require.Zero(t, accountID, "上游缓存失效后不应继续命中本地陈旧映射")
}
func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T) {
raw := NewOpenAIWSStateStore(nil)
store, ok := raw.(*defaultOpenAIWSStateStore)
require.True(t, ok)
expiredAt := time.Now().Add(-time.Minute)
total := 2048
store.responseToConnMu.Lock()
for i := 0; i < total; i++ {
store.responseToConn[fmt.Sprintf("resp_%d", i)] = openAIWSConnBinding{
connID: "conn_incremental",
expiresAt: expiredAt,
}
}
store.responseToConnMu.Unlock()
store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano())
store.maybeCleanup()
store.responseToConnMu.RLock()
remainingAfterFirst := len(store.responseToConn)
store.responseToConnMu.RUnlock()
require.Less(t, remainingAfterFirst, total, "单轮 cleanup 应至少有进展")
require.Greater(t, remainingAfterFirst, 0, "增量清理不要求单轮清空全部键")
for i := 0; i < 8; i++ {
store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano())
store.maybeCleanup()
}
store.responseToConnMu.RLock()
remaining := len(store.responseToConn)
store.responseToConnMu.RUnlock()
require.Zero(t, remaining, "多轮 cleanup 后应逐步清空全部过期键")
}
func TestEnsureBindingCapacity_EvictsOneWhenMapIsFull(t *testing.T) {
bindings := map[string]int{
"a": 1,
"b": 2,
}
ensureBindingCapacity(bindings, "c", 2)
bindings["c"] = 3
require.Len(t, bindings, 2)
require.Equal(t, 3, bindings["c"])
}
func TestEnsureBindingCapacity_DoesNotEvictWhenUpdatingExistingKey(t *testing.T) {
bindings := map[string]int{
"a": 1,
"b": 2,
}
ensureBindingCapacity(bindings, "a", 2)
bindings["a"] = 9
require.Len(t, bindings, 2)
require.Equal(t, 9, bindings["a"])
}
type openAIWSStateStoreTimeoutProbeCache struct {
setHasDeadline bool
getHasDeadline bool
deleteHasDeadline bool
setDeadlineDelta time.Duration
getDeadlineDelta time.Duration
delDeadlineDelta time.Duration
}
func (c *openAIWSStateStoreTimeoutProbeCache) GetSessionAccountID(ctx context.Context, _ int64, _ string) (int64, error) {
if deadline, ok := ctx.Deadline(); ok {
c.getHasDeadline = true
c.getDeadlineDelta = time.Until(deadline)
}
return 123, nil
}
func (c *openAIWSStateStoreTimeoutProbeCache) SetSessionAccountID(ctx context.Context, _ int64, _ string, _ int64, _ time.Duration) error {
if deadline, ok := ctx.Deadline(); ok {
c.setHasDeadline = true
c.setDeadlineDelta = time.Until(deadline)
}
return errors.New("set failed")
}
func (c *openAIWSStateStoreTimeoutProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error {
return nil
}
func (c *openAIWSStateStoreTimeoutProbeCache) DeleteSessionAccountID(ctx context.Context, _ int64, _ string) error {
if deadline, ok := ctx.Deadline(); ok {
c.deleteHasDeadline = true
c.delDeadlineDelta = time.Until(deadline)
}
return nil
}
func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) {
probe := &openAIWSStateStoreTimeoutProbeCache{}
store := NewOpenAIWSStateStore(probe)
ctx := context.Background()
groupID := int64(5)
err := store.BindResponseAccount(ctx, groupID, "resp_timeout_probe", 11, time.Minute)
require.Error(t, err)
accountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_timeout_probe")
require.NoError(t, getErr)
require.Equal(t, int64(11), accountID, "本地缓存命中应优先返回已绑定账号")
require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_timeout_probe"))
require.True(t, probe.setHasDeadline, "SetSessionAccountID 应携带独立超时上下文")
require.True(t, probe.deleteHasDeadline, "DeleteSessionAccountID 应携带独立超时上下文")
require.False(t, probe.getHasDeadline, "GetSessionAccountID 本用例应由本地缓存命中,不触发 Redis 读取")
require.Greater(t, probe.setDeadlineDelta, 2*time.Second)
require.LessOrEqual(t, probe.setDeadlineDelta, 3*time.Second)
require.Greater(t, probe.delDeadlineDelta, 2*time.Second)
require.LessOrEqual(t, probe.delDeadlineDelta, 3*time.Second)
probe2 := &openAIWSStateStoreTimeoutProbeCache{}
store2 := NewOpenAIWSStateStore(probe2)
accountID2, err2 := store2.GetResponseAccount(ctx, groupID, "resp_cache_only")
require.NoError(t, err2)
require.Equal(t, int64(123), accountID2)
require.True(t, probe2.getHasDeadline, "GetSessionAccountID 在缓存未命中时应携带独立超时上下文")
require.Greater(t, probe2.getDeadlineDelta, 2*time.Second)
require.LessOrEqual(t, probe2.getDeadlineDelta, 3*time.Second)
}
func TestWithOpenAIWSStateStoreRedisTimeout_WithParentContext(t *testing.T) {
ctx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
defer cancel()
require.NotNil(t, ctx)
_, ok := ctx.Deadline()
require.True(t, ok, "应附加短超时")
}

Some files were not shown because too many files have changed in this diff Show More