mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-18 13:54:46 +08:00
Merge tag 'v0.1.90' into merge/upstream-v0.1.90
注册邮箱域名白名单策略上线,后台大数据场景性能大幅优化。 - 注册邮箱域名白名单:支持管理员配置允许注册的邮箱域名策略 - Keys 页面表单筛选:用户 /keys 页面支持按条件筛选 API Key - Settings 页面分 Tab 拆分:管理后台设置页面按功能模块分 Tab 展示 - 后台大数据场景加载性能优化:仪表盘/用户/账号/Ops 页面大数据集加载显著提速 - Usage 大表分页优化:默认避免全量 COUNT(*),大幅降低分页查询耗时 - 消除重复的 normalizeAccountIDList,补充新增组件的单元测试 - 清理无用文件和过时文档,精简项目结构 - EmailVerifyView 硬编码英文字符串替换为 i18n 调用 - 修复 Anthropic 平台无限流重置时间的 429 误标记账号限流问题 - 修复自定义菜单页面管理员视角菜单不生效问题 - 修复 Ops 错误详情弹窗未展示真实上游 payload 的问题 - 修复充值/订阅菜单 icon 显示问题 # Conflicts: # .gitignore # backend/cmd/server/VERSION # backend/ent/group.go # backend/ent/runtime/runtime.go # backend/ent/schema/group.go # backend/go.sum # backend/internal/handler/admin/account_handler.go # backend/internal/handler/admin/dashboard_handler.go # backend/internal/pkg/usagestats/usage_log_types.go # backend/internal/repository/group_repo.go # backend/internal/repository/usage_log_repo.go # backend/internal/server/middleware/security_headers.go # backend/internal/server/router.go # backend/internal/service/account_usage_service.go # backend/internal/service/admin_service_bulk_update_test.go # backend/internal/service/dashboard_service.go # backend/internal/service/gateway_service.go # frontend/src/api/admin/dashboard.ts # frontend/src/components/account/BulkEditAccountModal.vue # frontend/src/components/charts/GroupDistributionChart.vue # frontend/src/components/layout/AppSidebar.vue # frontend/src/i18n/locales/en.ts # frontend/src/i18n/locales/zh.ts # frontend/src/views/admin/GroupsView.vue # frontend/src/views/admin/SettingsView.vue # frontend/src/views/admin/UsageView.vue # frontend/src/views/user/PurchaseSubscriptionView.vue
This commit is contained in:
@@ -3,11 +3,14 @@ package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"hash/fnv"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
)
|
||||
|
||||
@@ -50,6 +53,14 @@ type Account struct {
|
||||
AccountGroups []AccountGroup
|
||||
GroupIDs []int64
|
||||
Groups []*Group
|
||||
|
||||
// model_mapping 热路径缓存(非持久化字段)
|
||||
modelMappingCache map[string]string
|
||||
modelMappingCacheReady bool
|
||||
modelMappingCacheCredentialsPtr uintptr
|
||||
modelMappingCacheRawPtr uintptr
|
||||
modelMappingCacheRawLen int
|
||||
modelMappingCacheRawSig uint64
|
||||
}
|
||||
|
||||
type TempUnschedulableRule struct {
|
||||
@@ -349,6 +360,39 @@ func parseTempUnschedInt(value any) int {
|
||||
}
|
||||
|
||||
func (a *Account) GetModelMapping() map[string]string {
|
||||
credentialsPtr := mapPtr(a.Credentials)
|
||||
rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
|
||||
rawPtr := mapPtr(rawMapping)
|
||||
rawLen := len(rawMapping)
|
||||
rawSig := uint64(0)
|
||||
rawSigReady := false
|
||||
|
||||
if a.modelMappingCacheReady &&
|
||||
a.modelMappingCacheCredentialsPtr == credentialsPtr &&
|
||||
a.modelMappingCacheRawPtr == rawPtr &&
|
||||
a.modelMappingCacheRawLen == rawLen {
|
||||
rawSig = modelMappingSignature(rawMapping)
|
||||
rawSigReady = true
|
||||
if a.modelMappingCacheRawSig == rawSig {
|
||||
return a.modelMappingCache
|
||||
}
|
||||
}
|
||||
|
||||
mapping := a.resolveModelMapping(rawMapping)
|
||||
if !rawSigReady {
|
||||
rawSig = modelMappingSignature(rawMapping)
|
||||
}
|
||||
|
||||
a.modelMappingCache = mapping
|
||||
a.modelMappingCacheReady = true
|
||||
a.modelMappingCacheCredentialsPtr = credentialsPtr
|
||||
a.modelMappingCacheRawPtr = rawPtr
|
||||
a.modelMappingCacheRawLen = rawLen
|
||||
a.modelMappingCacheRawSig = rawSig
|
||||
return mapping
|
||||
}
|
||||
|
||||
func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
|
||||
if a.Credentials == nil {
|
||||
// Antigravity 平台使用默认映射
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
@@ -356,32 +400,31 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
raw, ok := a.Credentials["model_mapping"]
|
||||
if !ok || raw == nil {
|
||||
if len(rawMapping) == 0 {
|
||||
// Antigravity 平台使用默认映射
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if m, ok := raw.(map[string]any); ok {
|
||||
result := make(map[string]string)
|
||||
for k, v := range m {
|
||||
if s, ok := v.(string); ok {
|
||||
result[k] = s
|
||||
}
|
||||
}
|
||||
if len(result) > 0 {
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
ensureAntigravityDefaultPassthroughs(result, []string{
|
||||
"gemini-3-flash",
|
||||
"gemini-3.1-pro-high",
|
||||
"gemini-3.1-pro-low",
|
||||
})
|
||||
}
|
||||
return result
|
||||
|
||||
result := make(map[string]string)
|
||||
for k, v := range rawMapping {
|
||||
if s, ok := v.(string); ok {
|
||||
result[k] = s
|
||||
}
|
||||
}
|
||||
if len(result) > 0 {
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
ensureAntigravityDefaultPassthroughs(result, []string{
|
||||
"gemini-3-flash",
|
||||
"gemini-3.1-pro-high",
|
||||
"gemini-3.1-pro-low",
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Antigravity 平台使用默认映射
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
@@ -389,6 +432,37 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapPtr(m map[string]any) uintptr {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
return reflect.ValueOf(m).Pointer()
|
||||
}
|
||||
|
||||
func modelMappingSignature(rawMapping map[string]any) uint64 {
|
||||
if len(rawMapping) == 0 {
|
||||
return 0
|
||||
}
|
||||
keys := make([]string, 0, len(rawMapping))
|
||||
for k := range rawMapping {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
h := fnv.New64a()
|
||||
for _, k := range keys {
|
||||
_, _ = h.Write([]byte(k))
|
||||
_, _ = h.Write([]byte{0})
|
||||
if v, ok := rawMapping[k].(string); ok {
|
||||
_, _ = h.Write([]byte(v))
|
||||
} else {
|
||||
_, _ = h.Write([]byte{1})
|
||||
}
|
||||
_, _ = h.Write([]byte{0xff})
|
||||
}
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) {
|
||||
if mapping == nil || model == "" {
|
||||
return
|
||||
@@ -742,6 +816,159 @@ func (a *Account) IsOpenAIPassthroughEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsOpenAIResponsesWebSocketV2Enabled 返回 OpenAI 账号是否开启 Responses WebSocket v2。
|
||||
//
|
||||
// 分类型新字段:
|
||||
// - OAuth 账号:accounts.extra.openai_oauth_responses_websockets_v2_enabled
|
||||
// - API Key 账号:accounts.extra.openai_apikey_responses_websockets_v2_enabled
|
||||
//
|
||||
// 兼容字段:
|
||||
// - accounts.extra.responses_websockets_v2_enabled
|
||||
// - accounts.extra.openai_ws_enabled(历史开关)
|
||||
//
|
||||
// 优先级:
|
||||
// 1. 按账号类型读取分类型字段
|
||||
// 2. 分类型字段缺失时,回退兼容字段
|
||||
func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
|
||||
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if a.IsOpenAIOAuth() {
|
||||
if enabled, ok := a.Extra["openai_oauth_responses_websockets_v2_enabled"].(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
if a.IsOpenAIApiKey() {
|
||||
if enabled, ok := a.Extra["openai_apikey_responses_websockets_v2_enabled"].(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
if enabled, ok := a.Extra["responses_websockets_v2_enabled"].(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
if enabled, ok := a.Extra["openai_ws_enabled"].(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
OpenAIWSIngressModeOff = "off"
|
||||
OpenAIWSIngressModeShared = "shared"
|
||||
OpenAIWSIngressModeDedicated = "dedicated"
|
||||
)
|
||||
|
||||
func normalizeOpenAIWSIngressMode(mode string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case OpenAIWSIngressModeOff:
|
||||
return OpenAIWSIngressModeOff
|
||||
case OpenAIWSIngressModeShared:
|
||||
return OpenAIWSIngressModeShared
|
||||
case OpenAIWSIngressModeDedicated:
|
||||
return OpenAIWSIngressModeDedicated
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
|
||||
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
|
||||
return normalized
|
||||
}
|
||||
return OpenAIWSIngressModeShared
|
||||
}
|
||||
|
||||
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
|
||||
//
|
||||
// 优先级:
|
||||
// 1. 分类型 mode 新字段(string)
|
||||
// 2. 分类型 enabled 旧字段(bool)
|
||||
// 3. 兼容 enabled 旧字段(bool)
|
||||
// 4. defaultMode(非法时回退 shared)
|
||||
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
|
||||
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
|
||||
if a == nil || !a.IsOpenAI() {
|
||||
return OpenAIWSIngressModeOff
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return resolvedDefault
|
||||
}
|
||||
|
||||
resolveModeString := func(key string) (string, bool) {
|
||||
raw, ok := a.Extra[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
mode, ok := raw.(string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
normalized := normalizeOpenAIWSIngressMode(mode)
|
||||
if normalized == "" {
|
||||
return "", false
|
||||
}
|
||||
return normalized, true
|
||||
}
|
||||
resolveBoolMode := func(key string) (string, bool) {
|
||||
raw, ok := a.Extra[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
enabled, ok := raw.(bool)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if enabled {
|
||||
return OpenAIWSIngressModeShared, true
|
||||
}
|
||||
return OpenAIWSIngressModeOff, true
|
||||
}
|
||||
|
||||
if a.IsOpenAIOAuth() {
|
||||
if mode, ok := resolveModeString("openai_oauth_responses_websockets_v2_mode"); ok {
|
||||
return mode
|
||||
}
|
||||
if mode, ok := resolveBoolMode("openai_oauth_responses_websockets_v2_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
}
|
||||
if a.IsOpenAIApiKey() {
|
||||
if mode, ok := resolveModeString("openai_apikey_responses_websockets_v2_mode"); ok {
|
||||
return mode
|
||||
}
|
||||
if mode, ok := resolveBoolMode("openai_apikey_responses_websockets_v2_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
}
|
||||
if mode, ok := resolveBoolMode("responses_websockets_v2_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
return resolvedDefault
|
||||
}
|
||||
|
||||
// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
|
||||
// 字段:accounts.extra.openai_ws_force_http。
|
||||
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
|
||||
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
enabled, ok := a.Extra["openai_ws_force_http"].(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// IsOpenAIWSAllowStoreRecoveryEnabled 返回账号级 store 恢复开关。
|
||||
// 字段:accounts.extra.openai_ws_allow_store_recovery。
|
||||
func (a *Account) IsOpenAIWSAllowStoreRecoveryEnabled() bool {
|
||||
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
enabled, ok := a.Extra["openai_ws_allow_store_recovery"].(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。
|
||||
func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
|
||||
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
|
||||
@@ -806,6 +1033,26 @@ func (a *Account) IsTLSFingerprintEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetUserMsgQueueMode 获取用户消息队列模式
|
||||
// "serialize" = 串行队列, "throttle" = 软性限速, "" = 未设置(使用全局配置)
|
||||
func (a *Account) GetUserMsgQueueMode() string {
|
||||
if a.Extra == nil {
|
||||
return ""
|
||||
}
|
||||
// 优先读取新字段 user_msg_queue_mode(白名单校验,非法值视为未设置)
|
||||
if mode, ok := a.Extra["user_msg_queue_mode"].(string); ok && mode != "" {
|
||||
if mode == config.UMQModeSerialize || mode == config.UMQModeThrottle {
|
||||
return mode
|
||||
}
|
||||
return "" // 非法值 fallback 到全局配置
|
||||
}
|
||||
// 向后兼容: user_msg_queue_enabled: true → "serialize"
|
||||
if enabled, ok := a.Extra["user_msg_queue_enabled"].(bool); ok && enabled {
|
||||
return config.UMQModeSerialize
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
|
||||
// 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID,
|
||||
@@ -911,6 +1158,80 @@ func (a *Account) GetSessionIdleTimeoutMinutes() int {
|
||||
return 5
|
||||
}
|
||||
|
||||
// GetBaseRPM 获取基础 RPM 限制
|
||||
// 返回 0 表示未启用(负数视为无效配置,按 0 处理)
|
||||
func (a *Account) GetBaseRPM() int {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["base_rpm"]; ok {
|
||||
val := parseExtraInt(v)
|
||||
if val > 0 {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetRPMStrategy 获取 RPM 策略
|
||||
// "tiered" = 三区模型(默认), "sticky_exempt" = 粘性豁免
|
||||
func (a *Account) GetRPMStrategy() string {
|
||||
if a.Extra == nil {
|
||||
return "tiered"
|
||||
}
|
||||
if v, ok := a.Extra["rpm_strategy"]; ok {
|
||||
if s, ok := v.(string); ok && s == "sticky_exempt" {
|
||||
return "sticky_exempt"
|
||||
}
|
||||
}
|
||||
return "tiered"
|
||||
}
|
||||
|
||||
// GetRPMStickyBuffer 获取 RPM 粘性缓冲数量
|
||||
// tiered 模式下的黄区大小,默认为 base_rpm 的 20%(至少 1)
|
||||
func (a *Account) GetRPMStickyBuffer() int {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["rpm_sticky_buffer"]; ok {
|
||||
val := parseExtraInt(v)
|
||||
if val > 0 {
|
||||
return val
|
||||
}
|
||||
}
|
||||
base := a.GetBaseRPM()
|
||||
buffer := base / 5
|
||||
if buffer < 1 && base > 0 {
|
||||
buffer = 1
|
||||
}
|
||||
return buffer
|
||||
}
|
||||
|
||||
// CheckRPMSchedulability 根据当前 RPM 计数检查调度状态
|
||||
// 复用 WindowCostSchedulability 三态:Schedulable / StickyOnly / NotSchedulable
|
||||
func (a *Account) CheckRPMSchedulability(currentRPM int) WindowCostSchedulability {
|
||||
baseRPM := a.GetBaseRPM()
|
||||
if baseRPM <= 0 {
|
||||
return WindowCostSchedulable
|
||||
}
|
||||
|
||||
if currentRPM < baseRPM {
|
||||
return WindowCostSchedulable
|
||||
}
|
||||
|
||||
strategy := a.GetRPMStrategy()
|
||||
if strategy == "sticky_exempt" {
|
||||
return WindowCostStickyOnly // 粘性豁免无红区
|
||||
}
|
||||
|
||||
// tiered: 黄区 + 红区
|
||||
buffer := a.GetRPMStickyBuffer()
|
||||
if currentRPM < baseRPM+buffer {
|
||||
return WindowCostStickyOnly
|
||||
}
|
||||
return WindowCostNotSchedulable
|
||||
}
|
||||
|
||||
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
|
||||
// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
|
||||
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
|
||||
@@ -974,6 +1295,12 @@ func parseExtraFloat64(value any) float64 {
|
||||
}
|
||||
|
||||
// parseExtraInt 从 extra 字段解析 int 值
|
||||
// ParseExtraInt 从 extra 字段的 any 值解析为 int。
|
||||
// 支持 int, int64, float64, json.Number, string 类型,无法解析时返回 0。
|
||||
func ParseExtraInt(value any) int {
|
||||
return parseExtraInt(value)
|
||||
}
|
||||
|
||||
func parseExtraInt(value any) int {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
|
||||
@@ -134,3 +134,161 @@ func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) {
|
||||
require.False(t, otherPlatform.IsCodexCLIOnlyEnabled())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
|
||||
t.Run("OAuth使用OAuth专用开关", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
|
||||
t.Run("API Key使用API Key专用开关", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
|
||||
t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
|
||||
t.Run("分类型新键优先于兼容键", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": false,
|
||||
"responses_websockets_v2_enabled": true,
|
||||
"openai_ws_enabled": true,
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
|
||||
t.Run("分类型键缺失时回退兼容键", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
|
||||
t.Run("非OpenAI账号默认关闭", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
t.Run("default fallback to shared", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
|
||||
})
|
||||
|
||||
t.Run("oauth mode field has highest priority", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
"openai_oauth_responses_websockets_v2_enabled": false,
|
||||
"responses_websockets_v2_enabled": false,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
||||
})
|
||||
|
||||
t.Run("legacy enabled maps to shared", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
})
|
||||
|
||||
t.Run("legacy disabled maps to off", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": false,
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
||||
})
|
||||
|
||||
t.Run("non openai always off", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated))
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccount_OpenAIWSExtraFlags(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_ws_force_http": true,
|
||||
"openai_ws_allow_store_recovery": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsOpenAIWSForceHTTPEnabled())
|
||||
require.True(t, account.IsOpenAIWSAllowStoreRecoveryEnabled())
|
||||
|
||||
off := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}}
|
||||
require.False(t, off.IsOpenAIWSForceHTTPEnabled())
|
||||
require.False(t, off.IsOpenAIWSAllowStoreRecoveryEnabled())
|
||||
|
||||
var nilAccount *Account
|
||||
require.False(t, nilAccount.IsOpenAIWSAllowStoreRecoveryEnabled())
|
||||
|
||||
nonOpenAI := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_ws_allow_store_recovery": true,
|
||||
},
|
||||
}
|
||||
require.False(t, nonOpenAI.IsOpenAIWSAllowStoreRecoveryEnabled())
|
||||
}
|
||||
|
||||
120
backend/internal/service/account_rpm_test.go
Normal file
120
backend/internal/service/account_rpm_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetBaseRPM(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extra map[string]any
|
||||
expected int
|
||||
}{
|
||||
{"nil extra", nil, 0},
|
||||
{"no key", map[string]any{}, 0},
|
||||
{"zero", map[string]any{"base_rpm": 0}, 0},
|
||||
{"int value", map[string]any{"base_rpm": 15}, 15},
|
||||
{"float value", map[string]any{"base_rpm": 15.0}, 15},
|
||||
{"string value", map[string]any{"base_rpm": "15"}, 15},
|
||||
{"negative value", map[string]any{"base_rpm": -5}, 0},
|
||||
{"int64 value", map[string]any{"base_rpm": int64(20)}, 20},
|
||||
{"json.Number value", map[string]any{"base_rpm": json.Number("25")}, 25},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &Account{Extra: tt.extra}
|
||||
if got := a.GetBaseRPM(); got != tt.expected {
|
||||
t.Errorf("GetBaseRPM() = %d, want %d", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRPMStrategy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extra map[string]any
|
||||
expected string
|
||||
}{
|
||||
{"nil extra", nil, "tiered"},
|
||||
{"no key", map[string]any{}, "tiered"},
|
||||
{"tiered", map[string]any{"rpm_strategy": "tiered"}, "tiered"},
|
||||
{"sticky_exempt", map[string]any{"rpm_strategy": "sticky_exempt"}, "sticky_exempt"},
|
||||
{"invalid", map[string]any{"rpm_strategy": "foobar"}, "tiered"},
|
||||
{"empty string fallback", map[string]any{"rpm_strategy": ""}, "tiered"},
|
||||
{"numeric value fallback", map[string]any{"rpm_strategy": 123}, "tiered"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &Account{Extra: tt.extra}
|
||||
if got := a.GetRPMStrategy(); got != tt.expected {
|
||||
t.Errorf("GetRPMStrategy() = %q, want %q", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckRPMSchedulability(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extra map[string]any
|
||||
currentRPM int
|
||||
expected WindowCostSchedulability
|
||||
}{
|
||||
{"disabled", map[string]any{}, 100, WindowCostSchedulable},
|
||||
{"green zone", map[string]any{"base_rpm": 15}, 10, WindowCostSchedulable},
|
||||
{"yellow zone tiered", map[string]any{"base_rpm": 15}, 15, WindowCostStickyOnly},
|
||||
{"red zone tiered", map[string]any{"base_rpm": 15}, 18, WindowCostNotSchedulable},
|
||||
{"sticky_exempt at limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 15, WindowCostStickyOnly},
|
||||
{"sticky_exempt over limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 100, WindowCostStickyOnly},
|
||||
{"custom buffer", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 14, WindowCostStickyOnly},
|
||||
{"custom buffer red", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 15, WindowCostNotSchedulable},
|
||||
{"base_rpm=1 green", map[string]any{"base_rpm": 1}, 0, WindowCostSchedulable},
|
||||
{"base_rpm=1 yellow (at limit)", map[string]any{"base_rpm": 1}, 1, WindowCostStickyOnly},
|
||||
{"base_rpm=1 red (at limit+buffer)", map[string]any{"base_rpm": 1}, 2, WindowCostNotSchedulable},
|
||||
{"negative currentRPM", map[string]any{"base_rpm": 15}, -1, WindowCostSchedulable},
|
||||
{"base_rpm negative disabled", map[string]any{"base_rpm": -5}, 10, WindowCostSchedulable},
|
||||
{"very high currentRPM", map[string]any{"base_rpm": 10}, 9999, WindowCostNotSchedulable},
|
||||
{"sticky_exempt very high currentRPM", map[string]any{"base_rpm": 10, "rpm_strategy": "sticky_exempt"}, 9999, WindowCostStickyOnly},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &Account{Extra: tt.extra}
|
||||
if got := a.CheckRPMSchedulability(tt.currentRPM); got != tt.expected {
|
||||
t.Errorf("CheckRPMSchedulability(%d) = %d, want %d", tt.currentRPM, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRPMStickyBuffer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extra map[string]any
|
||||
expected int
|
||||
}{
|
||||
{"nil extra", nil, 0},
|
||||
{"no keys", map[string]any{}, 0},
|
||||
{"base_rpm=0", map[string]any{"base_rpm": 0}, 0},
|
||||
{"base_rpm=1 min buffer 1", map[string]any{"base_rpm": 1}, 1},
|
||||
{"base_rpm=4 min buffer 1", map[string]any{"base_rpm": 4}, 1},
|
||||
{"base_rpm=5 buffer 1", map[string]any{"base_rpm": 5}, 1},
|
||||
{"base_rpm=10 buffer 2", map[string]any{"base_rpm": 10}, 2},
|
||||
{"base_rpm=15 buffer 3", map[string]any{"base_rpm": 15}, 3},
|
||||
{"base_rpm=100 buffer 20", map[string]any{"base_rpm": 100}, 20},
|
||||
{"custom buffer=5", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 5},
|
||||
{"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2},
|
||||
{"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2},
|
||||
{"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7},
|
||||
{"json.Number base_rpm", map[string]any{"base_rpm": json.Number("10")}, 2},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &Account{Extra: tt.extra}
|
||||
if got := a.GetRPMStickyBuffer(); got != tt.expected {
|
||||
t.Errorf("GetRPMStickyBuffer() = %d, want %d", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -54,6 +54,8 @@ type AccountRepository interface {
|
||||
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
|
||||
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
|
||||
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||
ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
|
||||
@@ -119,6 +121,10 @@ type AccountService struct {
|
||||
groupRepo GroupRepository
|
||||
}
|
||||
|
||||
type groupExistenceBatchChecker interface {
|
||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||
}
|
||||
|
||||
// NewAccountService 创建账号服务实例
|
||||
func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService {
|
||||
return &AccountService{
|
||||
@@ -131,11 +137,8 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository)
|
||||
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
|
||||
// 验证分组是否存在(如果指定了分组)
|
||||
if len(req.GroupIDs) > 0 {
|
||||
for _, groupID := range req.GroupIDs {
|
||||
_, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
if err := s.validateGroupIDsExist(ctx, req.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,11 +259,8 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
|
||||
|
||||
// 先验证分组是否存在(在任何写操作之前)
|
||||
if req.GroupIDs != nil {
|
||||
for _, groupID := range *req.GroupIDs {
|
||||
_, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
if err := s.validateGroupIDsExist(ctx, *req.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,6 +300,39 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccountService) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
if s.groupRepo == nil {
|
||||
return fmt.Errorf("group repository not configured")
|
||||
}
|
||||
|
||||
if batchChecker, ok := s.groupRepo.(groupExistenceBatchChecker); ok {
|
||||
existsByID, err := batchChecker.ExistsByIDs(ctx, groupIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check groups exists: %w", err)
|
||||
}
|
||||
for _, groupID := range groupIDs {
|
||||
if groupID <= 0 {
|
||||
return fmt.Errorf("get group: %w", ErrGroupNotFound)
|
||||
}
|
||||
if !existsByID[groupID] {
|
||||
return fmt.Errorf("get group: %w", ErrGroupNotFound)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
_, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新账号状态
|
||||
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
|
||||
@@ -147,6 +147,14 @@ func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte
|
||||
panic("unexpected ListSchedulableByGroupIDAndPlatforms call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableUngroupedByPlatform call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableUngroupedByPlatforms call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
panic("unexpected SetRateLimited call")
|
||||
}
|
||||
|
||||
@@ -598,9 +598,102 @@ func ceilSeconds(d time.Duration) int {
|
||||
return sec
|
||||
}
|
||||
|
||||
// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。
|
||||
// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。
|
||||
func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证")
|
||||
}
|
||||
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" {
|
||||
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url")
|
||||
}
|
||||
|
||||
// 验证 base_url 格式
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error()))
|
||||
}
|
||||
upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions"
|
||||
|
||||
// 设置 SSE 头
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
|
||||
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
|
||||
return s.sendErrorAndEnd(c, msg)
|
||||
}
|
||||
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"})
|
||||
|
||||
// 构建轻量级 prompt-enhance 请求作为连通性测试
|
||||
testPayload := map[string]any{
|
||||
"model": "prompt-enhance-short-10s",
|
||||
"messages": []map[string]string{{"role": "user", "content": "test"}},
|
||||
"stream": false,
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(testPayload)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "构建测试请求失败")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
// 获取代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)})
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode))
|
||||
}
|
||||
|
||||
// 其他错误但能连通(如 400 参数错误)也算连通性测试通过
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)})
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256)))
|
||||
}
|
||||
|
||||
// testSoraAccountConnection 测试 Sora 账号的连接
|
||||
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token)
|
||||
// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性
|
||||
// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性
|
||||
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
|
||||
// apikey 类型走独立测试流程
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
return s.testSoraAPIKeyAccountConnection(c, account)
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
recorder := &soraProbeRecorder{}
|
||||
|
||||
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type UsageLogRepository interface {
|
||||
@@ -33,9 +35,9 @@ type UsageLogRepository interface {
|
||||
|
||||
// Admin dashboard stats
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
||||
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
||||
GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
@@ -63,6 +65,10 @@ type UsageLogRepository interface {
|
||||
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
|
||||
}
|
||||
|
||||
type accountWindowStatsBatchReader interface {
|
||||
GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error)
|
||||
}
|
||||
|
||||
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
|
||||
type apiUsageCache struct {
|
||||
response *ClaudeUsageResponse
|
||||
@@ -298,7 +304,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
dayStart := geminiDailyWindowStart(now)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
||||
}
|
||||
@@ -320,7 +326,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
|
||||
minuteStart := now.Truncate(time.Minute)
|
||||
minuteResetAt := minuteStart.Add(time.Minute)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
|
||||
}
|
||||
@@ -441,6 +447,78 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetTodayStatsBatch 批量获取账号今日统计,优先走批量 SQL,失败时回退单账号查询。
|
||||
func (s *AccountUsageService) GetTodayStatsBatch(ctx context.Context, accountIDs []int64) (map[int64]*WindowStats, error) {
|
||||
uniqueIDs := make([]int64, 0, len(accountIDs))
|
||||
seen := make(map[int64]struct{}, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
if accountID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[accountID]; exists {
|
||||
continue
|
||||
}
|
||||
seen[accountID] = struct{}{}
|
||||
uniqueIDs = append(uniqueIDs, accountID)
|
||||
}
|
||||
|
||||
result := make(map[int64]*WindowStats, len(uniqueIDs))
|
||||
if len(uniqueIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
startTime := timezone.Today()
|
||||
if batchReader, ok := s.usageLogRepo.(accountWindowStatsBatchReader); ok {
|
||||
statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, uniqueIDs, startTime)
|
||||
if err == nil {
|
||||
for _, accountID := range uniqueIDs {
|
||||
result[accountID] = windowStatsFromAccountStats(statsByAccount[accountID])
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(8)
|
||||
|
||||
for _, accountID := range uniqueIDs {
|
||||
id := accountID
|
||||
g.Go(func() error {
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(gctx, id, startTime)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
mu.Lock()
|
||||
result[id] = windowStatsFromAccountStats(stats)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
_ = g.Wait()
|
||||
|
||||
for _, accountID := range uniqueIDs {
|
||||
if _, ok := result[accountID]; !ok {
|
||||
result[accountID] = &WindowStats{}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
|
||||
if stats == nil {
|
||||
return &WindowStats{}
|
||||
}
|
||||
return &WindowStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
StandardCost: stats.StandardCost,
|
||||
UserCost: stats.UserCost,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
|
||||
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
|
||||
if err != nil {
|
||||
|
||||
@@ -314,3 +314,72 @@ func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T
|
||||
t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_CacheInvalidatesOnCredentialsReplace(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-3-5-sonnet": "upstream-a",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
first := account.GetModelMapping()
|
||||
if first["claude-3-5-sonnet"] != "upstream-a" {
|
||||
t.Fatalf("unexpected first mapping: %v", first)
|
||||
}
|
||||
|
||||
account.Credentials = map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-3-5-sonnet": "upstream-b",
|
||||
},
|
||||
}
|
||||
second := account.GetModelMapping()
|
||||
if second["claude-3-5-sonnet"] != "upstream-b" {
|
||||
t.Fatalf("expected cache invalidated after credentials replace, got: %v", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_CacheInvalidatesOnMappingLenChange(t *testing.T) {
|
||||
rawMapping := map[string]any{
|
||||
"claude-sonnet": "sonnet-a",
|
||||
}
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": rawMapping,
|
||||
},
|
||||
}
|
||||
|
||||
first := account.GetModelMapping()
|
||||
if len(first) != 1 {
|
||||
t.Fatalf("unexpected first mapping length: %d", len(first))
|
||||
}
|
||||
|
||||
rawMapping["claude-opus"] = "opus-b"
|
||||
second := account.GetModelMapping()
|
||||
if second["claude-opus"] != "opus-b" {
|
||||
t.Fatalf("expected cache invalidated after mapping len change, got: %v", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_CacheInvalidatesOnInPlaceValueChange(t *testing.T) {
|
||||
rawMapping := map[string]any{
|
||||
"claude-sonnet": "sonnet-a",
|
||||
}
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": rawMapping,
|
||||
},
|
||||
}
|
||||
|
||||
first := account.GetModelMapping()
|
||||
if first["claude-sonnet"] != "sonnet-a" {
|
||||
t.Fatalf("unexpected first mapping: %v", first)
|
||||
}
|
||||
|
||||
rawMapping["claude-sonnet"] = "sonnet-b"
|
||||
second := account.GetModelMapping()
|
||||
if second["claude-sonnet"] != "sonnet-b" {
|
||||
t.Fatalf("expected cache invalidated after in-place value change, got: %v", second)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
@@ -42,6 +44,9 @@ type AdminService interface {
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
||||
|
||||
// API Key management (admin)
|
||||
AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
|
||||
|
||||
// Account management
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
|
||||
GetAccount(ctx context.Context, id int64) (*Account, error)
|
||||
@@ -83,13 +88,14 @@ type AdminService interface {
|
||||
|
||||
// CreateUserInput represents input for creating a new user via admin operations.
|
||||
type CreateUserInput struct {
|
||||
Email string
|
||||
Password string
|
||||
Username string
|
||||
Notes string
|
||||
Balance float64
|
||||
Concurrency int
|
||||
AllowedGroups []int64
|
||||
Email string
|
||||
Password string
|
||||
Username string
|
||||
Notes string
|
||||
Balance float64
|
||||
Concurrency int
|
||||
AllowedGroups []int64
|
||||
SoraStorageQuotaBytes int64
|
||||
}
|
||||
|
||||
type UpdateUserInput struct {
|
||||
@@ -103,7 +109,8 @@ type UpdateUserInput struct {
|
||||
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||
GroupRates map[int64]*float64
|
||||
GroupRates map[int64]*float64
|
||||
SoraStorageQuotaBytes *int64
|
||||
}
|
||||
|
||||
type CreateGroupInput struct {
|
||||
@@ -136,6 +143,8 @@ type CreateGroupInput struct {
|
||||
SimulateClaudeMaxEnabled *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64
|
||||
// 从指定分组复制账号(创建分组后在同一事务内绑定)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -171,6 +180,8 @@ type UpdateGroupInput struct {
|
||||
SimulateClaudeMaxEnabled *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes *int64
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -238,6 +249,14 @@ type BulkUpdateAccountResult struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// AdminUpdateAPIKeyGroupIDResult is the result of AdminUpdateAPIKeyGroupID.
|
||||
type AdminUpdateAPIKeyGroupIDResult struct {
|
||||
APIKey *APIKey
|
||||
AutoGrantedGroupAccess bool // true if a new exclusive group permission was auto-added
|
||||
GrantedGroupID *int64 // the group ID that was auto-granted
|
||||
GrantedGroupName string // the group name that was auto-granted
|
||||
}
|
||||
|
||||
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
|
||||
type BulkUpdateAccountsResult struct {
|
||||
Success int `json:"success"`
|
||||
@@ -406,6 +425,17 @@ type adminServiceImpl struct {
|
||||
proxyProber ProxyExitInfoProber
|
||||
proxyLatencyCache ProxyLatencyCache
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
entClient *dbent.Client // 用于开启数据库事务
|
||||
settingService *SettingService
|
||||
defaultSubAssigner DefaultSubscriptionAssigner
|
||||
}
|
||||
|
||||
type userGroupRateBatchReader interface {
|
||||
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
|
||||
}
|
||||
|
||||
type groupExistenceBatchReader interface {
|
||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||
}
|
||||
|
||||
// NewAdminService creates a new AdminService
|
||||
@@ -422,6 +452,9 @@ func NewAdminService(
|
||||
proxyProber ProxyExitInfoProber,
|
||||
proxyLatencyCache ProxyLatencyCache,
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||
entClient *dbent.Client,
|
||||
settingService *SettingService,
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
) AdminService {
|
||||
return &adminServiceImpl{
|
||||
userRepo: userRepo,
|
||||
@@ -436,6 +469,9 @@ func NewAdminService(
|
||||
proxyProber: proxyProber,
|
||||
proxyLatencyCache: proxyLatencyCache,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
entClient: entClient,
|
||||
settingService: settingService,
|
||||
defaultSubAssigner: defaultSubAssigner,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -448,18 +484,43 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
|
||||
}
|
||||
// 批量加载用户专属分组倍率
|
||||
if s.userGroupRateRepo != nil && len(users) > 0 {
|
||||
for i := range users {
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
|
||||
continue
|
||||
if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok {
|
||||
userIDs := make([]int64, 0, len(users))
|
||||
for i := range users {
|
||||
userIDs = append(userIDs, users[i].ID)
|
||||
}
|
||||
users[i].GroupRates = rates
|
||||
ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err)
|
||||
s.loadUserGroupRatesOneByOne(ctx, users)
|
||||
} else {
|
||||
for i := range users {
|
||||
if rates, ok := ratesByUser[users[i].ID]; ok {
|
||||
users[i].GroupRates = rates
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.loadUserGroupRatesOneByOne(ctx, users)
|
||||
}
|
||||
}
|
||||
return users, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return
|
||||
}
|
||||
for i := range users {
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
|
||||
continue
|
||||
}
|
||||
users[i].GroupRates = rates
|
||||
}
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -479,14 +540,15 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
|
||||
|
||||
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
|
||||
user := &User{
|
||||
Email: input.Email,
|
||||
Username: input.Username,
|
||||
Notes: input.Notes,
|
||||
Role: RoleUser, // Always create as regular user, never admin
|
||||
Balance: input.Balance,
|
||||
Concurrency: input.Concurrency,
|
||||
Status: StatusActive,
|
||||
AllowedGroups: input.AllowedGroups,
|
||||
Email: input.Email,
|
||||
Username: input.Username,
|
||||
Notes: input.Notes,
|
||||
Role: RoleUser, // Always create as regular user, never admin
|
||||
Balance: input.Balance,
|
||||
Concurrency: input.Concurrency,
|
||||
Status: StatusActive,
|
||||
AllowedGroups: input.AllowedGroups,
|
||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
||||
}
|
||||
if err := user.SetPassword(input.Password); err != nil {
|
||||
return nil, err
|
||||
@@ -494,9 +556,27 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userID int64) {
|
||||
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
|
||||
return
|
||||
}
|
||||
items := s.settingService.GetDefaultSubscriptions(ctx)
|
||||
for _, item := range items {
|
||||
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
|
||||
UserID: userID,
|
||||
GroupID: item.GroupID,
|
||||
ValidityDays: item.ValidityDays,
|
||||
Notes: "auto assigned by default user subscriptions setting",
|
||||
}); err != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -540,6 +620,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
user.AllowedGroups = *input.AllowedGroups
|
||||
}
|
||||
|
||||
if input.SoraStorageQuotaBytes != nil {
|
||||
user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
|
||||
}
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -667,7 +751,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{})
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -834,6 +918,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
MCPXMLInject: mcpXMLInject,
|
||||
SimulateClaudeMaxEnabled: simulateClaudeMaxEnabled,
|
||||
SupportedModelScopes: input.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -996,6 +1081,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.SoraVideoPricePerRequestHD != nil {
|
||||
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
|
||||
}
|
||||
if input.SoraStorageQuotaBytes != nil {
|
||||
group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
|
||||
}
|
||||
|
||||
// Claude Code 客户端限制
|
||||
if input.ClaudeCodeOnly != nil {
|
||||
@@ -1160,6 +1248,103 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []
|
||||
return s.groupRepo.UpdateSortOrders(ctx, updates)
|
||||
}
|
||||
|
||||
// AdminUpdateAPIKeyGroupID 管理员修改 API Key 分组绑定
|
||||
// groupID: nil=不修改, 指向0=解绑, 指向正整数=绑定到目标分组
|
||||
func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if groupID == nil {
|
||||
// nil 表示不修改,直接返回
|
||||
return &AdminUpdateAPIKeyGroupIDResult{APIKey: apiKey}, nil
|
||||
}
|
||||
|
||||
if *groupID < 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative")
|
||||
}
|
||||
|
||||
result := &AdminUpdateAPIKeyGroupIDResult{}
|
||||
|
||||
if *groupID == 0 {
|
||||
// 0 表示解绑分组(不修改 user_allowed_groups,避免影响用户其他 Key)
|
||||
apiKey.GroupID = nil
|
||||
apiKey.Group = nil
|
||||
} else {
|
||||
// 验证目标分组存在且状态为 active
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if group.Status != StatusActive {
|
||||
return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active")
|
||||
}
|
||||
// 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程
|
||||
if group.IsSubscriptionType() {
|
||||
return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow")
|
||||
}
|
||||
|
||||
gid := *groupID
|
||||
apiKey.GroupID = &gid
|
||||
apiKey.Group = group
|
||||
|
||||
// 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性
|
||||
if group.IsExclusive {
|
||||
opCtx := ctx
|
||||
var tx *dbent.Tx
|
||||
if s.entClient == nil {
|
||||
logger.LegacyPrintf("service.admin", "Warning: entClient is nil, skipping transaction protection for exclusive group binding")
|
||||
} else {
|
||||
var txErr error
|
||||
tx, txErr = s.entClient.Tx(ctx)
|
||||
if txErr != nil {
|
||||
return nil, fmt.Errorf("begin transaction: %w", txErr)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
opCtx = dbent.NewTxContext(ctx, tx)
|
||||
}
|
||||
|
||||
if addErr := s.userRepo.AddGroupToAllowedGroups(opCtx, apiKey.UserID, gid); addErr != nil {
|
||||
return nil, fmt.Errorf("add group to user allowed groups: %w", addErr)
|
||||
}
|
||||
if err := s.apiKeyRepo.Update(opCtx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("update api key: %w", err)
|
||||
}
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
result.AutoGrantedGroupAccess = true
|
||||
result.GrantedGroupID = &gid
|
||||
result.GrantedGroupName = group.Name
|
||||
|
||||
// 失效认证缓存(在事务提交后执行)
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
}
|
||||
|
||||
result.APIKey = apiKey
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 非专属分组 / 解绑:无需事务,单步更新即可
|
||||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("update api key: %w", err)
|
||||
}
|
||||
|
||||
// 失效认证缓存
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
}
|
||||
|
||||
result.APIKey = apiKey
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Account management implementations
|
||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
@@ -1211,6 +1396,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
}
|
||||
}
|
||||
|
||||
// Sora apikey 账号的 base_url 必填校验
|
||||
if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
|
||||
baseURL, _ := input.Credentials["base_url"].(string)
|
||||
baseURL = strings.TrimSpace(baseURL)
|
||||
if baseURL == "" {
|
||||
return nil, errors.New("sora apikey 账号必须设置 base_url")
|
||||
}
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
|
||||
}
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
Name: input.Name,
|
||||
Notes: normalizeAccountNotes(input.Notes),
|
||||
@@ -1324,12 +1521,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
|
||||
}
|
||||
|
||||
// Sora apikey 账号的 base_url 必填校验
|
||||
if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
|
||||
baseURL, _ := account.Credentials["base_url"].(string)
|
||||
baseURL = strings.TrimSpace(baseURL)
|
||||
if baseURL == "" {
|
||||
return nil, errors.New("sora apikey 账号必须设置 base_url")
|
||||
}
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
|
||||
}
|
||||
}
|
||||
|
||||
// 先验证分组是否存在(在任何写操作之前)
|
||||
if input.GroupIDs != nil {
|
||||
for _, groupID := range *input.GroupIDs {
|
||||
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查混合渠道风险(除非用户已确认)
|
||||
@@ -1371,6 +1578,11 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
if len(input.AccountIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
if input.GroupIDs != nil {
|
||||
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
|
||||
|
||||
@@ -1839,7 +2051,6 @@ func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*Pr
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: proxyQualityRequestTimeout,
|
||||
ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout,
|
||||
ProxyStrict: true,
|
||||
})
|
||||
if err != nil {
|
||||
result.Items = append(result.Items, ProxyQualityCheckItem{
|
||||
|
||||
429
backend/internal/service/admin_service_apikey_test.go
Normal file
429
backend/internal/service/admin_service_apikey_test.go
Normal file
@@ -0,0 +1,429 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stubs
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// userRepoStubForGroupUpdate implements UserRepository for AdminUpdateAPIKeyGroupID tests.
|
||||
type userRepoStubForGroupUpdate struct {
|
||||
addGroupErr error
|
||||
addGroupCalled bool
|
||||
addedUserID int64
|
||||
addedGroupID int64
|
||||
}
|
||||
|
||||
func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context, userID int64, groupID int64) error {
|
||||
s.addGroupCalled = true
|
||||
s.addedUserID = userID
|
||||
s.addedGroupID = groupID
|
||||
return s.addGroupErr
|
||||
}
|
||||
|
||||
func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
|
||||
|
||||
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
|
||||
type apiKeyRepoStubForGroupUpdate struct {
|
||||
key *APIKey
|
||||
getErr error
|
||||
updateErr error
|
||||
updated *APIKey // captures what was passed to Update
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStubForGroupUpdate) GetByID(_ context.Context, _ int64) (*APIKey, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
clone := *s.key
|
||||
return &clone, nil
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) Update(_ context.Context, key *APIKey) error {
|
||||
if s.updateErr != nil {
|
||||
return s.updateErr
|
||||
}
|
||||
clone := *key
|
||||
s.updated = &clone
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unused methods – panic on unexpected call.
|
||||
func (s *apiKeyRepoStubForGroupUpdate) Create(context.Context, *APIKey) error { panic("unexpected") }
|
||||
func (s *apiKeyRepoStubForGroupUpdate) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) GetByKey(context.Context, string) (*APIKey, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) CountByUserID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ExistsByKey(context.Context, string) (bool, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ListKeysByUserID(context.Context, int64) ([]string, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ListKeysByGroupID(context.Context, int64) ([]string, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) UpdateLastUsed(context.Context, int64, time.Time) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) IncrementRateLimitUsage(context.Context, int64, float64) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, int64) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
|
||||
type groupRepoStubForGroupUpdate struct {
|
||||
group *Group
|
||||
getErr error
|
||||
lastGetByIDArg int64
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForGroupUpdate) GetByID(_ context.Context, id int64) (*Group, error) {
|
||||
s.lastGetByIDArg = id
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
clone := *s.group
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
// Unused methods – panic on unexpected call.
|
||||
func (s *groupRepoStubForGroupUpdate) Create(context.Context, *Group) error { panic("unexpected") }
|
||||
func (s *groupRepoStubForGroupUpdate) GetByIDLite(context.Context, int64) (*Group, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) Update(context.Context, *Group) error { panic("unexpected") }
|
||||
func (s *groupRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *groupRepoStubForGroupUpdate) DeleteCascade(context.Context, int64) ([]int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) ListActive(context.Context) ([]Group, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, string) ([]Group, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) BindAccountsToGroup(context.Context, int64, []int64) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_KeyNotFound(t *testing.T) {
|
||||
repo := &apiKeyRepoStubForGroupUpdate{getErr: ErrAPIKeyNotFound}
|
||||
svc := &adminServiceImpl{apiKeyRepo: repo}
|
||||
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 999, int64Ptr(1))
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_NilGroupID_NoOp(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5)}
|
||||
repo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
svc := &adminServiceImpl{apiKeyRepo: repo}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), got.APIKey.ID)
|
||||
// Update should NOT have been called (updated stays nil)
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5), Group: &Group{ID: 5, Name: "Old"}}
|
||||
repo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
cache := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: repo, authCacheInvalidator: cache}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, got.APIKey.GroupID, "group_id should be nil after unbind")
|
||||
require.Nil(t, got.APIKey.Group, "group object should be nil after unbind")
|
||||
require.NotNil(t, repo.updated, "Update should have been called")
|
||||
require.Nil(t, repo.updated.GroupID)
|
||||
require.Equal(t, []string{"sk-test"}, cache.keys, "cache should be invalidated")
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_BindActiveGroup(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
|
||||
cache := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.APIKey.GroupID)
|
||||
require.Equal(t, int64(10), *got.APIKey.GroupID)
|
||||
require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
|
||||
require.Equal(t, []string{"sk-test"}, cache.keys)
|
||||
// M3: verify correct group ID was passed to repo
|
||||
require.Equal(t, int64(10), groupRepo.lastGetByIDArg)
|
||||
// C1 fix: verify Group object is populated
|
||||
require.NotNil(t, got.APIKey.Group)
|
||||
require.Equal(t, "Pro", got.APIKey.Group.Name)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_SameGroup_Idempotent(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Pro"}}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
|
||||
cache := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.APIKey.GroupID)
|
||||
require.Equal(t, int64(10), *got.APIKey.GroupID)
|
||||
// Update is still called (current impl doesn't short-circuit on same group)
|
||||
require.NotNil(t, apiKeyRepo.updated)
|
||||
require.Equal(t, []string{"sk-test"}, cache.keys)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotFound(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test"}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{getErr: ErrGroupNotFound}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
|
||||
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(99))
|
||||
require.ErrorIs(t, err, ErrGroupNotFound)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotActive(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test"}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 5, Status: StatusDisabled}}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
|
||||
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(5))
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "GROUP_NOT_ACTIVE", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_UpdateFails(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(3)}
|
||||
repo := &apiKeyRepoStubForGroupUpdate{key: existing, updateErr: errors.New("db write error")}
|
||||
svc := &adminServiceImpl{apiKeyRepo: repo}
|
||||
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "update api key")
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_NegativeGroupID(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test"}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo}
|
||||
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(-5))
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "INVALID_GROUP_ID", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_PointerIsolation(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
|
||||
cache := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
|
||||
|
||||
inputGID := int64(10)
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, &inputGID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.APIKey.GroupID)
|
||||
// Mutating the input pointer must NOT affect the stored value
|
||||
inputGID = 999
|
||||
require.Equal(t, int64(10), *got.APIKey.GroupID)
|
||||
require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_NilCacheInvalidator(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test"}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 7, Status: StatusActive}}
|
||||
// authCacheInvalidator is nil – should not panic
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(7))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.APIKey.GroupID)
|
||||
require.Equal(t, int64(7), *got.APIKey.GroupID)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests: AllowedGroup auto-sync
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AddsAllowedGroup(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
cache := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.APIKey.GroupID)
|
||||
require.Equal(t, int64(10), *got.APIKey.GroupID)
|
||||
// 验证 AddGroupToAllowedGroups 被调用,且参数正确
|
||||
require.True(t, userRepo.addGroupCalled)
|
||||
require.Equal(t, int64(42), userRepo.addedUserID)
|
||||
require.Equal(t, int64(10), userRepo.addedGroupID)
|
||||
// 验证 result 标记了自动授权
|
||||
require.True(t, got.AutoGrantedGroupAccess)
|
||||
require.NotNil(t, got.GrantedGroupID)
|
||||
require.Equal(t, int64(10), *got.GrantedGroupID)
|
||||
require.Equal(t, "Exclusive", got.GrantedGroupName)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupUpdate(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Public", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeStandard}}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
cache := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.APIKey.GroupID)
|
||||
// 非专属分组不触发 AddGroupToAllowedGroups
|
||||
require.False(t, userRepo.addGroupCalled)
|
||||
require.False(t, got.AutoGrantedGroupAccess)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
|
||||
|
||||
// 订阅类型分组应被阻止绑定
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err))
|
||||
require.False(t, userRepo.addGroupCalled)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AllowedGroupAddFails_ReturnsError(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
|
||||
userRepo := &userRepoStubForGroupUpdate{addGroupErr: errors.New("db error")}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
|
||||
|
||||
// 严格模式:AddGroupToAllowedGroups 失败时,整体操作报错
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "add group to user allowed groups")
|
||||
require.True(t, userRepo.addGroupCalled)
|
||||
// apiKey 不应被更新
|
||||
require.Nil(t, apiKeyRepo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind_NoAllowedGroupUpdate(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Exclusive"}}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
cache := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, userRepo: userRepo, authCacheInvalidator: cache}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, got.APIKey.GroupID)
|
||||
// 解绑时不修改 allowed_groups
|
||||
require.False(t, userRepo.addGroupCalled)
|
||||
require.False(t, got.AutoGrantedGroupAccess)
|
||||
}
|
||||
@@ -100,7 +100,10 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
|
||||
2: errors.New("bind failed"),
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
svc := &adminServiceImpl{
|
||||
accountRepo: repo,
|
||||
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "g10"}},
|
||||
}
|
||||
|
||||
groupIDs := []int64{10}
|
||||
schedulable := false
|
||||
@@ -120,6 +123,22 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
|
||||
require.Len(t, result.Results, 3)
|
||||
}
|
||||
|
||||
func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
groupIDs := []int64{10}
|
||||
input := &BulkUpdateAccountsInput{
|
||||
AccountIDs: []int64{1},
|
||||
GroupIDs: &groupIDs,
|
||||
}
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.Nil(t, result)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "group repository not configured")
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict verifies
|
||||
// that the global pre-check detects a conflict with existing group members and returns an
|
||||
// error before any DB write is performed.
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -65,3 +66,32 @@ func TestAdminService_CreateUser_CreateError(t *testing.T) {
|
||||
require.ErrorIs(t, err, createErr)
|
||||
require.Empty(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateUser_AssignsDefaultSubscriptions(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 21}
|
||||
assigner := &defaultSubscriptionAssignerStub{}
|
||||
cfg := &config.Config{
|
||||
Default: config.DefaultConfig{
|
||||
UserBalance: 0,
|
||||
UserConcurrency: 1,
|
||||
},
|
||||
}
|
||||
settingService := NewSettingService(&settingRepoStub{values: map[string]string{
|
||||
SettingKeyDefaultSubscriptions: `[{"group_id":5,"validity_days":30}]`,
|
||||
}}, cfg)
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: repo,
|
||||
settingService: settingService,
|
||||
defaultSubAssigner: assigner,
|
||||
}
|
||||
|
||||
_, err := svc.CreateUser(context.Background(), &CreateUserInput{
|
||||
Email: "new-user@test.com",
|
||||
Password: "password",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, assigner.calls, 1)
|
||||
require.Equal(t, int64(21), assigner.calls[0].UserID)
|
||||
require.Equal(t, int64(5), assigner.calls[0].GroupID)
|
||||
require.Equal(t, 30, assigner.calls[0].ValidityDays)
|
||||
}
|
||||
|
||||
@@ -93,6 +93,10 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
|
||||
panic("unexpected RemoveGroupFromAllowedGroups call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||||
panic("unexpected AddGroupToAllowedGroups call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
panic("unexpected UpdateTotpSecret call")
|
||||
}
|
||||
@@ -344,6 +348,19 @@ func (s *billingCacheStub) InvalidateSubscriptionCache(ctx context.Context, user
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
|
||||
panic("unexpected GetAPIKeyRateLimit call")
|
||||
}
|
||||
func (s *billingCacheStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
|
||||
panic("unexpected SetAPIKeyRateLimit call")
|
||||
}
|
||||
func (s *billingCacheStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||
panic("unexpected UpdateAPIKeyRateLimitUsage call")
|
||||
}
|
||||
func (s *billingCacheStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||||
panic("unexpected InvalidateAPIKeyRateLimit call")
|
||||
}
|
||||
|
||||
func waitForInvalidations(t *testing.T, ch <-chan subscriptionInvalidateCall, expected int) []subscriptionInvalidateCall {
|
||||
t.Helper()
|
||||
calls := make([]subscriptionInvalidateCall, 0, expected)
|
||||
|
||||
106
backend/internal/service/admin_service_list_users_test.go
Normal file
106
backend/internal/service/admin_service_list_users_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type userRepoStubForListUsers struct {
|
||||
userRepoStub
|
||||
users []User
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||
if s.err != nil {
|
||||
return nil, nil, s.err
|
||||
}
|
||||
out := make([]User, len(s.users))
|
||||
copy(out, s.users)
|
||||
return out, &pagination.PaginationResult{
|
||||
Total: int64(len(out)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type userGroupRateRepoStubForListUsers struct {
|
||||
batchCalls int
|
||||
singleCall []int64
|
||||
|
||||
batchErr error
|
||||
batchData map[int64]map[int64]float64
|
||||
|
||||
singleErr map[int64]error
|
||||
singleData map[int64]map[int64]float64
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) GetByUserIDs(_ context.Context, _ []int64) (map[int64]map[int64]float64, error) {
|
||||
s.batchCalls++
|
||||
if s.batchErr != nil {
|
||||
return nil, s.batchErr
|
||||
}
|
||||
return s.batchData, nil
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) GetByUserID(_ context.Context, userID int64) (map[int64]float64, error) {
|
||||
s.singleCall = append(s.singleCall, userID)
|
||||
if err, ok := s.singleErr[userID]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if rates, ok := s.singleData[userID]; ok {
|
||||
return rates, nil
|
||||
}
|
||||
return map[int64]float64{}, nil
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, userID, groupID int64) (*float64, error) {
|
||||
panic("unexpected GetByUserAndGroup call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
|
||||
panic("unexpected SyncUserGroupRates call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
|
||||
panic("unexpected DeleteByGroupID call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) DeleteByUserID(_ context.Context, userID int64) error {
|
||||
panic("unexpected DeleteByUserID call")
|
||||
}
|
||||
|
||||
func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
|
||||
userRepo := &userRepoStubForListUsers{
|
||||
users: []User{
|
||||
{ID: 101, Username: "u1"},
|
||||
{ID: 202, Username: "u2"},
|
||||
},
|
||||
}
|
||||
rateRepo := &userGroupRateRepoStubForListUsers{
|
||||
batchErr: errors.New("batch unavailable"),
|
||||
singleData: map[int64]map[int64]float64{
|
||||
101: {11: 1.1},
|
||||
202: {22: 2.2},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: userRepo,
|
||||
userGroupRateRepo: rateRepo,
|
||||
}
|
||||
|
||||
users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), total)
|
||||
require.Len(t, users, 2)
|
||||
require.Equal(t, 1, rateRepo.batchCalls)
|
||||
require.ElementsMatch(t, []int64{101, 202}, rateRepo.singleCall)
|
||||
require.Equal(t, 1.1, users[0].GroupRates[11])
|
||||
require.Equal(t, 2.2, users[1].GroupRates[22])
|
||||
}
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -2294,7 +2293,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
||||
|
||||
// isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记
|
||||
func isSingleAccountRetry(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(ctxkey.SingleAccountRetry).(bool)
|
||||
v, _ := SingleAccountRetryFromContext(ctx)
|
||||
return v
|
||||
}
|
||||
|
||||
|
||||
@@ -112,7 +112,10 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
|
||||
}
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
client, err := antigravity.NewClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create antigravity client failed: %w", err)
|
||||
}
|
||||
|
||||
// 交换 token
|
||||
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
|
||||
@@ -167,7 +170,10 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
client, err := antigravity.NewClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create antigravity client failed: %w", err)
|
||||
}
|
||||
tokenResp, err := client.RefreshToken(ctx, refreshToken)
|
||||
if err == nil {
|
||||
now := time.Now()
|
||||
@@ -209,7 +215,10 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr
|
||||
}
|
||||
|
||||
// 获取用户信息(email)
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
client, err := antigravity.NewClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create antigravity client failed: %w", err)
|
||||
}
|
||||
userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken)
|
||||
if err != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
|
||||
@@ -309,7 +318,10 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
client, err := antigravity.NewClient(proxyURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create antigravity client failed: %w", err)
|
||||
}
|
||||
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
|
||||
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
@@ -31,7 +32,10 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
accessToken := account.GetCredential("access_token")
|
||||
projectID := account.GetCredential("project_id")
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
client, err := antigravity.NewClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create antigravity client failed: %w", err)
|
||||
}
|
||||
|
||||
// 调用 API 获取配额
|
||||
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
)
|
||||
|
||||
// API Key status constants
|
||||
const (
|
||||
@@ -19,22 +23,41 @@ type APIKey struct {
|
||||
Status string
|
||||
IPWhitelist []string
|
||||
IPBlacklist []string
|
||||
LastUsedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
User *User
|
||||
Group *Group
|
||||
// 预编译的 IP 规则,用于认证热路径避免重复 ParseIP/ParseCIDR。
|
||||
CompiledIPWhitelist *ip.CompiledIPRules `json:"-"`
|
||||
CompiledIPBlacklist *ip.CompiledIPRules `json:"-"`
|
||||
LastUsedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
User *User
|
||||
Group *Group
|
||||
|
||||
// Quota fields
|
||||
Quota float64 // Quota limit in USD (0 = unlimited)
|
||||
QuotaUsed float64 // Used quota amount
|
||||
ExpiresAt *time.Time // Expiration time (nil = never expires)
|
||||
|
||||
// Rate limit fields
|
||||
RateLimit5h float64 // Rate limit in USD per 5h (0 = unlimited)
|
||||
RateLimit1d float64 // Rate limit in USD per 1d (0 = unlimited)
|
||||
RateLimit7d float64 // Rate limit in USD per 7d (0 = unlimited)
|
||||
Usage5h float64 // Used amount in current 5h window
|
||||
Usage1d float64 // Used amount in current 1d window
|
||||
Usage7d float64 // Used amount in current 7d window
|
||||
Window5hStart *time.Time // Start of current 5h window
|
||||
Window1dStart *time.Time // Start of current 1d window
|
||||
Window7dStart *time.Time // Start of current 7d window
|
||||
}
|
||||
|
||||
func (k *APIKey) IsActive() bool {
|
||||
return k.Status == StatusActive
|
||||
}
|
||||
|
||||
// HasRateLimits returns true if any rate limit window is configured
|
||||
func (k *APIKey) HasRateLimits() bool {
|
||||
return k.RateLimit5h > 0 || k.RateLimit1d > 0 || k.RateLimit7d > 0
|
||||
}
|
||||
|
||||
// IsExpired checks if the API key has expired
|
||||
func (k *APIKey) IsExpired() bool {
|
||||
if k.ExpiresAt == nil {
|
||||
@@ -74,3 +97,10 @@ func (k *APIKey) GetDaysUntilExpiry() int {
|
||||
}
|
||||
return int(duration.Hours() / 24)
|
||||
}
|
||||
|
||||
// APIKeyListFilters holds optional filtering parameters for listing API keys.
|
||||
type APIKeyListFilters struct {
|
||||
Search string
|
||||
Status string
|
||||
GroupID *int64 // nil=不筛选, 0=无分组, >0=指定分组
|
||||
}
|
||||
|
||||
@@ -19,6 +19,11 @@ type APIKeyAuthSnapshot struct {
|
||||
|
||||
// Expiration field for API Key expiration feature
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires)
|
||||
|
||||
// Rate limit configuration (only limits, not usage - usage read from Redis at check time)
|
||||
RateLimit5h float64 `json:"rate_limit_5h"`
|
||||
RateLimit1d float64 `json:"rate_limit_1d"`
|
||||
RateLimit7d float64 `json:"rate_limit_7d"`
|
||||
}
|
||||
|
||||
// APIKeyAuthUserSnapshot 用户快照
|
||||
|
||||
@@ -209,6 +209,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
Quota: apiKey.Quota,
|
||||
QuotaUsed: apiKey.QuotaUsed,
|
||||
ExpiresAt: apiKey.ExpiresAt,
|
||||
RateLimit5h: apiKey.RateLimit5h,
|
||||
RateLimit1d: apiKey.RateLimit1d,
|
||||
RateLimit7d: apiKey.RateLimit7d,
|
||||
User: APIKeyAuthUserSnapshot{
|
||||
ID: apiKey.User.ID,
|
||||
Status: apiKey.User.Status,
|
||||
@@ -263,6 +266,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
Quota: snapshot.Quota,
|
||||
QuotaUsed: snapshot.QuotaUsed,
|
||||
ExpiresAt: snapshot.ExpiresAt,
|
||||
RateLimit5h: snapshot.RateLimit5h,
|
||||
RateLimit1d: snapshot.RateLimit1d,
|
||||
RateLimit7d: snapshot.RateLimit7d,
|
||||
User: &User{
|
||||
ID: snapshot.User.ID,
|
||||
Status: snapshot.User.Status,
|
||||
@@ -300,5 +306,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
|
||||
}
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey
|
||||
}
|
||||
|
||||
@@ -30,6 +30,11 @@ var (
|
||||
ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期")
|
||||
// ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted")
|
||||
ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完")
|
||||
|
||||
// Rate limit errors
|
||||
ErrAPIKeyRateLimit5hExceeded = infraerrors.TooManyRequests("API_KEY_RATE_5H_EXCEEDED", "api key 5小时限额已用完")
|
||||
ErrAPIKeyRateLimit1dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_1D_EXCEEDED", "api key 日限额已用完")
|
||||
ErrAPIKeyRateLimit7dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_7D_EXCEEDED", "api key 7天限额已用完")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -50,7 +55,7 @@ type APIKeyRepository interface {
|
||||
Update(ctx context.Context, key *APIKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error)
|
||||
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
@@ -64,6 +69,21 @@ type APIKeyRepository interface {
|
||||
// Quota methods
|
||||
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error)
|
||||
UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error
|
||||
|
||||
// Rate limit methods
|
||||
IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error
|
||||
ResetRateLimitWindows(ctx context.Context, id int64) error
|
||||
GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error)
|
||||
}
|
||||
|
||||
// APIKeyRateLimitData holds rate limit usage and window state for an API key.
|
||||
type APIKeyRateLimitData struct {
|
||||
Usage5h float64
|
||||
Usage1d float64
|
||||
Usage7d float64
|
||||
Window5hStart *time.Time
|
||||
Window1dStart *time.Time
|
||||
Window7dStart *time.Time
|
||||
}
|
||||
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
@@ -102,6 +122,11 @@ type CreateAPIKeyRequest struct {
|
||||
// Quota fields
|
||||
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
|
||||
ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires)
|
||||
|
||||
// Rate limit fields (0 = unlimited)
|
||||
RateLimit5h float64 `json:"rate_limit_5h"`
|
||||
RateLimit1d float64 `json:"rate_limit_1d"`
|
||||
RateLimit7d float64 `json:"rate_limit_7d"`
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest 更新API Key请求
|
||||
@@ -117,22 +142,34 @@ type UpdateAPIKeyRequest struct {
|
||||
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change)
|
||||
ClearExpiration bool `json:"-"` // Clear expiration (internal use)
|
||||
ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0
|
||||
|
||||
// Rate limit fields (nil = no change, 0 = unlimited)
|
||||
RateLimit5h *float64 `json:"rate_limit_5h"`
|
||||
RateLimit1d *float64 `json:"rate_limit_1d"`
|
||||
RateLimit7d *float64 `json:"rate_limit_7d"`
|
||||
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // Reset all usage counters to 0
|
||||
}
|
||||
|
||||
// APIKeyService API Key服务
|
||||
// RateLimitCacheInvalidator invalidates rate limit cache entries on manual reset.
|
||||
type RateLimitCacheInvalidator interface {
|
||||
InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error
|
||||
}
|
||||
|
||||
type APIKeyService struct {
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache APIKeyCache
|
||||
cfg *config.Config
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time)
|
||||
lastUsedTouchSF singleflight.Group
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache APIKeyCache
|
||||
rateLimitCacheInvalid RateLimitCacheInvalidator // optional: invalidate Redis rate limit cache
|
||||
cfg *config.Config
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time)
|
||||
lastUsedTouchSF singleflight.Group
|
||||
}
|
||||
|
||||
// NewAPIKeyService 创建API Key服务实例
|
||||
@@ -158,6 +195,20 @@ func NewAPIKeyService(
|
||||
return svc
|
||||
}
|
||||
|
||||
// SetRateLimitCacheInvalidator sets the optional rate limit cache invalidator.
|
||||
// Called after construction (e.g. in wire) to avoid circular dependencies.
|
||||
func (s *APIKeyService) SetRateLimitCacheInvalidator(inv RateLimitCacheInvalidator) {
|
||||
s.rateLimitCacheInvalid = inv
|
||||
}
|
||||
|
||||
func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) {
|
||||
if apiKey == nil {
|
||||
return
|
||||
}
|
||||
apiKey.CompiledIPWhitelist = ip.CompileIPRules(apiKey.IPWhitelist)
|
||||
apiKey.CompiledIPBlacklist = ip.CompileIPRules(apiKey.IPBlacklist)
|
||||
}
|
||||
|
||||
// GenerateKey 生成随机API Key
|
||||
func (s *APIKeyService) GenerateKey() (string, error) {
|
||||
// 生成32字节随机数据
|
||||
@@ -319,6 +370,9 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
Quota: req.Quota,
|
||||
QuotaUsed: 0,
|
||||
RateLimit5h: req.RateLimit5h,
|
||||
RateLimit1d: req.RateLimit1d,
|
||||
RateLimit7d: req.RateLimit7d,
|
||||
}
|
||||
|
||||
// Set expiration time if specified
|
||||
@@ -332,13 +386,14 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
}
|
||||
|
||||
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// List 获取用户的API Key列表
|
||||
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, filters)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||
}
|
||||
@@ -363,6 +418,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
@@ -375,6 +431,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey, nil
|
||||
}
|
||||
}
|
||||
@@ -391,6 +448,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey, nil
|
||||
}
|
||||
} else {
|
||||
@@ -402,6 +460,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey, nil
|
||||
}
|
||||
}
|
||||
@@ -411,6 +470,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
apiKey.Key = key
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
@@ -505,11 +565,37 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
apiKey.IPWhitelist = req.IPWhitelist
|
||||
apiKey.IPBlacklist = req.IPBlacklist
|
||||
|
||||
// Update rate limit configuration
|
||||
if req.RateLimit5h != nil {
|
||||
apiKey.RateLimit5h = *req.RateLimit5h
|
||||
}
|
||||
if req.RateLimit1d != nil {
|
||||
apiKey.RateLimit1d = *req.RateLimit1d
|
||||
}
|
||||
if req.RateLimit7d != nil {
|
||||
apiKey.RateLimit7d = *req.RateLimit7d
|
||||
}
|
||||
resetRateLimit := req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage
|
||||
if resetRateLimit {
|
||||
apiKey.Usage5h = 0
|
||||
apiKey.Usage1d = 0
|
||||
apiKey.Usage7d = 0
|
||||
apiKey.Window5hStart = nil
|
||||
apiKey.Window1dStart = nil
|
||||
apiKey.Window7dStart = nil
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("update api key: %w", err)
|
||||
}
|
||||
|
||||
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
|
||||
// Invalidate Redis rate limit cache so reset takes effect immediately
|
||||
if resetRateLimit && s.rateLimitCacheInvalid != nil {
|
||||
_ = s.rateLimitCacheInvalid.InvalidateAPIKeyRateLimit(ctx, apiKey.ID)
|
||||
}
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
@@ -731,3 +817,16 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRateLimitData returns rate limit usage and window state for an API key.
|
||||
func (s *APIKeyService) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
|
||||
return s.apiKeyRepo.GetRateLimitData(ctx, id)
|
||||
}
|
||||
|
||||
// UpdateRateLimitUsage atomically increments rate limit usage counters in the DB.
|
||||
func (s *APIKeyService) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||
if cost <= 0 {
|
||||
return nil
|
||||
}
|
||||
return s.apiKeyRepo.IncrementRateLimitUsage(ctx, apiKeyID, cost)
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
|
||||
@@ -106,6 +106,15 @@ func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount
|
||||
func (s *authRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
panic("unexpected UpdateLastUsed call")
|
||||
}
|
||||
func (s *authRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
panic("unexpected IncrementRateLimitUsage call")
|
||||
}
|
||||
func (s *authRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
panic("unexpected ResetRateLimitWindows call")
|
||||
}
|
||||
func (s *authRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
|
||||
panic("unexpected GetRateLimitData call")
|
||||
}
|
||||
|
||||
type authCacheStub struct {
|
||||
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||
|
||||
@@ -81,7 +81,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
|
||||
// 以下是接口要求实现但本测试不关心的方法
|
||||
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
|
||||
@@ -134,6 +134,18 @@ func (s *apiKeyRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt ti
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
panic("unexpected IncrementRateLimitUsage call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
panic("unexpected ResetRateLimitWindows call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
|
||||
panic("unexpected GetRateLimitData call")
|
||||
}
|
||||
|
||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||
//
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -33,6 +34,7 @@ var (
|
||||
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
|
||||
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required")
|
||||
@@ -56,15 +58,20 @@ type JWTClaims struct {
|
||||
|
||||
// AuthService 认证服务
|
||||
type AuthService struct {
|
||||
userRepo UserRepository
|
||||
redeemRepo RedeemCodeRepository
|
||||
refreshTokenCache RefreshTokenCache
|
||||
cfg *config.Config
|
||||
settingService *SettingService
|
||||
emailService *EmailService
|
||||
turnstileService *TurnstileService
|
||||
emailQueueService *EmailQueueService
|
||||
promoService *PromoService
|
||||
userRepo UserRepository
|
||||
redeemRepo RedeemCodeRepository
|
||||
refreshTokenCache RefreshTokenCache
|
||||
cfg *config.Config
|
||||
settingService *SettingService
|
||||
emailService *EmailService
|
||||
turnstileService *TurnstileService
|
||||
emailQueueService *EmailQueueService
|
||||
promoService *PromoService
|
||||
defaultSubAssigner DefaultSubscriptionAssigner
|
||||
}
|
||||
|
||||
type DefaultSubscriptionAssigner interface {
|
||||
AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error)
|
||||
}
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
@@ -78,17 +85,19 @@ func NewAuthService(
|
||||
turnstileService *TurnstileService,
|
||||
emailQueueService *EmailQueueService,
|
||||
promoService *PromoService,
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
userRepo: userRepo,
|
||||
redeemRepo: redeemRepo,
|
||||
refreshTokenCache: refreshTokenCache,
|
||||
cfg: cfg,
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
emailQueueService: emailQueueService,
|
||||
promoService: promoService,
|
||||
userRepo: userRepo,
|
||||
redeemRepo: redeemRepo,
|
||||
refreshTokenCache: refreshTokenCache,
|
||||
cfg: cfg,
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
emailQueueService: emailQueueService,
|
||||
promoService: promoService,
|
||||
defaultSubAssigner: defaultSubAssigner,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,6 +117,9 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
if isReservedEmail(email) {
|
||||
return "", nil, ErrEmailReserved
|
||||
}
|
||||
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// 检查是否需要邀请码
|
||||
var invitationRedeemCode *RedeemCode
|
||||
@@ -188,6 +200,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
|
||||
// 标记邀请码为已使用(如果使用了邀请码)
|
||||
if invitationRedeemCode != nil {
|
||||
@@ -233,6 +246,9 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
if isReservedEmail(email) {
|
||||
return ErrEmailReserved
|
||||
}
|
||||
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
@@ -271,6 +287,9 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
|
||||
if isReservedEmail(email) {
|
||||
return nil, ErrEmailReserved
|
||||
}
|
||||
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
@@ -477,6 +496,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
|
||||
@@ -572,6 +592,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
|
||||
@@ -597,6 +618,49 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
|
||||
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
|
||||
return
|
||||
}
|
||||
items := s.settingService.GetDefaultSubscriptions(ctx)
|
||||
for _, item := range items {
|
||||
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
|
||||
UserID: userID,
|
||||
GroupID: item.GroupID,
|
||||
ValidityDays: item.ValidityDays,
|
||||
Notes: "auto assigned by default user subscriptions setting",
|
||||
}); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
|
||||
if s.settingService == nil {
|
||||
return nil
|
||||
}
|
||||
whitelist := s.settingService.GetRegistrationEmailSuffixWhitelist(ctx)
|
||||
if !IsRegistrationEmailSuffixAllowed(email, whitelist) {
|
||||
return buildEmailSuffixNotAllowedError(whitelist)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildEmailSuffixNotAllowedError(whitelist []string) error {
|
||||
if len(whitelist) == 0 {
|
||||
return ErrEmailSuffixNotAllowed
|
||||
}
|
||||
|
||||
allowed := strings.Join(whitelist, ", ")
|
||||
return infraerrors.BadRequest(
|
||||
"EMAIL_SUFFIX_NOT_ALLOWED",
|
||||
fmt.Sprintf("email suffix is not allowed, allowed suffixes: %s", allowed),
|
||||
).WithMetadata(map[string]string{
|
||||
"allowed_suffixes": strings.Join(whitelist, ","),
|
||||
"allowed_suffix_count": strconv.Itoa(len(whitelist)),
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateToken 验证JWT token并返回用户声明
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -56,6 +57,21 @@ type emailCacheStub struct {
|
||||
err error
|
||||
}
|
||||
|
||||
type defaultSubscriptionAssignerStub struct {
|
||||
calls []AssignSubscriptionInput
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
|
||||
if input != nil {
|
||||
s.calls = append(s.calls, *input)
|
||||
}
|
||||
if s.err != nil {
|
||||
return nil, false, s.err
|
||||
}
|
||||
return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
@@ -123,6 +139,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
||||
nil,
|
||||
nil,
|
||||
nil, // promoService
|
||||
nil, // defaultSubAssigner
|
||||
)
|
||||
}
|
||||
|
||||
@@ -215,6 +232,51 @@ func TestAuthService_Register_ReservedEmail(t *testing.T) {
|
||||
require.ErrorIs(t, err, ErrEmailReserved)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailSuffixNotAllowed(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@other.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
|
||||
appErr := infraerrors.FromError(err)
|
||||
require.Contains(t, appErr.Message, "@example.com")
|
||||
require.Contains(t, appErr.Message, "@company.com")
|
||||
require.Equal(t, "EMAIL_SUFFIX_NOT_ALLOWED", appErr.Reason)
|
||||
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
|
||||
require.Equal(t, "@example.com,@company.com", appErr.Metadata["allowed_suffixes"])
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailSuffixAllowed(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 8}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: `["example.com"]`,
|
||||
}, nil)
|
||||
|
||||
_, user, err := service.Register(context.Background(), "user@example.com", "password")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.Equal(t, int64(8), user.ID)
|
||||
}
|
||||
|
||||
func TestAuthService_SendVerifyCode_EmailSuffixNotAllowed(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
|
||||
}, nil)
|
||||
|
||||
err := service.SendVerifyCode(context.Background(), "user@other.com")
|
||||
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
|
||||
appErr := infraerrors.FromError(err)
|
||||
require.Contains(t, appErr.Message, "@example.com")
|
||||
require.Contains(t, appErr.Message, "@company.com")
|
||||
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
|
||||
}
|
||||
|
||||
func TestAuthService_Register_CreateError(t *testing.T) {
|
||||
repo := &userRepoStub{createErr: errors.New("create failed")}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
@@ -381,3 +443,23 @@ func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) {
|
||||
|
||||
require.WithinDuration(t, claims.IssuedAt.Time.Add(90*time.Minute), claims.ExpiresAt.Time, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 42}
|
||||
assigner := &defaultSubscriptionAssignerStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
|
||||
}, nil)
|
||||
service.defaultSubAssigner = assigner
|
||||
|
||||
_, user, err := service.Register(context.Background(), "default-sub@test.com", "password")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.Len(t, assigner.calls, 2)
|
||||
require.Equal(t, int64(42), assigner.calls[0].UserID)
|
||||
require.Equal(t, int64(11), assigner.calls[0].GroupID)
|
||||
require.Equal(t, 30, assigner.calls[0].ValidityDays)
|
||||
require.Equal(t, int64(12), assigner.calls[1].GroupID)
|
||||
require.Equal(t, 7, assigner.calls[1].ValidityDays)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type turnstileVerifierSpy struct {
|
||||
called int
|
||||
lastToken string
|
||||
result *TurnstileVerifyResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *turnstileVerifierSpy) VerifyToken(_ context.Context, _ string, token, _ string) (*TurnstileVerifyResponse, error) {
|
||||
s.called++
|
||||
s.lastToken = token
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
if s.result != nil {
|
||||
return s.result, nil
|
||||
}
|
||||
return &TurnstileVerifyResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier TurnstileVerifier) *AuthService {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Mode: "release",
|
||||
},
|
||||
Turnstile: config.TurnstileConfig{
|
||||
Required: true,
|
||||
},
|
||||
}
|
||||
|
||||
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
|
||||
turnstileService := NewTurnstileService(settingService, verifier)
|
||||
|
||||
return NewAuthService(
|
||||
&userRepoStub{},
|
||||
nil, // redeemRepo
|
||||
nil, // refreshTokenCache
|
||||
cfg,
|
||||
settingService,
|
||||
nil, // emailService
|
||||
turnstileService,
|
||||
nil, // emailQueueService
|
||||
nil, // promoService
|
||||
nil, // defaultSubAssigner
|
||||
)
|
||||
}
|
||||
|
||||
func TestAuthService_VerifyTurnstileForRegister_SkipWhenEmailVerifyCodeProvided(t *testing.T) {
|
||||
verifier := &turnstileVerifierSpy{}
|
||||
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
SettingKeyTurnstileEnabled: "true",
|
||||
SettingKeyTurnstileSecretKey: "secret",
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, verifier)
|
||||
|
||||
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "123456")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, verifier.called)
|
||||
}
|
||||
|
||||
func TestAuthService_VerifyTurnstileForRegister_RequireWhenVerifyCodeMissing(t *testing.T) {
|
||||
verifier := &turnstileVerifierSpy{}
|
||||
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
SettingKeyTurnstileEnabled: "true",
|
||||
SettingKeyTurnstileSecretKey: "secret",
|
||||
}, verifier)
|
||||
|
||||
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "")
|
||||
require.ErrorIs(t, err, ErrTurnstileVerificationFailed)
|
||||
}
|
||||
|
||||
func TestAuthService_VerifyTurnstileForRegister_NoSkipWhenEmailVerifyDisabled(t *testing.T) {
|
||||
verifier := &turnstileVerifierSpy{}
|
||||
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
|
||||
SettingKeyEmailVerifyEnabled: "false",
|
||||
SettingKeyTurnstileEnabled: "true",
|
||||
SettingKeyTurnstileSecretKey: "secret",
|
||||
}, verifier)
|
||||
|
||||
err := service.VerifyTurnstileForRegister(context.Background(), "turnstile-token", "127.0.0.1", "123456")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, verifier.called)
|
||||
require.Equal(t, "turnstile-token", verifier.lastToken)
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
@@ -38,6 +40,7 @@ const (
|
||||
cacheWriteSetSubscription
|
||||
cacheWriteUpdateSubscriptionUsage
|
||||
cacheWriteDeductBalance
|
||||
cacheWriteUpdateRateLimitUsage
|
||||
)
|
||||
|
||||
// 异步缓存写入工作池配置
|
||||
@@ -58,6 +61,7 @@ const (
|
||||
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
|
||||
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
|
||||
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
|
||||
balanceLoadTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
// cacheWriteTask 缓存写入任务
|
||||
@@ -65,23 +69,33 @@ type cacheWriteTask struct {
|
||||
kind cacheWriteKind
|
||||
userID int64
|
||||
groupID int64
|
||||
apiKeyID int64
|
||||
balance float64
|
||||
amount float64
|
||||
subscriptionData *subscriptionCacheData
|
||||
}
|
||||
|
||||
// apiKeyRateLimitLoader defines the interface for loading rate limit data from DB.
|
||||
type apiKeyRateLimitLoader interface {
|
||||
GetRateLimitData(ctx context.Context, keyID int64) (*APIKeyRateLimitData, error)
|
||||
}
|
||||
|
||||
// BillingCacheService 计费缓存服务
|
||||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||||
type BillingCacheService struct {
|
||||
cache BillingCache
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
cfg *config.Config
|
||||
circuitBreaker *billingCircuitBreaker
|
||||
cache BillingCache
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
apiKeyRateLimitLoader apiKeyRateLimitLoader
|
||||
cfg *config.Config
|
||||
circuitBreaker *billingCircuitBreaker
|
||||
|
||||
cacheWriteChan chan cacheWriteTask
|
||||
cacheWriteWg sync.WaitGroup
|
||||
cacheWriteStopOnce sync.Once
|
||||
cacheWriteMu sync.RWMutex
|
||||
stopped atomic.Bool
|
||||
balanceLoadSF singleflight.Group
|
||||
// 丢弃日志节流计数器(减少高负载下日志噪音)
|
||||
cacheWriteDropFullCount uint64
|
||||
cacheWriteDropFullLastLog int64
|
||||
@@ -90,12 +104,13 @@ type BillingCacheService struct {
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService {
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService {
|
||||
svc := &BillingCacheService{
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
cfg: cfg,
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
apiKeyRateLimitLoader: apiKeyRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
|
||||
svc.startCacheWriteWorkers()
|
||||
@@ -105,35 +120,52 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
|
||||
// Stop 关闭缓存写入工作池
|
||||
func (s *BillingCacheService) Stop() {
|
||||
s.cacheWriteStopOnce.Do(func() {
|
||||
if s.cacheWriteChan == nil {
|
||||
s.stopped.Store(true)
|
||||
|
||||
s.cacheWriteMu.Lock()
|
||||
ch := s.cacheWriteChan
|
||||
if ch != nil {
|
||||
close(ch)
|
||||
}
|
||||
s.cacheWriteMu.Unlock()
|
||||
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
close(s.cacheWriteChan)
|
||||
s.cacheWriteWg.Wait()
|
||||
s.cacheWriteChan = nil
|
||||
|
||||
s.cacheWriteMu.Lock()
|
||||
if s.cacheWriteChan == ch {
|
||||
s.cacheWriteChan = nil
|
||||
}
|
||||
s.cacheWriteMu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) startCacheWriteWorkers() {
|
||||
s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize)
|
||||
ch := make(chan cacheWriteTask, cacheWriteBufferSize)
|
||||
s.cacheWriteChan = ch
|
||||
for i := 0; i < cacheWriteWorkerCount; i++ {
|
||||
s.cacheWriteWg.Add(1)
|
||||
go s.cacheWriteWorker()
|
||||
go s.cacheWriteWorker(ch)
|
||||
}
|
||||
}
|
||||
|
||||
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。
|
||||
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
|
||||
if s.cacheWriteChan == nil {
|
||||
if s.stopped.Load() {
|
||||
s.logCacheWriteDrop(task, "closed")
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
// 队列已关闭时可能触发 panic,记录后静默失败。
|
||||
s.logCacheWriteDrop(task, "closed")
|
||||
enqueued = false
|
||||
}
|
||||
}()
|
||||
|
||||
s.cacheWriteMu.RLock()
|
||||
defer s.cacheWriteMu.RUnlock()
|
||||
|
||||
if s.cacheWriteChan == nil {
|
||||
s.logCacheWriteDrop(task, "closed")
|
||||
return false
|
||||
}
|
||||
|
||||
select {
|
||||
case s.cacheWriteChan <- task:
|
||||
return true
|
||||
@@ -144,9 +176,9 @@ func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued b
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) cacheWriteWorker() {
|
||||
func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) {
|
||||
defer s.cacheWriteWg.Done()
|
||||
for task := range s.cacheWriteChan {
|
||||
for task := range ch {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
switch task.kind {
|
||||
case cacheWriteSetBalance:
|
||||
@@ -165,6 +197,12 @@ func (s *BillingCacheService) cacheWriteWorker() {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err)
|
||||
}
|
||||
}
|
||||
case cacheWriteUpdateRateLimitUsage:
|
||||
if s.cache != nil {
|
||||
if err := s.cache.UpdateAPIKeyRateLimitUsage(ctx, task.apiKeyID, task.amount); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: update rate limit usage cache failed for api key %d: %v", task.apiKeyID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
@@ -181,6 +219,8 @@ func cacheWriteKindName(kind cacheWriteKind) string {
|
||||
return "update_subscription_usage"
|
||||
case cacheWriteDeductBalance:
|
||||
return "deduct_balance"
|
||||
case cacheWriteUpdateRateLimitUsage:
|
||||
return "update_rate_limit_usage"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
@@ -243,19 +283,31 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库读取
|
||||
balance, err = s.getUserBalanceFromDB(ctx, userID)
|
||||
// 缓存未命中:singleflight 合并同一 userID 的并发回源请求。
|
||||
value, err, _ := s.balanceLoadSF.Do(strconv.FormatInt(userID, 10), func() (any, error) {
|
||||
loadCtx, cancel := context.WithTimeout(context.Background(), balanceLoadTimeout)
|
||||
defer cancel()
|
||||
|
||||
balance, err := s.getUserBalanceFromDB(loadCtx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
_ = s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetBalance,
|
||||
userID: userID,
|
||||
balance: balance,
|
||||
})
|
||||
return balance, nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
_ = s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetBalance,
|
||||
userID: userID,
|
||||
balance: balance,
|
||||
})
|
||||
|
||||
balance, ok := value.(float64)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("unexpected balance type: %T", value)
|
||||
}
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
@@ -441,6 +493,137 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// API Key 限速缓存方法
|
||||
// ============================================
|
||||
|
||||
// checkAPIKeyRateLimits checks rate limit windows for an API key.
|
||||
// It loads usage from Redis cache (falling back to DB on cache miss),
|
||||
// resets expired windows in-memory and triggers async DB reset,
|
||||
// and returns an error if any window limit is exceeded.
|
||||
func (s *BillingCacheService) checkAPIKeyRateLimits(ctx context.Context, apiKey *APIKey) error {
|
||||
if s.cache == nil {
|
||||
// No cache: fall back to reading from DB directly
|
||||
if s.apiKeyRateLimitLoader == nil {
|
||||
return nil
|
||||
}
|
||||
data, err := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID)
|
||||
if err != nil {
|
||||
return nil // Don't block requests on DB errors
|
||||
}
|
||||
return s.evaluateRateLimits(ctx, apiKey, data.Usage5h, data.Usage1d, data.Usage7d,
|
||||
data.Window5hStart, data.Window1dStart, data.Window7dStart)
|
||||
}
|
||||
|
||||
cacheData, err := s.cache.GetAPIKeyRateLimit(ctx, apiKey.ID)
|
||||
if err != nil {
|
||||
// Cache miss: load from DB and populate cache
|
||||
if s.apiKeyRateLimitLoader == nil {
|
||||
return nil
|
||||
}
|
||||
dbData, dbErr := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID)
|
||||
if dbErr != nil {
|
||||
return nil // Don't block requests on DB errors
|
||||
}
|
||||
// Build cache entry from DB data
|
||||
cacheEntry := &APIKeyRateLimitCacheData{
|
||||
Usage5h: dbData.Usage5h,
|
||||
Usage1d: dbData.Usage1d,
|
||||
Usage7d: dbData.Usage7d,
|
||||
}
|
||||
if dbData.Window5hStart != nil {
|
||||
cacheEntry.Window5h = dbData.Window5hStart.Unix()
|
||||
}
|
||||
if dbData.Window1dStart != nil {
|
||||
cacheEntry.Window1d = dbData.Window1dStart.Unix()
|
||||
}
|
||||
if dbData.Window7dStart != nil {
|
||||
cacheEntry.Window7d = dbData.Window7dStart.Unix()
|
||||
}
|
||||
_ = s.cache.SetAPIKeyRateLimit(ctx, apiKey.ID, cacheEntry)
|
||||
cacheData = cacheEntry
|
||||
}
|
||||
|
||||
var w5h, w1d, w7d *time.Time
|
||||
if cacheData.Window5h > 0 {
|
||||
t := time.Unix(cacheData.Window5h, 0)
|
||||
w5h = &t
|
||||
}
|
||||
if cacheData.Window1d > 0 {
|
||||
t := time.Unix(cacheData.Window1d, 0)
|
||||
w1d = &t
|
||||
}
|
||||
if cacheData.Window7d > 0 {
|
||||
t := time.Unix(cacheData.Window7d, 0)
|
||||
w7d = &t
|
||||
}
|
||||
return s.evaluateRateLimits(ctx, apiKey, cacheData.Usage5h, cacheData.Usage1d, cacheData.Usage7d, w5h, w1d, w7d)
|
||||
}
|
||||
|
||||
// evaluateRateLimits checks usage against limits, triggering async resets for expired windows.
|
||||
func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *APIKey, usage5h, usage1d, usage7d float64, w5h, w1d, w7d *time.Time) error {
|
||||
needsReset := false
|
||||
|
||||
// Reset expired windows in-memory for check purposes
|
||||
if w5h != nil && time.Since(*w5h) >= 5*time.Hour {
|
||||
usage5h = 0
|
||||
needsReset = true
|
||||
}
|
||||
if w1d != nil && time.Since(*w1d) >= 24*time.Hour {
|
||||
usage1d = 0
|
||||
needsReset = true
|
||||
}
|
||||
if w7d != nil && time.Since(*w7d) >= 7*24*time.Hour {
|
||||
usage7d = 0
|
||||
needsReset = true
|
||||
}
|
||||
|
||||
// Trigger async DB reset if any window expired
|
||||
if needsReset {
|
||||
keyID := apiKey.ID
|
||||
go func() {
|
||||
resetCtx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
defer cancel()
|
||||
if s.apiKeyRateLimitLoader != nil {
|
||||
// Use the repo directly - reset then reload cache
|
||||
if loader, ok := s.apiKeyRateLimitLoader.(interface {
|
||||
ResetRateLimitWindows(ctx context.Context, id int64) error
|
||||
}); ok {
|
||||
_ = loader.ResetRateLimitWindows(resetCtx, keyID)
|
||||
}
|
||||
}
|
||||
// Invalidate cache so next request loads fresh data
|
||||
if s.cache != nil {
|
||||
_ = s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Check limits
|
||||
if apiKey.RateLimit5h > 0 && usage5h >= apiKey.RateLimit5h {
|
||||
return ErrAPIKeyRateLimit5hExceeded
|
||||
}
|
||||
if apiKey.RateLimit1d > 0 && usage1d >= apiKey.RateLimit1d {
|
||||
return ErrAPIKeyRateLimit1dExceeded
|
||||
}
|
||||
if apiKey.RateLimit7d > 0 && usage7d >= apiKey.RateLimit7d {
|
||||
return ErrAPIKeyRateLimit7dExceeded
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueueUpdateAPIKeyRateLimitUsage asynchronously updates rate limit usage in the cache.
|
||||
func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, cost float64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteUpdateRateLimitUsage,
|
||||
apiKeyID: apiKeyID,
|
||||
amount: cost,
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 统一检查方法
|
||||
// ============================================
|
||||
@@ -461,10 +644,23 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
|
||||
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
|
||||
|
||||
if isSubscriptionMode {
|
||||
return s.checkSubscriptionEligibility(ctx, user.ID, group, subscription)
|
||||
if err := s.checkSubscriptionEligibility(ctx, user.ID, group, subscription); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := s.checkBalanceEligibility(ctx, user.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return s.checkBalanceEligibility(ctx, user.ID)
|
||||
// Check API Key rate limits (applies to both billing modes)
|
||||
if apiKey != nil && apiKey.HasRateLimits() {
|
||||
if err := s.checkAPIKeyRateLimits(ctx, apiKey); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkBalanceEligibility 检查余额模式资格
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type billingCacheMissStub struct {
|
||||
setBalanceCalls atomic.Int64
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
return 0, errors.New("cache miss")
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
s.setBalanceCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
|
||||
return nil, errors.New("cache miss")
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
|
||||
return nil, errors.New("cache miss")
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type balanceLoadUserRepoStub struct {
|
||||
mockUserRepo
|
||||
calls atomic.Int64
|
||||
delay time.Duration
|
||||
balance float64
|
||||
}
|
||||
|
||||
func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
|
||||
s.calls.Add(1)
|
||||
if s.delay > 0 {
|
||||
select {
|
||||
case <-time.After(s.delay):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
return &User{ID: id, Balance: s.balance}, nil
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
|
||||
cache := &billingCacheMissStub{}
|
||||
userRepo := &balanceLoadUserRepoStub{
|
||||
delay: 80 * time.Millisecond,
|
||||
balance: 12.34,
|
||||
}
|
||||
svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
const goroutines = 16
|
||||
start := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, goroutines)
|
||||
balCh := make(chan float64, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
bal, err := svc.GetUserBalance(context.Background(), 99)
|
||||
errCh <- err
|
||||
balCh <- bal
|
||||
}()
|
||||
}
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
close(balCh)
|
||||
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
for bal := range balCh {
|
||||
require.Equal(t, 12.34, bal)
|
||||
}
|
||||
|
||||
require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并")
|
||||
require.Eventually(t, func() bool {
|
||||
return cache.setBalanceCalls.Load() >= 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
@@ -52,9 +52,25 @@ func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
start := time.Now()
|
||||
@@ -73,3 +89,16 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
|
||||
}, 2*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
|
||||
svc.Stop()
|
||||
|
||||
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteDeductBalance,
|
||||
userID: 1,
|
||||
amount: 1,
|
||||
})
|
||||
require.False(t, enqueued)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,16 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// APIKeyRateLimitCacheData holds rate limit usage data cached in Redis.
|
||||
type APIKeyRateLimitCacheData struct {
|
||||
Usage5h float64 `json:"usage_5h"`
|
||||
Usage1d float64 `json:"usage_1d"`
|
||||
Usage7d float64 `json:"usage_7d"`
|
||||
Window5h int64 `json:"window_5h"` // unix timestamp, 0 = not started
|
||||
Window1d int64 `json:"window_1d"`
|
||||
Window7d int64 `json:"window_7d"`
|
||||
}
|
||||
|
||||
// BillingCache defines cache operations for billing service
|
||||
type BillingCache interface {
|
||||
// Balance operations
|
||||
@@ -23,6 +33,12 @@ type BillingCache interface {
|
||||
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
|
||||
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
|
||||
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
|
||||
|
||||
// API Key rate limit operations
|
||||
GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error)
|
||||
SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error
|
||||
UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error
|
||||
InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error
|
||||
}
|
||||
|
||||
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
|
||||
|
||||
@@ -63,7 +63,7 @@ func TestCalculateImageCost_RateMultiplier(t *testing.T) {
|
||||
|
||||
// 费率倍数 1.5x
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5)
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5
|
||||
require.InDelta(t, 0.3015, cost.ActualCost, 0.0001) // ActualCost = 0.201 * 1.5
|
||||
|
||||
// 费率倍数 2.0x
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
@@ -17,6 +18,9 @@ var (
|
||||
// User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感)
|
||||
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
||||
|
||||
// 带捕获组的版本提取正则
|
||||
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
|
||||
|
||||
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
|
||||
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
|
||||
|
||||
@@ -78,7 +82,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
|
||||
|
||||
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过
|
||||
// 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt
|
||||
if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku {
|
||||
if isMaxTokensOneHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(r.Context()); ok && isMaxTokensOneHaiku {
|
||||
return true // 绕过 system prompt 检查,UA 已在 Step 1 验证
|
||||
}
|
||||
|
||||
@@ -270,3 +274,55 @@ func IsClaudeCodeClient(ctx context.Context) bool {
|
||||
func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context {
|
||||
return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode)
|
||||
}
|
||||
|
||||
// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
|
||||
// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
|
||||
func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
|
||||
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
|
||||
if len(matches) >= 2 {
|
||||
return matches[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中
|
||||
func SetClaudeCodeVersion(ctx context.Context, version string) context.Context {
|
||||
return context.WithValue(ctx, ctxkey.ClaudeCodeVersion, version)
|
||||
}
|
||||
|
||||
// GetClaudeCodeVersion 从 context 中获取 Claude Code 版本号
|
||||
func GetClaudeCodeVersion(ctx context.Context) string {
|
||||
if v, ok := ctx.Value(ctxkey.ClaudeCodeVersion).(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// CompareVersions 比较两个 semver 版本号
|
||||
// 返回: -1 (a < b), 0 (a == b), 1 (a > b)
|
||||
func CompareVersions(a, b string) int {
|
||||
aParts := parseSemver(a)
|
||||
bParts := parseSemver(b)
|
||||
for i := 0; i < 3; i++ {
|
||||
if aParts[i] < bParts[i] {
|
||||
return -1
|
||||
}
|
||||
if aParts[i] > bParts[i] {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// parseSemver 解析 semver 版本号为 [major, minor, patch]
|
||||
func parseSemver(v string) [3]int {
|
||||
v = strings.TrimPrefix(v, "v")
|
||||
parts := strings.Split(v, ".")
|
||||
result := [3]int{0, 0, 0}
|
||||
for i := 0; i < len(parts) && i < 3; i++ {
|
||||
if parsed, err := strconv.Atoi(parts[i]); err == nil {
|
||||
result[i] = parsed
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -56,3 +56,51 @@ func TestClaudeCodeValidator_NonMessagesPathUAOnly(t *testing.T) {
|
||||
ok := validator.Validate(req, nil)
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestExtractVersion(t *testing.T) {
|
||||
v := NewClaudeCodeValidator()
|
||||
tests := []struct {
|
||||
ua string
|
||||
want string
|
||||
}{
|
||||
{"claude-cli/2.1.22 (darwin; arm64)", "2.1.22"},
|
||||
{"claude-cli/1.0.0", "1.0.0"},
|
||||
{"Claude-CLI/3.10.5 (linux; x86_64)", "3.10.5"}, // 大小写不敏感
|
||||
{"curl/8.0.0", ""}, // 非 Claude CLI
|
||||
{"", ""}, // 空字符串
|
||||
{"claude-cli/", ""}, // 无版本号
|
||||
{"claude-cli/2.1.22-beta", "2.1.22"}, // 带后缀仍提取主版本号
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := v.ExtractVersion(tt.ua)
|
||||
require.Equal(t, tt.want, got, "ExtractVersion(%q)", tt.ua)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareVersions(t *testing.T) {
|
||||
tests := []struct {
|
||||
a, b string
|
||||
want int
|
||||
}{
|
||||
{"2.1.0", "2.1.0", 0}, // 相等
|
||||
{"2.1.1", "2.1.0", 1}, // patch 更大
|
||||
{"2.0.0", "2.1.0", -1}, // minor 更小
|
||||
{"3.0.0", "2.99.99", 1}, // major 更大
|
||||
{"1.0.0", "2.0.0", -1}, // major 更小
|
||||
{"0.0.1", "0.0.0", 1}, // patch 差异
|
||||
{"", "1.0.0", -1}, // 空字符串 vs 正常版本
|
||||
{"v2.1.0", "2.1.0", 0}, // v 前缀处理
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := CompareVersions(tt.a, tt.b)
|
||||
require.Equal(t, tt.want, got, "CompareVersions(%q, %q)", tt.a, tt.b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetGetClaudeCodeVersion(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
require.Equal(t, "", GetClaudeCodeVersion(ctx), "empty context should return empty string")
|
||||
|
||||
ctx = SetClaudeCodeVersion(ctx, "2.1.63")
|
||||
require.Equal(t, "2.1.63", GetClaudeCodeVersion(ctx))
|
||||
}
|
||||
|
||||
@@ -3,8 +3,10 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
@@ -18,6 +20,7 @@ type ConcurrencyCache interface {
|
||||
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
|
||||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||||
GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
|
||||
|
||||
// 账号等待队列(账号级)
|
||||
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
|
||||
@@ -42,15 +45,25 @@ type ConcurrencyCache interface {
|
||||
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
|
||||
}
|
||||
|
||||
// generateRequestID generates a unique request ID for concurrency slot tracking
|
||||
// Uses 8 random bytes (16 hex chars) for uniqueness
|
||||
func generateRequestID() string {
|
||||
var (
|
||||
requestIDPrefix = initRequestIDPrefix()
|
||||
requestIDCounter atomic.Uint64
|
||||
)
|
||||
|
||||
func initRequestIDPrefix() string {
|
||||
b := make([]byte, 8)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to nanosecond timestamp (extremely rare case)
|
||||
return fmt.Sprintf("%x", time.Now().UnixNano())
|
||||
if _, err := rand.Read(b); err == nil {
|
||||
return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36)
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16)
|
||||
return "r" + strconv.FormatUint(fallback, 36)
|
||||
}
|
||||
|
||||
// generateRequestID generates a unique request ID for concurrency slot tracking.
|
||||
// Format: {process_random_prefix}-{base36_counter}
|
||||
func generateRequestID() string {
|
||||
seq := requestIDCounter.Add(1)
|
||||
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -321,16 +334,15 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
|
||||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
|
||||
// Returns a map of accountID -> current concurrency count
|
||||
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
result := make(map[int64]int)
|
||||
|
||||
for _, accountID := range accountIDs {
|
||||
count, err := s.cache.GetAccountConcurrency(ctx, accountID)
|
||||
if err != nil {
|
||||
// If key doesn't exist in Redis, count is 0
|
||||
count = 0
|
||||
}
|
||||
result[accountID] = count
|
||||
if len(accountIDs) == 0 {
|
||||
return map[int64]int{}, nil
|
||||
}
|
||||
|
||||
return result, nil
|
||||
if s.cache == nil {
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
result[accountID] = 0
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -12,20 +14,20 @@ import (
|
||||
|
||||
// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩
|
||||
type stubConcurrencyCacheForTest struct {
|
||||
acquireResult bool
|
||||
acquireErr error
|
||||
releaseErr error
|
||||
concurrency int
|
||||
acquireResult bool
|
||||
acquireErr error
|
||||
releaseErr error
|
||||
concurrency int
|
||||
concurrencyErr error
|
||||
waitAllowed bool
|
||||
waitErr error
|
||||
waitCount int
|
||||
waitCountErr error
|
||||
loadBatch map[int64]*AccountLoadInfo
|
||||
loadBatchErr error
|
||||
waitAllowed bool
|
||||
waitErr error
|
||||
waitCount int
|
||||
waitCountErr error
|
||||
loadBatch map[int64]*AccountLoadInfo
|
||||
loadBatchErr error
|
||||
usersLoadBatch map[int64]*UserLoadInfo
|
||||
usersLoadErr error
|
||||
cleanupErr error
|
||||
cleanupErr error
|
||||
|
||||
// 记录调用
|
||||
releasedAccountIDs []int64
|
||||
@@ -45,6 +47,16 @@ func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, acco
|
||||
func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) {
|
||||
return c.concurrency, c.concurrencyErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
if c.concurrencyErr != nil {
|
||||
return nil, c.concurrencyErr
|
||||
}
|
||||
result[accountID] = c.concurrency
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) {
|
||||
return c.waitAllowed, c.waitErr
|
||||
}
|
||||
@@ -155,6 +167,25 @@ func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) {
|
||||
require.True(t, result.Acquired)
|
||||
}
|
||||
|
||||
func TestGenerateRequestID_UsesStablePrefixAndMonotonicCounter(t *testing.T) {
|
||||
id1 := generateRequestID()
|
||||
id2 := generateRequestID()
|
||||
require.NotEmpty(t, id1)
|
||||
require.NotEmpty(t, id2)
|
||||
|
||||
p1 := strings.Split(id1, "-")
|
||||
p2 := strings.Split(id2, "-")
|
||||
require.Len(t, p1, 2)
|
||||
require.Len(t, p2, 2)
|
||||
require.Equal(t, p1[0], p2[0], "同一进程前缀应保持一致")
|
||||
|
||||
n1, err := strconv.ParseUint(p1[1], 36, 64)
|
||||
require.NoError(t, err)
|
||||
n2, err := strconv.ParseUint(p2[1], 36, 64)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n1+1, n2, "计数器应单调递增")
|
||||
}
|
||||
|
||||
func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) {
|
||||
expected := map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60},
|
||||
|
||||
@@ -221,7 +221,7 @@ func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username,
|
||||
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 20 * time.Second}
|
||||
return nil, fmt.Errorf("create http client failed: %w", err)
|
||||
}
|
||||
|
||||
adminToken, err := crsLogin(ctx, client, normalizedURL, username, password)
|
||||
|
||||
@@ -124,24 +124,24 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
|
||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get usage trend with filters: %w", err)
|
||||
}
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
|
||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get model stats with filters: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||
stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
|
||||
func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||
stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group stats with filters: %w", err)
|
||||
}
|
||||
|
||||
252
backend/internal/service/data_management_grpc.go
Normal file
252
backend/internal/service/data_management_grpc.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
type DataManagementPostgresConfig struct {
|
||||
Host string `json:"host"`
|
||||
Port int32 `json:"port"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password,omitempty"`
|
||||
PasswordConfigured bool `json:"password_configured"`
|
||||
Database string `json:"database"`
|
||||
SSLMode string `json:"ssl_mode"`
|
||||
ContainerName string `json:"container_name"`
|
||||
}
|
||||
|
||||
type DataManagementRedisConfig struct {
|
||||
Addr string `json:"addr"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password,omitempty"`
|
||||
PasswordConfigured bool `json:"password_configured"`
|
||||
DB int32 `json:"db"`
|
||||
ContainerName string `json:"container_name"`
|
||||
}
|
||||
|
||||
type DataManagementS3Config struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key,omitempty"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
UseSSL bool `json:"use_ssl"`
|
||||
}
|
||||
|
||||
type DataManagementConfig struct {
|
||||
SourceMode string `json:"source_mode"`
|
||||
BackupRoot string `json:"backup_root"`
|
||||
SQLitePath string `json:"sqlite_path,omitempty"`
|
||||
RetentionDays int32 `json:"retention_days"`
|
||||
KeepLast int32 `json:"keep_last"`
|
||||
ActivePostgresID string `json:"active_postgres_profile_id"`
|
||||
ActiveRedisID string `json:"active_redis_profile_id"`
|
||||
Postgres DataManagementPostgresConfig `json:"postgres"`
|
||||
Redis DataManagementRedisConfig `json:"redis"`
|
||||
S3 DataManagementS3Config `json:"s3"`
|
||||
ActiveS3ProfileID string `json:"active_s3_profile_id"`
|
||||
}
|
||||
|
||||
type DataManagementTestS3Result struct {
|
||||
OK bool `json:"ok"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type DataManagementCreateBackupJobInput struct {
|
||||
BackupType string
|
||||
UploadToS3 bool
|
||||
TriggeredBy string
|
||||
IdempotencyKey string
|
||||
S3ProfileID string
|
||||
PostgresID string
|
||||
RedisID string
|
||||
}
|
||||
|
||||
type DataManagementListBackupJobsInput struct {
|
||||
PageSize int32
|
||||
PageToken string
|
||||
Status string
|
||||
BackupType string
|
||||
}
|
||||
|
||||
type DataManagementArtifactInfo struct {
|
||||
LocalPath string `json:"local_path"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
SHA256 string `json:"sha256"`
|
||||
}
|
||||
|
||||
type DataManagementS3ObjectInfo struct {
|
||||
Bucket string `json:"bucket"`
|
||||
Key string `json:"key"`
|
||||
ETag string `json:"etag"`
|
||||
}
|
||||
|
||||
type DataManagementBackupJob struct {
|
||||
JobID string `json:"job_id"`
|
||||
BackupType string `json:"backup_type"`
|
||||
Status string `json:"status"`
|
||||
TriggeredBy string `json:"triggered_by"`
|
||||
IdempotencyKey string `json:"idempotency_key,omitempty"`
|
||||
UploadToS3 bool `json:"upload_to_s3"`
|
||||
S3ProfileID string `json:"s3_profile_id,omitempty"`
|
||||
PostgresID string `json:"postgres_profile_id,omitempty"`
|
||||
RedisID string `json:"redis_profile_id,omitempty"`
|
||||
StartedAt string `json:"started_at,omitempty"`
|
||||
FinishedAt string `json:"finished_at,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
Artifact DataManagementArtifactInfo `json:"artifact"`
|
||||
S3Object DataManagementS3ObjectInfo `json:"s3"`
|
||||
}
|
||||
|
||||
type DataManagementSourceProfile struct {
|
||||
SourceType string `json:"source_type"`
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Config DataManagementSourceConfig `json:"config"`
|
||||
PasswordConfigured bool `json:"password_configured"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
type DataManagementSourceConfig struct {
|
||||
Host string `json:"host"`
|
||||
Port int32 `json:"port"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Database string `json:"database"`
|
||||
SSLMode string `json:"ssl_mode"`
|
||||
Addr string `json:"addr"`
|
||||
Username string `json:"username"`
|
||||
DB int32 `json:"db"`
|
||||
ContainerName string `json:"container_name"`
|
||||
}
|
||||
|
||||
type DataManagementCreateSourceProfileInput struct {
|
||||
SourceType string
|
||||
ProfileID string
|
||||
Name string
|
||||
Config DataManagementSourceConfig
|
||||
SetActive bool
|
||||
}
|
||||
|
||||
type DataManagementUpdateSourceProfileInput struct {
|
||||
SourceType string
|
||||
ProfileID string
|
||||
Name string
|
||||
Config DataManagementSourceConfig
|
||||
}
|
||||
|
||||
type DataManagementS3Profile struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
S3 DataManagementS3Config `json:"s3"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
type DataManagementCreateS3ProfileInput struct {
|
||||
ProfileID string
|
||||
Name string
|
||||
S3 DataManagementS3Config
|
||||
SetActive bool
|
||||
}
|
||||
|
||||
type DataManagementUpdateS3ProfileInput struct {
|
||||
ProfileID string
|
||||
Name string
|
||||
S3 DataManagementS3Config
|
||||
}
|
||||
|
||||
type DataManagementListBackupJobsResult struct {
|
||||
Items []DataManagementBackupJob `json:"items"`
|
||||
NextPageToken string `json:"next_page_token,omitempty"`
|
||||
}
|
||||
|
||||
func (s *DataManagementService) GetConfig(ctx context.Context) (DataManagementConfig, error) {
|
||||
_ = ctx
|
||||
return DataManagementConfig{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) UpdateConfig(ctx context.Context, cfg DataManagementConfig) (DataManagementConfig, error) {
|
||||
_, _ = ctx, cfg
|
||||
return DataManagementConfig{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) ListSourceProfiles(ctx context.Context, sourceType string) ([]DataManagementSourceProfile, error) {
|
||||
_, _ = ctx, sourceType
|
||||
return nil, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) CreateSourceProfile(ctx context.Context, input DataManagementCreateSourceProfileInput) (DataManagementSourceProfile, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementSourceProfile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) UpdateSourceProfile(ctx context.Context, input DataManagementUpdateSourceProfileInput) (DataManagementSourceProfile, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementSourceProfile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error {
|
||||
_, _, _ = ctx, sourceType, profileID
|
||||
return s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (DataManagementSourceProfile, error) {
|
||||
_, _, _ = ctx, sourceType, profileID
|
||||
return DataManagementSourceProfile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) ValidateS3(ctx context.Context, cfg DataManagementS3Config) (DataManagementTestS3Result, error) {
|
||||
_, _ = ctx, cfg
|
||||
return DataManagementTestS3Result{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) ListS3Profiles(ctx context.Context) ([]DataManagementS3Profile, error) {
|
||||
_ = ctx
|
||||
return nil, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) CreateS3Profile(ctx context.Context, input DataManagementCreateS3ProfileInput) (DataManagementS3Profile, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementS3Profile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) UpdateS3Profile(ctx context.Context, input DataManagementUpdateS3ProfileInput) (DataManagementS3Profile, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementS3Profile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) DeleteS3Profile(ctx context.Context, profileID string) error {
|
||||
_, _ = ctx, profileID
|
||||
return s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) SetActiveS3Profile(ctx context.Context, profileID string) (DataManagementS3Profile, error) {
|
||||
_, _ = ctx, profileID
|
||||
return DataManagementS3Profile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) CreateBackupJob(ctx context.Context, input DataManagementCreateBackupJobInput) (DataManagementBackupJob, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementBackupJob{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) ListBackupJobs(ctx context.Context, input DataManagementListBackupJobsInput) (DataManagementListBackupJobsResult, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementListBackupJobsResult{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) GetBackupJob(ctx context.Context, jobID string) (DataManagementBackupJob, error) {
|
||||
_, _ = ctx, jobID
|
||||
return DataManagementBackupJob{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) deprecatedError() error {
|
||||
return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
|
||||
}
|
||||
36
backend/internal/service/data_management_grpc_test.go
Normal file
36
backend/internal/service/data_management_grpc_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDataManagementService_DeprecatedRPCMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "datamanagement.sock")
|
||||
svc := NewDataManagementServiceWithOptions(socketPath, 0)
|
||||
|
||||
_, err := svc.GetConfig(context.Background())
|
||||
assertDeprecatedDataManagementError(t, err, socketPath)
|
||||
|
||||
_, err = svc.CreateBackupJob(context.Background(), DataManagementCreateBackupJobInput{BackupType: "full"})
|
||||
assertDeprecatedDataManagementError(t, err, socketPath)
|
||||
|
||||
err = svc.DeleteS3Profile(context.Background(), "s3-default")
|
||||
assertDeprecatedDataManagementError(t, err, socketPath)
|
||||
}
|
||||
|
||||
func assertDeprecatedDataManagementError(t *testing.T, err error, socketPath string) {
|
||||
t.Helper()
|
||||
|
||||
require.Error(t, err)
|
||||
statusCode, status := infraerrors.ToHTTP(err)
|
||||
require.Equal(t, 503, statusCode)
|
||||
require.Equal(t, DataManagementDeprecatedReason, status.Reason)
|
||||
require.Equal(t, socketPath, status.Metadata["socket_path"])
|
||||
}
|
||||
95
backend/internal/service/data_management_service.go
Normal file
95
backend/internal/service/data_management_service.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultDataManagementAgentSocketPath = "/tmp/sub2api-datamanagement.sock"
|
||||
LegacyBackupAgentSocketPath = "/tmp/sub2api-backup.sock"
|
||||
|
||||
DataManagementDeprecatedReason = "DATA_MANAGEMENT_DEPRECATED"
|
||||
DataManagementAgentSocketMissingReason = "DATA_MANAGEMENT_AGENT_SOCKET_MISSING"
|
||||
DataManagementAgentUnavailableReason = "DATA_MANAGEMENT_AGENT_UNAVAILABLE"
|
||||
|
||||
// Deprecated: keep old names for compatibility.
|
||||
DefaultBackupAgentSocketPath = DefaultDataManagementAgentSocketPath
|
||||
BackupAgentSocketMissingReason = DataManagementAgentSocketMissingReason
|
||||
BackupAgentUnavailableReason = DataManagementAgentUnavailableReason
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDataManagementDeprecated = infraerrors.ServiceUnavailable(
|
||||
DataManagementDeprecatedReason,
|
||||
"data management feature is deprecated",
|
||||
)
|
||||
ErrDataManagementAgentSocketMissing = infraerrors.ServiceUnavailable(
|
||||
DataManagementAgentSocketMissingReason,
|
||||
"data management agent socket is missing",
|
||||
)
|
||||
ErrDataManagementAgentUnavailable = infraerrors.ServiceUnavailable(
|
||||
DataManagementAgentUnavailableReason,
|
||||
"data management agent is unavailable",
|
||||
)
|
||||
|
||||
// Deprecated: keep old names for compatibility.
|
||||
ErrBackupAgentSocketMissing = ErrDataManagementAgentSocketMissing
|
||||
ErrBackupAgentUnavailable = ErrDataManagementAgentUnavailable
|
||||
)
|
||||
|
||||
type DataManagementAgentHealth struct {
|
||||
Enabled bool
|
||||
Reason string
|
||||
SocketPath string
|
||||
Agent *DataManagementAgentInfo
|
||||
}
|
||||
|
||||
type DataManagementAgentInfo struct {
|
||||
Status string
|
||||
Version string
|
||||
UptimeSeconds int64
|
||||
}
|
||||
|
||||
type DataManagementService struct {
|
||||
socketPath string
|
||||
}
|
||||
|
||||
func NewDataManagementService() *DataManagementService {
|
||||
return NewDataManagementServiceWithOptions(DefaultDataManagementAgentSocketPath, 500*time.Millisecond)
|
||||
}
|
||||
|
||||
func NewDataManagementServiceWithOptions(socketPath string, dialTimeout time.Duration) *DataManagementService {
|
||||
_ = dialTimeout
|
||||
path := strings.TrimSpace(socketPath)
|
||||
if path == "" {
|
||||
path = DefaultDataManagementAgentSocketPath
|
||||
}
|
||||
return &DataManagementService{
|
||||
socketPath: path,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DataManagementService) SocketPath() string {
|
||||
if s == nil || strings.TrimSpace(s.socketPath) == "" {
|
||||
return DefaultDataManagementAgentSocketPath
|
||||
}
|
||||
return s.socketPath
|
||||
}
|
||||
|
||||
func (s *DataManagementService) GetAgentHealth(ctx context.Context) DataManagementAgentHealth {
|
||||
_ = ctx
|
||||
return DataManagementAgentHealth{
|
||||
Enabled: false,
|
||||
Reason: DataManagementDeprecatedReason,
|
||||
SocketPath: s.SocketPath(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DataManagementService) EnsureAgentEnabled(ctx context.Context) error {
|
||||
_ = ctx
|
||||
return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
|
||||
}
|
||||
37
backend/internal/service/data_management_service_test.go
Normal file
37
backend/internal/service/data_management_service_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDataManagementService_GetAgentHealth_Deprecated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "unused.sock")
|
||||
svc := NewDataManagementServiceWithOptions(socketPath, 0)
|
||||
health := svc.GetAgentHealth(context.Background())
|
||||
|
||||
require.False(t, health.Enabled)
|
||||
require.Equal(t, DataManagementDeprecatedReason, health.Reason)
|
||||
require.Equal(t, socketPath, health.SocketPath)
|
||||
require.Nil(t, health.Agent)
|
||||
}
|
||||
|
||||
func TestDataManagementService_EnsureAgentEnabled_Deprecated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "unused.sock")
|
||||
svc := NewDataManagementServiceWithOptions(socketPath, 100)
|
||||
err := svc.EnsureAgentEnabled(context.Background())
|
||||
require.Error(t, err)
|
||||
|
||||
statusCode, status := infraerrors.ToHTTP(err)
|
||||
require.Equal(t, 503, statusCode)
|
||||
require.Equal(t, DataManagementDeprecatedReason, status.Reason)
|
||||
require.Equal(t, socketPath, status.Metadata["socket_path"])
|
||||
}
|
||||
@@ -74,11 +74,12 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||
// Setting keys
|
||||
const (
|
||||
// 注册设置
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyRegistrationEmailSuffixWhitelist = "registration_email_suffix_whitelist" // 注册邮箱后缀白名单(JSON 数组)
|
||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
@@ -104,6 +105,7 @@ const (
|
||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||
|
||||
// OEM设置
|
||||
SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制)
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
@@ -112,12 +114,14 @@ const (
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
|
||||
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
|
||||
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示“购买订阅”页面入口
|
||||
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL(作为 iframe src)
|
||||
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口
|
||||
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL(作为 iframe src)
|
||||
SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组)
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
|
||||
|
||||
// 管理员 API Key
|
||||
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
@@ -170,6 +174,37 @@ const (
|
||||
|
||||
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
|
||||
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
|
||||
|
||||
// =========================
|
||||
// Sora S3 存储配置
|
||||
// =========================
|
||||
|
||||
SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储
|
||||
SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址
|
||||
SettingKeySoraS3Region = "sora_s3_region" // S3 区域
|
||||
SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称
|
||||
SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID
|
||||
SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储)
|
||||
SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀
|
||||
SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等)
|
||||
SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选)
|
||||
SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON)
|
||||
|
||||
// =========================
|
||||
// Sora 用户存储配额
|
||||
// =========================
|
||||
|
||||
SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节)
|
||||
|
||||
// =========================
|
||||
// Claude Code Version Check
|
||||
// =========================
|
||||
|
||||
// SettingKeyMinClaudeCodeVersion 最低 Claude Code 版本号要求 (semver, 如 "2.1.0",空值=不检查)
|
||||
SettingKeyMinClaudeCodeVersion = "min_claude_code_version"
|
||||
|
||||
// SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false:未分组 Key 返回 403)
|
||||
SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling"
|
||||
)
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
|
||||
@@ -279,10 +279,10 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotE
|
||||
wantPassthrough: true,
|
||||
},
|
||||
{
|
||||
name: "404 generic not found passes through as 404",
|
||||
name: "404 generic not found does not passthrough",
|
||||
statusCode: http.StatusNotFound,
|
||||
respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
|
||||
wantPassthrough: true,
|
||||
wantPassthrough: false,
|
||||
},
|
||||
{
|
||||
name: "400 Invalid URL does not passthrough",
|
||||
|
||||
@@ -136,3 +136,67 @@ func TestDroppedBetaSet(t *testing.T) {
|
||||
require.Contains(t, extended, claude.BetaClaudeCode)
|
||||
require.Len(t, extended, len(claude.DroppedBetas)+1)
|
||||
}
|
||||
|
||||
func TestBuildBetaTokenSet(t *testing.T) {
|
||||
got := buildBetaTokenSet([]string{"foo", "", "bar", "foo"})
|
||||
require.Len(t, got, 2)
|
||||
require.Contains(t, got, "foo")
|
||||
require.Contains(t, got, "bar")
|
||||
require.NotContains(t, got, "")
|
||||
|
||||
empty := buildBetaTokenSet(nil)
|
||||
require.Empty(t, empty)
|
||||
}
|
||||
|
||||
func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) {
|
||||
header := "oauth-2025-04-20,interleaved-thinking-2025-05-14"
|
||||
got := stripBetaTokensWithSet(header, map[string]struct{}{})
|
||||
require.Equal(t, header, got)
|
||||
}
|
||||
|
||||
func TestIsCountTokensUnsupported404(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "exact endpoint not found",
|
||||
statusCode: 404,
|
||||
body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "contains count_tokens and not found",
|
||||
statusCode: 404,
|
||||
body: `{"error":{"message":"count_tokens route not found","type":"not_found_error"}}`,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "generic 404",
|
||||
statusCode: 404,
|
||||
body: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "404 with empty error message",
|
||||
statusCode: 404,
|
||||
body: `{"error":{"message":"","type":"not_found_error"}}`,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "non-404 status",
|
||||
statusCode: 400,
|
||||
body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"invalid_request_error"}}`,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isCountTokensUnsupported404(tt.statusCode, []byte(tt.body))
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
363
backend/internal/service/gateway_group_isolation_test.go
Normal file
363
backend/internal/service/gateway_group_isolation_test.go
Normal file
@@ -0,0 +1,363 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// Part 1: isAccountInGroup 单元测试
|
||||
// ============================================================================
|
||||
|
||||
func TestIsAccountInGroup(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
groupID100 := int64(100)
|
||||
groupID200 := int64(200)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
groupID *int64
|
||||
expected bool
|
||||
}{
|
||||
// groupID == nil(无分组 API Key)
|
||||
{
|
||||
"nil_groupID_ungrouped_account_nil_groups",
|
||||
&Account{ID: 1, AccountGroups: nil},
|
||||
nil, true,
|
||||
},
|
||||
{
|
||||
"nil_groupID_ungrouped_account_empty_slice",
|
||||
&Account{ID: 2, AccountGroups: []AccountGroup{}},
|
||||
nil, true,
|
||||
},
|
||||
{
|
||||
"nil_groupID_grouped_account_single",
|
||||
&Account{ID: 3, AccountGroups: []AccountGroup{{GroupID: 100}}},
|
||||
nil, false,
|
||||
},
|
||||
{
|
||||
"nil_groupID_grouped_account_multiple",
|
||||
&Account{ID: 4, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}},
|
||||
nil, false,
|
||||
},
|
||||
// groupID != nil(有分组 API Key)
|
||||
{
|
||||
"with_groupID_account_in_group",
|
||||
&Account{ID: 5, AccountGroups: []AccountGroup{{GroupID: 100}}},
|
||||
&groupID100, true,
|
||||
},
|
||||
{
|
||||
"with_groupID_account_not_in_group",
|
||||
&Account{ID: 6, AccountGroups: []AccountGroup{{GroupID: 200}}},
|
||||
&groupID100, false,
|
||||
},
|
||||
{
|
||||
"with_groupID_ungrouped_account",
|
||||
&Account{ID: 7, AccountGroups: nil},
|
||||
&groupID100, false,
|
||||
},
|
||||
{
|
||||
"with_groupID_multi_group_account_match_one",
|
||||
&Account{ID: 8, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}},
|
||||
&groupID200, true,
|
||||
},
|
||||
{
|
||||
"with_groupID_multi_group_account_no_match",
|
||||
&Account{ID: 9, AccountGroups: []AccountGroup{{GroupID: 300}, {GroupID: 400}}},
|
||||
&groupID100, false,
|
||||
},
|
||||
// 防御性边界
|
||||
{
|
||||
"nil_account_nil_groupID",
|
||||
nil,
|
||||
nil, false,
|
||||
},
|
||||
{
|
||||
"nil_account_with_groupID",
|
||||
nil,
|
||||
&groupID100, false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.isAccountInGroup(tt.account, tt.groupID)
|
||||
require.Equal(t, tt.expected, got, "isAccountInGroup 结果不符预期")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Part 2: 分组隔离端到端调度测试
|
||||
// ============================================================================
|
||||
|
||||
// groupAwareMockAccountRepo 嵌入 mockAccountRepoForPlatform,覆写分组隔离相关方法。
|
||||
// allAccounts 存储所有账号,分组查询方法按 AccountGroups 字段进行真实过滤。
|
||||
type groupAwareMockAccountRepo struct {
|
||||
*mockAccountRepoForPlatform
|
||||
allAccounts []Account
|
||||
}
|
||||
|
||||
// ListSchedulableUngroupedByPlatform 仅返回未分组账号(AccountGroups 为空)
|
||||
func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
var result []Account
|
||||
for _, acc := range m.allAccounts {
|
||||
if acc.Platform == platform && acc.IsSchedulable() && len(acc.AccountGroups) == 0 {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ListSchedulableUngroupedByPlatforms 仅返回未分组账号(多平台版本)
|
||||
func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
platformSet := make(map[string]bool, len(platforms))
|
||||
for _, p := range platforms {
|
||||
platformSet[p] = true
|
||||
}
|
||||
var result []Account
|
||||
for _, acc := range m.allAccounts {
|
||||
if platformSet[acc.Platform] && acc.IsSchedulable() && len(acc.AccountGroups) == 0 {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ListSchedulableByGroupIDAndPlatform 返回属于指定分组的账号
|
||||
func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
var result []Account
|
||||
for _, acc := range m.allAccounts {
|
||||
if acc.Platform == platform && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ListSchedulableByGroupIDAndPlatforms 返回属于指定分组的账号(多平台版本)
|
||||
func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
platformSet := make(map[string]bool, len(platforms))
|
||||
for _, p := range platforms {
|
||||
platformSet[p] = true
|
||||
}
|
||||
var result []Account
|
||||
for _, acc := range m.allAccounts {
|
||||
if platformSet[acc.Platform] && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// accountBelongsToGroup 检查账号是否属于指定分组
|
||||
func accountBelongsToGroup(acc Account, groupID int64) bool {
|
||||
for _, ag := range acc.AccountGroups {
|
||||
if ag.GroupID == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify interface implementation
|
||||
var _ AccountRepository = (*groupAwareMockAccountRepo)(nil)
|
||||
|
||||
// newGroupAwareMockRepo 创建分组感知的 mock repo
|
||||
func newGroupAwareMockRepo(accounts []Account) *groupAwareMockAccountRepo {
|
||||
byID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
byID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
return &groupAwareMockAccountRepo{
|
||||
mockAccountRepoForPlatform: &mockAccountRepoForPlatform{
|
||||
accounts: accounts,
|
||||
accountsByID: byID,
|
||||
},
|
||||
allAccounts: accounts,
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupIsolation_UngroupedKey_ShouldNotScheduleGroupedAccounts(t *testing.T) {
|
||||
// 场景:无分组 API Key(groupID=nil),池中只有已分组账号 → 应返回错误
|
||||
ctx := context.Background()
|
||||
|
||||
accounts := []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 100}}},
|
||||
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 200}}},
|
||||
}
|
||||
repo := newGroupAwareMockRepo(accounts)
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
|
||||
require.Error(t, err, "无分组 Key 不应调度到已分组账号")
|
||||
require.Nil(t, acc)
|
||||
}
|
||||
|
||||
func TestGroupIsolation_GroupedKey_ShouldNotScheduleUngroupedAccounts(t *testing.T) {
|
||||
// 场景:有分组 API Key(groupID=100),池中只有未分组账号 → 应返回错误
|
||||
ctx := context.Background()
|
||||
groupID := int64(100)
|
||||
|
||||
accounts := []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: nil},
|
||||
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{}},
|
||||
}
|
||||
repo := newGroupAwareMockRepo(accounts)
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI)
|
||||
require.Error(t, err, "有分组 Key 不应调度到未分组账号")
|
||||
require.Nil(t, acc)
|
||||
}
|
||||
|
||||
func TestGroupIsolation_UngroupedKey_ShouldOnlyScheduleUngroupedAccounts(t *testing.T) {
|
||||
// 场景:无分组 API Key(groupID=nil),池中有未分组和已分组账号 → 应只选中未分组的
|
||||
ctx := context.Background()
|
||||
|
||||
accounts := []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组,不应被选中
|
||||
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: nil}, // 未分组,应被选中
|
||||
{ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 200}}}, // 已分组,不应被选中
|
||||
}
|
||||
repo := newGroupAwareMockRepo(accounts)
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
|
||||
require.NoError(t, err, "应成功调度未分组账号")
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "应选中未分组的账号 ID=2")
|
||||
}
|
||||
|
||||
func TestGroupIsolation_GroupedKey_ShouldOnlyScheduleMatchingGroupAccounts(t *testing.T) {
|
||||
// 场景:有分组 API Key(groupID=100),池中有未分组和多个分组账号 → 应只选中分组 100 内的
|
||||
ctx := context.Background()
|
||||
groupID := int64(100)
|
||||
|
||||
accounts := []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: nil}, // 未分组,不应被选中
|
||||
{ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 200}}}, // 属于分组 200,不应被选中
|
||||
{ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 属于分组 100,应被选中
|
||||
}
|
||||
repo := newGroupAwareMockRepo(accounts)
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI)
|
||||
require.NoError(t, err, "应成功调度分组内账号")
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(3), acc.ID, "应选中分组 100 内的账号 ID=3")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Part 3: SimpleMode 旁路测试
|
||||
// ============================================================================
|
||||
|
||||
func TestGroupIsolation_SimpleMode_SkipsGroupIsolation(t *testing.T) {
|
||||
// SimpleMode 应跳过分组隔离,使用 ListSchedulableByPlatform 返回所有账号。
|
||||
// 测试非 useMixed 路径(platform=openai,不会触发 mixed 调度逻辑)。
|
||||
ctx := context.Background()
|
||||
|
||||
// 混合未分组和已分组账号,SimpleMode 下应全部可调度
|
||||
accounts := []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组
|
||||
{ID: 2, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: nil}, // 未分组
|
||||
}
|
||||
|
||||
// 使用基础 mock(ListSchedulableByPlatform 返回所有匹配平台的账号,不做分组过滤)
|
||||
byID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
byID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: accounts,
|
||||
accountsByID: byID,
|
||||
}
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: &config.Config{RunMode: config.RunModeSimple},
|
||||
}
|
||||
|
||||
// groupID=nil 时,SimpleMode 应使用 ListSchedulableByPlatform(不过滤分组)
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
|
||||
require.NoError(t, err, "SimpleMode 应跳过分组隔离直接返回账号")
|
||||
require.NotNil(t, acc)
|
||||
// 应选择优先级最高的账号(Priority=1, ID=2),即使它未分组
|
||||
require.Equal(t, int64(2), acc.ID, "SimpleMode 应按优先级选择,不考虑分组")
|
||||
}
|
||||
|
||||
func TestGroupIsolation_SimpleMode_GroupedAccountAlsoSchedulable(t *testing.T) {
|
||||
// SimpleMode + groupID=nil 时,已分组账号也应该可被调度
|
||||
ctx := context.Background()
|
||||
|
||||
// 只有已分组账号,在 standard 模式下 groupID=nil 会报错,但 simple 模式应正常
|
||||
accounts := []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true,
|
||||
AccountGroups: []AccountGroup{{GroupID: 100}}},
|
||||
}
|
||||
|
||||
byID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
byID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: accounts,
|
||||
accountsByID: byID,
|
||||
}
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: &config.Config{RunMode: config.RunModeSimple},
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI)
|
||||
require.NoError(t, err, "SimpleMode 下已分组账号也应可调度")
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID, "SimpleMode 应能调度已分组账号")
|
||||
}
|
||||
@@ -147,6 +147,12 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByPlatforms(ctx context.Cont
|
||||
func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1892,6 +1898,14 @@ func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accoun
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
result[accountID] = 0
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -61,6 +61,10 @@ type ParsedRequest struct {
|
||||
ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名)
|
||||
MaxTokens int // max_tokens 值(用于探测请求拦截)
|
||||
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
|
||||
|
||||
// OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁)
|
||||
// 流式请求在收到 2xx 响应头后调用,避免持锁等流完成
|
||||
OnUpstreamAccepted func()
|
||||
}
|
||||
|
||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果。
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,141 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCollectSelectionFailureStats(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
model := "sora2-landscape-10s"
|
||||
resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339)
|
||||
|
||||
accounts := []Account{
|
||||
// excluded
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
},
|
||||
// unschedulable
|
||||
{
|
||||
ID: 2,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: false,
|
||||
},
|
||||
// platform filtered
|
||||
{
|
||||
ID: 3,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
},
|
||||
// model unsupported
|
||||
{
|
||||
ID: 4,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-image": "gpt-image",
|
||||
},
|
||||
},
|
||||
},
|
||||
// model rate limited
|
||||
{
|
||||
ID: 5,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{
|
||||
"model_rate_limits": map[string]any{
|
||||
model: map[string]any{
|
||||
"rate_limit_reset_at": resetAt,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// eligible
|
||||
{
|
||||
ID: 6,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
},
|
||||
}
|
||||
|
||||
excluded := map[int64]struct{}{1: {}}
|
||||
stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false)
|
||||
|
||||
if stats.Total != 6 {
|
||||
t.Fatalf("total=%d want=6", stats.Total)
|
||||
}
|
||||
if stats.Excluded != 1 {
|
||||
t.Fatalf("excluded=%d want=1", stats.Excluded)
|
||||
}
|
||||
if stats.Unschedulable != 1 {
|
||||
t.Fatalf("unschedulable=%d want=1", stats.Unschedulable)
|
||||
}
|
||||
if stats.PlatformFiltered != 1 {
|
||||
t.Fatalf("platform_filtered=%d want=1", stats.PlatformFiltered)
|
||||
}
|
||||
if stats.ModelUnsupported != 1 {
|
||||
t.Fatalf("model_unsupported=%d want=1", stats.ModelUnsupported)
|
||||
}
|
||||
if stats.ModelRateLimited != 1 {
|
||||
t.Fatalf("model_rate_limited=%d want=1", stats.ModelRateLimited)
|
||||
}
|
||||
if stats.Eligible != 1 {
|
||||
t.Fatalf("eligible=%d want=1", stats.Eligible)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
acc := &Account{
|
||||
ID: 7,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: false,
|
||||
}
|
||||
|
||||
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
|
||||
if diagnosis.Category != "unschedulable" {
|
||||
t.Fatalf("category=%s want=unschedulable", diagnosis.Category)
|
||||
}
|
||||
if diagnosis.Detail != "schedulable=false" {
|
||||
t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
model := "sora2-landscape-10s"
|
||||
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
|
||||
acc := &Account{
|
||||
ID: 8,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{
|
||||
"model_rate_limits": map[string]any{
|
||||
model: map[string]any{
|
||||
"rate_limit_reset_at": resetAt,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false)
|
||||
if diagnosis.Category != "model_rate_limited" {
|
||||
t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category)
|
||||
}
|
||||
if !strings.Contains(diagnosis.Detail, "remaining=") {
|
||||
t.Fatalf("detail=%s want contains remaining=", diagnosis.Detail)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformSora,
|
||||
Credentials: map[string]any{},
|
||||
}
|
||||
|
||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
||||
t.Fatalf("expected sora model to be supported when model_mapping is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformSora,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-4o": "gpt-4o",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
||||
t.Fatalf("expected sora model to be supported when mapping has no sora selectors")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformSora,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"sora2": "sora2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") {
|
||||
t.Fatalf("expected family selector sora2 to support sora2-landscape-15s")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformSora,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"sy_8": "sy_8",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
||||
t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformSora,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-image": "gpt-image",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
||||
t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
now := time.Now()
|
||||
past := now.Add(-1 * time.Minute)
|
||||
future := now.Add(5 * time.Minute)
|
||||
|
||||
acc := &Account{
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
AutoPauseOnExpired: true,
|
||||
ExpiresAt: &past,
|
||||
OverloadUntil: &future,
|
||||
RateLimitResetAt: &future,
|
||||
}
|
||||
|
||||
if !svc.isAccountSchedulableForSelection(acc) {
|
||||
t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
future := time.Now().Add(5 * time.Minute)
|
||||
|
||||
acc := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateLimitResetAt: &future,
|
||||
}
|
||||
|
||||
if svc.isAccountSchedulableForSelection(acc) {
|
||||
t.Fatalf("expected non-sora account to keep generic schedulable checks")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
model := "sora2-landscape-10s"
|
||||
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
|
||||
globalResetAt := time.Now().Add(2 * time.Minute)
|
||||
|
||||
acc := &Account{
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateLimitResetAt: &globalResetAt,
|
||||
Extra: map[string]any{
|
||||
"model_rate_limits": map[string]any{
|
||||
model: map[string]any{
|
||||
"rate_limit_reset_at": resetAt,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) {
|
||||
t.Fatalf("expected sora account to be blocked by model scope rate limit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
future := time.Now().Add(3 * time.Minute)
|
||||
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateLimitResetAt: &future,
|
||||
},
|
||||
}
|
||||
|
||||
stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
|
||||
if stats.Unschedulable != 0 || stats.Eligible != 1 {
|
||||
t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible)
|
||||
}
|
||||
}
|
||||
@@ -105,12 +105,12 @@ func TestCalculateMaxWait_Scenarios(t *testing.T) {
|
||||
concurrency int
|
||||
expected int
|
||||
}{
|
||||
{5, 25}, // 5 + 20
|
||||
{10, 30}, // 10 + 20
|
||||
{1, 21}, // 1 + 20
|
||||
{0, 21}, // min(1) + 20
|
||||
{-1, 21}, // min(1) + 20
|
||||
{-10, 21}, // min(1) + 20
|
||||
{5, 25}, // 5 + 20
|
||||
{10, 30}, // 10 + 20
|
||||
{1, 21}, // 1 + 20
|
||||
{0, 21}, // min(1) + 20
|
||||
{-1, 21}, // min(1) + 20
|
||||
{-10, 21}, // min(1) + 20
|
||||
{100, 120}, // 100 + 20
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -53,6 +53,7 @@ type GeminiMessagesCompatService struct {
|
||||
httpUpstream HTTPUpstream
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
cfg *config.Config
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
}
|
||||
|
||||
func NewGeminiMessagesCompatService(
|
||||
@@ -76,6 +77,7 @@ func NewGeminiMessagesCompatService(
|
||||
httpUpstream: httpUpstream,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
cfg: cfg,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,6 +231,16 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
|
||||
account *Account,
|
||||
requestedModel, platform string,
|
||||
useMixedScheduling bool,
|
||||
) bool {
|
||||
return s.isAccountUsableForRequestWithPrecheck(ctx, account, requestedModel, platform, useMixedScheduling, nil)
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) isAccountUsableForRequestWithPrecheck(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
requestedModel, platform string,
|
||||
useMixedScheduling bool,
|
||||
precheckResult map[int64]bool,
|
||||
) bool {
|
||||
// 检查模型调度能力
|
||||
// Check model scheduling capability
|
||||
@@ -250,7 +262,7 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
|
||||
|
||||
// 速率限制预检
|
||||
// Rate limit precheck
|
||||
if !s.passesRateLimitPreCheck(ctx, account, requestedModel) {
|
||||
if !s.passesRateLimitPreCheckWithCache(ctx, account, requestedModel, precheckResult) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -272,15 +284,17 @@ func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account
|
||||
return false
|
||||
}
|
||||
|
||||
// passesRateLimitPreCheck 执行速率限制预检。
|
||||
// 返回 true 表示通过预检或无需预检。
|
||||
//
|
||||
// passesRateLimitPreCheck performs rate limit precheck.
|
||||
// Returns true if passed or precheck not required.
|
||||
func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool {
|
||||
func (s *GeminiMessagesCompatService) passesRateLimitPreCheckWithCache(ctx context.Context, account *Account, requestedModel string, precheckResult map[int64]bool) bool {
|
||||
if s.rateLimitService == nil || requestedModel == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
if precheckResult != nil {
|
||||
if ok, exists := precheckResult[account.ID]; exists {
|
||||
return ok
|
||||
}
|
||||
}
|
||||
|
||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
|
||||
@@ -302,6 +316,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
|
||||
useMixedScheduling bool,
|
||||
) *Account {
|
||||
var selected *Account
|
||||
precheckResult := s.buildPreCheckUsageResultMap(ctx, accounts, requestedModel)
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
@@ -312,7 +327,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
|
||||
}
|
||||
|
||||
// 检查账号是否可用于当前请求
|
||||
if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) {
|
||||
if !s.isAccountUsableForRequestWithPrecheck(ctx, acc, requestedModel, platform, useMixedScheduling, precheckResult) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -330,6 +345,23 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
|
||||
return selected
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) buildPreCheckUsageResultMap(ctx context.Context, accounts []Account, requestedModel string) map[int64]bool {
|
||||
if s.rateLimitService == nil || requestedModel == "" || len(accounts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
candidates := make([]*Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
candidates = append(candidates, &accounts[i])
|
||||
}
|
||||
|
||||
result, err := s.rateLimitService.PreCheckUsageBatch(ctx, candidates, requestedModel)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheckBatch] failed: %v", err)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
|
||||
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。
|
||||
//
|
||||
@@ -399,7 +431,10 @@ func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Co
|
||||
if groupID != nil {
|
||||
return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
||||
}
|
||||
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||
}
|
||||
return s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, queryPlatforms)
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
@@ -2390,7 +2425,7 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
||||
}
|
||||
}
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
@@ -2415,8 +2450,8 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
|
||||
}
|
||||
|
||||
if s.cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
@@ -2557,7 +2592,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
|
||||
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
|
||||
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.responseHeaderFilter)
|
||||
if wwwAuthenticate != "" {
|
||||
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
|
||||
}
|
||||
|
||||
@@ -138,6 +138,12 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
|
||||
}
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1045,7 +1045,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
|
||||
ValidateResolvedIP: true,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
return "", fmt.Errorf("create http client failed: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
||||
@@ -32,6 +32,9 @@ type Group struct {
|
||||
SoraVideoPricePerRequest *float64
|
||||
SoraVideoPricePerRequestHD *float64
|
||||
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool
|
||||
FallbackGroupID *int64
|
||||
|
||||
@@ -46,6 +46,7 @@ type Fingerprint struct {
|
||||
StainlessArch string
|
||||
StainlessRuntime string
|
||||
StainlessRuntimeVersion string
|
||||
UpdatedAt int64 `json:",omitempty"` // Unix timestamp,用于判断是否需要续期TTL
|
||||
}
|
||||
|
||||
// IdentityCache defines cache operations for identity service
|
||||
@@ -78,14 +79,26 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
|
||||
// 尝试从缓存获取指纹
|
||||
cached, err := s.cache.GetFingerprint(ctx, accountID)
|
||||
if err == nil && cached != nil {
|
||||
needWrite := false
|
||||
|
||||
// 检查客户端的user-agent是否是更新版本
|
||||
clientUA := headers.Get("User-Agent")
|
||||
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
|
||||
// 更新user-agent
|
||||
cached.UserAgent = clientUA
|
||||
// 保存更新后的指纹
|
||||
_ = s.cache.SetFingerprint(ctx, accountID, cached)
|
||||
logger.LegacyPrintf("service.identity", "Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
|
||||
// 版本升级:merge 语义 — 仅更新请求中实际携带的字段,保留缓存值
|
||||
// 避免缺失的头被硬编码默认值覆盖(如新 CLI 版本 + 旧 SDK 默认值的不一致)
|
||||
mergeHeadersIntoFingerprint(cached, headers)
|
||||
needWrite = true
|
||||
logger.LegacyPrintf("service.identity", "Updated fingerprint for account %d: %s (merge update)", accountID, clientUA)
|
||||
} else if time.Since(time.Unix(cached.UpdatedAt, 0)) > 24*time.Hour {
|
||||
// 距上次写入超过24小时,续期TTL
|
||||
needWrite = true
|
||||
}
|
||||
|
||||
if needWrite {
|
||||
cached.UpdatedAt = time.Now().Unix()
|
||||
if err := s.cache.SetFingerprint(ctx, accountID, cached); err != nil {
|
||||
logger.LegacyPrintf("service.identity", "Warning: failed to refresh fingerprint for account %d: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
return cached, nil
|
||||
}
|
||||
@@ -95,8 +108,9 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
|
||||
|
||||
// 生成随机ClientID
|
||||
fp.ClientID = generateClientID()
|
||||
fp.UpdatedAt = time.Now().Unix()
|
||||
|
||||
// 保存到缓存(永不过期)
|
||||
// 保存到缓存(7天TTL,每24小时自动续期)
|
||||
if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil {
|
||||
logger.LegacyPrintf("service.identity", "Warning: failed to cache fingerprint for account %d: %v", accountID, err)
|
||||
}
|
||||
@@ -127,6 +141,31 @@ func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fin
|
||||
return fp
|
||||
}
|
||||
|
||||
// mergeHeadersIntoFingerprint 将请求头中实际存在的字段合并到现有指纹中(用于版本升级场景)
|
||||
// 关键语义:请求中有的字段 → 用新值覆盖;缺失的头 → 保留缓存中的已有值
|
||||
// 与 createFingerprintFromHeaders 的区别:后者用于首次创建,缺失头回退到 defaultFingerprint;
|
||||
// 本函数用于升级更新,缺失头保留缓存值,避免将已知的真实值退化为硬编码默认值
|
||||
func mergeHeadersIntoFingerprint(fp *Fingerprint, headers http.Header) {
|
||||
// User-Agent:版本升级的触发条件,一定存在
|
||||
if ua := headers.Get("User-Agent"); ua != "" {
|
||||
fp.UserAgent = ua
|
||||
}
|
||||
// X-Stainless-* 头:仅在请求中实际携带时才更新,否则保留缓存值
|
||||
mergeHeader(headers, "X-Stainless-Lang", &fp.StainlessLang)
|
||||
mergeHeader(headers, "X-Stainless-Package-Version", &fp.StainlessPackageVersion)
|
||||
mergeHeader(headers, "X-Stainless-OS", &fp.StainlessOS)
|
||||
mergeHeader(headers, "X-Stainless-Arch", &fp.StainlessArch)
|
||||
mergeHeader(headers, "X-Stainless-Runtime", &fp.StainlessRuntime)
|
||||
mergeHeader(headers, "X-Stainless-Runtime-Version", &fp.StainlessRuntimeVersion)
|
||||
}
|
||||
|
||||
// mergeHeader 如果请求头中存在该字段则更新目标值,否则保留原值
|
||||
func mergeHeader(headers http.Header, key string, target *string) {
|
||||
if v := headers.Get(key); v != "" {
|
||||
*target = v
|
||||
}
|
||||
}
|
||||
|
||||
// getHeaderOrDefault 获取header值,如果不存在则返回默认值
|
||||
func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
|
||||
if v := headers.Get(key); v != "" {
|
||||
@@ -371,8 +410,25 @@ func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
|
||||
return major, minor, patch, true
|
||||
}
|
||||
|
||||
// extractProduct 提取 User-Agent 中 "/" 前的产品名
|
||||
// 例如:claude-cli/2.1.22 (external, cli) -> "claude-cli"
|
||||
func extractProduct(ua string) string {
|
||||
if idx := strings.Index(ua, "/"); idx > 0 {
|
||||
return strings.ToLower(ua[:idx])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// isNewerVersion 比较版本号,判断newUA是否比cachedUA更新
|
||||
// 要求产品名一致(防止浏览器 UA 如 Mozilla/5.0 误判为更新版本)
|
||||
func isNewerVersion(newUA, cachedUA string) bool {
|
||||
// 校验产品名一致性
|
||||
newProduct := extractProduct(newUA)
|
||||
cachedProduct := extractProduct(cachedUA)
|
||||
if newProduct == "" || cachedProduct == "" || newProduct != cachedProduct {
|
||||
return false
|
||||
}
|
||||
|
||||
newMajor, newMinor, newPatch, newOk := parseUserAgentVersion(newUA)
|
||||
cachedMajor, cachedMinor, cachedPatch, cachedOk := parseUserAgentVersion(cachedUA)
|
||||
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
)
|
||||
|
||||
const modelRateLimitsKey = "model_rate_limits"
|
||||
@@ -73,7 +71,7 @@ func resolveFinalAntigravityModelKey(ctx context.Context, account *Account, requ
|
||||
return ""
|
||||
}
|
||||
// thinking 会影响 Antigravity 最终模型名(例如 claude-sonnet-4-5 -> claude-sonnet-4-5-thinking)
|
||||
if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
|
||||
if enabled, ok := ThinkingEnabledFromContext(ctx); ok {
|
||||
modelKey = applyThinkingModelSuffix(modelKey, enabled)
|
||||
}
|
||||
return modelKey
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
// OpenAIOAuthClient interface for OpenAI OAuth operations
|
||||
type OpenAIOAuthClient interface {
|
||||
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
|
||||
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error)
|
||||
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
|
||||
RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error)
|
||||
}
|
||||
|
||||
@@ -14,10 +14,10 @@ import (
|
||||
// --- mock: ClaudeOAuthClient ---
|
||||
|
||||
type mockClaudeOAuthClient struct {
|
||||
getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error)
|
||||
getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
|
||||
exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error)
|
||||
refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
|
||||
getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error)
|
||||
getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
|
||||
exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error)
|
||||
refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
|
||||
}
|
||||
|
||||
func (m *mockClaudeOAuthClient) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||
@@ -437,9 +437,9 @@ func TestOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) {
|
||||
|
||||
// 无 refresh_token 的账号
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
ID: 1,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "some-token",
|
||||
},
|
||||
@@ -460,9 +460,9 @@ func TestOAuthService_RefreshAccountToken_EmptyRefreshToken(t *testing.T) {
|
||||
defer svc.Stop()
|
||||
|
||||
account := &Account{
|
||||
ID: 2,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
ID: 2,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "some-token",
|
||||
"refresh_token": "",
|
||||
|
||||
909
backend/internal/service/openai_account_scheduler.go
Normal file
909
backend/internal/service/openai_account_scheduler.go
Normal file
@@ -0,0 +1,909 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"errors"
|
||||
"hash/fnv"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
|
||||
openAIAccountScheduleLayerSessionSticky = "session_hash"
|
||||
openAIAccountScheduleLayerLoadBalance = "load_balance"
|
||||
)
|
||||
|
||||
type OpenAIAccountScheduleRequest struct {
|
||||
GroupID *int64
|
||||
SessionHash string
|
||||
StickyAccountID int64
|
||||
PreviousResponseID string
|
||||
RequestedModel string
|
||||
RequiredTransport OpenAIUpstreamTransport
|
||||
ExcludedIDs map[int64]struct{}
|
||||
}
|
||||
|
||||
type OpenAIAccountScheduleDecision struct {
|
||||
Layer string
|
||||
StickyPreviousHit bool
|
||||
StickySessionHit bool
|
||||
CandidateCount int
|
||||
TopK int
|
||||
LatencyMs int64
|
||||
LoadSkew float64
|
||||
SelectedAccountID int64
|
||||
SelectedAccountType string
|
||||
}
|
||||
|
||||
type OpenAIAccountSchedulerMetricsSnapshot struct {
|
||||
SelectTotal int64
|
||||
StickyPreviousHitTotal int64
|
||||
StickySessionHitTotal int64
|
||||
LoadBalanceSelectTotal int64
|
||||
AccountSwitchTotal int64
|
||||
SchedulerLatencyMsTotal int64
|
||||
SchedulerLatencyMsAvg float64
|
||||
StickyHitRatio float64
|
||||
AccountSwitchRate float64
|
||||
LoadSkewAvg float64
|
||||
RuntimeStatsAccountCount int
|
||||
}
|
||||
|
||||
type OpenAIAccountScheduler interface {
|
||||
Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error)
|
||||
ReportResult(accountID int64, success bool, firstTokenMs *int)
|
||||
ReportSwitch()
|
||||
SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot
|
||||
}
|
||||
|
||||
type openAIAccountSchedulerMetrics struct {
|
||||
selectTotal atomic.Int64
|
||||
stickyPreviousHitTotal atomic.Int64
|
||||
stickySessionHitTotal atomic.Int64
|
||||
loadBalanceSelectTotal atomic.Int64
|
||||
accountSwitchTotal atomic.Int64
|
||||
latencyMsTotal atomic.Int64
|
||||
loadSkewMilliTotal atomic.Int64
|
||||
}
|
||||
|
||||
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.selectTotal.Add(1)
|
||||
m.latencyMsTotal.Add(decision.LatencyMs)
|
||||
m.loadSkewMilliTotal.Add(int64(math.Round(decision.LoadSkew * 1000)))
|
||||
if decision.StickyPreviousHit {
|
||||
m.stickyPreviousHitTotal.Add(1)
|
||||
}
|
||||
if decision.StickySessionHit {
|
||||
m.stickySessionHitTotal.Add(1)
|
||||
}
|
||||
if decision.Layer == openAIAccountScheduleLayerLoadBalance {
|
||||
m.loadBalanceSelectTotal.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *openAIAccountSchedulerMetrics) recordSwitch() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.accountSwitchTotal.Add(1)
|
||||
}
|
||||
|
||||
type openAIAccountRuntimeStats struct {
|
||||
accounts sync.Map
|
||||
accountCount atomic.Int64
|
||||
}
|
||||
|
||||
type openAIAccountRuntimeStat struct {
|
||||
errorRateEWMABits atomic.Uint64
|
||||
ttftEWMABits atomic.Uint64
|
||||
}
|
||||
|
||||
func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats {
|
||||
return &openAIAccountRuntimeStats{}
|
||||
}
|
||||
|
||||
func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat {
|
||||
if value, ok := s.accounts.Load(accountID); ok {
|
||||
stat, _ := value.(*openAIAccountRuntimeStat)
|
||||
if stat != nil {
|
||||
return stat
|
||||
}
|
||||
}
|
||||
|
||||
stat := &openAIAccountRuntimeStat{}
|
||||
stat.ttftEWMABits.Store(math.Float64bits(math.NaN()))
|
||||
actual, loaded := s.accounts.LoadOrStore(accountID, stat)
|
||||
if !loaded {
|
||||
s.accountCount.Add(1)
|
||||
return stat
|
||||
}
|
||||
existing, _ := actual.(*openAIAccountRuntimeStat)
|
||||
if existing != nil {
|
||||
return existing
|
||||
}
|
||||
return stat
|
||||
}
|
||||
|
||||
func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) {
|
||||
for {
|
||||
oldBits := target.Load()
|
||||
oldValue := math.Float64frombits(oldBits)
|
||||
newValue := alpha*sample + (1-alpha)*oldValue
|
||||
if target.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) {
|
||||
if s == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
const alpha = 0.2
|
||||
stat := s.loadOrCreate(accountID)
|
||||
|
||||
errorSample := 1.0
|
||||
if success {
|
||||
errorSample = 0.0
|
||||
}
|
||||
updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha)
|
||||
|
||||
if firstTokenMs != nil && *firstTokenMs > 0 {
|
||||
ttft := float64(*firstTokenMs)
|
||||
ttftBits := math.Float64bits(ttft)
|
||||
for {
|
||||
oldBits := stat.ttftEWMABits.Load()
|
||||
oldValue := math.Float64frombits(oldBits)
|
||||
if math.IsNaN(oldValue) {
|
||||
if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
newValue := alpha*ttft + (1-alpha)*oldValue
|
||||
if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) {
|
||||
if s == nil || accountID <= 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
value, ok := s.accounts.Load(accountID)
|
||||
if !ok {
|
||||
return 0, 0, false
|
||||
}
|
||||
stat, _ := value.(*openAIAccountRuntimeStat)
|
||||
if stat == nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load()))
|
||||
ttftValue := math.Float64frombits(stat.ttftEWMABits.Load())
|
||||
if math.IsNaN(ttftValue) {
|
||||
return errorRate, 0, false
|
||||
}
|
||||
return errorRate, ttftValue, true
|
||||
}
|
||||
|
||||
func (s *openAIAccountRuntimeStats) size() int {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
return int(s.accountCount.Load())
|
||||
}
|
||||
|
||||
type defaultOpenAIAccountScheduler struct {
|
||||
service *OpenAIGatewayService
|
||||
metrics openAIAccountSchedulerMetrics
|
||||
stats *openAIAccountRuntimeStats
|
||||
}
|
||||
|
||||
func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler {
|
||||
if stats == nil {
|
||||
stats = newOpenAIAccountRuntimeStats()
|
||||
}
|
||||
return &defaultOpenAIAccountScheduler{
|
||||
service: service,
|
||||
stats: stats,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) Select(
|
||||
ctx context.Context,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||
decision := OpenAIAccountScheduleDecision{}
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
decision.LatencyMs = time.Since(start).Milliseconds()
|
||||
s.metrics.recordSelect(decision)
|
||||
}()
|
||||
|
||||
previousResponseID := strings.TrimSpace(req.PreviousResponseID)
|
||||
if previousResponseID != "" {
|
||||
selection, err := s.service.SelectAccountByPreviousResponseID(
|
||||
ctx,
|
||||
req.GroupID,
|
||||
previousResponseID,
|
||||
req.RequestedModel,
|
||||
req.ExcludedIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, decision, err
|
||||
}
|
||||
if selection != nil && selection.Account != nil {
|
||||
if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) {
|
||||
selection = nil
|
||||
}
|
||||
}
|
||||
if selection != nil && selection.Account != nil {
|
||||
decision.Layer = openAIAccountScheduleLayerPreviousResponse
|
||||
decision.StickyPreviousHit = true
|
||||
decision.SelectedAccountID = selection.Account.ID
|
||||
decision.SelectedAccountType = selection.Account.Type
|
||||
if req.SessionHash != "" {
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, selection.Account.ID)
|
||||
}
|
||||
return selection, decision, nil
|
||||
}
|
||||
}
|
||||
|
||||
selection, err := s.selectBySessionHash(ctx, req)
|
||||
if err != nil {
|
||||
return nil, decision, err
|
||||
}
|
||||
if selection != nil && selection.Account != nil {
|
||||
decision.Layer = openAIAccountScheduleLayerSessionSticky
|
||||
decision.StickySessionHit = true
|
||||
decision.SelectedAccountID = selection.Account.ID
|
||||
decision.SelectedAccountType = selection.Account.Type
|
||||
return selection, decision, nil
|
||||
}
|
||||
|
||||
selection, candidateCount, topK, loadSkew, err := s.selectByLoadBalance(ctx, req)
|
||||
decision.Layer = openAIAccountScheduleLayerLoadBalance
|
||||
decision.CandidateCount = candidateCount
|
||||
decision.TopK = topK
|
||||
decision.LoadSkew = loadSkew
|
||||
if err != nil {
|
||||
return nil, decision, err
|
||||
}
|
||||
if selection != nil && selection.Account != nil {
|
||||
decision.SelectedAccountID = selection.Account.ID
|
||||
decision.SelectedAccountType = selection.Account.Type
|
||||
}
|
||||
return selection, decision, nil
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
||||
ctx context.Context,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
) (*AccountSelectionResult, error) {
|
||||
sessionHash := strings.TrimSpace(req.SessionHash)
|
||||
if sessionHash == "" || s == nil || s.service == nil || s.service.cache == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
accountID := req.StickyAccountID
|
||||
if accountID <= 0 {
|
||||
var err error
|
||||
accountID, err = s.service.getStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
if err != nil || accountID <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
if accountID <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if req.ExcludedIDs != nil {
|
||||
if _, excluded := req.ExcludedIDs[accountID]; excluded {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
account, err := s.service.getSchedulableAccount(ctx, accountID)
|
||||
if err != nil || account == nil {
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() {
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
return nil, nil
|
||||
}
|
||||
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if acquireErr == nil && result.Acquired {
|
||||
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
cfg := s.service.schedulingConfig()
|
||||
if s.service.concurrencyService != nil {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type openAIAccountCandidateScore struct {
|
||||
account *Account
|
||||
loadInfo *AccountLoadInfo
|
||||
score float64
|
||||
errorRate float64
|
||||
ttft float64
|
||||
hasTTFT bool
|
||||
}
|
||||
|
||||
type openAIAccountCandidateHeap []openAIAccountCandidateScore
|
||||
|
||||
func (h openAIAccountCandidateHeap) Len() int {
|
||||
return len(h)
|
||||
}
|
||||
|
||||
func (h openAIAccountCandidateHeap) Less(i, j int) bool {
|
||||
// 最小堆根节点保存“最差”候选,便于 O(log k) 维护 topK。
|
||||
return isOpenAIAccountCandidateBetter(h[j], h[i])
|
||||
}
|
||||
|
||||
func (h openAIAccountCandidateHeap) Swap(i, j int) {
|
||||
h[i], h[j] = h[j], h[i]
|
||||
}
|
||||
|
||||
func (h *openAIAccountCandidateHeap) Push(x any) {
|
||||
candidate, ok := x.(openAIAccountCandidateScore)
|
||||
if !ok {
|
||||
panic("openAIAccountCandidateHeap: invalid element type")
|
||||
}
|
||||
*h = append(*h, candidate)
|
||||
}
|
||||
|
||||
func (h *openAIAccountCandidateHeap) Pop() any {
|
||||
old := *h
|
||||
n := len(old)
|
||||
last := old[n-1]
|
||||
*h = old[:n-1]
|
||||
return last
|
||||
}
|
||||
|
||||
func isOpenAIAccountCandidateBetter(left openAIAccountCandidateScore, right openAIAccountCandidateScore) bool {
|
||||
if left.score != right.score {
|
||||
return left.score > right.score
|
||||
}
|
||||
if left.account.Priority != right.account.Priority {
|
||||
return left.account.Priority < right.account.Priority
|
||||
}
|
||||
if left.loadInfo.LoadRate != right.loadInfo.LoadRate {
|
||||
return left.loadInfo.LoadRate < right.loadInfo.LoadRate
|
||||
}
|
||||
if left.loadInfo.WaitingCount != right.loadInfo.WaitingCount {
|
||||
return left.loadInfo.WaitingCount < right.loadInfo.WaitingCount
|
||||
}
|
||||
return left.account.ID < right.account.ID
|
||||
}
|
||||
|
||||
func selectTopKOpenAICandidates(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
if topK <= 0 {
|
||||
topK = 1
|
||||
}
|
||||
if topK >= len(candidates) {
|
||||
ranked := append([]openAIAccountCandidateScore(nil), candidates...)
|
||||
sort.Slice(ranked, func(i, j int) bool {
|
||||
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
|
||||
})
|
||||
return ranked
|
||||
}
|
||||
|
||||
best := make(openAIAccountCandidateHeap, 0, topK)
|
||||
for _, candidate := range candidates {
|
||||
if len(best) < topK {
|
||||
heap.Push(&best, candidate)
|
||||
continue
|
||||
}
|
||||
if isOpenAIAccountCandidateBetter(candidate, best[0]) {
|
||||
best[0] = candidate
|
||||
heap.Fix(&best, 0)
|
||||
}
|
||||
}
|
||||
|
||||
ranked := make([]openAIAccountCandidateScore, len(best))
|
||||
copy(ranked, best)
|
||||
sort.Slice(ranked, func(i, j int) bool {
|
||||
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
|
||||
})
|
||||
return ranked
|
||||
}
|
||||
|
||||
type openAISelectionRNG struct {
|
||||
state uint64
|
||||
}
|
||||
|
||||
func newOpenAISelectionRNG(seed uint64) openAISelectionRNG {
|
||||
if seed == 0 {
|
||||
seed = 0x9e3779b97f4a7c15
|
||||
}
|
||||
return openAISelectionRNG{state: seed}
|
||||
}
|
||||
|
||||
func (r *openAISelectionRNG) nextUint64() uint64 {
|
||||
// xorshift64*
|
||||
x := r.state
|
||||
x ^= x >> 12
|
||||
x ^= x << 25
|
||||
x ^= x >> 27
|
||||
r.state = x
|
||||
return x * 2685821657736338717
|
||||
}
|
||||
|
||||
func (r *openAISelectionRNG) nextFloat64() float64 {
|
||||
// [0,1)
|
||||
return float64(r.nextUint64()>>11) / (1 << 53)
|
||||
}
|
||||
|
||||
func deriveOpenAISelectionSeed(req OpenAIAccountScheduleRequest) uint64 {
|
||||
hasher := fnv.New64a()
|
||||
writeValue := func(value string) {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return
|
||||
}
|
||||
_, _ = hasher.Write([]byte(trimmed))
|
||||
_, _ = hasher.Write([]byte{0})
|
||||
}
|
||||
|
||||
writeValue(req.SessionHash)
|
||||
writeValue(req.PreviousResponseID)
|
||||
writeValue(req.RequestedModel)
|
||||
if req.GroupID != nil {
|
||||
_, _ = hasher.Write([]byte(strconv.FormatInt(*req.GroupID, 10)))
|
||||
}
|
||||
|
||||
seed := hasher.Sum64()
|
||||
// 对“无会话锚点”的纯负载均衡请求引入时间熵,避免固定命中同一账号。
|
||||
if strings.TrimSpace(req.SessionHash) == "" && strings.TrimSpace(req.PreviousResponseID) == "" {
|
||||
seed ^= uint64(time.Now().UnixNano())
|
||||
}
|
||||
if seed == 0 {
|
||||
seed = uint64(time.Now().UnixNano()) ^ 0x9e3779b97f4a7c15
|
||||
}
|
||||
return seed
|
||||
}
|
||||
|
||||
func buildOpenAIWeightedSelectionOrder(
|
||||
candidates []openAIAccountCandidateScore,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
) []openAIAccountCandidateScore {
|
||||
if len(candidates) <= 1 {
|
||||
return append([]openAIAccountCandidateScore(nil), candidates...)
|
||||
}
|
||||
|
||||
pool := append([]openAIAccountCandidateScore(nil), candidates...)
|
||||
weights := make([]float64, len(pool))
|
||||
minScore := pool[0].score
|
||||
for i := 1; i < len(pool); i++ {
|
||||
if pool[i].score < minScore {
|
||||
minScore = pool[i].score
|
||||
}
|
||||
}
|
||||
for i := range pool {
|
||||
// 将 top-K 分值平移到正区间,避免“单一最高分账号”长期垄断。
|
||||
weight := (pool[i].score - minScore) + 1.0
|
||||
if math.IsNaN(weight) || math.IsInf(weight, 0) || weight <= 0 {
|
||||
weight = 1.0
|
||||
}
|
||||
weights[i] = weight
|
||||
}
|
||||
|
||||
order := make([]openAIAccountCandidateScore, 0, len(pool))
|
||||
rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
|
||||
for len(pool) > 0 {
|
||||
total := 0.0
|
||||
for _, w := range weights {
|
||||
total += w
|
||||
}
|
||||
|
||||
selectedIdx := 0
|
||||
if total > 0 {
|
||||
r := rng.nextFloat64() * total
|
||||
acc := 0.0
|
||||
for i, w := range weights {
|
||||
acc += w
|
||||
if r <= acc {
|
||||
selectedIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
selectedIdx = int(rng.nextUint64() % uint64(len(pool)))
|
||||
}
|
||||
|
||||
order = append(order, pool[selectedIdx])
|
||||
pool = append(pool[:selectedIdx], pool[selectedIdx+1:]...)
|
||||
weights = append(weights[:selectedIdx], weights[selectedIdx+1:]...)
|
||||
}
|
||||
return order
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
ctx context.Context,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
) (*AccountSelectionResult, int, int, float64, error) {
|
||||
accounts, err := s.service.listSchedulableAccounts(ctx, req.GroupID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, 0, err
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
|
||||
}
|
||||
|
||||
filtered := make([]*Account, 0, len(accounts))
|
||||
loadReq := make([]AccountWithConcurrency, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
account := &accounts[i]
|
||||
if req.ExcludedIDs != nil {
|
||||
if _, excluded := req.ExcludedIDs[account.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !account.IsSchedulable() || !account.IsOpenAI() {
|
||||
continue
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, account)
|
||||
loadReq = append(loadReq, AccountWithConcurrency{
|
||||
ID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
})
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
|
||||
}
|
||||
|
||||
loadMap := map[int64]*AccountLoadInfo{}
|
||||
if s.service.concurrencyService != nil {
|
||||
if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil {
|
||||
loadMap = batchLoad
|
||||
}
|
||||
}
|
||||
|
||||
minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
|
||||
maxWaiting := 1
|
||||
loadRateSum := 0.0
|
||||
loadRateSumSquares := 0.0
|
||||
minTTFT, maxTTFT := 0.0, 0.0
|
||||
hasTTFTSample := false
|
||||
candidates := make([]openAIAccountCandidateScore, 0, len(filtered))
|
||||
for _, account := range filtered {
|
||||
loadInfo := loadMap[account.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: account.ID}
|
||||
}
|
||||
if account.Priority < minPriority {
|
||||
minPriority = account.Priority
|
||||
}
|
||||
if account.Priority > maxPriority {
|
||||
maxPriority = account.Priority
|
||||
}
|
||||
if loadInfo.WaitingCount > maxWaiting {
|
||||
maxWaiting = loadInfo.WaitingCount
|
||||
}
|
||||
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
|
||||
if hasTTFT && ttft > 0 {
|
||||
if !hasTTFTSample {
|
||||
minTTFT, maxTTFT = ttft, ttft
|
||||
hasTTFTSample = true
|
||||
} else {
|
||||
if ttft < minTTFT {
|
||||
minTTFT = ttft
|
||||
}
|
||||
if ttft > maxTTFT {
|
||||
maxTTFT = ttft
|
||||
}
|
||||
}
|
||||
}
|
||||
loadRate := float64(loadInfo.LoadRate)
|
||||
loadRateSum += loadRate
|
||||
loadRateSumSquares += loadRate * loadRate
|
||||
candidates = append(candidates, openAIAccountCandidateScore{
|
||||
account: account,
|
||||
loadInfo: loadInfo,
|
||||
errorRate: errorRate,
|
||||
ttft: ttft,
|
||||
hasTTFT: hasTTFT,
|
||||
})
|
||||
}
|
||||
loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
|
||||
|
||||
weights := s.service.openAIWSSchedulerWeights()
|
||||
for i := range candidates {
|
||||
item := &candidates[i]
|
||||
priorityFactor := 1.0
|
||||
if maxPriority > minPriority {
|
||||
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
|
||||
}
|
||||
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
|
||||
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
|
||||
errorFactor := 1 - clamp01(item.errorRate)
|
||||
ttftFactor := 0.5
|
||||
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
|
||||
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
|
||||
}
|
||||
|
||||
item.score = weights.Priority*priorityFactor +
|
||||
weights.Load*loadFactor +
|
||||
weights.Queue*queueFactor +
|
||||
weights.ErrorRate*errorFactor +
|
||||
weights.TTFT*ttftFactor
|
||||
}
|
||||
|
||||
topK := s.service.openAIWSLBTopK()
|
||||
if topK > len(candidates) {
|
||||
topK = len(candidates)
|
||||
}
|
||||
if topK <= 0 {
|
||||
topK = 1
|
||||
}
|
||||
rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
|
||||
selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
|
||||
|
||||
for i := 0; i < len(selectionOrder); i++ {
|
||||
candidate := selectionOrder[i]
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency)
|
||||
if acquireErr != nil {
|
||||
return nil, len(candidates), topK, loadSkew, acquireErr
|
||||
}
|
||||
if result != nil && result.Acquired {
|
||||
if req.SessionHash != "" {
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: candidate.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
}
|
||||
}
|
||||
|
||||
cfg := s.service.schedulingConfig()
|
||||
candidate := selectionOrder[0]
|
||||
return &AccountSelectionResult{
|
||||
Account: candidate.account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: candidate.account.ID,
|
||||
MaxConcurrency: candidate.account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
|
||||
// HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。
|
||||
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
|
||||
return true
|
||||
}
|
||||
if s == nil || s.service == nil || account == nil {
|
||||
return false
|
||||
}
|
||||
return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
|
||||
if s == nil || s.stats == nil {
|
||||
return
|
||||
}
|
||||
s.stats.report(accountID, success, firstTokenMs)
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) ReportSwitch() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.metrics.recordSwitch()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot {
|
||||
if s == nil {
|
||||
return OpenAIAccountSchedulerMetricsSnapshot{}
|
||||
}
|
||||
|
||||
selectTotal := s.metrics.selectTotal.Load()
|
||||
prevHit := s.metrics.stickyPreviousHitTotal.Load()
|
||||
sessionHit := s.metrics.stickySessionHitTotal.Load()
|
||||
switchTotal := s.metrics.accountSwitchTotal.Load()
|
||||
latencyTotal := s.metrics.latencyMsTotal.Load()
|
||||
loadSkewTotal := s.metrics.loadSkewMilliTotal.Load()
|
||||
|
||||
snapshot := OpenAIAccountSchedulerMetricsSnapshot{
|
||||
SelectTotal: selectTotal,
|
||||
StickyPreviousHitTotal: prevHit,
|
||||
StickySessionHitTotal: sessionHit,
|
||||
LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(),
|
||||
AccountSwitchTotal: switchTotal,
|
||||
SchedulerLatencyMsTotal: latencyTotal,
|
||||
RuntimeStatsAccountCount: s.stats.size(),
|
||||
}
|
||||
if selectTotal > 0 {
|
||||
snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal)
|
||||
snapshot.StickyHitRatio = float64(prevHit+sessionHit) / float64(selectTotal)
|
||||
snapshot.AccountSwitchRate = float64(switchTotal) / float64(selectTotal)
|
||||
snapshot.LoadSkewAvg = float64(loadSkewTotal) / 1000 / float64(selectTotal)
|
||||
}
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
s.openaiSchedulerOnce.Do(func() {
|
||||
if s.openaiAccountStats == nil {
|
||||
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
|
||||
}
|
||||
if s.openaiScheduler == nil {
|
||||
s.openaiScheduler = newDefaultOpenAIAccountScheduler(s, s.openaiAccountStats)
|
||||
}
|
||||
})
|
||||
return s.openaiScheduler
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
|
||||
ctx context.Context,
|
||||
groupID *int64,
|
||||
previousResponseID string,
|
||||
sessionHash string,
|
||||
requestedModel string,
|
||||
excludedIDs map[int64]struct{},
|
||||
requiredTransport OpenAIUpstreamTransport,
|
||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||
decision := OpenAIAccountScheduleDecision{}
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
if scheduler == nil {
|
||||
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
decision.Layer = openAIAccountScheduleLayerLoadBalance
|
||||
return selection, decision, err
|
||||
}
|
||||
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 {
|
||||
stickyAccountID = accountID
|
||||
}
|
||||
}
|
||||
|
||||
return scheduler.Select(ctx, OpenAIAccountScheduleRequest{
|
||||
GroupID: groupID,
|
||||
SessionHash: sessionHash,
|
||||
StickyAccountID: stickyAccountID,
|
||||
PreviousResponseID: previousResponseID,
|
||||
RequestedModel: requestedModel,
|
||||
RequiredTransport: requiredTransport,
|
||||
ExcludedIDs: excludedIDs,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
if scheduler == nil {
|
||||
return
|
||||
}
|
||||
scheduler.ReportResult(accountID, success, firstTokenMs)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
if scheduler == nil {
|
||||
return
|
||||
}
|
||||
scheduler.ReportSwitch()
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
if scheduler == nil {
|
||||
return OpenAIAccountSchedulerMetricsSnapshot{}
|
||||
}
|
||||
return scheduler.SnapshotMetrics()
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIWSSessionStickyTTL() time.Duration {
|
||||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 {
|
||||
return time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second
|
||||
}
|
||||
return openaiStickySessionTTL
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIWSLBTopK() int {
|
||||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 {
|
||||
return s.cfg.Gateway.OpenAIWS.LBTopK
|
||||
}
|
||||
return 7
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView {
|
||||
if s != nil && s.cfg != nil {
|
||||
return GatewayOpenAIWSSchedulerScoreWeightsView{
|
||||
Priority: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority,
|
||||
Load: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load,
|
||||
Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue,
|
||||
ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate,
|
||||
TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT,
|
||||
}
|
||||
}
|
||||
return GatewayOpenAIWSSchedulerScoreWeightsView{
|
||||
Priority: 1.0,
|
||||
Load: 1.0,
|
||||
Queue: 0.7,
|
||||
ErrorRate: 0.8,
|
||||
TTFT: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
type GatewayOpenAIWSSchedulerScoreWeightsView struct {
|
||||
Priority float64
|
||||
Load float64
|
||||
Queue float64
|
||||
ErrorRate float64
|
||||
TTFT float64
|
||||
}
|
||||
|
||||
func clamp01(value float64) float64 {
|
||||
switch {
|
||||
case value < 0:
|
||||
return 0
|
||||
case value > 1:
|
||||
return 1
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 {
|
||||
if count <= 1 {
|
||||
return 0
|
||||
}
|
||||
mean := sum / float64(count)
|
||||
variance := sumSquares/float64(count) - mean*mean
|
||||
if variance < 0 {
|
||||
variance = 0
|
||||
}
|
||||
return math.Sqrt(variance)
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func buildOpenAISchedulerBenchmarkCandidates(size int) []openAIAccountCandidateScore {
|
||||
if size <= 0 {
|
||||
return nil
|
||||
}
|
||||
candidates := make([]openAIAccountCandidateScore, 0, size)
|
||||
for i := 0; i < size; i++ {
|
||||
accountID := int64(10_000 + i)
|
||||
candidates = append(candidates, openAIAccountCandidateScore{
|
||||
account: &Account{
|
||||
ID: accountID,
|
||||
Priority: i % 7,
|
||||
},
|
||||
loadInfo: &AccountLoadInfo{
|
||||
AccountID: accountID,
|
||||
LoadRate: (i * 17) % 100,
|
||||
WaitingCount: (i * 11) % 13,
|
||||
},
|
||||
score: float64((i*29)%1000) / 100,
|
||||
errorRate: float64((i * 5) % 100 / 100),
|
||||
ttft: float64(30 + (i*3)%500),
|
||||
hasTTFT: i%3 != 0,
|
||||
})
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func selectTopKOpenAICandidatesBySortBenchmark(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
if topK <= 0 {
|
||||
topK = 1
|
||||
}
|
||||
ranked := append([]openAIAccountCandidateScore(nil), candidates...)
|
||||
sort.Slice(ranked, func(i, j int) bool {
|
||||
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
|
||||
})
|
||||
if topK > len(ranked) {
|
||||
topK = len(ranked)
|
||||
}
|
||||
return ranked[:topK]
|
||||
}
|
||||
|
||||
func BenchmarkOpenAIAccountSchedulerSelectTopK(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
size int
|
||||
topK int
|
||||
}{
|
||||
{name: "n_16_k_3", size: 16, topK: 3},
|
||||
{name: "n_64_k_3", size: 64, topK: 3},
|
||||
{name: "n_256_k_5", size: 256, topK: 5},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
candidates := buildOpenAISchedulerBenchmarkCandidates(tc.size)
|
||||
b.Run(tc.name+"/heap_topk", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
result := selectTopKOpenAICandidates(candidates, tc.topK)
|
||||
if len(result) == 0 {
|
||||
b.Fatal("unexpected empty result")
|
||||
}
|
||||
}
|
||||
})
|
||||
b.Run(tc.name+"/full_sort", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
result := selectTopKOpenAICandidatesBySortBenchmark(candidates, tc.topK)
|
||||
if len(result) == 0 {
|
||||
b.Fatal("unexpected empty result")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
841
backend/internal/service/openai_account_scheduler_test.go
Normal file
841
backend/internal/service/openai_account_scheduler_test.go
Normal file
@@ -0,0 +1,841 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(9)
|
||||
account := Account{
|
||||
ID: 1001,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 2,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
|
||||
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
store := svc.getOpenAIWSStateStore()
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_001", account.ID, time.Hour))
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"resp_prev_001",
|
||||
"session_hash_001",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, account.ID, selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer)
|
||||
require.True(t, decision.StickyPreviousHit)
|
||||
require.Equal(t, account.ID, cache.sessionBindings["openai:session_hash_001"])
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10)
|
||||
account := Account{
|
||||
ID: 2001,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:session_hash_abc": account.ID,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: &config.Config{},
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
"session_hash_abc",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, account.ID, selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
|
||||
require.True(t, decision.StickySessionHit)
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsSticky(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10100)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 21001,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
},
|
||||
{
|
||||
ID: 21002,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 9,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:session_hash_sticky_busy": 21001,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2
|
||||
cfg.Gateway.Scheduling.StickySessionWaitTimeout = 45 * time.Second
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
acquireResults: map[int64]bool{
|
||||
21001: false, // sticky 账号已满
|
||||
21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换)
|
||||
},
|
||||
waitCounts: map[int64]int{
|
||||
21001: 999,
|
||||
},
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
21001: {AccountID: 21001, LoadRate: 90, WaitingCount: 9},
|
||||
21002: {AccountID: 21002, LoadRate: 1, WaitingCount: 0},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
"session_hash_sticky_busy",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(21001), selection.Account.ID, "busy sticky account should remain selected")
|
||||
require.False(t, selection.Acquired)
|
||||
require.NotNil(t, selection.WaitPlan)
|
||||
require.Equal(t, int64(21001), selection.WaitPlan.AccountID)
|
||||
require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
|
||||
require.True(t, decision.StickySessionHit)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(1010)
|
||||
account := Account{
|
||||
ID: 2101,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_ws_force_http": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:session_hash_force_http": account.ID,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: &config.Config{},
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
"session_hash_force_http",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, account.ID, selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
|
||||
require.True(t, decision.StickySessionHit)
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStickyHTTPAccount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(1011)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 2201,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
},
|
||||
{
|
||||
ID: 2202,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 5,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:session_hash_ws_only": 2201,
|
||||
},
|
||||
}
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
|
||||
// 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0},
|
||||
2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
"session_hash_ws_only",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(2202), selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
require.False(t, decision.StickySessionHit)
|
||||
require.Equal(t, 1, decision.CandidateCount)
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailableAccount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(1012)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 2301,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &stubGatewayCache{},
|
||||
cfg: newOpenAIWSV2TestConfig(),
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
"",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, selection)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
require.Equal(t, 0, decision.CandidateCount)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(11)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 3001,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
},
|
||||
{
|
||||
ID: 3002,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
},
|
||||
{
|
||||
ID: 3003,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.LBTopK = 2
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.4
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1
|
||||
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8},
|
||||
3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1},
|
||||
3003: {AccountID: 3003, LoadRate: 10, WaitingCount: 0},
|
||||
},
|
||||
acquireResults: map[int64]bool{
|
||||
3003: false, // top1 失败,必须回退到 top-K 的下一候选
|
||||
3002: true,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &stubGatewayCache{},
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
"",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(3002), selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
require.Equal(t, 3, decision.CandidateCount)
|
||||
require.Equal(t, 2, decision.TopK)
|
||||
require.Greater(t, decision.LoadSkew, 0.0)
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(12)
|
||||
account := Account{
|
||||
ID: 4001,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:session_hash_metrics": account.ID,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: &config.Config{},
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120))
|
||||
svc.RecordOpenAIAccountSwitch()
|
||||
|
||||
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
|
||||
require.GreaterOrEqual(t, snapshot.SelectTotal, int64(1))
|
||||
require.GreaterOrEqual(t, snapshot.StickySessionHitTotal, int64(1))
|
||||
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
|
||||
require.GreaterOrEqual(t, snapshot.SchedulerLatencyMsAvg, float64(0))
|
||||
require.GreaterOrEqual(t, snapshot.StickyHitRatio, 0.0)
|
||||
require.GreaterOrEqual(t, snapshot.RuntimeStatsAccountCount, 1)
|
||||
}
|
||||
|
||||
func intPtrForTest(v int) *int {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestOpenAIAccountRuntimeStats_ReportAndSnapshot(t *testing.T) {
|
||||
stats := newOpenAIAccountRuntimeStats()
|
||||
stats.report(1001, true, nil)
|
||||
firstTTFT := 100
|
||||
stats.report(1001, false, &firstTTFT)
|
||||
secondTTFT := 200
|
||||
stats.report(1001, false, &secondTTFT)
|
||||
|
||||
errorRate, ttft, hasTTFT := stats.snapshot(1001)
|
||||
require.True(t, hasTTFT)
|
||||
require.InDelta(t, 0.36, errorRate, 1e-9)
|
||||
require.InDelta(t, 120.0, ttft, 1e-9)
|
||||
require.Equal(t, 1, stats.size())
|
||||
}
|
||||
|
||||
func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) {
|
||||
stats := newOpenAIAccountRuntimeStats()
|
||||
|
||||
const (
|
||||
accountCount = 4
|
||||
workers = 16
|
||||
iterations = 800
|
||||
)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(workers)
|
||||
for worker := 0; worker < workers; worker++ {
|
||||
worker := worker
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < iterations; i++ {
|
||||
accountID := int64(i%accountCount + 1)
|
||||
success := (i+worker)%3 != 0
|
||||
ttft := 80 + (i+worker)%40
|
||||
stats.report(accountID, success, &ttft)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
require.Equal(t, accountCount, stats.size())
|
||||
for accountID := int64(1); accountID <= accountCount; accountID++ {
|
||||
errorRate, ttft, hasTTFT := stats.snapshot(accountID)
|
||||
require.GreaterOrEqual(t, errorRate, 0.0)
|
||||
require.LessOrEqual(t, errorRate, 1.0)
|
||||
require.True(t, hasTTFT)
|
||||
require.Greater(t, ttft, 0.0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectTopKOpenAICandidates(t *testing.T) {
|
||||
candidates := []openAIAccountCandidateScore{
|
||||
{
|
||||
account: &Account{ID: 11, Priority: 2},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 1},
|
||||
score: 10.0,
|
||||
},
|
||||
{
|
||||
account: &Account{ID: 12, Priority: 1},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 20, WaitingCount: 1},
|
||||
score: 9.5,
|
||||
},
|
||||
{
|
||||
account: &Account{ID: 13, Priority: 1},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 0},
|
||||
score: 10.0,
|
||||
},
|
||||
{
|
||||
account: &Account{ID: 14, Priority: 0},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 40, WaitingCount: 0},
|
||||
score: 8.0,
|
||||
},
|
||||
}
|
||||
|
||||
top2 := selectTopKOpenAICandidates(candidates, 2)
|
||||
require.Len(t, top2, 2)
|
||||
require.Equal(t, int64(13), top2[0].account.ID)
|
||||
require.Equal(t, int64(11), top2[1].account.ID)
|
||||
|
||||
topAll := selectTopKOpenAICandidates(candidates, 8)
|
||||
require.Len(t, topAll, len(candidates))
|
||||
require.Equal(t, int64(13), topAll[0].account.ID)
|
||||
require.Equal(t, int64(11), topAll[1].account.ID)
|
||||
require.Equal(t, int64(12), topAll[2].account.ID)
|
||||
require.Equal(t, int64(14), topAll[3].account.ID)
|
||||
}
|
||||
|
||||
func TestBuildOpenAIWeightedSelectionOrder_DeterministicBySessionSeed(t *testing.T) {
|
||||
candidates := []openAIAccountCandidateScore{
|
||||
{
|
||||
account: &Account{ID: 101},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 0},
|
||||
score: 4.2,
|
||||
},
|
||||
{
|
||||
account: &Account{ID: 102},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 1},
|
||||
score: 3.5,
|
||||
},
|
||||
{
|
||||
account: &Account{ID: 103},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 2},
|
||||
score: 2.1,
|
||||
},
|
||||
}
|
||||
req := OpenAIAccountScheduleRequest{
|
||||
GroupID: int64PtrForTest(99),
|
||||
SessionHash: "session_seed_fixed",
|
||||
RequestedModel: "gpt-5.1",
|
||||
}
|
||||
|
||||
first := buildOpenAIWeightedSelectionOrder(candidates, req)
|
||||
second := buildOpenAIWeightedSelectionOrder(candidates, req)
|
||||
require.Len(t, first, len(candidates))
|
||||
require.Len(t, second, len(candidates))
|
||||
for i := range first {
|
||||
require.Equal(t, first[i].account.ID, second[i].account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesAcrossSessions(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(15)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 5101,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 3,
|
||||
Priority: 0,
|
||||
},
|
||||
{
|
||||
ID: 5102,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 3,
|
||||
Priority: 0,
|
||||
},
|
||||
{
|
||||
ID: 5103,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 3,
|
||||
Priority: 0,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.LBTopK = 3
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
|
||||
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1},
|
||||
5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1},
|
||||
5103: {AccountID: 5103, LoadRate: 20, WaitingCount: 1},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selected := make(map[int64]int, len(accounts))
|
||||
for i := 0; i < 60; i++ {
|
||||
sessionHash := fmt.Sprintf("session_hash_lb_%d", i)
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
sessionHash,
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
selected[selection.Account.ID]++
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
// 多 session 应该能打散到多个账号,避免“恒定单账号命中”。
|
||||
require.GreaterOrEqual(t, len(selected), 2)
|
||||
}
|
||||
|
||||
func TestDeriveOpenAISelectionSeed_NoAffinityAddsEntropy(t *testing.T) {
|
||||
req := OpenAIAccountScheduleRequest{
|
||||
RequestedModel: "gpt-5.1",
|
||||
}
|
||||
seed1 := deriveOpenAISelectionSeed(req)
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
seed2 := deriveOpenAISelectionSeed(req)
|
||||
require.NotZero(t, seed1)
|
||||
require.NotZero(t, seed2)
|
||||
require.NotEqual(t, seed1, seed2)
|
||||
}
|
||||
|
||||
func TestBuildOpenAIWeightedSelectionOrder_HandlesInvalidScores(t *testing.T) {
|
||||
candidates := []openAIAccountCandidateScore{
|
||||
{
|
||||
account: &Account{ID: 901},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
|
||||
score: math.NaN(),
|
||||
},
|
||||
{
|
||||
account: &Account{ID: 902},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
|
||||
score: math.Inf(1),
|
||||
},
|
||||
{
|
||||
account: &Account{ID: 903},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
|
||||
score: -1,
|
||||
},
|
||||
}
|
||||
req := OpenAIAccountScheduleRequest{
|
||||
SessionHash: "seed_invalid_scores",
|
||||
}
|
||||
|
||||
order := buildOpenAIWeightedSelectionOrder(candidates, req)
|
||||
require.Len(t, order, len(candidates))
|
||||
seen := map[int64]struct{}{}
|
||||
for _, item := range order {
|
||||
seen[item.account.ID] = struct{}{}
|
||||
}
|
||||
require.Len(t, seen, len(candidates))
|
||||
}
|
||||
|
||||
func TestOpenAISelectionRNG_SeedZeroStillWorks(t *testing.T) {
|
||||
rng := newOpenAISelectionRNG(0)
|
||||
v1 := rng.nextUint64()
|
||||
v2 := rng.nextUint64()
|
||||
require.NotEqual(t, v1, v2)
|
||||
require.GreaterOrEqual(t, rng.nextFloat64(), 0.0)
|
||||
require.Less(t, rng.nextFloat64(), 1.0)
|
||||
}
|
||||
|
||||
func TestOpenAIAccountCandidateHeap_PushPopAndInvalidType(t *testing.T) {
|
||||
h := openAIAccountCandidateHeap{}
|
||||
h.Push(openAIAccountCandidateScore{
|
||||
account: &Account{ID: 7001},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 0, WaitingCount: 0},
|
||||
score: 1.0,
|
||||
})
|
||||
require.Equal(t, 1, h.Len())
|
||||
popped, ok := h.Pop().(openAIAccountCandidateScore)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(7001), popped.account.ID)
|
||||
require.Equal(t, 0, h.Len())
|
||||
|
||||
require.Panics(t, func() {
|
||||
h.Push("bad_element_type")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClamp01_AllBranches(t *testing.T) {
|
||||
require.Equal(t, 0.0, clamp01(-0.2))
|
||||
require.Equal(t, 1.0, clamp01(1.3))
|
||||
require.Equal(t, 0.5, clamp01(0.5))
|
||||
}
|
||||
|
||||
func TestCalcLoadSkewByMoments_Branches(t *testing.T) {
|
||||
require.Equal(t, 0.0, calcLoadSkewByMoments(1, 1, 1))
|
||||
// variance < 0 分支:sumSquares/count - mean^2 为负值时应钳制为 0。
|
||||
require.Equal(t, 0.0, calcLoadSkewByMoments(1, 0, 2))
|
||||
require.GreaterOrEqual(t, calcLoadSkewByMoments(6, 20, 3), 0.0)
|
||||
}
|
||||
|
||||
func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
|
||||
schedulerAny := newDefaultOpenAIAccountScheduler(&OpenAIGatewayService{}, nil)
|
||||
scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
|
||||
require.True(t, ok)
|
||||
|
||||
ttft := 100
|
||||
scheduler.ReportResult(1001, true, &ttft)
|
||||
scheduler.ReportSwitch()
|
||||
scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{
|
||||
Layer: openAIAccountScheduleLayerLoadBalance,
|
||||
LatencyMs: 8,
|
||||
LoadSkew: 0.5,
|
||||
StickyPreviousHit: true,
|
||||
})
|
||||
scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{
|
||||
Layer: openAIAccountScheduleLayerSessionSticky,
|
||||
LatencyMs: 6,
|
||||
LoadSkew: 0.2,
|
||||
StickySessionHit: true,
|
||||
})
|
||||
|
||||
snapshot := scheduler.SnapshotMetrics()
|
||||
require.Equal(t, int64(2), snapshot.SelectTotal)
|
||||
require.Equal(t, int64(1), snapshot.StickyPreviousHitTotal)
|
||||
require.Equal(t, int64(1), snapshot.StickySessionHitTotal)
|
||||
require.Equal(t, int64(1), snapshot.LoadBalanceSelectTotal)
|
||||
require.Equal(t, int64(1), snapshot.AccountSwitchTotal)
|
||||
require.Greater(t, snapshot.SchedulerLatencyMsAvg, 0.0)
|
||||
require.Greater(t, snapshot.StickyHitRatio, 0.0)
|
||||
require.Greater(t, snapshot.LoadSkewAvg, 0.0)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
ttft := 120
|
||||
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
|
||||
svc.RecordOpenAIAccountSwitch()
|
||||
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
|
||||
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
|
||||
require.Equal(t, 7, svc.openAIWSLBTopK())
|
||||
require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
|
||||
|
||||
defaultWeights := svc.openAIWSSchedulerWeights()
|
||||
require.Equal(t, 1.0, defaultWeights.Priority)
|
||||
require.Equal(t, 1.0, defaultWeights.Load)
|
||||
require.Equal(t, 0.7, defaultWeights.Queue)
|
||||
require.Equal(t, 0.8, defaultWeights.ErrorRate)
|
||||
require.Equal(t, 0.5, defaultWeights.TTFT)
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.LBTopK = 9
|
||||
cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 180
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.2
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0.3
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.4
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.5
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.6
|
||||
svcWithCfg := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
require.Equal(t, 9, svcWithCfg.openAIWSLBTopK())
|
||||
require.Equal(t, 180*time.Second, svcWithCfg.openAIWSSessionStickyTTL())
|
||||
customWeights := svcWithCfg.openAIWSSchedulerWeights()
|
||||
require.Equal(t, 0.2, customWeights.Priority)
|
||||
require.Equal(t, 0.3, customWeights.Load)
|
||||
require.Equal(t, 0.4, customWeights.Queue)
|
||||
require.Equal(t, 0.5, customWeights.ErrorRate)
|
||||
require.Equal(t, 0.6, customWeights.TTFT)
|
||||
}
|
||||
|
||||
func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *testing.T) {
|
||||
scheduler := &defaultOpenAIAccountScheduler{}
|
||||
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportAny))
|
||||
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE))
|
||||
require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2))
|
||||
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
scheduler.service = &OpenAIGatewayService{cfg: cfg}
|
||||
account := &Account{
|
||||
ID: 8801,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.True(t, scheduler.isAccountTransportCompatible(account, OpenAIUpstreamTransportResponsesWebsocketV2))
|
||||
}
|
||||
|
||||
func int64PtrForTest(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
71
backend/internal/service/openai_client_transport.go
Normal file
71
backend/internal/service/openai_client_transport.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpenAIClientTransport 表示客户端入站协议类型。
|
||||
type OpenAIClientTransport string
|
||||
|
||||
const (
|
||||
OpenAIClientTransportUnknown OpenAIClientTransport = ""
|
||||
OpenAIClientTransportHTTP OpenAIClientTransport = "http"
|
||||
OpenAIClientTransportWS OpenAIClientTransport = "ws"
|
||||
)
|
||||
|
||||
const openAIClientTransportContextKey = "openai_client_transport"
|
||||
|
||||
// SetOpenAIClientTransport 标记当前请求的客户端入站协议。
|
||||
func SetOpenAIClientTransport(c *gin.Context, transport OpenAIClientTransport) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
normalized := normalizeOpenAIClientTransport(transport)
|
||||
if normalized == OpenAIClientTransportUnknown {
|
||||
return
|
||||
}
|
||||
c.Set(openAIClientTransportContextKey, string(normalized))
|
||||
}
|
||||
|
||||
// GetOpenAIClientTransport 读取当前请求的客户端入站协议。
|
||||
func GetOpenAIClientTransport(c *gin.Context) OpenAIClientTransport {
|
||||
if c == nil {
|
||||
return OpenAIClientTransportUnknown
|
||||
}
|
||||
raw, ok := c.Get(openAIClientTransportContextKey)
|
||||
if !ok || raw == nil {
|
||||
return OpenAIClientTransportUnknown
|
||||
}
|
||||
|
||||
switch v := raw.(type) {
|
||||
case OpenAIClientTransport:
|
||||
return normalizeOpenAIClientTransport(v)
|
||||
case string:
|
||||
return normalizeOpenAIClientTransport(OpenAIClientTransport(v))
|
||||
default:
|
||||
return OpenAIClientTransportUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpenAIClientTransport(transport OpenAIClientTransport) OpenAIClientTransport {
|
||||
switch strings.ToLower(strings.TrimSpace(string(transport))) {
|
||||
case string(OpenAIClientTransportHTTP), "http_sse", "sse":
|
||||
return OpenAIClientTransportHTTP
|
||||
case string(OpenAIClientTransportWS), "websocket":
|
||||
return OpenAIClientTransportWS
|
||||
default:
|
||||
return OpenAIClientTransportUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func resolveOpenAIWSDecisionByClientTransport(
|
||||
decision OpenAIWSProtocolDecision,
|
||||
clientTransport OpenAIClientTransport,
|
||||
) OpenAIWSProtocolDecision {
|
||||
if clientTransport == OpenAIClientTransportHTTP {
|
||||
return openAIWSHTTPDecision("client_protocol_http")
|
||||
}
|
||||
return decision
|
||||
}
|
||||
107
backend/internal/service/openai_client_transport_test.go
Normal file
107
backend/internal/service/openai_client_transport_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpenAIClientTransport_SetAndGet(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(c))
|
||||
|
||||
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
|
||||
require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c))
|
||||
|
||||
SetOpenAIClientTransport(c, OpenAIClientTransportWS)
|
||||
require.Equal(t, OpenAIClientTransportWS, GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
func TestOpenAIClientTransport_GetNormalizesRawContextValue(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rawValue any
|
||||
want OpenAIClientTransport
|
||||
}{
|
||||
{
|
||||
name: "type_value_ws",
|
||||
rawValue: OpenAIClientTransportWS,
|
||||
want: OpenAIClientTransportWS,
|
||||
},
|
||||
{
|
||||
name: "http_sse_alias",
|
||||
rawValue: "http_sse",
|
||||
want: OpenAIClientTransportHTTP,
|
||||
},
|
||||
{
|
||||
name: "sse_alias",
|
||||
rawValue: "sSe",
|
||||
want: OpenAIClientTransportHTTP,
|
||||
},
|
||||
{
|
||||
name: "websocket_alias",
|
||||
rawValue: "WebSocket",
|
||||
want: OpenAIClientTransportWS,
|
||||
},
|
||||
{
|
||||
name: "invalid_string",
|
||||
rawValue: "tcp",
|
||||
want: OpenAIClientTransportUnknown,
|
||||
},
|
||||
{
|
||||
name: "invalid_type",
|
||||
rawValue: 123,
|
||||
want: OpenAIClientTransportUnknown,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Set(openAIClientTransportContextKey, tt.rawValue)
|
||||
require.Equal(t, tt.want, GetOpenAIClientTransport(c))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIClientTransport_NilAndUnknownInput(t *testing.T) {
|
||||
SetOpenAIClientTransport(nil, OpenAIClientTransportHTTP)
|
||||
require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(nil))
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
SetOpenAIClientTransport(c, OpenAIClientTransportUnknown)
|
||||
_, exists := c.Get(openAIClientTransportContextKey)
|
||||
require.False(t, exists)
|
||||
|
||||
SetOpenAIClientTransport(c, OpenAIClientTransport(" "))
|
||||
_, exists = c.Get(openAIClientTransportContextKey)
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func TestResolveOpenAIWSDecisionByClientTransport(t *testing.T) {
|
||||
base := OpenAIWSProtocolDecision{
|
||||
Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
Reason: "ws_v2_enabled",
|
||||
}
|
||||
|
||||
httpDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportHTTP)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, httpDecision.Transport)
|
||||
require.Equal(t, "client_protocol_http", httpDecision.Reason)
|
||||
|
||||
wsDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportWS)
|
||||
require.Equal(t, base, wsDecision)
|
||||
|
||||
unknownDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportUnknown)
|
||||
require.Equal(t, base, unknownDecision)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -123,3 +123,19 @@ func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "parse request")
|
||||
}
|
||||
|
||||
func TestGetOpenAIRequestBodyMap_WriteBackContextCache(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
got, err := getOpenAIRequestBodyMap(c, []byte(`{"model":"gpt-5","stream":true}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "gpt-5", got["model"])
|
||||
|
||||
cached, ok := c.Get(OpenAIParsedRequestBodyKey)
|
||||
require.True(t, ok)
|
||||
cachedMap, ok := cached.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, got, cachedMap)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -55,6 +57,10 @@ func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, pl
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return r.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
type stubConcurrencyCache struct {
|
||||
ConcurrencyCache
|
||||
loadBatchErr error
|
||||
@@ -166,6 +172,54 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_GenerateSessionHash_UsesXXHash64(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
|
||||
c.Request.Header.Set("session_id", "sess-fixed-value")
|
||||
svc := &OpenAIGatewayService{}
|
||||
|
||||
got := svc.GenerateSessionHash(c, nil)
|
||||
want := fmt.Sprintf("%016x", xxhash.Sum64String("sess-fixed-value"))
|
||||
require.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
|
||||
c.Request.Header.Set("session_id", "sess-legacy-check")
|
||||
svc := &OpenAIGatewayService{}
|
||||
|
||||
sessionHash := svc.GenerateSessionHash(c, nil)
|
||||
require.NotEmpty(t, sessionHash)
|
||||
require.NotNil(t, c.Request)
|
||||
require.NotNil(t, c.Request.Context())
|
||||
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
seed := "openai_ws_ingress:9:100:200"
|
||||
|
||||
got := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), seed)
|
||||
want := fmt.Sprintf("%016x", xxhash.Sum64String(seed))
|
||||
require.Equal(t, want, got)
|
||||
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||
|
||||
empty := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), " ")
|
||||
require.Equal(t, "", empty)
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
if c.waitCounts != nil {
|
||||
if count, ok := c.waitCounts[accountID]; ok {
|
||||
|
||||
@@ -0,0 +1,357 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
var (
|
||||
benchmarkToolContinuationBoolSink bool
|
||||
benchmarkWSParseStringSink string
|
||||
benchmarkWSParseMapSink map[string]any
|
||||
benchmarkUsageSink OpenAIUsage
|
||||
)
|
||||
|
||||
func BenchmarkToolContinuationValidationLegacy(b *testing.B) {
|
||||
reqBody := benchmarkToolContinuationRequestBody()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkToolContinuationBoolSink = legacyValidateFunctionCallOutputContext(reqBody)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkToolContinuationValidationOptimized(b *testing.B) {
|
||||
reqBody := benchmarkToolContinuationRequestBody()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkToolContinuationBoolSink = optimizedValidateFunctionCallOutputContext(reqBody)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWSIngressPayloadParseLegacy(b *testing.B) {
|
||||
raw := benchmarkWSIngressPayloadBytes()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
eventType, model, promptCacheKey, previousResponseID, payload, err := legacyParseWSIngressPayload(raw)
|
||||
if err == nil {
|
||||
benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID
|
||||
benchmarkWSParseMapSink = payload
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWSIngressPayloadParseOptimized(b *testing.B) {
|
||||
raw := benchmarkWSIngressPayloadBytes()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
eventType, model, promptCacheKey, previousResponseID, payload, err := optimizedParseWSIngressPayload(raw)
|
||||
if err == nil {
|
||||
benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID
|
||||
benchmarkWSParseMapSink = payload
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOpenAIUsageExtractLegacy(b *testing.B) {
|
||||
body := benchmarkOpenAIUsageJSONBytes()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
usage, ok := legacyExtractOpenAIUsageFromJSONBytes(body)
|
||||
if ok {
|
||||
benchmarkUsageSink = usage
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOpenAIUsageExtractOptimized(b *testing.B) {
|
||||
body := benchmarkOpenAIUsageJSONBytes()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
usage, ok := extractOpenAIUsageFromJSONBytes(body)
|
||||
if ok {
|
||||
benchmarkUsageSink = usage
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkToolContinuationRequestBody() map[string]any {
|
||||
input := make([]any, 0, 64)
|
||||
for i := 0; i < 24; i++ {
|
||||
input = append(input, map[string]any{
|
||||
"type": "text",
|
||||
"text": "benchmark text",
|
||||
})
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
callID := "call_" + strconv.Itoa(i)
|
||||
input = append(input, map[string]any{
|
||||
"type": "tool_call",
|
||||
"call_id": callID,
|
||||
})
|
||||
input = append(input, map[string]any{
|
||||
"type": "function_call_output",
|
||||
"call_id": callID,
|
||||
})
|
||||
input = append(input, map[string]any{
|
||||
"type": "item_reference",
|
||||
"id": callID,
|
||||
})
|
||||
}
|
||||
return map[string]any{
|
||||
"model": "gpt-5.3-codex",
|
||||
"input": input,
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkWSIngressPayloadBytes() []byte {
|
||||
return []byte(`{"type":"response.create","model":"gpt-5.3-codex","prompt_cache_key":"cache_bench","previous_response_id":"resp_prev_bench","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`)
|
||||
}
|
||||
|
||||
func benchmarkOpenAIUsageJSONBytes() []byte {
|
||||
return []byte(`{"id":"resp_bench","object":"response","model":"gpt-5.3-codex","usage":{"input_tokens":3210,"output_tokens":987,"input_tokens_details":{"cached_tokens":456}}}`)
|
||||
}
|
||||
|
||||
func legacyValidateFunctionCallOutputContext(reqBody map[string]any) bool {
|
||||
if !legacyHasFunctionCallOutput(reqBody) {
|
||||
return true
|
||||
}
|
||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||
if strings.TrimSpace(previousResponseID) != "" {
|
||||
return true
|
||||
}
|
||||
if legacyHasToolCallContext(reqBody) {
|
||||
return true
|
||||
}
|
||||
if legacyHasFunctionCallOutputMissingCallID(reqBody) {
|
||||
return false
|
||||
}
|
||||
callIDs := legacyFunctionCallOutputCallIDs(reqBody)
|
||||
return legacyHasItemReferenceForCallIDs(reqBody, callIDs)
|
||||
}
|
||||
|
||||
func optimizedValidateFunctionCallOutputContext(reqBody map[string]any) bool {
|
||||
validation := ValidateFunctionCallOutputContext(reqBody)
|
||||
if !validation.HasFunctionCallOutput {
|
||||
return true
|
||||
}
|
||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||
if strings.TrimSpace(previousResponseID) != "" {
|
||||
return true
|
||||
}
|
||||
if validation.HasToolCallContext {
|
||||
return true
|
||||
}
|
||||
if validation.HasFunctionCallOutputMissingCallID {
|
||||
return false
|
||||
}
|
||||
return validation.HasItemReferenceForAllCallIDs
|
||||
}
|
||||
|
||||
func legacyHasFunctionCallOutput(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType == "function_call_output" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func legacyHasToolCallContext(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "tool_call" && itemType != "function_call" {
|
||||
continue
|
||||
}
|
||||
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func legacyFunctionCallOutputCallIDs(reqBody map[string]any) []string {
|
||||
if reqBody == nil {
|
||||
return nil
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
ids := make(map[string]struct{})
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "function_call_output" {
|
||||
continue
|
||||
}
|
||||
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||
ids[callID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
callIDs := make([]string, 0, len(ids))
|
||||
for id := range ids {
|
||||
callIDs = append(callIDs, id)
|
||||
}
|
||||
return callIDs
|
||||
}
|
||||
|
||||
func legacyHasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "function_call_output" {
|
||||
continue
|
||||
}
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
if strings.TrimSpace(callID) == "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func legacyHasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
|
||||
if reqBody == nil || len(callIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
referenceIDs := make(map[string]struct{})
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "item_reference" {
|
||||
continue
|
||||
}
|
||||
idValue, _ := itemMap["id"].(string)
|
||||
idValue = strings.TrimSpace(idValue)
|
||||
if idValue == "" {
|
||||
continue
|
||||
}
|
||||
referenceIDs[idValue] = struct{}{}
|
||||
}
|
||||
if len(referenceIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, callID := range callIDs {
|
||||
if _, ok := referenceIDs[callID]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func legacyParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) {
|
||||
values := gjson.GetManyBytes(raw, "type", "model", "prompt_cache_key", "previous_response_id")
|
||||
eventType = strings.TrimSpace(values[0].String())
|
||||
if eventType == "" {
|
||||
eventType = "response.create"
|
||||
}
|
||||
model = strings.TrimSpace(values[1].String())
|
||||
promptCacheKey = strings.TrimSpace(values[2].String())
|
||||
previousResponseID = strings.TrimSpace(values[3].String())
|
||||
payload = make(map[string]any)
|
||||
if err = json.Unmarshal(raw, &payload); err != nil {
|
||||
return "", "", "", "", nil, err
|
||||
}
|
||||
if _, exists := payload["type"]; !exists {
|
||||
payload["type"] = "response.create"
|
||||
}
|
||||
return eventType, model, promptCacheKey, previousResponseID, payload, nil
|
||||
}
|
||||
|
||||
func optimizedParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) {
|
||||
payload = make(map[string]any)
|
||||
if err = json.Unmarshal(raw, &payload); err != nil {
|
||||
return "", "", "", "", nil, err
|
||||
}
|
||||
eventType = openAIWSPayloadString(payload, "type")
|
||||
if eventType == "" {
|
||||
eventType = "response.create"
|
||||
payload["type"] = eventType
|
||||
}
|
||||
model = openAIWSPayloadString(payload, "model")
|
||||
promptCacheKey = openAIWSPayloadString(payload, "prompt_cache_key")
|
||||
previousResponseID = openAIWSPayloadString(payload, "previous_response_id")
|
||||
return eventType, model, promptCacheKey, previousResponseID, payload, nil
|
||||
}
|
||||
|
||||
func legacyExtractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
|
||||
var response struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
return OpenAIUsage{}, false
|
||||
}
|
||||
return OpenAIUsage{
|
||||
InputTokens: response.Usage.InputTokens,
|
||||
OutputTokens: response.Usage.OutputTokens,
|
||||
CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens,
|
||||
}, true
|
||||
}
|
||||
@@ -515,7 +515,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
|
||||
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
|
||||
require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent"))
|
||||
require.Equal(t, "codex_cli_rs/0.104.0", upstream.lastReq.Header.Get("User-Agent"))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) {
|
||||
|
||||
@@ -5,17 +5,28 @@ import (
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
|
||||
|
||||
var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`)
|
||||
|
||||
type soraSessionChunk struct {
|
||||
index int
|
||||
value string
|
||||
}
|
||||
|
||||
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
||||
type OpenAIOAuthService struct {
|
||||
sessionStore *openai.SessionStore
|
||||
@@ -39,7 +50,7 @@ type OpenAIAuthURLResult struct {
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates an OpenAI OAuth authorization URL
|
||||
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*OpenAIAuthURLResult, error) {
|
||||
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, platform string) (*OpenAIAuthURLResult, error) {
|
||||
// Generate PKCE values
|
||||
state, err := openai.GenerateState()
|
||||
if err != nil {
|
||||
@@ -75,11 +86,14 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
if redirectURI == "" {
|
||||
redirectURI = openai.DefaultRedirectURI
|
||||
}
|
||||
normalizedPlatform := normalizeOpenAIOAuthPlatform(platform)
|
||||
clientID, _ := openai.OAuthClientConfigByPlatform(normalizedPlatform)
|
||||
|
||||
// Store session
|
||||
session := &openai.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
@@ -87,7 +101,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
// Build authorization URL
|
||||
authURL := openai.BuildAuthorizationURL(state, codeChallenge, redirectURI)
|
||||
authURL := openai.BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, normalizedPlatform)
|
||||
|
||||
return &OpenAIAuthURLResult{
|
||||
AuthURL: authURL,
|
||||
@@ -111,6 +125,7 @@ type OpenAITokenInfo struct {
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
|
||||
@@ -148,9 +163,13 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
if input.RedirectURI != "" {
|
||||
redirectURI = input.RedirectURI
|
||||
}
|
||||
clientID := strings.TrimSpace(session.ClientID)
|
||||
if clientID == "" {
|
||||
clientID = openai.ClientID
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -158,8 +177,10 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
// Parse ID token to get user info
|
||||
var userInfo *openai.UserInfo
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if parseErr != nil {
|
||||
slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
|
||||
} else {
|
||||
userInfo = claims.GetUserInfo()
|
||||
}
|
||||
}
|
||||
@@ -173,6 +194,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
IDToken: tokenResp.IDToken,
|
||||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
ClientID: clientID,
|
||||
}
|
||||
|
||||
if userInfo != nil {
|
||||
@@ -200,8 +222,10 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
|
||||
// Parse ID token to get user info
|
||||
var userInfo *openai.UserInfo
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if parseErr != nil {
|
||||
slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
|
||||
} else {
|
||||
userInfo = claims.GetUserInfo()
|
||||
}
|
||||
}
|
||||
@@ -213,6 +237,9 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
|
||||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
}
|
||||
if trimmed := strings.TrimSpace(clientID); trimmed != "" {
|
||||
tokenInfo.ClientID = trimmed
|
||||
}
|
||||
|
||||
if userInfo != nil {
|
||||
tokenInfo.Email = userInfo.Email
|
||||
@@ -226,6 +253,7 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
|
||||
|
||||
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
|
||||
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
|
||||
sessionToken = normalizeSoraSessionTokenInput(sessionToken)
|
||||
if strings.TrimSpace(sessionToken) == "" {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
|
||||
}
|
||||
@@ -245,7 +273,13 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi
|
||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
|
||||
client := newOpenAIOAuthHTTPClient(proxyURL)
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 120 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
|
||||
@@ -287,10 +321,141 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi
|
||||
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
|
||||
ExpiresIn: expiresIn,
|
||||
ExpiresAt: expiresAt,
|
||||
ClientID: openai.SoraClientID,
|
||||
Email: strings.TrimSpace(sessionResp.User.Email),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeSoraSessionTokenInput(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1)
|
||||
if len(matches) == 0 {
|
||||
return sanitizeSessionToken(trimmed)
|
||||
}
|
||||
|
||||
chunkMatches := make([]soraSessionChunk, 0, len(matches))
|
||||
singleValues := make([]string, 0, len(matches))
|
||||
|
||||
for _, match := range matches {
|
||||
if len(match) < 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
value := sanitizeSessionToken(match[2])
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.TrimSpace(match[1]) == "" {
|
||||
singleValues = append(singleValues, value)
|
||||
continue
|
||||
}
|
||||
|
||||
idx, err := strconv.Atoi(strings.TrimSpace(match[1]))
|
||||
if err != nil || idx < 0 {
|
||||
continue
|
||||
}
|
||||
chunkMatches = append(chunkMatches, soraSessionChunk{
|
||||
index: idx,
|
||||
value: value,
|
||||
})
|
||||
}
|
||||
|
||||
if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" {
|
||||
return merged
|
||||
}
|
||||
|
||||
if len(singleValues) > 0 {
|
||||
return singleValues[len(singleValues)-1]
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string {
|
||||
if len(chunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
byIndex := make(map[int]string, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
byIndex[chunk.index] = chunk.value
|
||||
}
|
||||
|
||||
if _, ok := byIndex[0]; !ok {
|
||||
return ""
|
||||
}
|
||||
if requireComplete {
|
||||
for idx := 0; idx <= requiredMaxIndex; idx++ {
|
||||
if _, ok := byIndex[idx]; !ok {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
orderedIndexes := make([]int, 0, len(byIndex))
|
||||
for idx := range byIndex {
|
||||
orderedIndexes = append(orderedIndexes, idx)
|
||||
}
|
||||
sort.Ints(orderedIndexes)
|
||||
|
||||
var builder strings.Builder
|
||||
for _, idx := range orderedIndexes {
|
||||
if _, err := builder.WriteString(byIndex[idx]); err != nil {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return sanitizeSessionToken(builder.String())
|
||||
}
|
||||
|
||||
func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string {
|
||||
if len(chunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
requiredMaxIndex := 0
|
||||
for _, chunk := range chunks {
|
||||
if chunk.index > requiredMaxIndex {
|
||||
requiredMaxIndex = chunk.index
|
||||
}
|
||||
}
|
||||
|
||||
groupStarts := make([]int, 0, len(chunks))
|
||||
for idx, chunk := range chunks {
|
||||
if chunk.index == 0 {
|
||||
groupStarts = append(groupStarts, idx)
|
||||
}
|
||||
}
|
||||
|
||||
if len(groupStarts) == 0 {
|
||||
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
|
||||
}
|
||||
|
||||
for i := len(groupStarts) - 1; i >= 0; i-- {
|
||||
start := groupStarts[i]
|
||||
end := len(chunks)
|
||||
if i+1 < len(groupStarts) {
|
||||
end = groupStarts[i+1]
|
||||
}
|
||||
if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" {
|
||||
return merged
|
||||
}
|
||||
}
|
||||
|
||||
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
|
||||
}
|
||||
|
||||
func sanitizeSessionToken(raw string) string {
|
||||
token := strings.TrimSpace(raw)
|
||||
token = strings.Trim(token, "\"'`")
|
||||
token = strings.TrimSuffix(token, ";")
|
||||
return strings.TrimSpace(token)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||
if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
|
||||
@@ -322,9 +487,12 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
|
||||
expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339)
|
||||
|
||||
creds := map[string]any{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"refresh_token": tokenInfo.RefreshToken,
|
||||
"expires_at": expiresAt,
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"expires_at": expiresAt,
|
||||
}
|
||||
// 仅在刷新响应返回了新的 refresh_token 时才更新,防止用空值覆盖已有令牌
|
||||
if strings.TrimSpace(tokenInfo.RefreshToken) != "" {
|
||||
creds["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
|
||||
if tokenInfo.IDToken != "" {
|
||||
@@ -342,6 +510,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
|
||||
if tokenInfo.OrganizationID != "" {
|
||||
creds["organization_id"] = tokenInfo.OrganizationID
|
||||
}
|
||||
if strings.TrimSpace(tokenInfo.ClientID) != "" {
|
||||
creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID)
|
||||
}
|
||||
|
||||
return creds
|
||||
}
|
||||
@@ -365,15 +536,11 @@ func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64
|
||||
return proxy.URL(), nil
|
||||
}
|
||||
|
||||
func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
|
||||
transport := &http.Transport{}
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" {
|
||||
transport.Proxy = http.ProxyURL(parsed)
|
||||
}
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
Transport: transport,
|
||||
func normalizeOpenAIOAuthPlatform(platform string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(platform)) {
|
||||
case PlatformSora:
|
||||
return openai.OAuthPlatformSora
|
||||
default:
|
||||
return openai.OAuthPlatformOpenAI
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openaiOAuthClientAuthURLStub struct{}
|
||||
|
||||
func (s *openaiOAuthClientAuthURLStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientAuthURLStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientAuthURLStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) {
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformOpenAI)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, result.AuthURL)
|
||||
require.NotEmpty(t, result.SessionID)
|
||||
|
||||
parsed, err := url.Parse(result.AuthURL)
|
||||
require.NoError(t, err)
|
||||
q := parsed.Query()
|
||||
require.Equal(t, openai.ClientID, q.Get("client_id"))
|
||||
require.Equal(t, "true", q.Get("codex_cli_simplified_flow"))
|
||||
|
||||
session, ok := svc.sessionStore.Get(result.SessionID)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, openai.ClientID, session.ClientID)
|
||||
}
|
||||
|
||||
// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的
|
||||
// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。
|
||||
func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) {
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, result.AuthURL)
|
||||
require.NotEmpty(t, result.SessionID)
|
||||
|
||||
parsed, err := url.Parse(result.AuthURL)
|
||||
require.NoError(t, err)
|
||||
q := parsed.Query()
|
||||
require.Equal(t, openai.ClientID, q.Get("client_id"))
|
||||
require.Empty(t, q.Get("codex_cli_simplified_flow"))
|
||||
|
||||
session, ok := svc.sessionStore.Get(result.SessionID)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, openai.ClientID, session.ClientID)
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
@@ -13,7 +14,7 @@ import (
|
||||
|
||||
type openaiOAuthClientNoopStub struct{}
|
||||
|
||||
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -67,3 +68,106 @@ func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testi
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing access token")
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax"
|
||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "at-token", info.AccessToken)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
raw := strings.Join([]string{
|
||||
"Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly",
|
||||
"Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly",
|
||||
}, "\n")
|
||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "at-token", info.AccessToken)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
raw := strings.Join([]string{
|
||||
"Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly",
|
||||
"Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly",
|
||||
"Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly",
|
||||
"Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly",
|
||||
}, "\n")
|
||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "at-token", info.AccessToken)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
raw := strings.Join([]string{
|
||||
"set-cookie",
|
||||
"__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/",
|
||||
"set-cookie",
|
||||
"__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/",
|
||||
"set-cookie",
|
||||
"__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/",
|
||||
}, "\n")
|
||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "at-token", info.AccessToken)
|
||||
}
|
||||
|
||||
@@ -13,10 +13,12 @@ import (
|
||||
|
||||
type openaiOAuthClientStateStub struct {
|
||||
exchangeCalled int32
|
||||
lastClientID string
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||
atomic.AddInt32(&s.exchangeCalled, 1)
|
||||
s.lastClientID = clientID
|
||||
return &openai.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rt",
|
||||
@@ -95,6 +97,8 @@ func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, "at", info.AccessToken)
|
||||
require.Equal(t, openai.ClientID, info.ClientID)
|
||||
require.Equal(t, openai.ClientID, client.lastClientID)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled))
|
||||
|
||||
_, ok := svc.sessionStore.Get("sid")
|
||||
|
||||
37
backend/internal/service/openai_previous_response_id.go
Normal file
37
backend/internal/service/openai_previous_response_id.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
OpenAIPreviousResponseIDKindEmpty = "empty"
|
||||
OpenAIPreviousResponseIDKindResponseID = "response_id"
|
||||
OpenAIPreviousResponseIDKindMessageID = "message_id"
|
||||
OpenAIPreviousResponseIDKindUnknown = "unknown"
|
||||
)
|
||||
|
||||
var (
|
||||
openAIResponseIDPattern = regexp.MustCompile(`^resp_[A-Za-z0-9_-]{1,256}$`)
|
||||
openAIMessageIDPattern = regexp.MustCompile(`^(msg|message|item|chatcmpl)_[A-Za-z0-9_-]{1,256}$`)
|
||||
)
|
||||
|
||||
// ClassifyOpenAIPreviousResponseIDKind classifies previous_response_id to improve diagnostics.
|
||||
func ClassifyOpenAIPreviousResponseIDKind(id string) string {
|
||||
trimmed := strings.TrimSpace(id)
|
||||
if trimmed == "" {
|
||||
return OpenAIPreviousResponseIDKindEmpty
|
||||
}
|
||||
if openAIResponseIDPattern.MatchString(trimmed) {
|
||||
return OpenAIPreviousResponseIDKindResponseID
|
||||
}
|
||||
if openAIMessageIDPattern.MatchString(strings.ToLower(trimmed)) {
|
||||
return OpenAIPreviousResponseIDKindMessageID
|
||||
}
|
||||
return OpenAIPreviousResponseIDKindUnknown
|
||||
}
|
||||
|
||||
func IsOpenAIPreviousResponseIDLikelyMessageID(id string) bool {
|
||||
return ClassifyOpenAIPreviousResponseIDKind(id) == OpenAIPreviousResponseIDKindMessageID
|
||||
}
|
||||
34
backend/internal/service/openai_previous_response_id_test.go
Normal file
34
backend/internal/service/openai_previous_response_id_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestClassifyOpenAIPreviousResponseIDKind(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id string
|
||||
want string
|
||||
}{
|
||||
{name: "empty", id: " ", want: OpenAIPreviousResponseIDKindEmpty},
|
||||
{name: "response_id", id: "resp_0906a621bc423a8d0169a108637ef88197b74b0e2f37ba358f", want: OpenAIPreviousResponseIDKindResponseID},
|
||||
{name: "message_id", id: "msg_123456", want: OpenAIPreviousResponseIDKindMessageID},
|
||||
{name: "item_id", id: "item_abcdef", want: OpenAIPreviousResponseIDKindMessageID},
|
||||
{name: "unknown", id: "foo_123456", want: OpenAIPreviousResponseIDKindUnknown},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := ClassifyOpenAIPreviousResponseIDKind(tc.id); got != tc.want {
|
||||
t.Fatalf("ClassifyOpenAIPreviousResponseIDKind(%q)=%q want=%q", tc.id, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOpenAIPreviousResponseIDLikelyMessageID(t *testing.T) {
|
||||
if !IsOpenAIPreviousResponseIDLikelyMessageID("msg_123") {
|
||||
t.Fatal("expected msg_123 to be identified as message id")
|
||||
}
|
||||
if IsOpenAIPreviousResponseIDLikelyMessageID("resp_123") {
|
||||
t.Fatal("expected resp_123 not to be identified as message id")
|
||||
}
|
||||
}
|
||||
214
backend/internal/service/openai_sticky_compat.go
Normal file
214
backend/internal/service/openai_sticky_compat.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type openAILegacySessionHashContextKey struct{}
|
||||
|
||||
var openAILegacySessionHashKey = openAILegacySessionHashContextKey{}
|
||||
|
||||
var (
|
||||
openAIStickyLegacyReadFallbackTotal atomic.Int64
|
||||
openAIStickyLegacyReadFallbackHit atomic.Int64
|
||||
openAIStickyLegacyDualWriteTotal atomic.Int64
|
||||
)
|
||||
|
||||
func openAIStickyCompatStats() (legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal int64) {
|
||||
return openAIStickyLegacyReadFallbackTotal.Load(),
|
||||
openAIStickyLegacyReadFallbackHit.Load(),
|
||||
openAIStickyLegacyDualWriteTotal.Load()
|
||||
}
|
||||
|
||||
func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash string) {
|
||||
normalized := strings.TrimSpace(sessionID)
|
||||
if normalized == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
currentHash = fmt.Sprintf("%016x", xxhash.Sum64String(normalized))
|
||||
sum := sha256.Sum256([]byte(normalized))
|
||||
legacyHash = hex.EncodeToString(sum[:])
|
||||
return currentHash, legacyHash
|
||||
}
|
||||
|
||||
func withOpenAILegacySessionHash(ctx context.Context, legacyHash string) context.Context {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
trimmed := strings.TrimSpace(legacyHash)
|
||||
if trimmed == "" {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, openAILegacySessionHashKey, trimmed)
|
||||
}
|
||||
|
||||
func openAILegacySessionHashFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
value, _ := ctx.Value(openAILegacySessionHashKey).(string)
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
func attachOpenAILegacySessionHashToGin(c *gin.Context, legacyHash string) {
|
||||
if c == nil || c.Request == nil {
|
||||
return
|
||||
}
|
||||
c.Request = c.Request.WithContext(withOpenAILegacySessionHash(c.Request.Context(), legacyHash))
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAISessionHashReadOldFallbackEnabled() bool {
|
||||
if s == nil || s.cfg == nil {
|
||||
return true
|
||||
}
|
||||
return s.cfg.Gateway.OpenAIWS.SessionHashReadOldFallback
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAISessionHashDualWriteOldEnabled() bool {
|
||||
if s == nil || s.cfg == nil {
|
||||
return true
|
||||
}
|
||||
return s.cfg.Gateway.OpenAIWS.SessionHashDualWriteOld
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAISessionCacheKey(sessionHash string) string {
|
||||
normalized := strings.TrimSpace(sessionHash)
|
||||
if normalized == "" {
|
||||
return ""
|
||||
}
|
||||
return "openai:" + normalized
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAILegacySessionCacheKey(ctx context.Context, sessionHash string) string {
|
||||
legacyHash := openAILegacySessionHashFromContext(ctx)
|
||||
if legacyHash == "" {
|
||||
return ""
|
||||
}
|
||||
legacyKey := "openai:" + legacyHash
|
||||
if legacyKey == s.openAISessionCacheKey(sessionHash) {
|
||||
return ""
|
||||
}
|
||||
return legacyKey
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIStickyLegacyTTL(ttl time.Duration) time.Duration {
|
||||
legacyTTL := ttl
|
||||
if legacyTTL <= 0 {
|
||||
legacyTTL = openaiStickySessionTTL
|
||||
}
|
||||
if legacyTTL > 10*time.Minute {
|
||||
return 10 * time.Minute
|
||||
}
|
||||
return legacyTTL
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) {
|
||||
if s == nil || s.cache == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
primaryKey := s.openAISessionCacheKey(sessionHash)
|
||||
if primaryKey == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), primaryKey)
|
||||
if err == nil && accountID > 0 {
|
||||
return accountID, nil
|
||||
}
|
||||
if !s.openAISessionHashReadOldFallbackEnabled() {
|
||||
return accountID, err
|
||||
}
|
||||
|
||||
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
|
||||
if legacyKey == "" {
|
||||
return accountID, err
|
||||
}
|
||||
|
||||
openAIStickyLegacyReadFallbackTotal.Add(1)
|
||||
legacyAccountID, legacyErr := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), legacyKey)
|
||||
if legacyErr == nil && legacyAccountID > 0 {
|
||||
openAIStickyLegacyReadFallbackHit.Add(1)
|
||||
return legacyAccountID, nil
|
||||
}
|
||||
return accountID, err
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) setStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
if s == nil || s.cache == nil || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
primaryKey := s.openAISessionCacheKey(sessionHash)
|
||||
if primaryKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), primaryKey, accountID, ttl); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !s.openAISessionHashDualWriteOldEnabled() {
|
||||
return nil
|
||||
}
|
||||
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
|
||||
if legacyKey == "" {
|
||||
return nil
|
||||
}
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), legacyKey, accountID, s.openAIStickyLegacyTTL(ttl)); err != nil {
|
||||
return err
|
||||
}
|
||||
openAIStickyLegacyDualWriteTotal.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) refreshStickySessionTTL(ctx context.Context, groupID *int64, sessionHash string, ttl time.Duration) error {
|
||||
if s == nil || s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
primaryKey := s.openAISessionCacheKey(sessionHash)
|
||||
if primaryKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), primaryKey, ttl)
|
||||
if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() {
|
||||
return err
|
||||
}
|
||||
|
||||
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
|
||||
if legacyKey != "" {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), legacyKey, s.openAIStickyLegacyTTL(ttl))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) deleteStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) error {
|
||||
if s == nil || s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
primaryKey := s.openAISessionCacheKey(sessionHash)
|
||||
if primaryKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), primaryKey)
|
||||
if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() {
|
||||
return err
|
||||
}
|
||||
|
||||
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
|
||||
if legacyKey != "" {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), legacyKey)
|
||||
}
|
||||
return err
|
||||
}
|
||||
96
backend/internal/service/openai_sticky_compat_test.go
Normal file
96
backend/internal/service/openai_sticky_compat_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetStickySessionAccountID_FallbackToLegacyKey(t *testing.T) {
|
||||
beforeFallbackTotal, beforeFallbackHit, _ := openAIStickyCompatStats()
|
||||
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:legacy-hash": 42,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{
|
||||
cache: cache,
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
OpenAIWS: config.GatewayOpenAIWSConfig{
|
||||
SessionHashReadOldFallback: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash")
|
||||
accountID, err := svc.getStickySessionAccountID(ctx, nil, "new-hash")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(42), accountID)
|
||||
|
||||
afterFallbackTotal, afterFallbackHit, _ := openAIStickyCompatStats()
|
||||
require.Equal(t, beforeFallbackTotal+1, afterFallbackTotal)
|
||||
require.Equal(t, beforeFallbackHit+1, afterFallbackHit)
|
||||
}
|
||||
|
||||
func TestSetStickySessionAccountID_DualWriteOldEnabled(t *testing.T) {
|
||||
_, _, beforeDualWriteTotal := openAIStickyCompatStats()
|
||||
|
||||
cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
|
||||
svc := &OpenAIGatewayService{
|
||||
cache: cache,
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
OpenAIWS: config.GatewayOpenAIWSConfig{
|
||||
SessionHashDualWriteOld: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash")
|
||||
err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"])
|
||||
require.Equal(t, int64(9), cache.sessionBindings["openai:legacy-hash"])
|
||||
|
||||
_, _, afterDualWriteTotal := openAIStickyCompatStats()
|
||||
require.Equal(t, beforeDualWriteTotal+1, afterDualWriteTotal)
|
||||
}
|
||||
|
||||
func TestSetStickySessionAccountID_DualWriteOldDisabled(t *testing.T) {
|
||||
cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
|
||||
svc := &OpenAIGatewayService{
|
||||
cache: cache,
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
OpenAIWS: config.GatewayOpenAIWSConfig{
|
||||
SessionHashDualWriteOld: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash")
|
||||
err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"])
|
||||
_, exists := cache.sessionBindings["openai:legacy-hash"]
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func TestSnapshotOpenAICompatibilityFallbackMetrics(t *testing.T) {
|
||||
before := SnapshotOpenAICompatibilityFallbackMetrics()
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true)
|
||||
_, _ = ThinkingEnabledFromContext(ctx)
|
||||
|
||||
after := SnapshotOpenAICompatibilityFallbackMetrics()
|
||||
require.GreaterOrEqual(t, after.MetadataLegacyFallbackTotal, before.MetadataLegacyFallbackTotal+1)
|
||||
require.GreaterOrEqual(t, after.MetadataLegacyFallbackThinkingEnabledTotal, before.MetadataLegacyFallbackThinkingEnabledTotal+1)
|
||||
}
|
||||
@@ -2,6 +2,24 @@ package service
|
||||
|
||||
import "strings"
|
||||
|
||||
// ToolContinuationSignals 聚合工具续链相关信号,避免重复遍历 input。
|
||||
type ToolContinuationSignals struct {
|
||||
HasFunctionCallOutput bool
|
||||
HasFunctionCallOutputMissingCallID bool
|
||||
HasToolCallContext bool
|
||||
HasItemReference bool
|
||||
HasItemReferenceForAllCallIDs bool
|
||||
FunctionCallOutputCallIDs []string
|
||||
}
|
||||
|
||||
// FunctionCallOutputValidation 汇总 function_call_output 关联性校验结果。
|
||||
type FunctionCallOutputValidation struct {
|
||||
HasFunctionCallOutput bool
|
||||
HasToolCallContext bool
|
||||
HasFunctionCallOutputMissingCallID bool
|
||||
HasItemReferenceForAllCallIDs bool
|
||||
}
|
||||
|
||||
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
|
||||
// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、
|
||||
// 或显式声明 tools/tool_choice。
|
||||
@@ -18,107 +36,191 @@ func NeedsToolContinuation(reqBody map[string]any) bool {
|
||||
if hasToolChoiceSignal(reqBody) {
|
||||
return true
|
||||
}
|
||||
if inputHasType(reqBody, "function_call_output") {
|
||||
return true
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if inputHasType(reqBody, "item_reference") {
|
||||
return true
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType == "function_call_output" || itemType == "item_reference" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AnalyzeToolContinuationSignals 单次遍历 input,提取 function_call_output/tool_call/item_reference 相关信号。
|
||||
func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSignals {
|
||||
signals := ToolContinuationSignals{}
|
||||
if reqBody == nil {
|
||||
return signals
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return signals
|
||||
}
|
||||
|
||||
var callIDs map[string]struct{}
|
||||
var referenceIDs map[string]struct{}
|
||||
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
switch itemType {
|
||||
case "tool_call", "function_call":
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
if strings.TrimSpace(callID) != "" {
|
||||
signals.HasToolCallContext = true
|
||||
}
|
||||
case "function_call_output":
|
||||
signals.HasFunctionCallOutput = true
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
callID = strings.TrimSpace(callID)
|
||||
if callID == "" {
|
||||
signals.HasFunctionCallOutputMissingCallID = true
|
||||
continue
|
||||
}
|
||||
if callIDs == nil {
|
||||
callIDs = make(map[string]struct{})
|
||||
}
|
||||
callIDs[callID] = struct{}{}
|
||||
case "item_reference":
|
||||
signals.HasItemReference = true
|
||||
idValue, _ := itemMap["id"].(string)
|
||||
idValue = strings.TrimSpace(idValue)
|
||||
if idValue == "" {
|
||||
continue
|
||||
}
|
||||
if referenceIDs == nil {
|
||||
referenceIDs = make(map[string]struct{})
|
||||
}
|
||||
referenceIDs[idValue] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if len(callIDs) == 0 {
|
||||
return signals
|
||||
}
|
||||
signals.FunctionCallOutputCallIDs = make([]string, 0, len(callIDs))
|
||||
allReferenced := len(referenceIDs) > 0
|
||||
for callID := range callIDs {
|
||||
signals.FunctionCallOutputCallIDs = append(signals.FunctionCallOutputCallIDs, callID)
|
||||
if allReferenced {
|
||||
if _, ok := referenceIDs[callID]; !ok {
|
||||
allReferenced = false
|
||||
}
|
||||
}
|
||||
}
|
||||
signals.HasItemReferenceForAllCallIDs = allReferenced
|
||||
return signals
|
||||
}
|
||||
|
||||
// ValidateFunctionCallOutputContext 为 handler 提供低开销校验结果:
|
||||
// 1) 无 function_call_output 直接返回
|
||||
// 2) 若已存在 tool_call/function_call 上下文则提前返回
|
||||
// 3) 仅在无工具上下文时才构建 call_id / item_reference 集合
|
||||
func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutputValidation {
|
||||
result := FunctionCallOutputValidation{}
|
||||
if reqBody == nil {
|
||||
return result
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return result
|
||||
}
|
||||
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
switch itemType {
|
||||
case "function_call_output":
|
||||
result.HasFunctionCallOutput = true
|
||||
case "tool_call", "function_call":
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
if strings.TrimSpace(callID) != "" {
|
||||
result.HasToolCallContext = true
|
||||
}
|
||||
}
|
||||
if result.HasFunctionCallOutput && result.HasToolCallContext {
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
if !result.HasFunctionCallOutput || result.HasToolCallContext {
|
||||
return result
|
||||
}
|
||||
|
||||
callIDs := make(map[string]struct{})
|
||||
referenceIDs := make(map[string]struct{})
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
switch itemType {
|
||||
case "function_call_output":
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
callID = strings.TrimSpace(callID)
|
||||
if callID == "" {
|
||||
result.HasFunctionCallOutputMissingCallID = true
|
||||
continue
|
||||
}
|
||||
callIDs[callID] = struct{}{}
|
||||
case "item_reference":
|
||||
idValue, _ := itemMap["id"].(string)
|
||||
idValue = strings.TrimSpace(idValue)
|
||||
if idValue == "" {
|
||||
continue
|
||||
}
|
||||
referenceIDs[idValue] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if len(callIDs) == 0 || len(referenceIDs) == 0 {
|
||||
return result
|
||||
}
|
||||
allReferenced := true
|
||||
for callID := range callIDs {
|
||||
if _, ok := referenceIDs[callID]; !ok {
|
||||
allReferenced = false
|
||||
break
|
||||
}
|
||||
}
|
||||
result.HasItemReferenceForAllCallIDs = allReferenced
|
||||
return result
|
||||
}
|
||||
|
||||
// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。
|
||||
func HasFunctionCallOutput(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
return inputHasType(reqBody, "function_call_output")
|
||||
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutput
|
||||
}
|
||||
|
||||
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call,
|
||||
// 用于判断 function_call_output 是否具备可关联的上下文。
|
||||
func HasToolCallContext(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "tool_call" && itemType != "function_call" {
|
||||
continue
|
||||
}
|
||||
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return AnalyzeToolContinuationSignals(reqBody).HasToolCallContext
|
||||
}
|
||||
|
||||
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
|
||||
// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。
|
||||
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
|
||||
if reqBody == nil {
|
||||
return nil
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
ids := make(map[string]struct{})
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "function_call_output" {
|
||||
continue
|
||||
}
|
||||
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||
ids[callID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(ids))
|
||||
for id := range ids {
|
||||
result = append(result, id)
|
||||
}
|
||||
return result
|
||||
return AnalyzeToolContinuationSignals(reqBody).FunctionCallOutputCallIDs
|
||||
}
|
||||
|
||||
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
|
||||
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "function_call_output" {
|
||||
continue
|
||||
}
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
if strings.TrimSpace(callID) == "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutputMissingCallID
|
||||
}
|
||||
|
||||
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
|
||||
@@ -152,32 +254,13 @@ func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
|
||||
return false
|
||||
}
|
||||
for _, callID := range callIDs {
|
||||
if _, ok := referenceIDs[callID]; !ok {
|
||||
if _, ok := referenceIDs[strings.TrimSpace(callID)]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// inputHasType 判断 input 中是否存在指定类型的 item。
|
||||
func inputHasType(reqBody map[string]any, want string) bool {
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasNonEmptyString 判断字段是否为非空字符串。
|
||||
func hasNonEmptyString(value any) bool {
|
||||
stringValue, ok := value.(string)
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
|
||||
@@ -62,169 +66,201 @@ func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, boo
|
||||
if data == "" || data == "\n" {
|
||||
return data, false
|
||||
}
|
||||
correctedBytes, corrected := c.CorrectToolCallsInSSEBytes([]byte(data))
|
||||
if !corrected {
|
||||
return data, false
|
||||
}
|
||||
return string(correctedBytes), true
|
||||
}
|
||||
|
||||
// 尝试解析 JSON
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||
// 不是有效的 JSON,直接返回原数据
|
||||
// CorrectToolCallsInSSEBytes 修正 SSE JSON 数据中的工具调用(字节路径)。
|
||||
// 返回修正后的数据和是否进行了修正。
|
||||
func (c *CodexToolCorrector) CorrectToolCallsInSSEBytes(data []byte) ([]byte, bool) {
|
||||
if len(bytes.TrimSpace(data)) == 0 {
|
||||
return data, false
|
||||
}
|
||||
if !mayContainToolCallPayload(data) {
|
||||
return data, false
|
||||
}
|
||||
if !gjson.ValidBytes(data) {
|
||||
// 不是有效 JSON,直接返回原数据
|
||||
return data, false
|
||||
}
|
||||
|
||||
updated := data
|
||||
corrected := false
|
||||
|
||||
// 处理 tool_calls 数组
|
||||
if toolCalls, ok := payload["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
collect := func(changed bool, next []byte) {
|
||||
if changed {
|
||||
corrected = true
|
||||
updated = next
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 function_call 对象
|
||||
if functionCall, ok := payload["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
if next, changed := c.correctToolCallsArrayAtPath(updated, "tool_calls"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
if next, changed := c.correctFunctionAtPath(updated, "function_call"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
if next, changed := c.correctToolCallsArrayAtPath(updated, "delta.tool_calls"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
if next, changed := c.correctFunctionAtPath(updated, "delta.function_call"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
|
||||
// 处理 delta.tool_calls
|
||||
if delta, ok := payload["delta"].(map[string]any); ok {
|
||||
if toolCalls, ok := delta["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
choicesCount := int(gjson.GetBytes(updated, "choices.#").Int())
|
||||
for i := 0; i < choicesCount; i++ {
|
||||
prefix := "choices." + strconv.Itoa(i)
|
||||
if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".message.tool_calls"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
if functionCall, ok := delta["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
if next, changed := c.correctFunctionAtPath(updated, prefix+".message.function_call"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
|
||||
if choices, ok := payload["choices"].([]any); ok {
|
||||
for _, choice := range choices {
|
||||
if choiceMap, ok := choice.(map[string]any); ok {
|
||||
// 处理 message 中的工具调用
|
||||
if message, ok := choiceMap["message"].(map[string]any); ok {
|
||||
if toolCalls, ok := message["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
if functionCall, ok := message["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// 处理 delta 中的工具调用
|
||||
if delta, ok := choiceMap["delta"].(map[string]any); ok {
|
||||
if toolCalls, ok := delta["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
if functionCall, ok := delta["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".delta.tool_calls"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
if next, changed := c.correctFunctionAtPath(updated, prefix+".delta.function_call"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
}
|
||||
|
||||
if !corrected {
|
||||
return data, false
|
||||
}
|
||||
return updated, true
|
||||
}
|
||||
|
||||
// 序列化回 JSON
|
||||
correctedBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Failed to marshal corrected data: %v", err)
|
||||
func mayContainToolCallPayload(data []byte) bool {
|
||||
// 快速路径:多数 token / 文本事件不包含工具字段,避免进入 JSON 解析热路径。
|
||||
return bytes.Contains(data, []byte(`"tool_calls"`)) ||
|
||||
bytes.Contains(data, []byte(`"function_call"`)) ||
|
||||
bytes.Contains(data, []byte(`"function":{"name"`))
|
||||
}
|
||||
|
||||
// correctToolCallsArrayAtPath 修正指定路径下 tool_calls 数组中的工具名称。
|
||||
func (c *CodexToolCorrector) correctToolCallsArrayAtPath(data []byte, toolCallsPath string) ([]byte, bool) {
|
||||
count := int(gjson.GetBytes(data, toolCallsPath+".#").Int())
|
||||
if count <= 0 {
|
||||
return data, false
|
||||
}
|
||||
|
||||
return string(correctedBytes), true
|
||||
}
|
||||
|
||||
// correctToolCallsArray 修正工具调用数组中的工具名称
|
||||
func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool {
|
||||
updated := data
|
||||
corrected := false
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCallMap, ok := toolCall.(map[string]any); ok {
|
||||
if function, ok := toolCallMap["function"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(function) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
for i := 0; i < count; i++ {
|
||||
functionPath := toolCallsPath + "." + strconv.Itoa(i) + ".function"
|
||||
if next, changed := c.correctFunctionAtPath(updated, functionPath); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
return corrected
|
||||
return updated, corrected
|
||||
}
|
||||
|
||||
// correctFunctionCall 修正单个函数调用的工具名称和参数
|
||||
func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool {
|
||||
name, ok := functionCall["name"].(string)
|
||||
if !ok || name == "" {
|
||||
return false
|
||||
// correctFunctionAtPath 修正指定路径下单个函数调用的工具名称和参数。
|
||||
func (c *CodexToolCorrector) correctFunctionAtPath(data []byte, functionPath string) ([]byte, bool) {
|
||||
namePath := functionPath + ".name"
|
||||
nameResult := gjson.GetBytes(data, namePath)
|
||||
if !nameResult.Exists() || nameResult.Type != gjson.String {
|
||||
return data, false
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(nameResult.Str)
|
||||
if name == "" {
|
||||
return data, false
|
||||
}
|
||||
updated := data
|
||||
corrected := false
|
||||
|
||||
// 查找并修正工具名称
|
||||
if correctName, found := codexToolNameMapping[name]; found {
|
||||
functionCall["name"] = correctName
|
||||
c.recordCorrection(name, correctName)
|
||||
corrected = true
|
||||
name = correctName // 使用修正后的名称进行参数修正
|
||||
if next, err := sjson.SetBytes(updated, namePath, correctName); err == nil {
|
||||
updated = next
|
||||
c.recordCorrection(name, correctName)
|
||||
corrected = true
|
||||
name = correctName // 使用修正后的名称进行参数修正
|
||||
}
|
||||
}
|
||||
|
||||
// 修正工具参数(基于工具名称)
|
||||
if c.correctToolParameters(name, functionCall) {
|
||||
if next, changed := c.correctToolParametersAtPath(updated, functionPath+".arguments", name); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
}
|
||||
|
||||
return corrected
|
||||
return updated, corrected
|
||||
}
|
||||
|
||||
// correctToolParameters 修正工具参数以符合 OpenCode 规范
|
||||
func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool {
|
||||
arguments, ok := functionCall["arguments"]
|
||||
if !ok {
|
||||
return false
|
||||
// correctToolParametersAtPath 修正指定路径下 arguments 参数。
|
||||
func (c *CodexToolCorrector) correctToolParametersAtPath(data []byte, argumentsPath, toolName string) ([]byte, bool) {
|
||||
if toolName != "bash" && toolName != "edit" {
|
||||
return data, false
|
||||
}
|
||||
|
||||
// arguments 可能是字符串(JSON)或已解析的 map
|
||||
var argsMap map[string]any
|
||||
switch v := arguments.(type) {
|
||||
case string:
|
||||
// 解析 JSON 字符串
|
||||
if err := json.Unmarshal([]byte(v), &argsMap); err != nil {
|
||||
return false
|
||||
args := gjson.GetBytes(data, argumentsPath)
|
||||
if !args.Exists() {
|
||||
return data, false
|
||||
}
|
||||
|
||||
switch args.Type {
|
||||
case gjson.String:
|
||||
argsJSON := strings.TrimSpace(args.Str)
|
||||
if !gjson.Valid(argsJSON) {
|
||||
return data, false
|
||||
}
|
||||
case map[string]any:
|
||||
argsMap = v
|
||||
if !gjson.Parse(argsJSON).IsObject() {
|
||||
return data, false
|
||||
}
|
||||
nextArgsJSON, corrected := c.correctToolArgumentsJSON(argsJSON, toolName)
|
||||
if !corrected {
|
||||
return data, false
|
||||
}
|
||||
next, err := sjson.SetBytes(data, argumentsPath, nextArgsJSON)
|
||||
if err != nil {
|
||||
return data, false
|
||||
}
|
||||
return next, true
|
||||
case gjson.JSON:
|
||||
if !args.IsObject() || !gjson.Valid(args.Raw) {
|
||||
return data, false
|
||||
}
|
||||
nextArgsJSON, corrected := c.correctToolArgumentsJSON(args.Raw, toolName)
|
||||
if !corrected {
|
||||
return data, false
|
||||
}
|
||||
next, err := sjson.SetRawBytes(data, argumentsPath, []byte(nextArgsJSON))
|
||||
if err != nil {
|
||||
return data, false
|
||||
}
|
||||
return next, true
|
||||
default:
|
||||
return false
|
||||
return data, false
|
||||
}
|
||||
}
|
||||
|
||||
// correctToolArgumentsJSON 修正工具参数 JSON(对象字符串),返回修正后的 JSON 与是否变更。
|
||||
func (c *CodexToolCorrector) correctToolArgumentsJSON(argsJSON, toolName string) (string, bool) {
|
||||
if !gjson.Valid(argsJSON) {
|
||||
return argsJSON, false
|
||||
}
|
||||
if !gjson.Parse(argsJSON).IsObject() {
|
||||
return argsJSON, false
|
||||
}
|
||||
|
||||
updated := argsJSON
|
||||
corrected := false
|
||||
|
||||
// 根据工具名称应用特定的参数修正规则
|
||||
switch toolName {
|
||||
case "bash":
|
||||
// OpenCode bash 支持 workdir;有些来源会输出 work_dir。
|
||||
if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir {
|
||||
if workDir, exists := argsMap["work_dir"]; exists {
|
||||
argsMap["workdir"] = workDir
|
||||
delete(argsMap, "work_dir")
|
||||
if !gjson.Get(updated, "workdir").Exists() {
|
||||
if next, changed := moveJSONField(updated, "work_dir", "workdir"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool")
|
||||
}
|
||||
} else {
|
||||
if _, exists := argsMap["work_dir"]; exists {
|
||||
delete(argsMap, "work_dir")
|
||||
if next, changed := deleteJSONField(updated, "work_dir"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool")
|
||||
}
|
||||
@@ -232,67 +268,71 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
|
||||
|
||||
case "edit":
|
||||
// OpenCode edit 参数为 filePath/oldString/newString(camelCase)。
|
||||
if _, exists := argsMap["filePath"]; !exists {
|
||||
if filePath, exists := argsMap["file_path"]; exists {
|
||||
argsMap["filePath"] = filePath
|
||||
delete(argsMap, "file_path")
|
||||
if !gjson.Get(updated, "filePath").Exists() {
|
||||
if next, changed := moveJSONField(updated, "file_path", "filePath"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool")
|
||||
} else if filePath, exists := argsMap["path"]; exists {
|
||||
argsMap["filePath"] = filePath
|
||||
delete(argsMap, "path")
|
||||
} else if next, changed := moveJSONField(updated, "path", "filePath"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool")
|
||||
} else if filePath, exists := argsMap["file"]; exists {
|
||||
argsMap["filePath"] = filePath
|
||||
delete(argsMap, "file")
|
||||
} else if next, changed := moveJSONField(updated, "file", "filePath"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool")
|
||||
}
|
||||
}
|
||||
|
||||
if _, exists := argsMap["oldString"]; !exists {
|
||||
if oldString, exists := argsMap["old_string"]; exists {
|
||||
argsMap["oldString"] = oldString
|
||||
delete(argsMap, "old_string")
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
|
||||
}
|
||||
if next, changed := moveJSONField(updated, "old_string", "oldString"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
|
||||
}
|
||||
|
||||
if _, exists := argsMap["newString"]; !exists {
|
||||
if newString, exists := argsMap["new_string"]; exists {
|
||||
argsMap["newString"] = newString
|
||||
delete(argsMap, "new_string")
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
|
||||
}
|
||||
if next, changed := moveJSONField(updated, "new_string", "newString"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
|
||||
}
|
||||
|
||||
if _, exists := argsMap["replaceAll"]; !exists {
|
||||
if replaceAll, exists := argsMap["replace_all"]; exists {
|
||||
argsMap["replaceAll"] = replaceAll
|
||||
delete(argsMap, "replace_all")
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
|
||||
}
|
||||
if next, changed := moveJSONField(updated, "replace_all", "replaceAll"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
|
||||
}
|
||||
}
|
||||
return updated, corrected
|
||||
}
|
||||
|
||||
// 如果修正了参数,需要重新序列化
|
||||
if corrected {
|
||||
if _, wasString := arguments.(string); wasString {
|
||||
// 原本是字符串,序列化回字符串
|
||||
if newArgsJSON, err := json.Marshal(argsMap); err == nil {
|
||||
functionCall["arguments"] = string(newArgsJSON)
|
||||
}
|
||||
} else {
|
||||
// 原本是 map,直接赋值
|
||||
functionCall["arguments"] = argsMap
|
||||
}
|
||||
func moveJSONField(input, from, to string) (string, bool) {
|
||||
if gjson.Get(input, to).Exists() {
|
||||
return input, false
|
||||
}
|
||||
src := gjson.Get(input, from)
|
||||
if !src.Exists() {
|
||||
return input, false
|
||||
}
|
||||
next, err := sjson.SetRaw(input, to, src.Raw)
|
||||
if err != nil {
|
||||
return input, false
|
||||
}
|
||||
next, err = sjson.Delete(next, from)
|
||||
if err != nil {
|
||||
return input, false
|
||||
}
|
||||
return next, true
|
||||
}
|
||||
|
||||
return corrected
|
||||
func deleteJSONField(input, path string) (string, bool) {
|
||||
if !gjson.Get(input, path).Exists() {
|
||||
return input, false
|
||||
}
|
||||
next, err := sjson.Delete(input, path)
|
||||
if err != nil {
|
||||
return input, false
|
||||
}
|
||||
return next, true
|
||||
}
|
||||
|
||||
// recordCorrection 记录一次工具名称修正
|
||||
|
||||
@@ -5,6 +5,15 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMayContainToolCallPayload(t *testing.T) {
|
||||
if mayContainToolCallPayload([]byte(`{"type":"response.output_text.delta","delta":"hello"}`)) {
|
||||
t.Fatalf("plain text event should not trigger tool-call parsing")
|
||||
}
|
||||
if !mayContainToolCallPayload([]byte(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)) {
|
||||
t.Fatalf("tool_calls event should trigger tool-call parsing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorrectToolCallsInSSEData(t *testing.T) {
|
||||
corrector := NewCodexToolCorrector()
|
||||
|
||||
|
||||
190
backend/internal/service/openai_ws_account_sticky_test.go
Normal file
190
backend/internal/service/openai_ws_account_sticky_test.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(23)
|
||||
account := Account{
|
||||
ID: 2,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 2,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
openaiWSStateStore: store,
|
||||
}
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_1", account.ID, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, account.ID, selection.Account.ID)
|
||||
require.True(t, selection.Acquired)
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(23)
|
||||
account := Account{
|
||||
ID: 8,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
openaiWSStateStore: store,
|
||||
}
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_2", account.ID, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, selection)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_ForceHTTPIgnored(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(23)
|
||||
account := Account{
|
||||
ID: 11,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_ws_force_http": true,
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
openaiWSStateStore: store,
|
||||
}
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_force_http", account.ID, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, selection, "force_http 场景应忽略 previous_response_id 粘连")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(23)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 21,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 22,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 9,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cache := &stubGatewayCache{}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2
|
||||
cfg.Gateway.Scheduling.StickySessionWaitTimeout = 30 * time.Second
|
||||
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
acquireResults: map[int64]bool{
|
||||
21: false, // previous_response 命中的账号繁忙
|
||||
22: true, // 次优账号可用(若回退会命中)
|
||||
},
|
||||
waitCounts: map[int64]int{
|
||||
21: 999,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
openaiWSStateStore: store,
|
||||
}
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_busy", 21, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(21), selection.Account.ID, "busy previous_response sticky account should remain selected")
|
||||
require.False(t, selection.Acquired)
|
||||
require.NotNil(t, selection.WaitPlan)
|
||||
require.Equal(t, int64(21), selection.WaitPlan.AccountID)
|
||||
}
|
||||
|
||||
func newOpenAIWSV2TestConfig() *config.Config {
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
|
||||
return cfg
|
||||
}
|
||||
285
backend/internal/service/openai_ws_client.go
Normal file
285
backend/internal/service/openai_ws_client.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
)
|
||||
|
||||
const openAIWSMessageReadLimitBytes int64 = 16 * 1024 * 1024
|
||||
const (
|
||||
openAIWSProxyTransportMaxIdleConns = 128
|
||||
openAIWSProxyTransportMaxIdleConnsPerHost = 64
|
||||
openAIWSProxyTransportIdleConnTimeout = 90 * time.Second
|
||||
openAIWSProxyClientCacheMaxEntries = 256
|
||||
openAIWSProxyClientCacheIdleTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
type OpenAIWSTransportMetricsSnapshot struct {
|
||||
ProxyClientCacheHits int64 `json:"proxy_client_cache_hits"`
|
||||
ProxyClientCacheMisses int64 `json:"proxy_client_cache_misses"`
|
||||
TransportReuseRatio float64 `json:"transport_reuse_ratio"`
|
||||
}
|
||||
|
||||
// openAIWSClientConn 抽象 WS 客户端连接,便于替换底层实现。
|
||||
type openAIWSClientConn interface {
|
||||
WriteJSON(ctx context.Context, value any) error
|
||||
ReadMessage(ctx context.Context) ([]byte, error)
|
||||
Ping(ctx context.Context) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// openAIWSClientDialer 抽象 WS 建连器。
|
||||
type openAIWSClientDialer interface {
|
||||
Dial(ctx context.Context, wsURL string, headers http.Header, proxyURL string) (openAIWSClientConn, int, http.Header, error)
|
||||
}
|
||||
|
||||
type openAIWSTransportMetricsDialer interface {
|
||||
SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot
|
||||
}
|
||||
|
||||
func newDefaultOpenAIWSClientDialer() openAIWSClientDialer {
|
||||
return &coderOpenAIWSClientDialer{
|
||||
proxyClients: make(map[string]*openAIWSProxyClientEntry),
|
||||
}
|
||||
}
|
||||
|
||||
type coderOpenAIWSClientDialer struct {
|
||||
proxyMu sync.Mutex
|
||||
proxyClients map[string]*openAIWSProxyClientEntry
|
||||
proxyHits atomic.Int64
|
||||
proxyMisses atomic.Int64
|
||||
}
|
||||
|
||||
type openAIWSProxyClientEntry struct {
|
||||
client *http.Client
|
||||
lastUsedUnixNano int64
|
||||
}
|
||||
|
||||
func (d *coderOpenAIWSClientDialer) Dial(
|
||||
ctx context.Context,
|
||||
wsURL string,
|
||||
headers http.Header,
|
||||
proxyURL string,
|
||||
) (openAIWSClientConn, int, http.Header, error) {
|
||||
targetURL := strings.TrimSpace(wsURL)
|
||||
if targetURL == "" {
|
||||
return nil, 0, nil, errors.New("ws url is empty")
|
||||
}
|
||||
|
||||
opts := &coderws.DialOptions{
|
||||
HTTPHeader: cloneHeader(headers),
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
}
|
||||
if proxy := strings.TrimSpace(proxyURL); proxy != "" {
|
||||
proxyClient, err := d.proxyHTTPClient(proxy)
|
||||
if err != nil {
|
||||
return nil, 0, nil, err
|
||||
}
|
||||
opts.HTTPClient = proxyClient
|
||||
}
|
||||
|
||||
conn, resp, err := coderws.Dial(ctx, targetURL, opts)
|
||||
if err != nil {
|
||||
status := 0
|
||||
respHeaders := http.Header(nil)
|
||||
if resp != nil {
|
||||
status = resp.StatusCode
|
||||
respHeaders = cloneHeader(resp.Header)
|
||||
}
|
||||
return nil, status, respHeaders, err
|
||||
}
|
||||
// coder/websocket 默认单消息读取上限为 32KB,Codex WS 事件(如 rate_limits/大 delta)
|
||||
// 可能超过该阈值,需显式提高上限,避免本地 read_fail(message too big)。
|
||||
conn.SetReadLimit(openAIWSMessageReadLimitBytes)
|
||||
respHeaders := http.Header(nil)
|
||||
if resp != nil {
|
||||
respHeaders = cloneHeader(resp.Header)
|
||||
}
|
||||
return &coderOpenAIWSClientConn{conn: conn}, 0, respHeaders, nil
|
||||
}
|
||||
|
||||
func (d *coderOpenAIWSClientDialer) proxyHTTPClient(proxy string) (*http.Client, error) {
|
||||
if d == nil {
|
||||
return nil, errors.New("openai ws dialer is nil")
|
||||
}
|
||||
normalizedProxy := strings.TrimSpace(proxy)
|
||||
if normalizedProxy == "" {
|
||||
return nil, errors.New("proxy url is empty")
|
||||
}
|
||||
parsedProxyURL, err := url.Parse(normalizedProxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy url: %w", err)
|
||||
}
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
d.proxyMu.Lock()
|
||||
defer d.proxyMu.Unlock()
|
||||
if entry, ok := d.proxyClients[normalizedProxy]; ok && entry != nil && entry.client != nil {
|
||||
entry.lastUsedUnixNano = now
|
||||
d.proxyHits.Add(1)
|
||||
return entry.client, nil
|
||||
}
|
||||
d.cleanupProxyClientsLocked(now)
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyURL(parsedProxyURL),
|
||||
MaxIdleConns: openAIWSProxyTransportMaxIdleConns,
|
||||
MaxIdleConnsPerHost: openAIWSProxyTransportMaxIdleConnsPerHost,
|
||||
IdleConnTimeout: openAIWSProxyTransportIdleConnTimeout,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
d.proxyClients[normalizedProxy] = &openAIWSProxyClientEntry{
|
||||
client: client,
|
||||
lastUsedUnixNano: now,
|
||||
}
|
||||
d.ensureProxyClientCapacityLocked()
|
||||
d.proxyMisses.Add(1)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (d *coderOpenAIWSClientDialer) cleanupProxyClientsLocked(nowUnixNano int64) {
|
||||
if d == nil || len(d.proxyClients) == 0 {
|
||||
return
|
||||
}
|
||||
idleTTL := openAIWSProxyClientCacheIdleTTL
|
||||
if idleTTL <= 0 {
|
||||
return
|
||||
}
|
||||
now := time.Unix(0, nowUnixNano)
|
||||
for key, entry := range d.proxyClients {
|
||||
if entry == nil || entry.client == nil {
|
||||
delete(d.proxyClients, key)
|
||||
continue
|
||||
}
|
||||
lastUsed := time.Unix(0, entry.lastUsedUnixNano)
|
||||
if now.Sub(lastUsed) > idleTTL {
|
||||
closeOpenAIWSProxyClient(entry.client)
|
||||
delete(d.proxyClients, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *coderOpenAIWSClientDialer) ensureProxyClientCapacityLocked() {
|
||||
if d == nil {
|
||||
return
|
||||
}
|
||||
maxEntries := openAIWSProxyClientCacheMaxEntries
|
||||
if maxEntries <= 0 {
|
||||
return
|
||||
}
|
||||
for len(d.proxyClients) > maxEntries {
|
||||
var oldestKey string
|
||||
var oldestLastUsed int64
|
||||
hasOldest := false
|
||||
for key, entry := range d.proxyClients {
|
||||
lastUsed := int64(0)
|
||||
if entry != nil {
|
||||
lastUsed = entry.lastUsedUnixNano
|
||||
}
|
||||
if !hasOldest || lastUsed < oldestLastUsed {
|
||||
hasOldest = true
|
||||
oldestKey = key
|
||||
oldestLastUsed = lastUsed
|
||||
}
|
||||
}
|
||||
if !hasOldest {
|
||||
return
|
||||
}
|
||||
if entry := d.proxyClients[oldestKey]; entry != nil {
|
||||
closeOpenAIWSProxyClient(entry.client)
|
||||
}
|
||||
delete(d.proxyClients, oldestKey)
|
||||
}
|
||||
}
|
||||
|
||||
func closeOpenAIWSProxyClient(client *http.Client) {
|
||||
if client == nil || client.Transport == nil {
|
||||
return
|
||||
}
|
||||
if transport, ok := client.Transport.(*http.Transport); ok && transport != nil {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *coderOpenAIWSClientDialer) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
|
||||
if d == nil {
|
||||
return OpenAIWSTransportMetricsSnapshot{}
|
||||
}
|
||||
hits := d.proxyHits.Load()
|
||||
misses := d.proxyMisses.Load()
|
||||
total := hits + misses
|
||||
reuseRatio := 0.0
|
||||
if total > 0 {
|
||||
reuseRatio = float64(hits) / float64(total)
|
||||
}
|
||||
return OpenAIWSTransportMetricsSnapshot{
|
||||
ProxyClientCacheHits: hits,
|
||||
ProxyClientCacheMisses: misses,
|
||||
TransportReuseRatio: reuseRatio,
|
||||
}
|
||||
}
|
||||
|
||||
type coderOpenAIWSClientConn struct {
|
||||
conn *coderws.Conn
|
||||
}
|
||||
|
||||
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
|
||||
if c == nil || c.conn == nil {
|
||||
return errOpenAIWSConnClosed
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return wsjson.Write(ctx, c.conn, value)
|
||||
}
|
||||
|
||||
func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, error) {
|
||||
if c == nil || c.conn == nil {
|
||||
return nil, errOpenAIWSConnClosed
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
msgType, payload, err := c.conn.Read(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch msgType {
|
||||
case coderws.MessageText, coderws.MessageBinary:
|
||||
return payload, nil
|
||||
default:
|
||||
return nil, errOpenAIWSConnClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
|
||||
if c == nil || c.conn == nil {
|
||||
return errOpenAIWSConnClosed
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return c.conn.Ping(ctx)
|
||||
}
|
||||
|
||||
func (c *coderOpenAIWSClientConn) Close() error {
|
||||
if c == nil || c.conn == nil {
|
||||
return nil
|
||||
}
|
||||
// Close 为幂等,忽略重复关闭错误。
|
||||
_ = c.conn.Close(coderws.StatusNormalClosure, "")
|
||||
_ = c.conn.CloseNow()
|
||||
return nil
|
||||
}
|
||||
112
backend/internal/service/openai_ws_client_test.go
Normal file
112
backend/internal/service/openai_ws_client_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCoderOpenAIWSClientDialer_ProxyHTTPClientReuse(t *testing.T) {
|
||||
dialer := newDefaultOpenAIWSClientDialer()
|
||||
impl, ok := dialer.(*coderOpenAIWSClientDialer)
|
||||
require.True(t, ok)
|
||||
|
||||
c1, err := impl.proxyHTTPClient("http://127.0.0.1:8080")
|
||||
require.NoError(t, err)
|
||||
c2, err := impl.proxyHTTPClient("http://127.0.0.1:8080")
|
||||
require.NoError(t, err)
|
||||
require.Same(t, c1, c2, "同一代理地址应复用同一个 HTTP 客户端")
|
||||
|
||||
c3, err := impl.proxyHTTPClient("http://127.0.0.1:8081")
|
||||
require.NoError(t, err)
|
||||
require.NotSame(t, c1, c3, "不同代理地址应分离客户端")
|
||||
}
|
||||
|
||||
func TestCoderOpenAIWSClientDialer_ProxyHTTPClientInvalidURL(t *testing.T) {
|
||||
dialer := newDefaultOpenAIWSClientDialer()
|
||||
impl, ok := dialer.(*coderOpenAIWSClientDialer)
|
||||
require.True(t, ok)
|
||||
|
||||
_, err := impl.proxyHTTPClient("://bad")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCoderOpenAIWSClientDialer_TransportMetricsSnapshot(t *testing.T) {
|
||||
dialer := newDefaultOpenAIWSClientDialer()
|
||||
impl, ok := dialer.(*coderOpenAIWSClientDialer)
|
||||
require.True(t, ok)
|
||||
|
||||
_, err := impl.proxyHTTPClient("http://127.0.0.1:18080")
|
||||
require.NoError(t, err)
|
||||
_, err = impl.proxyHTTPClient("http://127.0.0.1:18080")
|
||||
require.NoError(t, err)
|
||||
_, err = impl.proxyHTTPClient("http://127.0.0.1:18081")
|
||||
require.NoError(t, err)
|
||||
|
||||
snapshot := impl.SnapshotTransportMetrics()
|
||||
require.Equal(t, int64(1), snapshot.ProxyClientCacheHits)
|
||||
require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses)
|
||||
require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001)
|
||||
}
|
||||
|
||||
func TestCoderOpenAIWSClientDialer_ProxyClientCacheCapacity(t *testing.T) {
|
||||
dialer := newDefaultOpenAIWSClientDialer()
|
||||
impl, ok := dialer.(*coderOpenAIWSClientDialer)
|
||||
require.True(t, ok)
|
||||
|
||||
total := openAIWSProxyClientCacheMaxEntries + 32
|
||||
for i := 0; i < total; i++ {
|
||||
_, err := impl.proxyHTTPClient(fmt.Sprintf("http://127.0.0.1:%d", 20000+i))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
impl.proxyMu.Lock()
|
||||
cacheSize := len(impl.proxyClients)
|
||||
impl.proxyMu.Unlock()
|
||||
|
||||
require.LessOrEqual(t, cacheSize, openAIWSProxyClientCacheMaxEntries, "代理客户端缓存应受容量上限约束")
|
||||
}
|
||||
|
||||
func TestCoderOpenAIWSClientDialer_ProxyClientCacheIdleTTL(t *testing.T) {
|
||||
dialer := newDefaultOpenAIWSClientDialer()
|
||||
impl, ok := dialer.(*coderOpenAIWSClientDialer)
|
||||
require.True(t, ok)
|
||||
|
||||
oldProxy := "http://127.0.0.1:28080"
|
||||
_, err := impl.proxyHTTPClient(oldProxy)
|
||||
require.NoError(t, err)
|
||||
|
||||
impl.proxyMu.Lock()
|
||||
oldEntry := impl.proxyClients[oldProxy]
|
||||
require.NotNil(t, oldEntry)
|
||||
oldEntry.lastUsedUnixNano = time.Now().Add(-openAIWSProxyClientCacheIdleTTL - time.Minute).UnixNano()
|
||||
impl.proxyMu.Unlock()
|
||||
|
||||
// 触发一次新的代理获取,驱动 TTL 清理。
|
||||
_, err = impl.proxyHTTPClient("http://127.0.0.1:28081")
|
||||
require.NoError(t, err)
|
||||
|
||||
impl.proxyMu.Lock()
|
||||
_, exists := impl.proxyClients[oldProxy]
|
||||
impl.proxyMu.Unlock()
|
||||
|
||||
require.False(t, exists, "超过空闲 TTL 的代理客户端应被回收")
|
||||
}
|
||||
|
||||
func TestCoderOpenAIWSClientDialer_ProxyTransportTLSHandshakeTimeout(t *testing.T) {
|
||||
dialer := newDefaultOpenAIWSClientDialer()
|
||||
impl, ok := dialer.(*coderOpenAIWSClientDialer)
|
||||
require.True(t, ok)
|
||||
|
||||
client, err := impl.proxyHTTPClient("http://127.0.0.1:38080")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, transport)
|
||||
require.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout)
|
||||
}
|
||||
251
backend/internal/service/openai_ws_fallback_test.go
Normal file
251
backend/internal/service/openai_ws_fallback_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClassifyOpenAIWSAcquireError(t *testing.T) {
|
||||
t.Run("dial_426_upgrade_required", func(t *testing.T) {
|
||||
err := &openAIWSDialError{StatusCode: 426, Err: errors.New("upgrade required")}
|
||||
require.Equal(t, "upgrade_required", classifyOpenAIWSAcquireError(err))
|
||||
})
|
||||
|
||||
t.Run("queue_full", func(t *testing.T) {
|
||||
require.Equal(t, "conn_queue_full", classifyOpenAIWSAcquireError(errOpenAIWSConnQueueFull))
|
||||
})
|
||||
|
||||
t.Run("preferred_conn_unavailable", func(t *testing.T) {
|
||||
require.Equal(t, "preferred_conn_unavailable", classifyOpenAIWSAcquireError(errOpenAIWSPreferredConnUnavailable))
|
||||
})
|
||||
|
||||
t.Run("acquire_timeout", func(t *testing.T) {
|
||||
require.Equal(t, "acquire_timeout", classifyOpenAIWSAcquireError(context.DeadlineExceeded))
|
||||
})
|
||||
|
||||
t.Run("auth_failed_401", func(t *testing.T) {
|
||||
err := &openAIWSDialError{StatusCode: 401, Err: errors.New("unauthorized")}
|
||||
require.Equal(t, "auth_failed", classifyOpenAIWSAcquireError(err))
|
||||
})
|
||||
|
||||
t.Run("upstream_rate_limited", func(t *testing.T) {
|
||||
err := &openAIWSDialError{StatusCode: 429, Err: errors.New("rate limited")}
|
||||
require.Equal(t, "upstream_rate_limited", classifyOpenAIWSAcquireError(err))
|
||||
})
|
||||
|
||||
t.Run("upstream_5xx", func(t *testing.T) {
|
||||
err := &openAIWSDialError{StatusCode: 502, Err: errors.New("bad gateway")}
|
||||
require.Equal(t, "upstream_5xx", classifyOpenAIWSAcquireError(err))
|
||||
})
|
||||
|
||||
t.Run("dial_failed_other_status", func(t *testing.T) {
|
||||
err := &openAIWSDialError{StatusCode: 418, Err: errors.New("teapot")}
|
||||
require.Equal(t, "dial_failed", classifyOpenAIWSAcquireError(err))
|
||||
})
|
||||
|
||||
t.Run("other", func(t *testing.T) {
|
||||
require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(errors.New("x")))
|
||||
})
|
||||
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(nil))
|
||||
})
|
||||
}
|
||||
|
||||
func TestClassifyOpenAIWSDialError(t *testing.T) {
|
||||
t.Run("handshake_not_finished", func(t *testing.T) {
|
||||
err := &openAIWSDialError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Err: errors.New("WebSocket protocol error: Handshake not finished"),
|
||||
}
|
||||
require.Equal(t, "handshake_not_finished", classifyOpenAIWSDialError(err))
|
||||
})
|
||||
|
||||
t.Run("context_deadline", func(t *testing.T) {
|
||||
err := &openAIWSDialError{
|
||||
StatusCode: 0,
|
||||
Err: context.DeadlineExceeded,
|
||||
}
|
||||
require.Equal(t, "ctx_deadline_exceeded", classifyOpenAIWSDialError(err))
|
||||
})
|
||||
}
|
||||
|
||||
func TestSummarizeOpenAIWSDialError(t *testing.T) {
|
||||
err := &openAIWSDialError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
ResponseHeaders: http.Header{
|
||||
"Server": []string{"cloudflare"},
|
||||
"Via": []string{"1.1 example"},
|
||||
"Cf-Ray": []string{"abcd1234"},
|
||||
"X-Request-Id": []string{"req_123"},
|
||||
},
|
||||
Err: errors.New("WebSocket protocol error: Handshake not finished"),
|
||||
}
|
||||
|
||||
status, class, closeStatus, closeReason, server, via, cfRay, reqID := summarizeOpenAIWSDialError(err)
|
||||
require.Equal(t, http.StatusBadGateway, status)
|
||||
require.Equal(t, "handshake_not_finished", class)
|
||||
require.Equal(t, "-", closeStatus)
|
||||
require.Equal(t, "-", closeReason)
|
||||
require.Equal(t, "cloudflare", server)
|
||||
require.Equal(t, "1.1 example", via)
|
||||
require.Equal(t, "abcd1234", cfRay)
|
||||
require.Equal(t, "req_123", reqID)
|
||||
}
|
||||
|
||||
func TestClassifyOpenAIWSErrorEvent(t *testing.T) {
|
||||
reason, recoverable := classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"upgrade_required","message":"Upgrade required"}}`))
|
||||
require.Equal(t, "upgrade_required", reason)
|
||||
require.True(t, recoverable)
|
||||
|
||||
reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"previous_response_not_found","message":"not found"}}`))
|
||||
require.Equal(t, "previous_response_not_found", reason)
|
||||
require.True(t, recoverable)
|
||||
}
|
||||
|
||||
func TestClassifyOpenAIWSReconnectReason(t *testing.T) {
|
||||
reason, retryable := classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("policy_violation", errors.New("policy")))
|
||||
require.Equal(t, "policy_violation", reason)
|
||||
require.False(t, retryable)
|
||||
|
||||
reason, retryable = classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("read_event", errors.New("io")))
|
||||
require.Equal(t, "read_event", reason)
|
||||
require.True(t, retryable)
|
||||
}
|
||||
|
||||
func TestOpenAIWSErrorHTTPStatus(t *testing.T) {
|
||||
require.Equal(t, http.StatusBadRequest, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)))
|
||||
require.Equal(t, http.StatusUnauthorized, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"authentication_error","code":"invalid_api_key","message":"auth failed"}}`)))
|
||||
require.Equal(t, http.StatusForbidden, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"permission_error","code":"forbidden","message":"forbidden"}}`)))
|
||||
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"rate_limit_error","code":"rate_limit_exceeded","message":"rate limited"}}`)))
|
||||
require.Equal(t, http.StatusBadGateway, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"server_error","code":"server_error","message":"server"}}`)))
|
||||
}
|
||||
|
||||
func TestResolveOpenAIWSFallbackErrorResponse(t *testing.T) {
|
||||
t.Run("previous_response_not_found", func(t *testing.T) {
|
||||
statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(
|
||||
wrapOpenAIWSFallback("previous_response_not_found", errors.New("previous response not found")),
|
||||
)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, http.StatusBadRequest, statusCode)
|
||||
require.Equal(t, "invalid_request_error", errType)
|
||||
require.Equal(t, "previous response not found", clientMessage)
|
||||
require.Equal(t, "previous response not found", upstreamMessage)
|
||||
})
|
||||
|
||||
t.Run("auth_failed_uses_dial_status", func(t *testing.T) {
|
||||
statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(
|
||||
wrapOpenAIWSFallback("auth_failed", &openAIWSDialError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Err: errors.New("forbidden"),
|
||||
}),
|
||||
)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, http.StatusForbidden, statusCode)
|
||||
require.Equal(t, "upstream_error", errType)
|
||||
require.Equal(t, "forbidden", clientMessage)
|
||||
require.Equal(t, "forbidden", upstreamMessage)
|
||||
})
|
||||
|
||||
t.Run("non_fallback_error_not_resolved", func(t *testing.T) {
|
||||
_, _, _, _, ok := resolveOpenAIWSFallbackErrorResponse(errors.New("plain error"))
|
||||
require.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIWSFallbackCooling(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
svc.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
|
||||
|
||||
require.False(t, svc.isOpenAIWSFallbackCooling(1))
|
||||
svc.markOpenAIWSFallbackCooling(1, "upgrade_required")
|
||||
require.True(t, svc.isOpenAIWSFallbackCooling(1))
|
||||
|
||||
svc.clearOpenAIWSFallbackCooling(1)
|
||||
require.False(t, svc.isOpenAIWSFallbackCooling(1))
|
||||
|
||||
svc.markOpenAIWSFallbackCooling(2, "x")
|
||||
time.Sleep(1200 * time.Millisecond)
|
||||
require.False(t, svc.isOpenAIWSFallbackCooling(2))
|
||||
}
|
||||
|
||||
func TestOpenAIWSRetryBackoff(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
svc.cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 100
|
||||
svc.cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 400
|
||||
svc.cfg.Gateway.OpenAIWS.RetryJitterRatio = 0
|
||||
|
||||
require.Equal(t, time.Duration(100)*time.Millisecond, svc.openAIWSRetryBackoff(1))
|
||||
require.Equal(t, time.Duration(200)*time.Millisecond, svc.openAIWSRetryBackoff(2))
|
||||
require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(3))
|
||||
require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(4))
|
||||
}
|
||||
|
||||
func TestOpenAIWSRetryTotalBudget(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 1200
|
||||
require.Equal(t, 1200*time.Millisecond, svc.openAIWSRetryTotalBudget())
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 0
|
||||
require.Equal(t, time.Duration(0), svc.openAIWSRetryTotalBudget())
|
||||
}
|
||||
|
||||
func TestClassifyOpenAIWSReadFallbackReason(t *testing.T) {
|
||||
require.Equal(t, "policy_violation", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusPolicyViolation}))
|
||||
require.Equal(t, "message_too_big", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusMessageTooBig}))
|
||||
require.Equal(t, "read_event", classifyOpenAIWSReadFallbackReason(errors.New("io")))
|
||||
}
|
||||
|
||||
func TestOpenAIWSStoreDisabledConnMode(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true
|
||||
require.Equal(t, openAIWSStoreDisabledConnModeStrict, svc.openAIWSStoreDisabledConnMode())
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = "adaptive"
|
||||
require.Equal(t, openAIWSStoreDisabledConnModeAdaptive, svc.openAIWSStoreDisabledConnMode())
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = ""
|
||||
svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false
|
||||
require.Equal(t, openAIWSStoreDisabledConnModeOff, svc.openAIWSStoreDisabledConnMode())
|
||||
}
|
||||
|
||||
func TestShouldForceNewConnOnStoreDisabled(t *testing.T) {
|
||||
require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeStrict, ""))
|
||||
require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeOff, "policy_violation"))
|
||||
|
||||
require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "policy_violation"))
|
||||
require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "prewarm_message_too_big"))
|
||||
require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "read_event"))
|
||||
}
|
||||
|
||||
func TestOpenAIWSRetryMetricsSnapshot(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
svc.recordOpenAIWSRetryAttempt(150 * time.Millisecond)
|
||||
svc.recordOpenAIWSRetryAttempt(0)
|
||||
svc.recordOpenAIWSRetryExhausted()
|
||||
svc.recordOpenAIWSNonRetryableFastFallback()
|
||||
|
||||
snapshot := svc.SnapshotOpenAIWSRetryMetrics()
|
||||
require.Equal(t, int64(2), snapshot.RetryAttemptsTotal)
|
||||
require.Equal(t, int64(150), snapshot.RetryBackoffMsTotal)
|
||||
require.Equal(t, int64(1), snapshot.RetryExhaustedTotal)
|
||||
require.Equal(t, int64(1), snapshot.NonRetryableFastFallbackTotal)
|
||||
}
|
||||
|
||||
func TestShouldLogOpenAIWSPayloadSchema(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 0
|
||||
require.True(t, svc.shouldLogOpenAIWSPayloadSchema(1), "首次尝试应始终记录 payload_schema")
|
||||
require.False(t, svc.shouldLogOpenAIWSPayloadSchema(2))
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 1
|
||||
require.True(t, svc.shouldLogOpenAIWSPayloadSchema(2))
|
||||
}
|
||||
3955
backend/internal/service/openai_ws_forwarder.go
Normal file
3955
backend/internal/service/openai_ws_forwarder.go
Normal file
File diff suppressed because it is too large
Load Diff
127
backend/internal/service/openai_ws_forwarder_benchmark_test.go
Normal file
127
backend/internal/service/openai_ws_forwarder_benchmark_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
var (
|
||||
benchmarkOpenAIWSPayloadJSONSink string
|
||||
benchmarkOpenAIWSStringSink string
|
||||
benchmarkOpenAIWSBoolSink bool
|
||||
benchmarkOpenAIWSBytesSink []byte
|
||||
)
|
||||
|
||||
func BenchmarkOpenAIWSForwarderHotPath(b *testing.B) {
|
||||
cfg := &config.Config{}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
account := &Account{ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
reqBody := benchmarkOpenAIWSHotPathRequest()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
payload := svc.buildOpenAIWSCreatePayload(reqBody, account)
|
||||
_, _ = applyOpenAIWSRetryPayloadStrategy(payload, 2)
|
||||
setOpenAIWSTurnMetadata(payload, `{"trace":"bench","turn":"1"}`)
|
||||
|
||||
benchmarkOpenAIWSStringSink = openAIWSPayloadString(payload, "previous_response_id")
|
||||
benchmarkOpenAIWSBoolSink = payload["tools"] != nil
|
||||
benchmarkOpenAIWSStringSink = summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN)
|
||||
benchmarkOpenAIWSStringSink = summarizeOpenAIWSInput(payload["input"])
|
||||
benchmarkOpenAIWSPayloadJSONSink = payloadAsJSON(payload)
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkOpenAIWSHotPathRequest() map[string]any {
|
||||
tools := make([]map[string]any, 0, 24)
|
||||
for i := 0; i < 24; i++ {
|
||||
tools = append(tools, map[string]any{
|
||||
"type": "function",
|
||||
"name": fmt.Sprintf("tool_%02d", i),
|
||||
"description": "benchmark tool schema",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{"type": "string"},
|
||||
"limit": map[string]any{"type": "number"},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
input := make([]map[string]any, 0, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
input = append(input, map[string]any{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": fmt.Sprintf("benchmark message %d", i),
|
||||
})
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"type": "response.create",
|
||||
"model": "gpt-5.3-codex",
|
||||
"input": input,
|
||||
"tools": tools,
|
||||
"parallel_tool_calls": true,
|
||||
"previous_response_id": "resp_benchmark_prev",
|
||||
"prompt_cache_key": "bench-cache-key",
|
||||
"reasoning": map[string]any{"effort": "medium"},
|
||||
"instructions": "benchmark instructions",
|
||||
"store": false,
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOpenAIWSEventEnvelopeParse(b *testing.B) {
|
||||
event := []byte(`{"type":"response.completed","response":{"id":"resp_bench_1","model":"gpt-5.1","usage":{"input_tokens":12,"output_tokens":8}}}`)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
eventType, responseID, response := parseOpenAIWSEventEnvelope(event)
|
||||
benchmarkOpenAIWSStringSink = eventType
|
||||
benchmarkOpenAIWSStringSink = responseID
|
||||
benchmarkOpenAIWSBoolSink = response.Exists()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOpenAIWSErrorEventFieldReuse(b *testing.B) {
|
||||
event := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(event)
|
||||
benchmarkOpenAIWSStringSink, benchmarkOpenAIWSBoolSink = classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw)
|
||||
code, errType, errMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw)
|
||||
benchmarkOpenAIWSStringSink = code
|
||||
benchmarkOpenAIWSStringSink = errType
|
||||
benchmarkOpenAIWSStringSink = errMsg
|
||||
benchmarkOpenAIWSBoolSink = openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) > 0
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReplaceOpenAIWSMessageModel_NoMatchFastPath(b *testing.B) {
|
||||
event := []byte(`{"type":"response.output_text.delta","delta":"hello world"}`)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReplaceOpenAIWSMessageModel_DualReplace(b *testing.B) {
|
||||
event := []byte(`{"type":"response.completed","model":"gpt-5.1","response":{"id":"resp_1","model":"gpt-5.1"}}`)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseOpenAIWSEventEnvelope(t *testing.T) {
|
||||
eventType, responseID, response := parseOpenAIWSEventEnvelope([]byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.1"}}`))
|
||||
require.Equal(t, "response.completed", eventType)
|
||||
require.Equal(t, "resp_1", responseID)
|
||||
require.True(t, response.Exists())
|
||||
require.Equal(t, `{"id":"resp_1","model":"gpt-5.1"}`, response.Raw)
|
||||
|
||||
eventType, responseID, response = parseOpenAIWSEventEnvelope([]byte(`{"type":"response.delta","id":"evt_1"}`))
|
||||
require.Equal(t, "response.delta", eventType)
|
||||
require.Equal(t, "evt_1", responseID)
|
||||
require.False(t, response.Exists())
|
||||
}
|
||||
|
||||
func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) {
|
||||
usage := &OpenAIUsage{}
|
||||
parseOpenAIWSResponseUsageFromCompletedEvent(
|
||||
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`),
|
||||
usage,
|
||||
)
|
||||
require.Equal(t, 11, usage.InputTokens)
|
||||
require.Equal(t, 7, usage.OutputTokens)
|
||||
require.Equal(t, 3, usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) {
|
||||
message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
|
||||
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
||||
|
||||
wrappedReason, wrappedRecoverable := classifyOpenAIWSErrorEvent(message)
|
||||
rawReason, rawRecoverable := classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw)
|
||||
require.Equal(t, wrappedReason, rawReason)
|
||||
require.Equal(t, wrappedRecoverable, rawRecoverable)
|
||||
|
||||
wrappedStatus := openAIWSErrorHTTPStatus(message)
|
||||
rawStatus := openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw)
|
||||
require.Equal(t, wrappedStatus, rawStatus)
|
||||
require.Equal(t, http.StatusBadRequest, rawStatus)
|
||||
|
||||
wrappedCode, wrappedType, wrappedMsg := summarizeOpenAIWSErrorEventFields(message)
|
||||
rawCode, rawType, rawMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw)
|
||||
require.Equal(t, wrappedCode, rawCode)
|
||||
require.Equal(t, wrappedType, rawType)
|
||||
require.Equal(t, wrappedMsg, rawMsg)
|
||||
}
|
||||
|
||||
func TestOpenAIWSMessageLikelyContainsToolCalls(t *testing.T) {
|
||||
require.False(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_text.delta","delta":"hello"}`)))
|
||||
require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"tool_calls":[{"id":"tc1"}]}}`)))
|
||||
require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"type":"function_call"}}`)))
|
||||
}
|
||||
|
||||
func TestReplaceOpenAIWSMessageModel_OptimizedStillCorrect(t *testing.T) {
|
||||
noModel := []byte(`{"type":"response.output_text.delta","delta":"hello"}`)
|
||||
require.Equal(t, string(noModel), string(replaceOpenAIWSMessageModel(noModel, "gpt-5.1", "custom-model")))
|
||||
|
||||
rootOnly := []byte(`{"type":"response.created","model":"gpt-5.1"}`)
|
||||
require.Equal(t, `{"type":"response.created","model":"custom-model"}`, string(replaceOpenAIWSMessageModel(rootOnly, "gpt-5.1", "custom-model")))
|
||||
|
||||
responseOnly := []byte(`{"type":"response.completed","response":{"model":"gpt-5.1"}}`)
|
||||
require.Equal(t, `{"type":"response.completed","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(responseOnly, "gpt-5.1", "custom-model")))
|
||||
|
||||
both := []byte(`{"model":"gpt-5.1","response":{"model":"gpt-5.1"}}`)
|
||||
require.Equal(t, `{"model":"custom-model","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(both, "gpt-5.1", "custom-model")))
|
||||
}
|
||||
2483
backend/internal/service/openai_ws_forwarder_ingress_session_test.go
Normal file
2483
backend/internal/service/openai_ws_forwarder_ingress_session_test.go
Normal file
File diff suppressed because it is too large
Load Diff
714
backend/internal/service/openai_ws_forwarder_ingress_test.go
Normal file
714
backend/internal/service/openai_ws_forwarder_ingress_test.go
Normal file
@@ -0,0 +1,714 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestIsOpenAIWSClientDisconnectError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{name: "nil", err: nil, want: false},
|
||||
{name: "io_eof", err: io.EOF, want: true},
|
||||
{name: "net_closed", err: net.ErrClosed, want: true},
|
||||
{name: "context_canceled", err: context.Canceled, want: true},
|
||||
{name: "ws_normal_closure", err: coderws.CloseError{Code: coderws.StatusNormalClosure}, want: true},
|
||||
{name: "ws_going_away", err: coderws.CloseError{Code: coderws.StatusGoingAway}, want: true},
|
||||
{name: "ws_no_status", err: coderws.CloseError{Code: coderws.StatusNoStatusRcvd}, want: true},
|
||||
{name: "ws_abnormal_1006", err: coderws.CloseError{Code: coderws.StatusAbnormalClosure}, want: true},
|
||||
{name: "ws_policy_violation", err: coderws.CloseError{Code: coderws.StatusPolicyViolation}, want: false},
|
||||
{name: "wrapped_eof_message", err: errors.New("failed to get reader: failed to read frame header: EOF"), want: true},
|
||||
{name: "connection_reset_by_peer", err: errors.New("failed to read frame header: read tcp 127.0.0.1:1234->127.0.0.1:5678: read: connection reset by peer"), want: true},
|
||||
{name: "broken_pipe", err: errors.New("write tcp 127.0.0.1:1234->127.0.0.1:5678: write: broken pipe"), want: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, tt.want, isOpenAIWSClientDisconnectError(tt.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOpenAIWSIngressPreviousResponseNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(nil))
|
||||
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(errors.New("plain error")))
|
||||
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(
|
||||
wrapOpenAIWSIngressTurnError("read_upstream", errors.New("upstream read failed"), false),
|
||||
))
|
||||
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(
|
||||
wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), true),
|
||||
))
|
||||
require.True(t, isOpenAIWSIngressPreviousResponseNotFound(
|
||||
wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), false),
|
||||
))
|
||||
}
|
||||
|
||||
func TestOpenAIWSIngressPreviousResponseRecoveryEnabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var nilService *OpenAIGatewayService
|
||||
require.True(t, nilService.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil service should default to enabled")
|
||||
|
||||
svcWithNilCfg := &OpenAIGatewayService{}
|
||||
require.True(t, svcWithNilCfg.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil config should default to enabled")
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
}
|
||||
require.False(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled(), "explicit config default should be false")
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true
|
||||
require.True(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled())
|
||||
}
|
||||
|
||||
func TestDropPreviousResponseIDFromRawPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("empty_payload", func(t *testing.T) {
|
||||
updated, removed, err := dropPreviousResponseIDFromRawPayload(nil)
|
||||
require.NoError(t, err)
|
||||
require.False(t, removed)
|
||||
require.Empty(t, updated)
|
||||
})
|
||||
|
||||
t.Run("payload_without_previous_response_id", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
|
||||
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
|
||||
require.NoError(t, err)
|
||||
require.False(t, removed)
|
||||
require.Equal(t, string(payload), string(updated))
|
||||
})
|
||||
|
||||
t.Run("normal_delete_success", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`)
|
||||
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
|
||||
require.NoError(t, err)
|
||||
require.True(t, removed)
|
||||
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
|
||||
})
|
||||
|
||||
t.Run("duplicate_keys_are_removed", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","previous_response_id":"resp_a","input":[],"previous_response_id":"resp_b"}`)
|
||||
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
|
||||
require.NoError(t, err)
|
||||
require.True(t, removed)
|
||||
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
|
||||
})
|
||||
|
||||
t.Run("nil_delete_fn_uses_default_delete_logic", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`)
|
||||
updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, nil)
|
||||
require.NoError(t, err)
|
||||
require.True(t, removed)
|
||||
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
|
||||
})
|
||||
|
||||
t.Run("delete_error", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`)
|
||||
updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, func(_ []byte, _ string) ([]byte, error) {
|
||||
return nil, errors.New("delete failed")
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.False(t, removed)
|
||||
require.Equal(t, string(payload), string(updated))
|
||||
})
|
||||
|
||||
t.Run("malformed_json_is_still_best_effort_deleted", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","previous_response_id":"resp_abc"`)
|
||||
require.True(t, gjson.GetBytes(payload, "previous_response_id").Exists())
|
||||
|
||||
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
|
||||
require.NoError(t, err)
|
||||
require.True(t, removed)
|
||||
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAlignStoreDisabledPreviousResponseID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("empty_payload", func(t *testing.T) {
|
||||
updated, changed, err := alignStoreDisabledPreviousResponseID(nil, "resp_target")
|
||||
require.NoError(t, err)
|
||||
require.False(t, changed)
|
||||
require.Empty(t, updated)
|
||||
})
|
||||
|
||||
t.Run("empty_expected", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","previous_response_id":"resp_old"}`)
|
||||
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "")
|
||||
require.NoError(t, err)
|
||||
require.False(t, changed)
|
||||
require.Equal(t, string(payload), string(updated))
|
||||
})
|
||||
|
||||
t.Run("missing_previous_response_id", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
|
||||
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
|
||||
require.NoError(t, err)
|
||||
require.False(t, changed)
|
||||
require.Equal(t, string(payload), string(updated))
|
||||
})
|
||||
|
||||
t.Run("already_aligned", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","previous_response_id":"resp_target"}`)
|
||||
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
|
||||
require.NoError(t, err)
|
||||
require.False(t, changed)
|
||||
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
|
||||
})
|
||||
|
||||
t.Run("mismatch_rewrites_to_expected", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","previous_response_id":"resp_old","input":[]}`)
|
||||
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
|
||||
})
|
||||
|
||||
t.Run("duplicate_keys_rewrites_to_single_expected", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","previous_response_id":"resp_old_1","input":[],"previous_response_id":"resp_old_2"}`)
|
||||
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetPreviousResponseIDToRawPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("empty_payload", func(t *testing.T) {
|
||||
updated, err := setPreviousResponseIDToRawPayload(nil, "resp_target")
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, updated)
|
||||
})
|
||||
|
||||
t.Run("empty_previous_response_id", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
|
||||
updated, err := setPreviousResponseIDToRawPayload(payload, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(payload), string(updated))
|
||||
})
|
||||
|
||||
t.Run("set_previous_response_id_when_missing", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
|
||||
updated, err := setPreviousResponseIDToRawPayload(payload, "resp_target")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
|
||||
require.Equal(t, "gpt-5.1", gjson.GetBytes(updated, "model").String())
|
||||
})
|
||||
|
||||
t.Run("overwrite_existing_previous_response_id", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_old"}`)
|
||||
updated, err := setPreviousResponseIDToRawPayload(payload, "resp_new")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "resp_new", gjson.GetBytes(updated, "previous_response_id").String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
storeDisabled bool
|
||||
turn int
|
||||
hasFunctionCallOutput bool
|
||||
currentPreviousResponse string
|
||||
expectedPrevious string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "infer_when_all_conditions_match",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "resp_1",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "skip_when_store_enabled",
|
||||
storeDisabled: false,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_on_first_turn",
|
||||
storeDisabled: true,
|
||||
turn: 1,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_without_function_call_output",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: false,
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_when_request_already_has_previous_response_id",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
currentPreviousResponse: "resp_client",
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_when_last_turn_response_id_missing",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "trim_whitespace_before_judgement",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: " resp_2 ",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||||
tt.storeDisabled,
|
||||
tt.turn,
|
||||
tt.hasFunctionCallOutput,
|
||||
tt.currentPreviousResponse,
|
||||
tt.expectedPrevious,
|
||||
)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIWSInputIsPrefixExtended(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
previous []byte
|
||||
current []byte
|
||||
want bool
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "both_missing_input",
|
||||
previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`),
|
||||
current: []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_1"}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "previous_missing_current_empty_array",
|
||||
previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`),
|
||||
current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "previous_missing_current_non_empty_array",
|
||||
previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`),
|
||||
current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[{"type":"input_text","text":"hello"}]}`),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "array_prefix_match",
|
||||
previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`),
|
||||
current: []byte(`{"input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}]}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "array_prefix_mismatch",
|
||||
previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`),
|
||||
current: []byte(`{"input":[{"type":"input_text","text":"different"}]}`),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "current_shorter_than_previous",
|
||||
previous: []byte(`{"input":[{"type":"input_text","text":"a"},{"type":"input_text","text":"b"}]}`),
|
||||
current: []byte(`{"input":[{"type":"input_text","text":"a"}]}`),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "previous_has_input_current_missing",
|
||||
previous: []byte(`{"input":[{"type":"input_text","text":"a"}]}`),
|
||||
current: []byte(`{"model":"gpt-5.1"}`),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "input_string_treated_as_single_item",
|
||||
previous: []byte(`{"input":"hello"}`),
|
||||
current: []byte(`{"input":"hello"}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "current_invalid_input_json",
|
||||
previous: []byte(`{"input":[]}`),
|
||||
current: []byte(`{"input":[}`),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid_input_json",
|
||||
previous: []byte(`{"input":[}`),
|
||||
current: []byte(`{"input":[]}`),
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := openAIWSInputIsPrefixExtended(tt.previous, tt.current)
|
||||
if tt.expectErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIWSJSONForCompare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
normalized, err := normalizeOpenAIWSJSONForCompare([]byte(`{"b":2,"a":1}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, `{"a":1,"b":2}`, string(normalized))
|
||||
|
||||
_, err = normalizeOpenAIWSJSONForCompare([]byte(" "))
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = normalizeOpenAIWSJSONForCompare([]byte(`{"a":`))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIWSJSONForCompareOrRaw(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, `{"a":1,"b":2}`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"b":2,"a":1}`))))
|
||||
require.Equal(t, `{"a":`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"a":`))))
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
normalized, err := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(
|
||||
[]byte(`{"model":"gpt-5.1","input":[1],"previous_response_id":"resp_x","metadata":{"b":2,"a":1}}`),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.False(t, gjson.GetBytes(normalized, "input").Exists())
|
||||
require.False(t, gjson.GetBytes(normalized, "previous_response_id").Exists())
|
||||
require.Equal(t, float64(1), gjson.GetBytes(normalized, "metadata.a").Float())
|
||||
|
||||
_, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(nil)
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID([]byte(`[]`))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOpenAIWSExtractNormalizedInputSequence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("empty_payload", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence(nil)
|
||||
require.NoError(t, err)
|
||||
require.False(t, exists)
|
||||
require.Nil(t, items)
|
||||
})
|
||||
|
||||
t.Run("input_missing", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"type":"response.create"}`))
|
||||
require.NoError(t, err)
|
||||
require.False(t, exists)
|
||||
require.Nil(t, items)
|
||||
})
|
||||
|
||||
t.Run("input_array", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[{"type":"input_text","text":"hello"}]}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 1)
|
||||
})
|
||||
|
||||
t.Run("input_object", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":{"type":"input_text","text":"hello"}}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 1)
|
||||
})
|
||||
|
||||
t.Run("input_string", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":"hello"}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 1)
|
||||
require.Equal(t, `"hello"`, string(items[0]))
|
||||
})
|
||||
|
||||
t.Run("input_number", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":42}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 1)
|
||||
require.Equal(t, "42", string(items[0]))
|
||||
})
|
||||
|
||||
t.Run("input_bool", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":true}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 1)
|
||||
require.Equal(t, "true", string(items[0]))
|
||||
})
|
||||
|
||||
t.Run("input_null", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":null}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 1)
|
||||
require.Equal(t, "null", string(items[0]))
|
||||
})
|
||||
|
||||
t.Run("input_invalid_array_json", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[}`))
|
||||
require.Error(t, err)
|
||||
require.True(t, exists)
|
||||
require.Nil(t, items)
|
||||
})
|
||||
}
|
||||
|
||||
func TestShouldKeepIngressPreviousResponseID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
previousPayload := []byte(`{
|
||||
"type":"response.create",
|
||||
"model":"gpt-5.1",
|
||||
"store":false,
|
||||
"tools":[{"type":"function","name":"tool_a"}],
|
||||
"input":[{"type":"input_text","text":"hello"}]
|
||||
}`)
|
||||
currentStrictPayload := []byte(`{
|
||||
"type":"response.create",
|
||||
"model":"gpt-5.1",
|
||||
"store":false,
|
||||
"tools":[{"name":"tool_a","type":"function"}],
|
||||
"previous_response_id":"resp_turn_1",
|
||||
"input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}]
|
||||
}`)
|
||||
|
||||
t.Run("strict_incremental_keep", func(t *testing.T) {
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_1", false)
|
||||
require.NoError(t, err)
|
||||
require.True(t, keep)
|
||||
require.Equal(t, "strict_incremental_ok", reason)
|
||||
})
|
||||
|
||||
t.Run("missing_previous_response_id", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`)
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, keep)
|
||||
require.Equal(t, "missing_previous_response_id", reason)
|
||||
})
|
||||
|
||||
t.Run("missing_last_turn_response_id", func(t *testing.T) {
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "", false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, keep)
|
||||
require.Equal(t, "missing_last_turn_response_id", reason)
|
||||
})
|
||||
|
||||
t.Run("previous_response_id_mismatch", func(t *testing.T) {
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_other", false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, keep)
|
||||
require.Equal(t, "previous_response_id_mismatch", reason)
|
||||
})
|
||||
|
||||
t.Run("missing_previous_turn_payload", func(t *testing.T) {
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(nil, currentStrictPayload, "resp_turn_1", false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, keep)
|
||||
require.Equal(t, "missing_previous_turn_payload", reason)
|
||||
})
|
||||
|
||||
t.Run("non_input_changed", func(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"type":"response.create",
|
||||
"model":"gpt-5.1-mini",
|
||||
"store":false,
|
||||
"tools":[{"type":"function","name":"tool_a"}],
|
||||
"previous_response_id":"resp_turn_1",
|
||||
"input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]
|
||||
}`)
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, keep)
|
||||
require.Equal(t, "non_input_changed", reason)
|
||||
})
|
||||
|
||||
t.Run("delta_input_keeps_previous_response_id", func(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"type":"response.create",
|
||||
"model":"gpt-5.1",
|
||||
"store":false,
|
||||
"tools":[{"type":"function","name":"tool_a"}],
|
||||
"previous_response_id":"resp_turn_1",
|
||||
"input":[{"type":"input_text","text":"different"}]
|
||||
}`)
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
|
||||
require.NoError(t, err)
|
||||
require.True(t, keep)
|
||||
require.Equal(t, "strict_incremental_ok", reason)
|
||||
})
|
||||
|
||||
t.Run("function_call_output_keeps_previous_response_id", func(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"type":"response.create",
|
||||
"model":"gpt-5.1",
|
||||
"store":false,
|
||||
"previous_response_id":"resp_external",
|
||||
"input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
|
||||
}`)
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, keep)
|
||||
require.Equal(t, "has_function_call_output", reason)
|
||||
})
|
||||
|
||||
t.Run("non_input_compare_error", func(t *testing.T) {
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID([]byte(`[]`), currentStrictPayload, "resp_turn_1", false)
|
||||
require.Error(t, err)
|
||||
require.False(t, keep)
|
||||
require.Equal(t, "non_input_compare_error", reason)
|
||||
})
|
||||
|
||||
t.Run("current_payload_compare_error", func(t *testing.T) {
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, []byte(`{"previous_response_id":"resp_turn_1","input":[}`), "resp_turn_1", false)
|
||||
require.Error(t, err)
|
||||
require.False(t, keep)
|
||||
require.Equal(t, "non_input_compare_error", reason)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildOpenAIWSReplayInputSequence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lastFull := []json.RawMessage{
|
||||
json.RawMessage(`{"type":"input_text","text":"hello"}`),
|
||||
}
|
||||
|
||||
t.Run("no_previous_response_id_use_current", func(t *testing.T) {
|
||||
items, exists, err := buildOpenAIWSReplayInputSequence(
|
||||
lastFull,
|
||||
true,
|
||||
[]byte(`{"input":[{"type":"input_text","text":"new"}]}`),
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 1)
|
||||
require.Equal(t, "new", gjson.GetBytes(items[0], "text").String())
|
||||
})
|
||||
|
||||
t.Run("previous_response_id_delta_append", func(t *testing.T) {
|
||||
items, exists, err := buildOpenAIWSReplayInputSequence(
|
||||
lastFull,
|
||||
true,
|
||||
[]byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"world"}]}`),
|
||||
true,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 2)
|
||||
require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String())
|
||||
require.Equal(t, "world", gjson.GetBytes(items[1], "text").String())
|
||||
})
|
||||
|
||||
t.Run("previous_response_id_full_input_replace", func(t *testing.T) {
|
||||
items, exists, err := buildOpenAIWSReplayInputSequence(
|
||||
lastFull,
|
||||
true,
|
||||
[]byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]}`),
|
||||
true,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 2)
|
||||
require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String())
|
||||
require.Equal(t, "world", gjson.GetBytes(items[1], "text").String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetOpenAIWSPayloadInputSequence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("set_items", func(t *testing.T) {
|
||||
original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`)
|
||||
items := []json.RawMessage{
|
||||
json.RawMessage(`{"type":"input_text","text":"hello"}`),
|
||||
json.RawMessage(`{"type":"input_text","text":"world"}`),
|
||||
}
|
||||
updated, err := setOpenAIWSPayloadInputSequence(original, items, true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hello", gjson.GetBytes(updated, "input.0.text").String())
|
||||
require.Equal(t, "world", gjson.GetBytes(updated, "input.1.text").String())
|
||||
})
|
||||
|
||||
t.Run("preserve_empty_array_not_null", func(t *testing.T) {
|
||||
original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`)
|
||||
updated, err := setOpenAIWSPayloadInputSequence(original, nil, true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, gjson.GetBytes(updated, "input").IsArray())
|
||||
require.Len(t, gjson.GetBytes(updated, "input").Array(), 0)
|
||||
require.False(t, gjson.GetBytes(updated, "input").Type == gjson.Null)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCloneOpenAIWSRawMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("nil_slice", func(t *testing.T) {
|
||||
cloned := cloneOpenAIWSRawMessages(nil)
|
||||
require.Nil(t, cloned)
|
||||
})
|
||||
|
||||
t.Run("empty_slice", func(t *testing.T) {
|
||||
items := make([]json.RawMessage, 0)
|
||||
cloned := cloneOpenAIWSRawMessages(items)
|
||||
require.NotNil(t, cloned)
|
||||
require.Len(t, cloned, 0)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplyOpenAIWSRetryPayloadStrategy_KeepPromptCacheKey(t *testing.T) {
|
||||
payload := map[string]any{
|
||||
"model": "gpt-5.3-codex",
|
||||
"prompt_cache_key": "pcache_123",
|
||||
"include": []any{"reasoning.encrypted_content"},
|
||||
"text": map[string]any{
|
||||
"verbosity": "low",
|
||||
},
|
||||
"tools": []any{map[string]any{"type": "function"}},
|
||||
}
|
||||
|
||||
strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 3)
|
||||
require.Equal(t, "trim_optional_fields", strategy)
|
||||
require.Contains(t, removed, "include")
|
||||
require.NotContains(t, removed, "prompt_cache_key")
|
||||
require.Equal(t, "pcache_123", payload["prompt_cache_key"])
|
||||
require.NotContains(t, payload, "include")
|
||||
require.Contains(t, payload, "text")
|
||||
}
|
||||
|
||||
func TestApplyOpenAIWSRetryPayloadStrategy_AttemptSixKeepsSemanticFields(t *testing.T) {
|
||||
payload := map[string]any{
|
||||
"prompt_cache_key": "pcache_456",
|
||||
"instructions": "long instructions",
|
||||
"tools": []any{map[string]any{"type": "function"}},
|
||||
"parallel_tool_calls": true,
|
||||
"tool_choice": "auto",
|
||||
"include": []any{"reasoning.encrypted_content"},
|
||||
"text": map[string]any{"verbosity": "high"},
|
||||
}
|
||||
|
||||
strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 6)
|
||||
require.Equal(t, "trim_optional_fields", strategy)
|
||||
require.Contains(t, removed, "include")
|
||||
require.NotContains(t, removed, "prompt_cache_key")
|
||||
require.Equal(t, "pcache_456", payload["prompt_cache_key"])
|
||||
require.Contains(t, payload, "instructions")
|
||||
require.Contains(t, payload, "tools")
|
||||
require.Contains(t, payload, "tool_choice")
|
||||
require.Contains(t, payload, "parallel_tool_calls")
|
||||
require.Contains(t, payload, "text")
|
||||
}
|
||||
1306
backend/internal/service/openai_ws_forwarder_success_test.go
Normal file
1306
backend/internal/service/openai_ws_forwarder_success_test.go
Normal file
File diff suppressed because it is too large
Load Diff
1706
backend/internal/service/openai_ws_pool.go
Normal file
1706
backend/internal/service/openai_ws_pool.go
Normal file
File diff suppressed because it is too large
Load Diff
58
backend/internal/service/openai_ws_pool_benchmark_test.go
Normal file
58
backend/internal/service/openai_ws_pool_benchmark_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
func BenchmarkOpenAIWSPoolAcquire(b *testing.B) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4
|
||||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 256
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
|
||||
|
||||
pool := newOpenAIWSConnPool(cfg)
|
||||
pool.setClientDialerForTest(&openAIWSCountingDialer{})
|
||||
|
||||
account := &Account{ID: 1001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
req := openAIWSAcquireRequest{
|
||||
Account: account,
|
||||
WSURL: "wss://example.com/v1/responses",
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
lease, err := pool.Acquire(ctx, req)
|
||||
if err != nil {
|
||||
b.Fatalf("warm acquire failed: %v", err)
|
||||
}
|
||||
lease.Release()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
var (
|
||||
got *openAIWSConnLease
|
||||
acquireErr error
|
||||
)
|
||||
for retry := 0; retry < 3; retry++ {
|
||||
got, acquireErr = pool.Acquire(ctx, req)
|
||||
if acquireErr == nil {
|
||||
break
|
||||
}
|
||||
if !errors.Is(acquireErr, errOpenAIWSConnClosed) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if acquireErr != nil {
|
||||
b.Fatalf("acquire failed: %v", acquireErr)
|
||||
}
|
||||
got.Release()
|
||||
}
|
||||
})
|
||||
}
|
||||
1709
backend/internal/service/openai_ws_pool_test.go
Normal file
1709
backend/internal/service/openai_ws_pool_test.go
Normal file
File diff suppressed because it is too large
Load Diff
1218
backend/internal/service/openai_ws_protocol_forward_test.go
Normal file
1218
backend/internal/service/openai_ws_protocol_forward_test.go
Normal file
File diff suppressed because it is too large
Load Diff
117
backend/internal/service/openai_ws_protocol_resolver.go
Normal file
117
backend/internal/service/openai_ws_protocol_resolver.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package service
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/config"
|
||||
|
||||
// OpenAIUpstreamTransport 表示 OpenAI 上游传输协议。
|
||||
type OpenAIUpstreamTransport string
|
||||
|
||||
const (
|
||||
OpenAIUpstreamTransportAny OpenAIUpstreamTransport = ""
|
||||
OpenAIUpstreamTransportHTTPSSE OpenAIUpstreamTransport = "http_sse"
|
||||
OpenAIUpstreamTransportResponsesWebsocket OpenAIUpstreamTransport = "responses_websockets"
|
||||
OpenAIUpstreamTransportResponsesWebsocketV2 OpenAIUpstreamTransport = "responses_websockets_v2"
|
||||
)
|
||||
|
||||
// OpenAIWSProtocolDecision 表示协议决策结果。
|
||||
type OpenAIWSProtocolDecision struct {
|
||||
Transport OpenAIUpstreamTransport
|
||||
Reason string
|
||||
}
|
||||
|
||||
// OpenAIWSProtocolResolver 定义 OpenAI 上游协议决策。
|
||||
type OpenAIWSProtocolResolver interface {
|
||||
Resolve(account *Account) OpenAIWSProtocolDecision
|
||||
}
|
||||
|
||||
type defaultOpenAIWSProtocolResolver struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewOpenAIWSProtocolResolver 创建默认协议决策器。
|
||||
func NewOpenAIWSProtocolResolver(cfg *config.Config) OpenAIWSProtocolResolver {
|
||||
return &defaultOpenAIWSProtocolResolver{cfg: cfg}
|
||||
}
|
||||
|
||||
func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProtocolDecision {
|
||||
if account == nil {
|
||||
return openAIWSHTTPDecision("account_missing")
|
||||
}
|
||||
if !account.IsOpenAI() {
|
||||
return openAIWSHTTPDecision("platform_not_openai")
|
||||
}
|
||||
if account.IsOpenAIWSForceHTTPEnabled() {
|
||||
return openAIWSHTTPDecision("account_force_http")
|
||||
}
|
||||
if r == nil || r.cfg == nil {
|
||||
return openAIWSHTTPDecision("config_missing")
|
||||
}
|
||||
|
||||
wsCfg := r.cfg.Gateway.OpenAIWS
|
||||
if wsCfg.ForceHTTP {
|
||||
return openAIWSHTTPDecision("global_force_http")
|
||||
}
|
||||
if !wsCfg.Enabled {
|
||||
return openAIWSHTTPDecision("global_disabled")
|
||||
}
|
||||
if account.IsOpenAIOAuth() {
|
||||
if !wsCfg.OAuthEnabled {
|
||||
return openAIWSHTTPDecision("oauth_disabled")
|
||||
}
|
||||
} else if account.IsOpenAIApiKey() {
|
||||
if !wsCfg.APIKeyEnabled {
|
||||
return openAIWSHTTPDecision("apikey_disabled")
|
||||
}
|
||||
} else {
|
||||
return openAIWSHTTPDecision("unknown_auth_type")
|
||||
}
|
||||
if wsCfg.ModeRouterV2Enabled {
|
||||
mode := account.ResolveOpenAIResponsesWebSocketV2Mode(wsCfg.IngressModeDefault)
|
||||
switch mode {
|
||||
case OpenAIWSIngressModeOff:
|
||||
return openAIWSHTTPDecision("account_mode_off")
|
||||
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
|
||||
// continue
|
||||
default:
|
||||
return openAIWSHTTPDecision("account_mode_off")
|
||||
}
|
||||
if account.Concurrency <= 0 {
|
||||
return openAIWSHTTPDecision("account_concurrency_invalid")
|
||||
}
|
||||
if wsCfg.ResponsesWebsocketsV2 {
|
||||
return OpenAIWSProtocolDecision{
|
||||
Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
Reason: "ws_v2_mode_" + mode,
|
||||
}
|
||||
}
|
||||
if wsCfg.ResponsesWebsockets {
|
||||
return OpenAIWSProtocolDecision{
|
||||
Transport: OpenAIUpstreamTransportResponsesWebsocket,
|
||||
Reason: "ws_v1_mode_" + mode,
|
||||
}
|
||||
}
|
||||
return openAIWSHTTPDecision("feature_disabled")
|
||||
}
|
||||
if !account.IsOpenAIResponsesWebSocketV2Enabled() {
|
||||
return openAIWSHTTPDecision("account_disabled")
|
||||
}
|
||||
if wsCfg.ResponsesWebsocketsV2 {
|
||||
return OpenAIWSProtocolDecision{
|
||||
Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
Reason: "ws_v2_enabled",
|
||||
}
|
||||
}
|
||||
if wsCfg.ResponsesWebsockets {
|
||||
return OpenAIWSProtocolDecision{
|
||||
Transport: OpenAIUpstreamTransportResponsesWebsocket,
|
||||
Reason: "ws_v1_enabled",
|
||||
}
|
||||
}
|
||||
return openAIWSHTTPDecision("feature_disabled")
|
||||
}
|
||||
|
||||
func openAIWSHTTPDecision(reason string) OpenAIWSProtocolDecision {
|
||||
return OpenAIWSProtocolDecision{
|
||||
Transport: OpenAIUpstreamTransportHTTPSSE,
|
||||
Reason: reason,
|
||||
}
|
||||
}
|
||||
203
backend/internal/service/openai_ws_protocol_resolver_test.go
Normal file
203
backend/internal/service/openai_ws_protocol_resolver_test.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpenAIWSProtocolResolver_Resolve(t *testing.T) {
|
||||
baseCfg := &config.Config{}
|
||||
baseCfg.Gateway.OpenAIWS.Enabled = true
|
||||
baseCfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
baseCfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
baseCfg.Gateway.OpenAIWS.ResponsesWebsockets = false
|
||||
baseCfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
|
||||
openAIOAuthEnabled := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("v2优先", func(t *testing.T) {
|
||||
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(openAIOAuthEnabled)
|
||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||
require.Equal(t, "ws_v2_enabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("v2关闭时回退v1", func(t *testing.T) {
|
||||
cfg := *baseCfg
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
|
||||
|
||||
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
|
||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocket, decision.Transport)
|
||||
require.Equal(t, "ws_v1_enabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("透传开关不影响WS协议判定", func(t *testing.T) {
|
||||
account := *openAIOAuthEnabled
|
||||
account.Extra = map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||
"openai_passthrough": true,
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
|
||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||
require.Equal(t, "ws_v2_enabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("账号级强制HTTP", func(t *testing.T) {
|
||||
account := *openAIOAuthEnabled
|
||||
account.Extra = map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||
"openai_ws_force_http": true,
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "account_force_http", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("全局关闭保持HTTP", func(t *testing.T) {
|
||||
cfg := *baseCfg
|
||||
cfg.Gateway.OpenAIWS.Enabled = false
|
||||
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "global_disabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("账号开关关闭保持HTTP", func(t *testing.T) {
|
||||
account := *openAIOAuthEnabled
|
||||
account.Extra = map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": false,
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "account_disabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
|
||||
account := *openAIOAuthEnabled
|
||||
account.Extra = map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "account_disabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("兼容旧键openai_ws_enabled", func(t *testing.T) {
|
||||
account := *openAIOAuthEnabled
|
||||
account.Extra = map[string]any{
|
||||
"openai_ws_enabled": true,
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
|
||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||
require.Equal(t, "ws_v2_enabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("按账号类型开关控制", func(t *testing.T) {
|
||||
cfg := *baseCfg
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = false
|
||||
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "oauth_disabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("API Key 账号关闭开关时回退HTTP", func(t *testing.T) {
|
||||
cfg := *baseCfg
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = false
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(account)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "apikey_disabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("未知认证类型回退HTTP", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: "unknown_type",
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(account)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "unknown_auth_type", decision.Reason)
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
|
||||
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("dedicated mode routes to ws v2", func(t *testing.T) {
|
||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
|
||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||
require.Equal(t, "ws_v2_mode_dedicated", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("off mode routes to http", func(t *testing.T) {
|
||||
offAccount := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff,
|
||||
},
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(offAccount)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "account_mode_off", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) {
|
||||
legacyAccount := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
|
||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||
require.Equal(t, "ws_v2_mode_shared", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
|
||||
invalidConcurrency := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
|
||||
},
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "account_concurrency_invalid", decision.Reason)
|
||||
})
|
||||
}
|
||||
440
backend/internal/service/openai_ws_state_store.go
Normal file
440
backend/internal/service/openai_ws_state_store.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIWSResponseAccountCachePrefix = "openai:response:"
|
||||
openAIWSStateStoreCleanupInterval = time.Minute
|
||||
openAIWSStateStoreCleanupMaxPerMap = 512
|
||||
openAIWSStateStoreMaxEntriesPerMap = 65536
|
||||
openAIWSStateStoreRedisTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
type openAIWSAccountBinding struct {
|
||||
accountID int64
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type openAIWSConnBinding struct {
|
||||
connID string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type openAIWSTurnStateBinding struct {
|
||||
turnState string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type openAIWSSessionConnBinding struct {
|
||||
connID string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// OpenAIWSStateStore 管理 WSv2 的粘连状态。
|
||||
// - response_id -> account_id 用于续链路由
|
||||
// - response_id -> conn_id 用于连接内上下文复用
|
||||
//
|
||||
// response_id -> account_id 优先走 GatewayCache(Redis),同时维护本地热缓存。
|
||||
// response_id -> conn_id 仅在本进程内有效。
|
||||
type OpenAIWSStateStore interface {
|
||||
BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error
|
||||
GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error)
|
||||
DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error
|
||||
|
||||
BindResponseConn(responseID, connID string, ttl time.Duration)
|
||||
GetResponseConn(responseID string) (string, bool)
|
||||
DeleteResponseConn(responseID string)
|
||||
|
||||
BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration)
|
||||
GetSessionTurnState(groupID int64, sessionHash string) (string, bool)
|
||||
DeleteSessionTurnState(groupID int64, sessionHash string)
|
||||
|
||||
BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration)
|
||||
GetSessionConn(groupID int64, sessionHash string) (string, bool)
|
||||
DeleteSessionConn(groupID int64, sessionHash string)
|
||||
}
|
||||
|
||||
type defaultOpenAIWSStateStore struct {
|
||||
cache GatewayCache
|
||||
|
||||
responseToAccountMu sync.RWMutex
|
||||
responseToAccount map[string]openAIWSAccountBinding
|
||||
responseToConnMu sync.RWMutex
|
||||
responseToConn map[string]openAIWSConnBinding
|
||||
sessionToTurnStateMu sync.RWMutex
|
||||
sessionToTurnState map[string]openAIWSTurnStateBinding
|
||||
sessionToConnMu sync.RWMutex
|
||||
sessionToConn map[string]openAIWSSessionConnBinding
|
||||
|
||||
lastCleanupUnixNano atomic.Int64
|
||||
}
|
||||
|
||||
// NewOpenAIWSStateStore 创建默认 WS 状态存储。
|
||||
func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore {
|
||||
store := &defaultOpenAIWSStateStore{
|
||||
cache: cache,
|
||||
responseToAccount: make(map[string]openAIWSAccountBinding, 256),
|
||||
responseToConn: make(map[string]openAIWSConnBinding, 256),
|
||||
sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256),
|
||||
sessionToConn: make(map[string]openAIWSSessionConnBinding, 256),
|
||||
}
|
||||
store.lastCleanupUnixNano.Store(time.Now().UnixNano())
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
if id == "" || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
ttl = normalizeOpenAIWSTTL(ttl)
|
||||
s.maybeCleanup()
|
||||
|
||||
expiresAt := time.Now().Add(ttl)
|
||||
s.responseToAccountMu.Lock()
|
||||
ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap)
|
||||
s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt}
|
||||
s.responseToAccountMu.Unlock()
|
||||
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
cacheKey := openAIWSResponseAccountCacheKey(id)
|
||||
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl)
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
if id == "" {
|
||||
return 0, nil
|
||||
}
|
||||
s.maybeCleanup()
|
||||
|
||||
now := time.Now()
|
||||
s.responseToAccountMu.RLock()
|
||||
if binding, ok := s.responseToAccount[id]; ok {
|
||||
if now.Before(binding.expiresAt) {
|
||||
accountID := binding.accountID
|
||||
s.responseToAccountMu.RUnlock()
|
||||
return accountID, nil
|
||||
}
|
||||
}
|
||||
s.responseToAccountMu.RUnlock()
|
||||
|
||||
if s.cache == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
cacheKey := openAIWSResponseAccountCacheKey(id)
|
||||
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey)
|
||||
if err != nil || accountID <= 0 {
|
||||
// 缓存读取失败不阻断主流程,按未命中降级。
|
||||
return 0, nil
|
||||
}
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
if id == "" {
|
||||
return nil
|
||||
}
|
||||
s.responseToAccountMu.Lock()
|
||||
delete(s.responseToAccount, id)
|
||||
s.responseToAccountMu.Unlock()
|
||||
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id))
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
conn := strings.TrimSpace(connID)
|
||||
if id == "" || conn == "" {
|
||||
return
|
||||
}
|
||||
ttl = normalizeOpenAIWSTTL(ttl)
|
||||
s.maybeCleanup()
|
||||
|
||||
s.responseToConnMu.Lock()
|
||||
ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap)
|
||||
s.responseToConn[id] = openAIWSConnBinding{
|
||||
connID: conn,
|
||||
expiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
s.responseToConnMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
if id == "" {
|
||||
return "", false
|
||||
}
|
||||
s.maybeCleanup()
|
||||
|
||||
now := time.Now()
|
||||
s.responseToConnMu.RLock()
|
||||
binding, ok := s.responseToConn[id]
|
||||
s.responseToConnMu.RUnlock()
|
||||
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
|
||||
return "", false
|
||||
}
|
||||
return binding.connID, true
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
s.responseToConnMu.Lock()
|
||||
delete(s.responseToConn, id)
|
||||
s.responseToConnMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
state := strings.TrimSpace(turnState)
|
||||
if key == "" || state == "" {
|
||||
return
|
||||
}
|
||||
ttl = normalizeOpenAIWSTTL(ttl)
|
||||
s.maybeCleanup()
|
||||
|
||||
s.sessionToTurnStateMu.Lock()
|
||||
ensureBindingCapacity(s.sessionToTurnState, key, openAIWSStateStoreMaxEntriesPerMap)
|
||||
s.sessionToTurnState[key] = openAIWSTurnStateBinding{
|
||||
turnState: state,
|
||||
expiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
s.sessionToTurnStateMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
if key == "" {
|
||||
return "", false
|
||||
}
|
||||
s.maybeCleanup()
|
||||
|
||||
now := time.Now()
|
||||
s.sessionToTurnStateMu.RLock()
|
||||
binding, ok := s.sessionToTurnState[key]
|
||||
s.sessionToTurnStateMu.RUnlock()
|
||||
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.turnState) == "" {
|
||||
return "", false
|
||||
}
|
||||
return binding.turnState, true
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessionHash string) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
s.sessionToTurnStateMu.Lock()
|
||||
delete(s.sessionToTurnState, key)
|
||||
s.sessionToTurnStateMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
conn := strings.TrimSpace(connID)
|
||||
if key == "" || conn == "" {
|
||||
return
|
||||
}
|
||||
ttl = normalizeOpenAIWSTTL(ttl)
|
||||
s.maybeCleanup()
|
||||
|
||||
s.sessionToConnMu.Lock()
|
||||
ensureBindingCapacity(s.sessionToConn, key, openAIWSStateStoreMaxEntriesPerMap)
|
||||
s.sessionToConn[key] = openAIWSSessionConnBinding{
|
||||
connID: conn,
|
||||
expiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
s.sessionToConnMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash string) (string, bool) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
if key == "" {
|
||||
return "", false
|
||||
}
|
||||
s.maybeCleanup()
|
||||
|
||||
now := time.Now()
|
||||
s.sessionToConnMu.RLock()
|
||||
binding, ok := s.sessionToConn[key]
|
||||
s.sessionToConnMu.RUnlock()
|
||||
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
|
||||
return "", false
|
||||
}
|
||||
return binding.connID, true
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) DeleteSessionConn(groupID int64, sessionHash string) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
s.sessionToConnMu.Lock()
|
||||
delete(s.sessionToConn, key)
|
||||
s.sessionToConnMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) maybeCleanup() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
last := time.Unix(0, s.lastCleanupUnixNano.Load())
|
||||
if now.Sub(last) < openAIWSStateStoreCleanupInterval {
|
||||
return
|
||||
}
|
||||
if !s.lastCleanupUnixNano.CompareAndSwap(last.UnixNano(), now.UnixNano()) {
|
||||
return
|
||||
}
|
||||
|
||||
// 增量限额清理,避免高规模下一次性全量扫描导致长时间阻塞。
|
||||
s.responseToAccountMu.Lock()
|
||||
cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap)
|
||||
s.responseToAccountMu.Unlock()
|
||||
|
||||
s.responseToConnMu.Lock()
|
||||
cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap)
|
||||
s.responseToConnMu.Unlock()
|
||||
|
||||
s.sessionToTurnStateMu.Lock()
|
||||
cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap)
|
||||
s.sessionToTurnStateMu.Unlock()
|
||||
|
||||
s.sessionToConnMu.Lock()
|
||||
cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap)
|
||||
s.sessionToConnMu.Unlock()
|
||||
}
|
||||
|
||||
func cleanupExpiredAccountBindings(bindings map[string]openAIWSAccountBinding, now time.Time, maxScan int) {
|
||||
if len(bindings) == 0 || maxScan <= 0 {
|
||||
return
|
||||
}
|
||||
scanned := 0
|
||||
for key, binding := range bindings {
|
||||
if now.After(binding.expiresAt) {
|
||||
delete(bindings, key)
|
||||
}
|
||||
scanned++
|
||||
if scanned >= maxScan {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now time.Time, maxScan int) {
|
||||
if len(bindings) == 0 || maxScan <= 0 {
|
||||
return
|
||||
}
|
||||
scanned := 0
|
||||
for key, binding := range bindings {
|
||||
if now.After(binding.expiresAt) {
|
||||
delete(bindings, key)
|
||||
}
|
||||
scanned++
|
||||
if scanned >= maxScan {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) {
|
||||
if len(bindings) == 0 || maxScan <= 0 {
|
||||
return
|
||||
}
|
||||
scanned := 0
|
||||
for key, binding := range bindings {
|
||||
if now.After(binding.expiresAt) {
|
||||
delete(bindings, key)
|
||||
}
|
||||
scanned++
|
||||
if scanned >= maxScan {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) {
|
||||
if len(bindings) == 0 || maxScan <= 0 {
|
||||
return
|
||||
}
|
||||
scanned := 0
|
||||
for key, binding := range bindings {
|
||||
if now.After(binding.expiresAt) {
|
||||
delete(bindings, key)
|
||||
}
|
||||
scanned++
|
||||
if scanned >= maxScan {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) {
|
||||
if len(bindings) < maxEntries || maxEntries <= 0 {
|
||||
return
|
||||
}
|
||||
if _, exists := bindings[incomingKey]; exists {
|
||||
return
|
||||
}
|
||||
// 固定上限保护:淘汰任意一项,优先保证内存有界。
|
||||
for key := range bindings {
|
||||
delete(bindings, key)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpenAIWSResponseID(responseID string) string {
|
||||
return strings.TrimSpace(responseID)
|
||||
}
|
||||
|
||||
func openAIWSResponseAccountCacheKey(responseID string) string {
|
||||
sum := sha256.Sum256([]byte(responseID))
|
||||
return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration {
|
||||
if ttl <= 0 {
|
||||
return time.Hour
|
||||
}
|
||||
return ttl
|
||||
}
|
||||
|
||||
func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string {
|
||||
hash := strings.TrimSpace(sessionHash)
|
||||
if hash == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%d:%s", groupID, hash)
|
||||
}
|
||||
|
||||
func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithTimeout(ctx, openAIWSStateStoreRedisTimeout)
|
||||
}
|
||||
235
backend/internal/service/openai_ws_state_store_test.go
Normal file
235
backend/internal/service/openai_ws_state_store_test.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpenAIWSStateStore_BindGetDeleteResponseAccount(t *testing.T) {
|
||||
cache := &stubGatewayCache{}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
ctx := context.Background()
|
||||
groupID := int64(7)
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_abc", 101, time.Minute))
|
||||
|
||||
accountID, err := store.GetResponseAccount(ctx, groupID, "resp_abc")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(101), accountID)
|
||||
|
||||
require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_abc"))
|
||||
accountID, err = store.GetResponseAccount(ctx, groupID, "resp_abc")
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, accountID)
|
||||
}
|
||||
|
||||
func TestOpenAIWSStateStore_ResponseConnTTL(t *testing.T) {
|
||||
store := NewOpenAIWSStateStore(nil)
|
||||
store.BindResponseConn("resp_conn", "conn_1", 30*time.Millisecond)
|
||||
|
||||
connID, ok := store.GetResponseConn("resp_conn")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "conn_1", connID)
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
_, ok = store.GetResponseConn("resp_conn")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestOpenAIWSStateStore_SessionTurnStateTTL(t *testing.T) {
|
||||
store := NewOpenAIWSStateStore(nil)
|
||||
store.BindSessionTurnState(9, "session_hash_1", "turn_state_1", 30*time.Millisecond)
|
||||
|
||||
state, ok := store.GetSessionTurnState(9, "session_hash_1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "turn_state_1", state)
|
||||
|
||||
// group 隔离
|
||||
_, ok = store.GetSessionTurnState(10, "session_hash_1")
|
||||
require.False(t, ok)
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
_, ok = store.GetSessionTurnState(9, "session_hash_1")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestOpenAIWSStateStore_SessionConnTTL(t *testing.T) {
|
||||
store := NewOpenAIWSStateStore(nil)
|
||||
store.BindSessionConn(9, "session_hash_conn_1", "conn_1", 30*time.Millisecond)
|
||||
|
||||
connID, ok := store.GetSessionConn(9, "session_hash_conn_1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "conn_1", connID)
|
||||
|
||||
// group 隔离
|
||||
_, ok = store.GetSessionConn(10, "session_hash_conn_1")
|
||||
require.False(t, ok)
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
_, ok = store.GetSessionConn(9, "session_hash_conn_1")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing.T) {
|
||||
cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
ctx := context.Background()
|
||||
groupID := int64(17)
|
||||
responseID := "resp_cache_stale"
|
||||
cacheKey := openAIWSResponseAccountCacheKey(responseID)
|
||||
|
||||
cache.sessionBindings[cacheKey] = 501
|
||||
accountID, err := store.GetResponseAccount(ctx, groupID, responseID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(501), accountID)
|
||||
|
||||
delete(cache.sessionBindings, cacheKey)
|
||||
accountID, err = store.GetResponseAccount(ctx, groupID, responseID)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, accountID, "上游缓存失效后不应继续命中本地陈旧映射")
|
||||
}
|
||||
|
||||
func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T) {
|
||||
raw := NewOpenAIWSStateStore(nil)
|
||||
store, ok := raw.(*defaultOpenAIWSStateStore)
|
||||
require.True(t, ok)
|
||||
|
||||
expiredAt := time.Now().Add(-time.Minute)
|
||||
total := 2048
|
||||
store.responseToConnMu.Lock()
|
||||
for i := 0; i < total; i++ {
|
||||
store.responseToConn[fmt.Sprintf("resp_%d", i)] = openAIWSConnBinding{
|
||||
connID: "conn_incremental",
|
||||
expiresAt: expiredAt,
|
||||
}
|
||||
}
|
||||
store.responseToConnMu.Unlock()
|
||||
|
||||
store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano())
|
||||
store.maybeCleanup()
|
||||
|
||||
store.responseToConnMu.RLock()
|
||||
remainingAfterFirst := len(store.responseToConn)
|
||||
store.responseToConnMu.RUnlock()
|
||||
require.Less(t, remainingAfterFirst, total, "单轮 cleanup 应至少有进展")
|
||||
require.Greater(t, remainingAfterFirst, 0, "增量清理不要求单轮清空全部键")
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano())
|
||||
store.maybeCleanup()
|
||||
}
|
||||
|
||||
store.responseToConnMu.RLock()
|
||||
remaining := len(store.responseToConn)
|
||||
store.responseToConnMu.RUnlock()
|
||||
require.Zero(t, remaining, "多轮 cleanup 后应逐步清空全部过期键")
|
||||
}
|
||||
|
||||
func TestEnsureBindingCapacity_EvictsOneWhenMapIsFull(t *testing.T) {
|
||||
bindings := map[string]int{
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
}
|
||||
|
||||
ensureBindingCapacity(bindings, "c", 2)
|
||||
bindings["c"] = 3
|
||||
|
||||
require.Len(t, bindings, 2)
|
||||
require.Equal(t, 3, bindings["c"])
|
||||
}
|
||||
|
||||
func TestEnsureBindingCapacity_DoesNotEvictWhenUpdatingExistingKey(t *testing.T) {
|
||||
bindings := map[string]int{
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
}
|
||||
|
||||
ensureBindingCapacity(bindings, "a", 2)
|
||||
bindings["a"] = 9
|
||||
|
||||
require.Len(t, bindings, 2)
|
||||
require.Equal(t, 9, bindings["a"])
|
||||
}
|
||||
|
||||
type openAIWSStateStoreTimeoutProbeCache struct {
|
||||
setHasDeadline bool
|
||||
getHasDeadline bool
|
||||
deleteHasDeadline bool
|
||||
setDeadlineDelta time.Duration
|
||||
getDeadlineDelta time.Duration
|
||||
delDeadlineDelta time.Duration
|
||||
}
|
||||
|
||||
func (c *openAIWSStateStoreTimeoutProbeCache) GetSessionAccountID(ctx context.Context, _ int64, _ string) (int64, error) {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
c.getHasDeadline = true
|
||||
c.getDeadlineDelta = time.Until(deadline)
|
||||
}
|
||||
return 123, nil
|
||||
}
|
||||
|
||||
func (c *openAIWSStateStoreTimeoutProbeCache) SetSessionAccountID(ctx context.Context, _ int64, _ string, _ int64, _ time.Duration) error {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
c.setHasDeadline = true
|
||||
c.setDeadlineDelta = time.Until(deadline)
|
||||
}
|
||||
return errors.New("set failed")
|
||||
}
|
||||
|
||||
func (c *openAIWSStateStoreTimeoutProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *openAIWSStateStoreTimeoutProbeCache) DeleteSessionAccountID(ctx context.Context, _ int64, _ string) error {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
c.deleteHasDeadline = true
|
||||
c.delDeadlineDelta = time.Until(deadline)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) {
|
||||
probe := &openAIWSStateStoreTimeoutProbeCache{}
|
||||
store := NewOpenAIWSStateStore(probe)
|
||||
ctx := context.Background()
|
||||
groupID := int64(5)
|
||||
|
||||
err := store.BindResponseAccount(ctx, groupID, "resp_timeout_probe", 11, time.Minute)
|
||||
require.Error(t, err)
|
||||
|
||||
accountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_timeout_probe")
|
||||
require.NoError(t, getErr)
|
||||
require.Equal(t, int64(11), accountID, "本地缓存命中应优先返回已绑定账号")
|
||||
|
||||
require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_timeout_probe"))
|
||||
|
||||
require.True(t, probe.setHasDeadline, "SetSessionAccountID 应携带独立超时上下文")
|
||||
require.True(t, probe.deleteHasDeadline, "DeleteSessionAccountID 应携带独立超时上下文")
|
||||
require.False(t, probe.getHasDeadline, "GetSessionAccountID 本用例应由本地缓存命中,不触发 Redis 读取")
|
||||
require.Greater(t, probe.setDeadlineDelta, 2*time.Second)
|
||||
require.LessOrEqual(t, probe.setDeadlineDelta, 3*time.Second)
|
||||
require.Greater(t, probe.delDeadlineDelta, 2*time.Second)
|
||||
require.LessOrEqual(t, probe.delDeadlineDelta, 3*time.Second)
|
||||
|
||||
probe2 := &openAIWSStateStoreTimeoutProbeCache{}
|
||||
store2 := NewOpenAIWSStateStore(probe2)
|
||||
accountID2, err2 := store2.GetResponseAccount(ctx, groupID, "resp_cache_only")
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, int64(123), accountID2)
|
||||
require.True(t, probe2.getHasDeadline, "GetSessionAccountID 在缓存未命中时应携带独立超时上下文")
|
||||
require.Greater(t, probe2.getDeadlineDelta, 2*time.Second)
|
||||
require.LessOrEqual(t, probe2.getDeadlineDelta, 3*time.Second)
|
||||
}
|
||||
|
||||
func TestWithOpenAIWSStateStoreRedisTimeout_WithParentContext(t *testing.T) {
|
||||
ctx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
|
||||
defer cancel()
|
||||
require.NotNil(t, ctx)
|
||||
_, ok := ctx.Deadline()
|
||||
require.True(t, ok, "应附加短超时")
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user