mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-27 01:44:48 +08:00
feat(sync): full code sync from release
This commit is contained in:
@@ -3,6 +3,8 @@ package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"hash/fnv"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -50,6 +52,14 @@ type Account struct {
|
||||
AccountGroups []AccountGroup
|
||||
GroupIDs []int64
|
||||
Groups []*Group
|
||||
|
||||
// model_mapping 热路径缓存(非持久化字段)
|
||||
modelMappingCache map[string]string
|
||||
modelMappingCacheReady bool
|
||||
modelMappingCacheCredentialsPtr uintptr
|
||||
modelMappingCacheRawPtr uintptr
|
||||
modelMappingCacheRawLen int
|
||||
modelMappingCacheRawSig uint64
|
||||
}
|
||||
|
||||
type TempUnschedulableRule struct {
|
||||
@@ -349,6 +359,39 @@ func parseTempUnschedInt(value any) int {
|
||||
}
|
||||
|
||||
func (a *Account) GetModelMapping() map[string]string {
|
||||
credentialsPtr := mapPtr(a.Credentials)
|
||||
rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
|
||||
rawPtr := mapPtr(rawMapping)
|
||||
rawLen := len(rawMapping)
|
||||
rawSig := uint64(0)
|
||||
rawSigReady := false
|
||||
|
||||
if a.modelMappingCacheReady &&
|
||||
a.modelMappingCacheCredentialsPtr == credentialsPtr &&
|
||||
a.modelMappingCacheRawPtr == rawPtr &&
|
||||
a.modelMappingCacheRawLen == rawLen {
|
||||
rawSig = modelMappingSignature(rawMapping)
|
||||
rawSigReady = true
|
||||
if a.modelMappingCacheRawSig == rawSig {
|
||||
return a.modelMappingCache
|
||||
}
|
||||
}
|
||||
|
||||
mapping := a.resolveModelMapping(rawMapping)
|
||||
if !rawSigReady {
|
||||
rawSig = modelMappingSignature(rawMapping)
|
||||
}
|
||||
|
||||
a.modelMappingCache = mapping
|
||||
a.modelMappingCacheReady = true
|
||||
a.modelMappingCacheCredentialsPtr = credentialsPtr
|
||||
a.modelMappingCacheRawPtr = rawPtr
|
||||
a.modelMappingCacheRawLen = rawLen
|
||||
a.modelMappingCacheRawSig = rawSig
|
||||
return mapping
|
||||
}
|
||||
|
||||
func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
|
||||
if a.Credentials == nil {
|
||||
// Antigravity 平台使用默认映射
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
@@ -356,32 +399,31 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
raw, ok := a.Credentials["model_mapping"]
|
||||
if !ok || raw == nil {
|
||||
if len(rawMapping) == 0 {
|
||||
// Antigravity 平台使用默认映射
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if m, ok := raw.(map[string]any); ok {
|
||||
result := make(map[string]string)
|
||||
for k, v := range m {
|
||||
if s, ok := v.(string); ok {
|
||||
result[k] = s
|
||||
}
|
||||
}
|
||||
if len(result) > 0 {
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
ensureAntigravityDefaultPassthroughs(result, []string{
|
||||
"gemini-3-flash",
|
||||
"gemini-3.1-pro-high",
|
||||
"gemini-3.1-pro-low",
|
||||
})
|
||||
}
|
||||
return result
|
||||
|
||||
result := make(map[string]string)
|
||||
for k, v := range rawMapping {
|
||||
if s, ok := v.(string); ok {
|
||||
result[k] = s
|
||||
}
|
||||
}
|
||||
if len(result) > 0 {
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
ensureAntigravityDefaultPassthroughs(result, []string{
|
||||
"gemini-3-flash",
|
||||
"gemini-3.1-pro-high",
|
||||
"gemini-3.1-pro-low",
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Antigravity 平台使用默认映射
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
@@ -389,6 +431,37 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapPtr(m map[string]any) uintptr {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
return reflect.ValueOf(m).Pointer()
|
||||
}
|
||||
|
||||
func modelMappingSignature(rawMapping map[string]any) uint64 {
|
||||
if len(rawMapping) == 0 {
|
||||
return 0
|
||||
}
|
||||
keys := make([]string, 0, len(rawMapping))
|
||||
for k := range rawMapping {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
h := fnv.New64a()
|
||||
for _, k := range keys {
|
||||
_, _ = h.Write([]byte(k))
|
||||
_, _ = h.Write([]byte{0})
|
||||
if v, ok := rawMapping[k].(string); ok {
|
||||
_, _ = h.Write([]byte(v))
|
||||
} else {
|
||||
_, _ = h.Write([]byte{1})
|
||||
}
|
||||
_, _ = h.Write([]byte{0xff})
|
||||
}
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) {
|
||||
if mapping == nil || model == "" {
|
||||
return
|
||||
@@ -742,6 +815,159 @@ func (a *Account) IsOpenAIPassthroughEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsOpenAIResponsesWebSocketV2Enabled 返回 OpenAI 账号是否开启 Responses WebSocket v2。
|
||||
//
|
||||
// 分类型新字段:
|
||||
// - OAuth 账号:accounts.extra.openai_oauth_responses_websockets_v2_enabled
|
||||
// - API Key 账号:accounts.extra.openai_apikey_responses_websockets_v2_enabled
|
||||
//
|
||||
// 兼容字段:
|
||||
// - accounts.extra.responses_websockets_v2_enabled
|
||||
// - accounts.extra.openai_ws_enabled(历史开关)
|
||||
//
|
||||
// 优先级:
|
||||
// 1. 按账号类型读取分类型字段
|
||||
// 2. 分类型字段缺失时,回退兼容字段
|
||||
func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
|
||||
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if a.IsOpenAIOAuth() {
|
||||
if enabled, ok := a.Extra["openai_oauth_responses_websockets_v2_enabled"].(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
if a.IsOpenAIApiKey() {
|
||||
if enabled, ok := a.Extra["openai_apikey_responses_websockets_v2_enabled"].(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
if enabled, ok := a.Extra["responses_websockets_v2_enabled"].(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
if enabled, ok := a.Extra["openai_ws_enabled"].(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
OpenAIWSIngressModeOff = "off"
|
||||
OpenAIWSIngressModeShared = "shared"
|
||||
OpenAIWSIngressModeDedicated = "dedicated"
|
||||
)
|
||||
|
||||
func normalizeOpenAIWSIngressMode(mode string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case OpenAIWSIngressModeOff:
|
||||
return OpenAIWSIngressModeOff
|
||||
case OpenAIWSIngressModeShared:
|
||||
return OpenAIWSIngressModeShared
|
||||
case OpenAIWSIngressModeDedicated:
|
||||
return OpenAIWSIngressModeDedicated
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
|
||||
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
|
||||
return normalized
|
||||
}
|
||||
return OpenAIWSIngressModeShared
|
||||
}
|
||||
|
||||
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
|
||||
//
|
||||
// 优先级:
|
||||
// 1. 分类型 mode 新字段(string)
|
||||
// 2. 分类型 enabled 旧字段(bool)
|
||||
// 3. 兼容 enabled 旧字段(bool)
|
||||
// 4. defaultMode(非法时回退 shared)
|
||||
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
|
||||
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
|
||||
if a == nil || !a.IsOpenAI() {
|
||||
return OpenAIWSIngressModeOff
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return resolvedDefault
|
||||
}
|
||||
|
||||
resolveModeString := func(key string) (string, bool) {
|
||||
raw, ok := a.Extra[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
mode, ok := raw.(string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
normalized := normalizeOpenAIWSIngressMode(mode)
|
||||
if normalized == "" {
|
||||
return "", false
|
||||
}
|
||||
return normalized, true
|
||||
}
|
||||
resolveBoolMode := func(key string) (string, bool) {
|
||||
raw, ok := a.Extra[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
enabled, ok := raw.(bool)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if enabled {
|
||||
return OpenAIWSIngressModeShared, true
|
||||
}
|
||||
return OpenAIWSIngressModeOff, true
|
||||
}
|
||||
|
||||
if a.IsOpenAIOAuth() {
|
||||
if mode, ok := resolveModeString("openai_oauth_responses_websockets_v2_mode"); ok {
|
||||
return mode
|
||||
}
|
||||
if mode, ok := resolveBoolMode("openai_oauth_responses_websockets_v2_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
}
|
||||
if a.IsOpenAIApiKey() {
|
||||
if mode, ok := resolveModeString("openai_apikey_responses_websockets_v2_mode"); ok {
|
||||
return mode
|
||||
}
|
||||
if mode, ok := resolveBoolMode("openai_apikey_responses_websockets_v2_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
}
|
||||
if mode, ok := resolveBoolMode("responses_websockets_v2_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
return resolvedDefault
|
||||
}
|
||||
|
||||
// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
|
||||
// 字段:accounts.extra.openai_ws_force_http。
|
||||
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
|
||||
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
enabled, ok := a.Extra["openai_ws_force_http"].(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// IsOpenAIWSAllowStoreRecoveryEnabled 返回账号级 store 恢复开关。
|
||||
// 字段:accounts.extra.openai_ws_allow_store_recovery。
|
||||
func (a *Account) IsOpenAIWSAllowStoreRecoveryEnabled() bool {
|
||||
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
enabled, ok := a.Extra["openai_ws_allow_store_recovery"].(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。
|
||||
func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
|
||||
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
|
||||
|
||||
@@ -134,3 +134,161 @@ func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) {
|
||||
require.False(t, otherPlatform.IsCodexCLIOnlyEnabled())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
|
||||
t.Run("OAuth使用OAuth专用开关", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
|
||||
t.Run("API Key使用API Key专用开关", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
|
||||
t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
|
||||
t.Run("分类型新键优先于兼容键", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": false,
|
||||
"responses_websockets_v2_enabled": true,
|
||||
"openai_ws_enabled": true,
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
|
||||
t.Run("分类型键缺失时回退兼容键", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
|
||||
t.Run("非OpenAI账号默认关闭", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
t.Run("default fallback to shared", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
|
||||
})
|
||||
|
||||
t.Run("oauth mode field has highest priority", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
"openai_oauth_responses_websockets_v2_enabled": false,
|
||||
"responses_websockets_v2_enabled": false,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
||||
})
|
||||
|
||||
t.Run("legacy enabled maps to shared", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
})
|
||||
|
||||
t.Run("legacy disabled maps to off", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": false,
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
|
||||
})
|
||||
|
||||
t.Run("non openai always off", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated))
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccount_OpenAIWSExtraFlags(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_ws_force_http": true,
|
||||
"openai_ws_allow_store_recovery": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsOpenAIWSForceHTTPEnabled())
|
||||
require.True(t, account.IsOpenAIWSAllowStoreRecoveryEnabled())
|
||||
|
||||
off := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}}
|
||||
require.False(t, off.IsOpenAIWSForceHTTPEnabled())
|
||||
require.False(t, off.IsOpenAIWSAllowStoreRecoveryEnabled())
|
||||
|
||||
var nilAccount *Account
|
||||
require.False(t, nilAccount.IsOpenAIWSAllowStoreRecoveryEnabled())
|
||||
|
||||
nonOpenAI := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_ws_allow_store_recovery": true,
|
||||
},
|
||||
}
|
||||
require.False(t, nonOpenAI.IsOpenAIWSAllowStoreRecoveryEnabled())
|
||||
}
|
||||
|
||||
@@ -119,6 +119,10 @@ type AccountService struct {
|
||||
groupRepo GroupRepository
|
||||
}
|
||||
|
||||
type groupExistenceBatchChecker interface {
|
||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||
}
|
||||
|
||||
// NewAccountService 创建账号服务实例
|
||||
func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService {
|
||||
return &AccountService{
|
||||
@@ -131,11 +135,8 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository)
|
||||
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
|
||||
// 验证分组是否存在(如果指定了分组)
|
||||
if len(req.GroupIDs) > 0 {
|
||||
for _, groupID := range req.GroupIDs {
|
||||
_, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
if err := s.validateGroupIDsExist(ctx, req.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,11 +257,8 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
|
||||
|
||||
// 先验证分组是否存在(在任何写操作之前)
|
||||
if req.GroupIDs != nil {
|
||||
for _, groupID := range *req.GroupIDs {
|
||||
_, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
if err := s.validateGroupIDsExist(ctx, *req.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,6 +298,39 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccountService) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
if s.groupRepo == nil {
|
||||
return fmt.Errorf("group repository not configured")
|
||||
}
|
||||
|
||||
if batchChecker, ok := s.groupRepo.(groupExistenceBatchChecker); ok {
|
||||
existsByID, err := batchChecker.ExistsByIDs(ctx, groupIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check groups exists: %w", err)
|
||||
}
|
||||
for _, groupID := range groupIDs {
|
||||
if groupID <= 0 {
|
||||
return fmt.Errorf("get group: %w", ErrGroupNotFound)
|
||||
}
|
||||
if !existsByID[groupID] {
|
||||
return fmt.Errorf("get group: %w", ErrGroupNotFound)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
_, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新账号状态
|
||||
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
|
||||
@@ -598,9 +598,102 @@ func ceilSeconds(d time.Duration) int {
|
||||
return sec
|
||||
}
|
||||
|
||||
// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。
|
||||
// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。
|
||||
func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证")
|
||||
}
|
||||
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" {
|
||||
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url")
|
||||
}
|
||||
|
||||
// 验证 base_url 格式
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error()))
|
||||
}
|
||||
upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions"
|
||||
|
||||
// 设置 SSE 头
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
|
||||
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
|
||||
return s.sendErrorAndEnd(c, msg)
|
||||
}
|
||||
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"})
|
||||
|
||||
// 构建轻量级 prompt-enhance 请求作为连通性测试
|
||||
testPayload := map[string]any{
|
||||
"model": "prompt-enhance-short-10s",
|
||||
"messages": []map[string]string{{"role": "user", "content": "test"}},
|
||||
"stream": false,
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(testPayload)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "构建测试请求失败")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
// 获取代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)})
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode))
|
||||
}
|
||||
|
||||
// 其他错误但能连通(如 400 参数错误)也算连通性测试通过
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)})
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256)))
|
||||
}
|
||||
|
||||
// testSoraAccountConnection 测试 Sora 账号的连接
|
||||
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token)
|
||||
// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性
|
||||
// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性
|
||||
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
|
||||
// apikey 类型走独立测试流程
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
return s.testSoraAPIKeyAccountConnection(c, account)
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
recorder := &soraProbeRecorder{}
|
||||
|
||||
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type UsageLogRepository interface {
|
||||
@@ -33,8 +35,8 @@ type UsageLogRepository interface {
|
||||
|
||||
// Admin dashboard stats
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
@@ -62,6 +64,10 @@ type UsageLogRepository interface {
|
||||
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
|
||||
}
|
||||
|
||||
type accountWindowStatsBatchReader interface {
|
||||
GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error)
|
||||
}
|
||||
|
||||
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
|
||||
type apiUsageCache struct {
|
||||
response *ClaudeUsageResponse
|
||||
@@ -297,7 +303,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
dayStart := geminiDailyWindowStart(now)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
||||
}
|
||||
@@ -319,7 +325,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
|
||||
minuteStart := now.Truncate(time.Minute)
|
||||
minuteResetAt := minuteStart.Add(time.Minute)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
|
||||
}
|
||||
@@ -440,6 +446,78 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetTodayStatsBatch 批量获取账号今日统计,优先走批量 SQL,失败时回退单账号查询。
|
||||
func (s *AccountUsageService) GetTodayStatsBatch(ctx context.Context, accountIDs []int64) (map[int64]*WindowStats, error) {
|
||||
uniqueIDs := make([]int64, 0, len(accountIDs))
|
||||
seen := make(map[int64]struct{}, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
if accountID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[accountID]; exists {
|
||||
continue
|
||||
}
|
||||
seen[accountID] = struct{}{}
|
||||
uniqueIDs = append(uniqueIDs, accountID)
|
||||
}
|
||||
|
||||
result := make(map[int64]*WindowStats, len(uniqueIDs))
|
||||
if len(uniqueIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
startTime := timezone.Today()
|
||||
if batchReader, ok := s.usageLogRepo.(accountWindowStatsBatchReader); ok {
|
||||
statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, uniqueIDs, startTime)
|
||||
if err == nil {
|
||||
for _, accountID := range uniqueIDs {
|
||||
result[accountID] = windowStatsFromAccountStats(statsByAccount[accountID])
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(8)
|
||||
|
||||
for _, accountID := range uniqueIDs {
|
||||
id := accountID
|
||||
g.Go(func() error {
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(gctx, id, startTime)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
mu.Lock()
|
||||
result[id] = windowStatsFromAccountStats(stats)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
_ = g.Wait()
|
||||
|
||||
for _, accountID := range uniqueIDs {
|
||||
if _, ok := result[accountID]; !ok {
|
||||
result[accountID] = &WindowStats{}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
|
||||
if stats == nil {
|
||||
return &WindowStats{}
|
||||
}
|
||||
return &WindowStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
StandardCost: stats.StandardCost,
|
||||
UserCost: stats.UserCost,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
|
||||
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
|
||||
if err != nil {
|
||||
|
||||
@@ -314,3 +314,72 @@ func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T
|
||||
t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_CacheInvalidatesOnCredentialsReplace(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-3-5-sonnet": "upstream-a",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
first := account.GetModelMapping()
|
||||
if first["claude-3-5-sonnet"] != "upstream-a" {
|
||||
t.Fatalf("unexpected first mapping: %v", first)
|
||||
}
|
||||
|
||||
account.Credentials = map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-3-5-sonnet": "upstream-b",
|
||||
},
|
||||
}
|
||||
second := account.GetModelMapping()
|
||||
if second["claude-3-5-sonnet"] != "upstream-b" {
|
||||
t.Fatalf("expected cache invalidated after credentials replace, got: %v", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_CacheInvalidatesOnMappingLenChange(t *testing.T) {
|
||||
rawMapping := map[string]any{
|
||||
"claude-sonnet": "sonnet-a",
|
||||
}
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": rawMapping,
|
||||
},
|
||||
}
|
||||
|
||||
first := account.GetModelMapping()
|
||||
if len(first) != 1 {
|
||||
t.Fatalf("unexpected first mapping length: %d", len(first))
|
||||
}
|
||||
|
||||
rawMapping["claude-opus"] = "opus-b"
|
||||
second := account.GetModelMapping()
|
||||
if second["claude-opus"] != "opus-b" {
|
||||
t.Fatalf("expected cache invalidated after mapping len change, got: %v", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_CacheInvalidatesOnInPlaceValueChange(t *testing.T) {
|
||||
rawMapping := map[string]any{
|
||||
"claude-sonnet": "sonnet-a",
|
||||
}
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": rawMapping,
|
||||
},
|
||||
}
|
||||
|
||||
first := account.GetModelMapping()
|
||||
if first["claude-sonnet"] != "sonnet-a" {
|
||||
t.Fatalf("unexpected first mapping: %v", first)
|
||||
}
|
||||
|
||||
rawMapping["claude-sonnet"] = "sonnet-b"
|
||||
second := account.GetModelMapping()
|
||||
if second["claude-sonnet"] != "sonnet-b" {
|
||||
t.Fatalf("expected cache invalidated after in-place value change, got: %v", second)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,13 +83,14 @@ type AdminService interface {
|
||||
|
||||
// CreateUserInput represents input for creating a new user via admin operations.
|
||||
type CreateUserInput struct {
|
||||
Email string
|
||||
Password string
|
||||
Username string
|
||||
Notes string
|
||||
Balance float64
|
||||
Concurrency int
|
||||
AllowedGroups []int64
|
||||
Email string
|
||||
Password string
|
||||
Username string
|
||||
Notes string
|
||||
Balance float64
|
||||
Concurrency int
|
||||
AllowedGroups []int64
|
||||
SoraStorageQuotaBytes int64
|
||||
}
|
||||
|
||||
type UpdateUserInput struct {
|
||||
@@ -103,7 +104,8 @@ type UpdateUserInput struct {
|
||||
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||
GroupRates map[int64]*float64
|
||||
GroupRates map[int64]*float64
|
||||
SoraStorageQuotaBytes *int64
|
||||
}
|
||||
|
||||
type CreateGroupInput struct {
|
||||
@@ -135,6 +137,8 @@ type CreateGroupInput struct {
|
||||
MCPXMLInject *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64
|
||||
// 从指定分组复制账号(创建分组后在同一事务内绑定)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -169,6 +173,8 @@ type UpdateGroupInput struct {
|
||||
MCPXMLInject *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes *int64
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -402,6 +408,14 @@ type adminServiceImpl struct {
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
}
|
||||
|
||||
type userGroupRateBatchReader interface {
|
||||
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
|
||||
}
|
||||
|
||||
type groupExistenceBatchReader interface {
|
||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||
}
|
||||
|
||||
// NewAdminService creates a new AdminService
|
||||
func NewAdminService(
|
||||
userRepo UserRepository,
|
||||
@@ -442,18 +456,43 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
|
||||
}
|
||||
// 批量加载用户专属分组倍率
|
||||
if s.userGroupRateRepo != nil && len(users) > 0 {
|
||||
for i := range users {
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
|
||||
continue
|
||||
if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok {
|
||||
userIDs := make([]int64, 0, len(users))
|
||||
for i := range users {
|
||||
userIDs = append(userIDs, users[i].ID)
|
||||
}
|
||||
users[i].GroupRates = rates
|
||||
ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err)
|
||||
s.loadUserGroupRatesOneByOne(ctx, users)
|
||||
} else {
|
||||
for i := range users {
|
||||
if rates, ok := ratesByUser[users[i].ID]; ok {
|
||||
users[i].GroupRates = rates
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.loadUserGroupRatesOneByOne(ctx, users)
|
||||
}
|
||||
}
|
||||
return users, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return
|
||||
}
|
||||
for i := range users {
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
|
||||
continue
|
||||
}
|
||||
users[i].GroupRates = rates
|
||||
}
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -473,14 +512,15 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
|
||||
|
||||
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
|
||||
user := &User{
|
||||
Email: input.Email,
|
||||
Username: input.Username,
|
||||
Notes: input.Notes,
|
||||
Role: RoleUser, // Always create as regular user, never admin
|
||||
Balance: input.Balance,
|
||||
Concurrency: input.Concurrency,
|
||||
Status: StatusActive,
|
||||
AllowedGroups: input.AllowedGroups,
|
||||
Email: input.Email,
|
||||
Username: input.Username,
|
||||
Notes: input.Notes,
|
||||
Role: RoleUser, // Always create as regular user, never admin
|
||||
Balance: input.Balance,
|
||||
Concurrency: input.Concurrency,
|
||||
Status: StatusActive,
|
||||
AllowedGroups: input.AllowedGroups,
|
||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
||||
}
|
||||
if err := user.SetPassword(input.Password); err != nil {
|
||||
return nil, err
|
||||
@@ -534,6 +574,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
user.AllowedGroups = *input.AllowedGroups
|
||||
}
|
||||
|
||||
if input.SoraStorageQuotaBytes != nil {
|
||||
user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
|
||||
}
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -820,6 +864,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
ModelRouting: input.ModelRouting,
|
||||
MCPXMLInject: mcpXMLInject,
|
||||
SupportedModelScopes: input.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -982,6 +1027,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.SoraVideoPricePerRequestHD != nil {
|
||||
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
|
||||
}
|
||||
if input.SoraStorageQuotaBytes != nil {
|
||||
group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
|
||||
}
|
||||
|
||||
// Claude Code 客户端限制
|
||||
if input.ClaudeCodeOnly != nil {
|
||||
@@ -1188,6 +1236,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
}
|
||||
}
|
||||
|
||||
// Sora apikey 账号的 base_url 必填校验
|
||||
if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
|
||||
baseURL, _ := input.Credentials["base_url"].(string)
|
||||
baseURL = strings.TrimSpace(baseURL)
|
||||
if baseURL == "" {
|
||||
return nil, errors.New("sora apikey 账号必须设置 base_url")
|
||||
}
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
|
||||
}
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
Name: input.Name,
|
||||
Notes: normalizeAccountNotes(input.Notes),
|
||||
@@ -1301,12 +1361,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
|
||||
}
|
||||
|
||||
// Sora apikey 账号的 base_url 必填校验
|
||||
if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
|
||||
baseURL, _ := account.Credentials["base_url"].(string)
|
||||
baseURL = strings.TrimSpace(baseURL)
|
||||
if baseURL == "" {
|
||||
return nil, errors.New("sora apikey 账号必须设置 base_url")
|
||||
}
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
|
||||
}
|
||||
}
|
||||
|
||||
// 先验证分组是否存在(在任何写操作之前)
|
||||
if input.GroupIDs != nil {
|
||||
for _, groupID := range *input.GroupIDs {
|
||||
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查混合渠道风险(除非用户已确认)
|
||||
@@ -1348,11 +1418,18 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
if len(input.AccountIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
if input.GroupIDs != nil {
|
||||
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
|
||||
|
||||
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
|
||||
platformByID := map[int64]string{}
|
||||
groupAccountsByID := map[int64][]Account{}
|
||||
groupNameByID := map[int64]string{}
|
||||
if needMixedChannelCheck {
|
||||
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
|
||||
if err != nil {
|
||||
@@ -1366,6 +1443,13 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loadedAccounts, loadedNames, err := s.preloadMixedChannelRiskData(ctx, *input.GroupIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
groupAccountsByID = loadedAccounts
|
||||
groupNameByID = loadedNames
|
||||
}
|
||||
|
||||
if input.RateMultiplier != nil {
|
||||
@@ -1409,11 +1493,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
// Handle group bindings per account (requires individual operations).
|
||||
for _, accountID := range input.AccountIDs {
|
||||
entry := BulkUpdateAccountResult{AccountID: accountID}
|
||||
platform := ""
|
||||
|
||||
if input.GroupIDs != nil {
|
||||
// 检查混合渠道风险(除非用户已确认)
|
||||
if !input.SkipMixedChannelCheck {
|
||||
platform := platformByID[accountID]
|
||||
platform = platformByID[accountID]
|
||||
if platform == "" {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -1426,7 +1511,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
}
|
||||
platform = account.Platform
|
||||
}
|
||||
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
|
||||
if err := s.checkMixedChannelRiskWithPreloaded(accountID, platform, *input.GroupIDs, groupAccountsByID, groupNameByID); err != nil {
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
result.Failed++
|
||||
@@ -1444,6 +1529,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
result.Results = append(result.Results, entry)
|
||||
continue
|
||||
}
|
||||
if !input.SkipMixedChannelCheck && platform != "" {
|
||||
updateMixedChannelPreloadedAccounts(groupAccountsByID, *input.GroupIDs, accountID, platform)
|
||||
}
|
||||
}
|
||||
|
||||
entry.Success = true
|
||||
@@ -2115,6 +2203,135 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) preloadMixedChannelRiskData(ctx context.Context, groupIDs []int64) (map[int64][]Account, map[int64]string, error) {
|
||||
accountsByGroup := make(map[int64][]Account)
|
||||
groupNameByID := make(map[int64]string)
|
||||
if len(groupIDs) == 0 {
|
||||
return accountsByGroup, groupNameByID, nil
|
||||
}
|
||||
|
||||
seen := make(map[int64]struct{}, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
if groupID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[groupID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[groupID] = struct{}{}
|
||||
|
||||
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get accounts in group %d: %w", groupID, err)
|
||||
}
|
||||
accountsByGroup[groupID] = accounts
|
||||
|
||||
group, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if group != nil {
|
||||
groupNameByID[groupID] = group.Name
|
||||
}
|
||||
}
|
||||
|
||||
return accountsByGroup, groupNameByID, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
if s.groupRepo == nil {
|
||||
return errors.New("group repository not configured")
|
||||
}
|
||||
|
||||
if batchReader, ok := s.groupRepo.(groupExistenceBatchReader); ok {
|
||||
existsByID, err := batchReader.ExistsByIDs(ctx, groupIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check groups exists: %w", err)
|
||||
}
|
||||
for _, groupID := range groupIDs {
|
||||
if groupID <= 0 || !existsByID[groupID] {
|
||||
return fmt.Errorf("get group: %w", ErrGroupNotFound)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
|
||||
return fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) checkMixedChannelRiskWithPreloaded(currentAccountID int64, currentAccountPlatform string, groupIDs []int64, accountsByGroup map[int64][]Account, groupNameByID map[int64]string) error {
|
||||
currentPlatform := getAccountPlatform(currentAccountPlatform)
|
||||
if currentPlatform == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
accounts := accountsByGroup[groupID]
|
||||
for _, account := range accounts {
|
||||
if currentAccountID > 0 && account.ID == currentAccountID {
|
||||
continue
|
||||
}
|
||||
|
||||
otherPlatform := getAccountPlatform(account.Platform)
|
||||
if otherPlatform == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if currentPlatform != otherPlatform {
|
||||
groupName := fmt.Sprintf("Group %d", groupID)
|
||||
if name := strings.TrimSpace(groupNameByID[groupID]); name != "" {
|
||||
groupName = name
|
||||
}
|
||||
|
||||
return &MixedChannelError{
|
||||
GroupID: groupID,
|
||||
GroupName: groupName,
|
||||
CurrentPlatform: currentPlatform,
|
||||
OtherPlatform: otherPlatform,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateMixedChannelPreloadedAccounts(accountsByGroup map[int64][]Account, groupIDs []int64, accountID int64, platform string) {
|
||||
if len(groupIDs) == 0 || accountID <= 0 || platform == "" {
|
||||
return
|
||||
}
|
||||
for _, groupID := range groupIDs {
|
||||
if groupID <= 0 {
|
||||
continue
|
||||
}
|
||||
accounts := accountsByGroup[groupID]
|
||||
found := false
|
||||
for i := range accounts {
|
||||
if accounts[i].ID != accountID {
|
||||
continue
|
||||
}
|
||||
accounts[i].Platform = platform
|
||||
found = true
|
||||
break
|
||||
}
|
||||
if !found {
|
||||
accounts = append(accounts, Account{
|
||||
ID: accountID,
|
||||
Platform: platform,
|
||||
})
|
||||
}
|
||||
accountsByGroup[groupID] = accounts
|
||||
}
|
||||
}
|
||||
|
||||
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
|
||||
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)
|
||||
|
||||
@@ -15,6 +15,7 @@ type accountRepoStubForBulkUpdate struct {
|
||||
bulkUpdateErr error
|
||||
bulkUpdateIDs []int64
|
||||
bindGroupErrByID map[int64]error
|
||||
bindGroupsCalls []int64
|
||||
getByIDsAccounts []*Account
|
||||
getByIDsErr error
|
||||
getByIDsCalled bool
|
||||
@@ -22,6 +23,8 @@ type accountRepoStubForBulkUpdate struct {
|
||||
getByIDAccounts map[int64]*Account
|
||||
getByIDErrByID map[int64]error
|
||||
getByIDCalled []int64
|
||||
listByGroupData map[int64][]Account
|
||||
listByGroupErr map[int64]error
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
||||
@@ -33,6 +36,7 @@ func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error {
|
||||
s.bindGroupsCalls = append(s.bindGroupsCalls, accountID)
|
||||
if err, ok := s.bindGroupErrByID[accountID]; ok {
|
||||
return err
|
||||
}
|
||||
@@ -59,6 +63,16 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
|
||||
if err, ok := s.listByGroupErr[groupID]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if rows, ok := s.listByGroupData[groupID]; ok {
|
||||
return rows, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
@@ -86,7 +100,10 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
|
||||
2: errors.New("bind failed"),
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
svc := &adminServiceImpl{
|
||||
accountRepo: repo,
|
||||
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "g10"}},
|
||||
}
|
||||
|
||||
groupIDs := []int64{10}
|
||||
schedulable := false
|
||||
@@ -105,3 +122,51 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
|
||||
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
|
||||
require.Len(t, result.Results, 3)
|
||||
}
|
||||
|
||||
func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
groupIDs := []int64{10}
|
||||
input := &BulkUpdateAccountsInput{
|
||||
AccountIDs: []int64{1},
|
||||
GroupIDs: &groupIDs,
|
||||
}
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.Nil(t, result)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "group repository not configured")
|
||||
}
|
||||
|
||||
func TestAdminService_BulkUpdateAccounts_MixedChannelCheckUsesUpdatedSnapshot(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{
|
||||
getByIDsAccounts: []*Account{
|
||||
{ID: 1, Platform: PlatformAnthropic},
|
||||
{ID: 2, Platform: PlatformAntigravity},
|
||||
},
|
||||
listByGroupData: map[int64][]Account{
|
||||
10: {},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{
|
||||
accountRepo: repo,
|
||||
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "目标分组"}},
|
||||
}
|
||||
|
||||
groupIDs := []int64{10}
|
||||
input := &BulkUpdateAccountsInput{
|
||||
AccountIDs: []int64{1, 2},
|
||||
GroupIDs: &groupIDs,
|
||||
}
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, result.Success)
|
||||
require.Equal(t, 1, result.Failed)
|
||||
require.ElementsMatch(t, []int64{1}, result.SuccessIDs)
|
||||
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
|
||||
require.Len(t, result.Results, 2)
|
||||
require.Contains(t, result.Results[1].Error, "mixed channel")
|
||||
require.Equal(t, []int64{1}, repo.bindGroupsCalls)
|
||||
}
|
||||
|
||||
106
backend/internal/service/admin_service_list_users_test.go
Normal file
106
backend/internal/service/admin_service_list_users_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type userRepoStubForListUsers struct {
|
||||
userRepoStub
|
||||
users []User
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||
if s.err != nil {
|
||||
return nil, nil, s.err
|
||||
}
|
||||
out := make([]User, len(s.users))
|
||||
copy(out, s.users)
|
||||
return out, &pagination.PaginationResult{
|
||||
Total: int64(len(out)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type userGroupRateRepoStubForListUsers struct {
|
||||
batchCalls int
|
||||
singleCall []int64
|
||||
|
||||
batchErr error
|
||||
batchData map[int64]map[int64]float64
|
||||
|
||||
singleErr map[int64]error
|
||||
singleData map[int64]map[int64]float64
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) GetByUserIDs(_ context.Context, _ []int64) (map[int64]map[int64]float64, error) {
|
||||
s.batchCalls++
|
||||
if s.batchErr != nil {
|
||||
return nil, s.batchErr
|
||||
}
|
||||
return s.batchData, nil
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) GetByUserID(_ context.Context, userID int64) (map[int64]float64, error) {
|
||||
s.singleCall = append(s.singleCall, userID)
|
||||
if err, ok := s.singleErr[userID]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if rates, ok := s.singleData[userID]; ok {
|
||||
return rates, nil
|
||||
}
|
||||
return map[int64]float64{}, nil
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, userID, groupID int64) (*float64, error) {
|
||||
panic("unexpected GetByUserAndGroup call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
|
||||
panic("unexpected SyncUserGroupRates call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
|
||||
panic("unexpected DeleteByGroupID call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) DeleteByUserID(_ context.Context, userID int64) error {
|
||||
panic("unexpected DeleteByUserID call")
|
||||
}
|
||||
|
||||
func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
|
||||
userRepo := &userRepoStubForListUsers{
|
||||
users: []User{
|
||||
{ID: 101, Username: "u1"},
|
||||
{ID: 202, Username: "u2"},
|
||||
},
|
||||
}
|
||||
rateRepo := &userGroupRateRepoStubForListUsers{
|
||||
batchErr: errors.New("batch unavailable"),
|
||||
singleData: map[int64]map[int64]float64{
|
||||
101: {11: 1.1},
|
||||
202: {22: 2.2},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: userRepo,
|
||||
userGroupRateRepo: rateRepo,
|
||||
}
|
||||
|
||||
users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), total)
|
||||
require.Len(t, users, 2)
|
||||
require.Equal(t, 1, rateRepo.batchCalls)
|
||||
require.ElementsMatch(t, []int64{101, 202}, rateRepo.singleCall)
|
||||
require.Equal(t, 1.1, users[0].GroupRates[11])
|
||||
require.Equal(t, 2.2, users[1].GroupRates[22])
|
||||
}
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -2291,7 +2290,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
||||
|
||||
// isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记
|
||||
func isSingleAccountRetry(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(ctxkey.SingleAccountRetry).(bool)
|
||||
v, _ := SingleAccountRetryFromContext(ctx)
|
||||
return v
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
)
|
||||
|
||||
// API Key status constants
|
||||
const (
|
||||
@@ -19,11 +23,14 @@ type APIKey struct {
|
||||
Status string
|
||||
IPWhitelist []string
|
||||
IPBlacklist []string
|
||||
LastUsedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
User *User
|
||||
Group *Group
|
||||
// 预编译的 IP 规则,用于认证热路径避免重复 ParseIP/ParseCIDR。
|
||||
CompiledIPWhitelist *ip.CompiledIPRules `json:"-"`
|
||||
CompiledIPBlacklist *ip.CompiledIPRules `json:"-"`
|
||||
LastUsedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
User *User
|
||||
Group *Group
|
||||
|
||||
// Quota fields
|
||||
Quota float64 // Quota limit in USD (0 = unlimited)
|
||||
|
||||
@@ -298,5 +298,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
|
||||
}
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey
|
||||
}
|
||||
|
||||
@@ -158,6 +158,14 @@ func NewAPIKeyService(
|
||||
return svc
|
||||
}
|
||||
|
||||
func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) {
|
||||
if apiKey == nil {
|
||||
return
|
||||
}
|
||||
apiKey.CompiledIPWhitelist = ip.CompileIPRules(apiKey.IPWhitelist)
|
||||
apiKey.CompiledIPBlacklist = ip.CompileIPRules(apiKey.IPBlacklist)
|
||||
}
|
||||
|
||||
// GenerateKey 生成随机API Key
|
||||
func (s *APIKeyService) GenerateKey() (string, error) {
|
||||
// 生成32字节随机数据
|
||||
@@ -332,6 +340,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
}
|
||||
|
||||
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
@@ -363,6 +372,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
@@ -375,6 +385,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey, nil
|
||||
}
|
||||
}
|
||||
@@ -391,6 +402,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey, nil
|
||||
}
|
||||
} else {
|
||||
@@ -402,6 +414,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey, nil
|
||||
}
|
||||
}
|
||||
@@ -411,6 +424,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
apiKey.Key = key
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
@@ -510,6 +524,7 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
}
|
||||
|
||||
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
@@ -308,6 +308,17 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifyTurnstileForRegister 在注册场景下验证 Turnstile。
|
||||
// 当邮箱验证开启且已提交验证码时,说明验证码发送阶段已完成 Turnstile 校验,
|
||||
// 此处跳过二次校验,避免一次性 token 在注册提交时重复使用导致误报失败。
|
||||
func (s *AuthService) VerifyTurnstileForRegister(ctx context.Context, token, remoteIP, verifyCode string) error {
|
||||
if s.IsEmailVerifyEnabled(ctx) && strings.TrimSpace(verifyCode) != "" {
|
||||
logger.LegacyPrintf("service.auth", "%s", "[Auth] Email verify flow detected, skip duplicate Turnstile check on register")
|
||||
return nil
|
||||
}
|
||||
return s.VerifyTurnstile(ctx, token, remoteIP)
|
||||
}
|
||||
|
||||
// VerifyTurnstile 验证Turnstile token
|
||||
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
|
||||
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type turnstileVerifierSpy struct {
|
||||
called int
|
||||
lastToken string
|
||||
result *TurnstileVerifyResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *turnstileVerifierSpy) VerifyToken(_ context.Context, _ string, token, _ string) (*TurnstileVerifyResponse, error) {
|
||||
s.called++
|
||||
s.lastToken = token
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
if s.result != nil {
|
||||
return s.result, nil
|
||||
}
|
||||
return &TurnstileVerifyResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier TurnstileVerifier) *AuthService {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Mode: "release",
|
||||
},
|
||||
Turnstile: config.TurnstileConfig{
|
||||
Required: true,
|
||||
},
|
||||
}
|
||||
|
||||
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
|
||||
turnstileService := NewTurnstileService(settingService, verifier)
|
||||
|
||||
return NewAuthService(
|
||||
&userRepoStub{},
|
||||
nil, // redeemRepo
|
||||
nil, // refreshTokenCache
|
||||
cfg,
|
||||
settingService,
|
||||
nil, // emailService
|
||||
turnstileService,
|
||||
nil, // emailQueueService
|
||||
nil, // promoService
|
||||
)
|
||||
}
|
||||
|
||||
func TestAuthService_VerifyTurnstileForRegister_SkipWhenEmailVerifyCodeProvided(t *testing.T) {
|
||||
verifier := &turnstileVerifierSpy{}
|
||||
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
SettingKeyTurnstileEnabled: "true",
|
||||
SettingKeyTurnstileSecretKey: "secret",
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, verifier)
|
||||
|
||||
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "123456")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, verifier.called)
|
||||
}
|
||||
|
||||
func TestAuthService_VerifyTurnstileForRegister_RequireWhenVerifyCodeMissing(t *testing.T) {
|
||||
verifier := &turnstileVerifierSpy{}
|
||||
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
SettingKeyTurnstileEnabled: "true",
|
||||
SettingKeyTurnstileSecretKey: "secret",
|
||||
}, verifier)
|
||||
|
||||
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "")
|
||||
require.ErrorIs(t, err, ErrTurnstileVerificationFailed)
|
||||
}
|
||||
|
||||
func TestAuthService_VerifyTurnstileForRegister_NoSkipWhenEmailVerifyDisabled(t *testing.T) {
|
||||
verifier := &turnstileVerifierSpy{}
|
||||
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
|
||||
SettingKeyEmailVerifyEnabled: "false",
|
||||
SettingKeyTurnstileEnabled: "true",
|
||||
SettingKeyTurnstileSecretKey: "secret",
|
||||
}, verifier)
|
||||
|
||||
err := service.VerifyTurnstileForRegister(context.Background(), "turnstile-token", "127.0.0.1", "123456")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, verifier.called)
|
||||
require.Equal(t, "turnstile-token", verifier.lastToken)
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
@@ -58,6 +60,7 @@ const (
|
||||
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
|
||||
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
|
||||
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
|
||||
balanceLoadTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
// cacheWriteTask 缓存写入任务
|
||||
@@ -82,6 +85,9 @@ type BillingCacheService struct {
|
||||
cacheWriteChan chan cacheWriteTask
|
||||
cacheWriteWg sync.WaitGroup
|
||||
cacheWriteStopOnce sync.Once
|
||||
cacheWriteMu sync.RWMutex
|
||||
stopped atomic.Bool
|
||||
balanceLoadSF singleflight.Group
|
||||
// 丢弃日志节流计数器(减少高负载下日志噪音)
|
||||
cacheWriteDropFullCount uint64
|
||||
cacheWriteDropFullLastLog int64
|
||||
@@ -105,35 +111,52 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
|
||||
// Stop 关闭缓存写入工作池
|
||||
func (s *BillingCacheService) Stop() {
|
||||
s.cacheWriteStopOnce.Do(func() {
|
||||
if s.cacheWriteChan == nil {
|
||||
s.stopped.Store(true)
|
||||
|
||||
s.cacheWriteMu.Lock()
|
||||
ch := s.cacheWriteChan
|
||||
if ch != nil {
|
||||
close(ch)
|
||||
}
|
||||
s.cacheWriteMu.Unlock()
|
||||
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
close(s.cacheWriteChan)
|
||||
s.cacheWriteWg.Wait()
|
||||
s.cacheWriteChan = nil
|
||||
|
||||
s.cacheWriteMu.Lock()
|
||||
if s.cacheWriteChan == ch {
|
||||
s.cacheWriteChan = nil
|
||||
}
|
||||
s.cacheWriteMu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) startCacheWriteWorkers() {
|
||||
s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize)
|
||||
ch := make(chan cacheWriteTask, cacheWriteBufferSize)
|
||||
s.cacheWriteChan = ch
|
||||
for i := 0; i < cacheWriteWorkerCount; i++ {
|
||||
s.cacheWriteWg.Add(1)
|
||||
go s.cacheWriteWorker()
|
||||
go s.cacheWriteWorker(ch)
|
||||
}
|
||||
}
|
||||
|
||||
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。
|
||||
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
|
||||
if s.cacheWriteChan == nil {
|
||||
if s.stopped.Load() {
|
||||
s.logCacheWriteDrop(task, "closed")
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
// 队列已关闭时可能触发 panic,记录后静默失败。
|
||||
s.logCacheWriteDrop(task, "closed")
|
||||
enqueued = false
|
||||
}
|
||||
}()
|
||||
|
||||
s.cacheWriteMu.RLock()
|
||||
defer s.cacheWriteMu.RUnlock()
|
||||
|
||||
if s.cacheWriteChan == nil {
|
||||
s.logCacheWriteDrop(task, "closed")
|
||||
return false
|
||||
}
|
||||
|
||||
select {
|
||||
case s.cacheWriteChan <- task:
|
||||
return true
|
||||
@@ -144,9 +167,9 @@ func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued b
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) cacheWriteWorker() {
|
||||
func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) {
|
||||
defer s.cacheWriteWg.Done()
|
||||
for task := range s.cacheWriteChan {
|
||||
for task := range ch {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
switch task.kind {
|
||||
case cacheWriteSetBalance:
|
||||
@@ -243,19 +266,31 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库读取
|
||||
balance, err = s.getUserBalanceFromDB(ctx, userID)
|
||||
// 缓存未命中:singleflight 合并同一 userID 的并发回源请求。
|
||||
value, err, _ := s.balanceLoadSF.Do(strconv.FormatInt(userID, 10), func() (any, error) {
|
||||
loadCtx, cancel := context.WithTimeout(context.Background(), balanceLoadTimeout)
|
||||
defer cancel()
|
||||
|
||||
balance, err := s.getUserBalanceFromDB(loadCtx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
_ = s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetBalance,
|
||||
userID: userID,
|
||||
balance: balance,
|
||||
})
|
||||
return balance, nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
_ = s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetBalance,
|
||||
userID: userID,
|
||||
balance: balance,
|
||||
})
|
||||
|
||||
balance, ok := value.(float64)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("unexpected balance type: %T", value)
|
||||
}
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type billingCacheMissStub struct {
|
||||
setBalanceCalls atomic.Int64
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
return 0, errors.New("cache miss")
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
s.setBalanceCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
|
||||
return nil, errors.New("cache miss")
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type balanceLoadUserRepoStub struct {
|
||||
mockUserRepo
|
||||
calls atomic.Int64
|
||||
delay time.Duration
|
||||
balance float64
|
||||
}
|
||||
|
||||
func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
|
||||
s.calls.Add(1)
|
||||
if s.delay > 0 {
|
||||
select {
|
||||
case <-time.After(s.delay):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
return &User{ID: id, Balance: s.balance}, nil
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
|
||||
cache := &billingCacheMissStub{}
|
||||
userRepo := &balanceLoadUserRepoStub{
|
||||
delay: 80 * time.Millisecond,
|
||||
balance: 12.34,
|
||||
}
|
||||
svc := NewBillingCacheService(cache, userRepo, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
const goroutines = 16
|
||||
start := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, goroutines)
|
||||
balCh := make(chan float64, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
bal, err := svc.GetUserBalance(context.Background(), 99)
|
||||
errCh <- err
|
||||
balCh <- bal
|
||||
}()
|
||||
}
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
close(balCh)
|
||||
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
for bal := range balCh {
|
||||
require.Equal(t, 12.34, bal)
|
||||
}
|
||||
|
||||
require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并")
|
||||
require.Eventually(t, func() bool {
|
||||
return cache.setBalanceCalls.Load() >= 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
@@ -73,3 +73,16 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
|
||||
}, 2*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
|
||||
svc.Stop()
|
||||
|
||||
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteDeductBalance,
|
||||
userID: 1,
|
||||
amount: 1,
|
||||
})
|
||||
require.False(t, enqueued)
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func TestCalculateImageCost_RateMultiplier(t *testing.T) {
|
||||
|
||||
// 费率倍数 1.5x
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5)
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5
|
||||
require.InDelta(t, 0.3015, cost.ActualCost, 0.0001) // ActualCost = 0.201 * 1.5
|
||||
|
||||
// 费率倍数 2.0x
|
||||
|
||||
@@ -78,7 +78,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
|
||||
|
||||
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过
|
||||
// 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt
|
||||
if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku {
|
||||
if isMaxTokensOneHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(r.Context()); ok && isMaxTokensOneHaiku {
|
||||
return true // 绕过 system prompt 检查,UA 已在 Step 1 验证
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,10 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
@@ -18,6 +20,7 @@ type ConcurrencyCache interface {
|
||||
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
|
||||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||||
GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
|
||||
|
||||
// 账号等待队列(账号级)
|
||||
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
|
||||
@@ -42,15 +45,25 @@ type ConcurrencyCache interface {
|
||||
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
|
||||
}
|
||||
|
||||
// generateRequestID generates a unique request ID for concurrency slot tracking
|
||||
// Uses 8 random bytes (16 hex chars) for uniqueness
|
||||
func generateRequestID() string {
|
||||
var (
|
||||
requestIDPrefix = initRequestIDPrefix()
|
||||
requestIDCounter atomic.Uint64
|
||||
)
|
||||
|
||||
func initRequestIDPrefix() string {
|
||||
b := make([]byte, 8)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to nanosecond timestamp (extremely rare case)
|
||||
return fmt.Sprintf("%x", time.Now().UnixNano())
|
||||
if _, err := rand.Read(b); err == nil {
|
||||
return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36)
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16)
|
||||
return "r" + strconv.FormatUint(fallback, 36)
|
||||
}
|
||||
|
||||
// generateRequestID generates a unique request ID for concurrency slot tracking.
|
||||
// Format: {process_random_prefix}-{base36_counter}
|
||||
func generateRequestID() string {
|
||||
seq := requestIDCounter.Add(1)
|
||||
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -321,16 +334,15 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
|
||||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
|
||||
// Returns a map of accountID -> current concurrency count
|
||||
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
result := make(map[int64]int)
|
||||
|
||||
for _, accountID := range accountIDs {
|
||||
count, err := s.cache.GetAccountConcurrency(ctx, accountID)
|
||||
if err != nil {
|
||||
// If key doesn't exist in Redis, count is 0
|
||||
count = 0
|
||||
}
|
||||
result[accountID] = count
|
||||
if len(accountIDs) == 0 {
|
||||
return map[int64]int{}, nil
|
||||
}
|
||||
|
||||
return result, nil
|
||||
if s.cache == nil {
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
result[accountID] = 0
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -12,20 +14,20 @@ import (
|
||||
|
||||
// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩
|
||||
type stubConcurrencyCacheForTest struct {
|
||||
acquireResult bool
|
||||
acquireErr error
|
||||
releaseErr error
|
||||
concurrency int
|
||||
acquireResult bool
|
||||
acquireErr error
|
||||
releaseErr error
|
||||
concurrency int
|
||||
concurrencyErr error
|
||||
waitAllowed bool
|
||||
waitErr error
|
||||
waitCount int
|
||||
waitCountErr error
|
||||
loadBatch map[int64]*AccountLoadInfo
|
||||
loadBatchErr error
|
||||
waitAllowed bool
|
||||
waitErr error
|
||||
waitCount int
|
||||
waitCountErr error
|
||||
loadBatch map[int64]*AccountLoadInfo
|
||||
loadBatchErr error
|
||||
usersLoadBatch map[int64]*UserLoadInfo
|
||||
usersLoadErr error
|
||||
cleanupErr error
|
||||
cleanupErr error
|
||||
|
||||
// 记录调用
|
||||
releasedAccountIDs []int64
|
||||
@@ -45,6 +47,16 @@ func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, acco
|
||||
func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) {
|
||||
return c.concurrency, c.concurrencyErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
if c.concurrencyErr != nil {
|
||||
return nil, c.concurrencyErr
|
||||
}
|
||||
result[accountID] = c.concurrency
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) {
|
||||
return c.waitAllowed, c.waitErr
|
||||
}
|
||||
@@ -155,6 +167,25 @@ func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) {
|
||||
require.True(t, result.Acquired)
|
||||
}
|
||||
|
||||
func TestGenerateRequestID_UsesStablePrefixAndMonotonicCounter(t *testing.T) {
|
||||
id1 := generateRequestID()
|
||||
id2 := generateRequestID()
|
||||
require.NotEmpty(t, id1)
|
||||
require.NotEmpty(t, id2)
|
||||
|
||||
p1 := strings.Split(id1, "-")
|
||||
p2 := strings.Split(id2, "-")
|
||||
require.Len(t, p1, 2)
|
||||
require.Len(t, p2, 2)
|
||||
require.Equal(t, p1[0], p2[0], "同一进程前缀应保持一致")
|
||||
|
||||
n1, err := strconv.ParseUint(p1[1], 36, 64)
|
||||
require.NoError(t, err)
|
||||
n2, err := strconv.ParseUint(p2[1], 36, 64)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n1+1, n2, "计数器应单调递增")
|
||||
}
|
||||
|
||||
func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) {
|
||||
expected := map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60},
|
||||
|
||||
@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
|
||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get usage trend with filters: %w", err)
|
||||
}
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
|
||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get model stats with filters: %w", err)
|
||||
}
|
||||
|
||||
252
backend/internal/service/data_management_grpc.go
Normal file
252
backend/internal/service/data_management_grpc.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
type DataManagementPostgresConfig struct {
|
||||
Host string `json:"host"`
|
||||
Port int32 `json:"port"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password,omitempty"`
|
||||
PasswordConfigured bool `json:"password_configured"`
|
||||
Database string `json:"database"`
|
||||
SSLMode string `json:"ssl_mode"`
|
||||
ContainerName string `json:"container_name"`
|
||||
}
|
||||
|
||||
type DataManagementRedisConfig struct {
|
||||
Addr string `json:"addr"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password,omitempty"`
|
||||
PasswordConfigured bool `json:"password_configured"`
|
||||
DB int32 `json:"db"`
|
||||
ContainerName string `json:"container_name"`
|
||||
}
|
||||
|
||||
type DataManagementS3Config struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key,omitempty"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
UseSSL bool `json:"use_ssl"`
|
||||
}
|
||||
|
||||
type DataManagementConfig struct {
|
||||
SourceMode string `json:"source_mode"`
|
||||
BackupRoot string `json:"backup_root"`
|
||||
SQLitePath string `json:"sqlite_path,omitempty"`
|
||||
RetentionDays int32 `json:"retention_days"`
|
||||
KeepLast int32 `json:"keep_last"`
|
||||
ActivePostgresID string `json:"active_postgres_profile_id"`
|
||||
ActiveRedisID string `json:"active_redis_profile_id"`
|
||||
Postgres DataManagementPostgresConfig `json:"postgres"`
|
||||
Redis DataManagementRedisConfig `json:"redis"`
|
||||
S3 DataManagementS3Config `json:"s3"`
|
||||
ActiveS3ProfileID string `json:"active_s3_profile_id"`
|
||||
}
|
||||
|
||||
type DataManagementTestS3Result struct {
|
||||
OK bool `json:"ok"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type DataManagementCreateBackupJobInput struct {
|
||||
BackupType string
|
||||
UploadToS3 bool
|
||||
TriggeredBy string
|
||||
IdempotencyKey string
|
||||
S3ProfileID string
|
||||
PostgresID string
|
||||
RedisID string
|
||||
}
|
||||
|
||||
type DataManagementListBackupJobsInput struct {
|
||||
PageSize int32
|
||||
PageToken string
|
||||
Status string
|
||||
BackupType string
|
||||
}
|
||||
|
||||
type DataManagementArtifactInfo struct {
|
||||
LocalPath string `json:"local_path"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
SHA256 string `json:"sha256"`
|
||||
}
|
||||
|
||||
type DataManagementS3ObjectInfo struct {
|
||||
Bucket string `json:"bucket"`
|
||||
Key string `json:"key"`
|
||||
ETag string `json:"etag"`
|
||||
}
|
||||
|
||||
type DataManagementBackupJob struct {
|
||||
JobID string `json:"job_id"`
|
||||
BackupType string `json:"backup_type"`
|
||||
Status string `json:"status"`
|
||||
TriggeredBy string `json:"triggered_by"`
|
||||
IdempotencyKey string `json:"idempotency_key,omitempty"`
|
||||
UploadToS3 bool `json:"upload_to_s3"`
|
||||
S3ProfileID string `json:"s3_profile_id,omitempty"`
|
||||
PostgresID string `json:"postgres_profile_id,omitempty"`
|
||||
RedisID string `json:"redis_profile_id,omitempty"`
|
||||
StartedAt string `json:"started_at,omitempty"`
|
||||
FinishedAt string `json:"finished_at,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
Artifact DataManagementArtifactInfo `json:"artifact"`
|
||||
S3Object DataManagementS3ObjectInfo `json:"s3"`
|
||||
}
|
||||
|
||||
type DataManagementSourceProfile struct {
|
||||
SourceType string `json:"source_type"`
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Config DataManagementSourceConfig `json:"config"`
|
||||
PasswordConfigured bool `json:"password_configured"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
type DataManagementSourceConfig struct {
|
||||
Host string `json:"host"`
|
||||
Port int32 `json:"port"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Database string `json:"database"`
|
||||
SSLMode string `json:"ssl_mode"`
|
||||
Addr string `json:"addr"`
|
||||
Username string `json:"username"`
|
||||
DB int32 `json:"db"`
|
||||
ContainerName string `json:"container_name"`
|
||||
}
|
||||
|
||||
type DataManagementCreateSourceProfileInput struct {
|
||||
SourceType string
|
||||
ProfileID string
|
||||
Name string
|
||||
Config DataManagementSourceConfig
|
||||
SetActive bool
|
||||
}
|
||||
|
||||
type DataManagementUpdateSourceProfileInput struct {
|
||||
SourceType string
|
||||
ProfileID string
|
||||
Name string
|
||||
Config DataManagementSourceConfig
|
||||
}
|
||||
|
||||
type DataManagementS3Profile struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
S3 DataManagementS3Config `json:"s3"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
type DataManagementCreateS3ProfileInput struct {
|
||||
ProfileID string
|
||||
Name string
|
||||
S3 DataManagementS3Config
|
||||
SetActive bool
|
||||
}
|
||||
|
||||
type DataManagementUpdateS3ProfileInput struct {
|
||||
ProfileID string
|
||||
Name string
|
||||
S3 DataManagementS3Config
|
||||
}
|
||||
|
||||
type DataManagementListBackupJobsResult struct {
|
||||
Items []DataManagementBackupJob `json:"items"`
|
||||
NextPageToken string `json:"next_page_token,omitempty"`
|
||||
}
|
||||
|
||||
func (s *DataManagementService) GetConfig(ctx context.Context) (DataManagementConfig, error) {
|
||||
_ = ctx
|
||||
return DataManagementConfig{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) UpdateConfig(ctx context.Context, cfg DataManagementConfig) (DataManagementConfig, error) {
|
||||
_, _ = ctx, cfg
|
||||
return DataManagementConfig{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) ListSourceProfiles(ctx context.Context, sourceType string) ([]DataManagementSourceProfile, error) {
|
||||
_, _ = ctx, sourceType
|
||||
return nil, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) CreateSourceProfile(ctx context.Context, input DataManagementCreateSourceProfileInput) (DataManagementSourceProfile, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementSourceProfile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) UpdateSourceProfile(ctx context.Context, input DataManagementUpdateSourceProfileInput) (DataManagementSourceProfile, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementSourceProfile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error {
|
||||
_, _, _ = ctx, sourceType, profileID
|
||||
return s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (DataManagementSourceProfile, error) {
|
||||
_, _, _ = ctx, sourceType, profileID
|
||||
return DataManagementSourceProfile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) ValidateS3(ctx context.Context, cfg DataManagementS3Config) (DataManagementTestS3Result, error) {
|
||||
_, _ = ctx, cfg
|
||||
return DataManagementTestS3Result{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) ListS3Profiles(ctx context.Context) ([]DataManagementS3Profile, error) {
|
||||
_ = ctx
|
||||
return nil, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) CreateS3Profile(ctx context.Context, input DataManagementCreateS3ProfileInput) (DataManagementS3Profile, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementS3Profile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) UpdateS3Profile(ctx context.Context, input DataManagementUpdateS3ProfileInput) (DataManagementS3Profile, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementS3Profile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) DeleteS3Profile(ctx context.Context, profileID string) error {
|
||||
_, _ = ctx, profileID
|
||||
return s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) SetActiveS3Profile(ctx context.Context, profileID string) (DataManagementS3Profile, error) {
|
||||
_, _ = ctx, profileID
|
||||
return DataManagementS3Profile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) CreateBackupJob(ctx context.Context, input DataManagementCreateBackupJobInput) (DataManagementBackupJob, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementBackupJob{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) ListBackupJobs(ctx context.Context, input DataManagementListBackupJobsInput) (DataManagementListBackupJobsResult, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementListBackupJobsResult{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) GetBackupJob(ctx context.Context, jobID string) (DataManagementBackupJob, error) {
|
||||
_, _ = ctx, jobID
|
||||
return DataManagementBackupJob{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) deprecatedError() error {
|
||||
return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
|
||||
}
|
||||
36
backend/internal/service/data_management_grpc_test.go
Normal file
36
backend/internal/service/data_management_grpc_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDataManagementService_DeprecatedRPCMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "datamanagement.sock")
|
||||
svc := NewDataManagementServiceWithOptions(socketPath, 0)
|
||||
|
||||
_, err := svc.GetConfig(context.Background())
|
||||
assertDeprecatedDataManagementError(t, err, socketPath)
|
||||
|
||||
_, err = svc.CreateBackupJob(context.Background(), DataManagementCreateBackupJobInput{BackupType: "full"})
|
||||
assertDeprecatedDataManagementError(t, err, socketPath)
|
||||
|
||||
err = svc.DeleteS3Profile(context.Background(), "s3-default")
|
||||
assertDeprecatedDataManagementError(t, err, socketPath)
|
||||
}
|
||||
|
||||
func assertDeprecatedDataManagementError(t *testing.T, err error, socketPath string) {
|
||||
t.Helper()
|
||||
|
||||
require.Error(t, err)
|
||||
statusCode, status := infraerrors.ToHTTP(err)
|
||||
require.Equal(t, 503, statusCode)
|
||||
require.Equal(t, DataManagementDeprecatedReason, status.Reason)
|
||||
require.Equal(t, socketPath, status.Metadata["socket_path"])
|
||||
}
|
||||
99
backend/internal/service/data_management_service.go
Normal file
99
backend/internal/service/data_management_service.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultDataManagementAgentSocketPath = "/tmp/sub2api-datamanagement.sock"
|
||||
LegacyBackupAgentSocketPath = "/tmp/sub2api-backup.sock"
|
||||
|
||||
DataManagementDeprecatedReason = "DATA_MANAGEMENT_DEPRECATED"
|
||||
DataManagementAgentSocketMissingReason = "DATA_MANAGEMENT_AGENT_SOCKET_MISSING"
|
||||
DataManagementAgentUnavailableReason = "DATA_MANAGEMENT_AGENT_UNAVAILABLE"
|
||||
|
||||
// Deprecated: keep old names for compatibility.
|
||||
DefaultBackupAgentSocketPath = DefaultDataManagementAgentSocketPath
|
||||
BackupAgentSocketMissingReason = DataManagementAgentSocketMissingReason
|
||||
BackupAgentUnavailableReason = DataManagementAgentUnavailableReason
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDataManagementDeprecated = infraerrors.ServiceUnavailable(
|
||||
DataManagementDeprecatedReason,
|
||||
"data management feature is deprecated",
|
||||
)
|
||||
ErrDataManagementAgentSocketMissing = infraerrors.ServiceUnavailable(
|
||||
DataManagementAgentSocketMissingReason,
|
||||
"data management agent socket is missing",
|
||||
)
|
||||
ErrDataManagementAgentUnavailable = infraerrors.ServiceUnavailable(
|
||||
DataManagementAgentUnavailableReason,
|
||||
"data management agent is unavailable",
|
||||
)
|
||||
|
||||
// Deprecated: keep old names for compatibility.
|
||||
ErrBackupAgentSocketMissing = ErrDataManagementAgentSocketMissing
|
||||
ErrBackupAgentUnavailable = ErrDataManagementAgentUnavailable
|
||||
)
|
||||
|
||||
type DataManagementAgentHealth struct {
|
||||
Enabled bool
|
||||
Reason string
|
||||
SocketPath string
|
||||
Agent *DataManagementAgentInfo
|
||||
}
|
||||
|
||||
type DataManagementAgentInfo struct {
|
||||
Status string
|
||||
Version string
|
||||
UptimeSeconds int64
|
||||
}
|
||||
|
||||
type DataManagementService struct {
|
||||
socketPath string
|
||||
dialTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewDataManagementService() *DataManagementService {
|
||||
return NewDataManagementServiceWithOptions(DefaultDataManagementAgentSocketPath, 500*time.Millisecond)
|
||||
}
|
||||
|
||||
func NewDataManagementServiceWithOptions(socketPath string, dialTimeout time.Duration) *DataManagementService {
|
||||
path := strings.TrimSpace(socketPath)
|
||||
if path == "" {
|
||||
path = DefaultDataManagementAgentSocketPath
|
||||
}
|
||||
if dialTimeout <= 0 {
|
||||
dialTimeout = 500 * time.Millisecond
|
||||
}
|
||||
return &DataManagementService{
|
||||
socketPath: path,
|
||||
dialTimeout: dialTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DataManagementService) SocketPath() string {
|
||||
if s == nil || strings.TrimSpace(s.socketPath) == "" {
|
||||
return DefaultDataManagementAgentSocketPath
|
||||
}
|
||||
return s.socketPath
|
||||
}
|
||||
|
||||
func (s *DataManagementService) GetAgentHealth(ctx context.Context) DataManagementAgentHealth {
|
||||
_ = ctx
|
||||
return DataManagementAgentHealth{
|
||||
Enabled: false,
|
||||
Reason: DataManagementDeprecatedReason,
|
||||
SocketPath: s.SocketPath(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DataManagementService) EnsureAgentEnabled(ctx context.Context) error {
|
||||
_ = ctx
|
||||
return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
|
||||
}
|
||||
37
backend/internal/service/data_management_service_test.go
Normal file
37
backend/internal/service/data_management_service_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDataManagementService_GetAgentHealth_Deprecated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "unused.sock")
|
||||
svc := NewDataManagementServiceWithOptions(socketPath, 0)
|
||||
health := svc.GetAgentHealth(context.Background())
|
||||
|
||||
require.False(t, health.Enabled)
|
||||
require.Equal(t, DataManagementDeprecatedReason, health.Reason)
|
||||
require.Equal(t, socketPath, health.SocketPath)
|
||||
require.Nil(t, health.Agent)
|
||||
}
|
||||
|
||||
func TestDataManagementService_EnsureAgentEnabled_Deprecated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "unused.sock")
|
||||
svc := NewDataManagementServiceWithOptions(socketPath, 100)
|
||||
err := svc.EnsureAgentEnabled(context.Background())
|
||||
require.Error(t, err)
|
||||
|
||||
statusCode, status := infraerrors.ToHTTP(err)
|
||||
require.Equal(t, 503, statusCode)
|
||||
require.Equal(t, DataManagementDeprecatedReason, status.Reason)
|
||||
require.Equal(t, socketPath, status.Metadata["socket_path"])
|
||||
}
|
||||
@@ -104,6 +104,7 @@ const (
|
||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||
|
||||
// OEM设置
|
||||
SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制)
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
@@ -170,6 +171,27 @@ const (
|
||||
|
||||
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
|
||||
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
|
||||
|
||||
// =========================
|
||||
// Sora S3 存储配置
|
||||
// =========================
|
||||
|
||||
SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储
|
||||
SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址
|
||||
SettingKeySoraS3Region = "sora_s3_region" // S3 区域
|
||||
SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称
|
||||
SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID
|
||||
SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储)
|
||||
SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀
|
||||
SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等)
|
||||
SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选)
|
||||
SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON)
|
||||
|
||||
// =========================
|
||||
// Sora 用户存储配额
|
||||
// =========================
|
||||
|
||||
SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节)
|
||||
)
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
|
||||
@@ -279,10 +279,10 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotE
|
||||
wantPassthrough: true,
|
||||
},
|
||||
{
|
||||
name: "404 generic not found passes through as 404",
|
||||
name: "404 generic not found does not passthrough",
|
||||
statusCode: http.StatusNotFound,
|
||||
respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
|
||||
wantPassthrough: true,
|
||||
wantPassthrough: false,
|
||||
},
|
||||
{
|
||||
name: "400 Invalid URL does not passthrough",
|
||||
|
||||
@@ -136,3 +136,67 @@ func TestDroppedBetaSet(t *testing.T) {
|
||||
require.Contains(t, extended, claude.BetaClaudeCode)
|
||||
require.Len(t, extended, len(claude.DroppedBetas)+1)
|
||||
}
|
||||
|
||||
func TestBuildBetaTokenSet(t *testing.T) {
|
||||
got := buildBetaTokenSet([]string{"foo", "", "bar", "foo"})
|
||||
require.Len(t, got, 2)
|
||||
require.Contains(t, got, "foo")
|
||||
require.Contains(t, got, "bar")
|
||||
require.NotContains(t, got, "")
|
||||
|
||||
empty := buildBetaTokenSet(nil)
|
||||
require.Empty(t, empty)
|
||||
}
|
||||
|
||||
func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) {
|
||||
header := "oauth-2025-04-20,interleaved-thinking-2025-05-14"
|
||||
got := stripBetaTokensWithSet(header, map[string]struct{}{})
|
||||
require.Equal(t, header, got)
|
||||
}
|
||||
|
||||
func TestIsCountTokensUnsupported404(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "exact endpoint not found",
|
||||
statusCode: 404,
|
||||
body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "contains count_tokens and not found",
|
||||
statusCode: 404,
|
||||
body: `{"error":{"message":"count_tokens route not found","type":"not_found_error"}}`,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "generic 404",
|
||||
statusCode: 404,
|
||||
body: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "404 with empty error message",
|
||||
statusCode: 404,
|
||||
body: `{"error":{"message":"","type":"not_found_error"}}`,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "non-404 status",
|
||||
statusCode: 400,
|
||||
body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"invalid_request_error"}}`,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isCountTokensUnsupported404(tt.statusCode, []byte(tt.body))
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1892,6 +1892,14 @@ func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accoun
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
result[accountID] = 0
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,141 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCollectSelectionFailureStats(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
model := "sora2-landscape-10s"
|
||||
resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339)
|
||||
|
||||
accounts := []Account{
|
||||
// excluded
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
},
|
||||
// unschedulable
|
||||
{
|
||||
ID: 2,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: false,
|
||||
},
|
||||
// platform filtered
|
||||
{
|
||||
ID: 3,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
},
|
||||
// model unsupported
|
||||
{
|
||||
ID: 4,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-image": "gpt-image",
|
||||
},
|
||||
},
|
||||
},
|
||||
// model rate limited
|
||||
{
|
||||
ID: 5,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{
|
||||
"model_rate_limits": map[string]any{
|
||||
model: map[string]any{
|
||||
"rate_limit_reset_at": resetAt,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// eligible
|
||||
{
|
||||
ID: 6,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
},
|
||||
}
|
||||
|
||||
excluded := map[int64]struct{}{1: {}}
|
||||
stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false)
|
||||
|
||||
if stats.Total != 6 {
|
||||
t.Fatalf("total=%d want=6", stats.Total)
|
||||
}
|
||||
if stats.Excluded != 1 {
|
||||
t.Fatalf("excluded=%d want=1", stats.Excluded)
|
||||
}
|
||||
if stats.Unschedulable != 1 {
|
||||
t.Fatalf("unschedulable=%d want=1", stats.Unschedulable)
|
||||
}
|
||||
if stats.PlatformFiltered != 1 {
|
||||
t.Fatalf("platform_filtered=%d want=1", stats.PlatformFiltered)
|
||||
}
|
||||
if stats.ModelUnsupported != 1 {
|
||||
t.Fatalf("model_unsupported=%d want=1", stats.ModelUnsupported)
|
||||
}
|
||||
if stats.ModelRateLimited != 1 {
|
||||
t.Fatalf("model_rate_limited=%d want=1", stats.ModelRateLimited)
|
||||
}
|
||||
if stats.Eligible != 1 {
|
||||
t.Fatalf("eligible=%d want=1", stats.Eligible)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
acc := &Account{
|
||||
ID: 7,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: false,
|
||||
}
|
||||
|
||||
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
|
||||
if diagnosis.Category != "unschedulable" {
|
||||
t.Fatalf("category=%s want=unschedulable", diagnosis.Category)
|
||||
}
|
||||
if diagnosis.Detail != "schedulable=false" {
|
||||
t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
model := "sora2-landscape-10s"
|
||||
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
|
||||
acc := &Account{
|
||||
ID: 8,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{
|
||||
"model_rate_limits": map[string]any{
|
||||
model: map[string]any{
|
||||
"rate_limit_reset_at": resetAt,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false)
|
||||
if diagnosis.Category != "model_rate_limited" {
|
||||
t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category)
|
||||
}
|
||||
if !strings.Contains(diagnosis.Detail, "remaining=") {
|
||||
t.Fatalf("detail=%s want contains remaining=", diagnosis.Detail)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformSora,
|
||||
Credentials: map[string]any{},
|
||||
}
|
||||
|
||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
||||
t.Fatalf("expected sora model to be supported when model_mapping is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformSora,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-4o": "gpt-4o",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
||||
t.Fatalf("expected sora model to be supported when mapping has no sora selectors")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformSora,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"sora2": "sora2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") {
|
||||
t.Fatalf("expected family selector sora2 to support sora2-landscape-15s")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformSora,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"sy_8": "sy_8",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
||||
t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
account := &Account{
|
||||
Platform: PlatformSora,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-image": "gpt-image",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
|
||||
t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
now := time.Now()
|
||||
past := now.Add(-1 * time.Minute)
|
||||
future := now.Add(5 * time.Minute)
|
||||
|
||||
acc := &Account{
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
AutoPauseOnExpired: true,
|
||||
ExpiresAt: &past,
|
||||
OverloadUntil: &future,
|
||||
RateLimitResetAt: &future,
|
||||
}
|
||||
|
||||
if !svc.isAccountSchedulableForSelection(acc) {
|
||||
t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
future := time.Now().Add(5 * time.Minute)
|
||||
|
||||
acc := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateLimitResetAt: &future,
|
||||
}
|
||||
|
||||
if svc.isAccountSchedulableForSelection(acc) {
|
||||
t.Fatalf("expected non-sora account to keep generic schedulable checks")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
model := "sora2-landscape-10s"
|
||||
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
|
||||
globalResetAt := time.Now().Add(2 * time.Minute)
|
||||
|
||||
acc := &Account{
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateLimitResetAt: &globalResetAt,
|
||||
Extra: map[string]any{
|
||||
"model_rate_limits": map[string]any{
|
||||
model: map[string]any{
|
||||
"rate_limit_reset_at": resetAt,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) {
|
||||
t.Fatalf("expected sora account to be blocked by model scope rate limit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
future := time.Now().Add(3 * time.Minute)
|
||||
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateLimitResetAt: &future,
|
||||
},
|
||||
}
|
||||
|
||||
stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
|
||||
if stats.Unschedulable != 0 || stats.Eligible != 1 {
|
||||
t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible)
|
||||
}
|
||||
}
|
||||
@@ -105,12 +105,12 @@ func TestCalculateMaxWait_Scenarios(t *testing.T) {
|
||||
concurrency int
|
||||
expected int
|
||||
}{
|
||||
{5, 25}, // 5 + 20
|
||||
{10, 30}, // 10 + 20
|
||||
{1, 21}, // 1 + 20
|
||||
{0, 21}, // min(1) + 20
|
||||
{-1, 21}, // min(1) + 20
|
||||
{-10, 21}, // min(1) + 20
|
||||
{5, 25}, // 5 + 20
|
||||
{10, 30}, // 10 + 20
|
||||
{1, 21}, // 1 + 20
|
||||
{0, 21}, // min(1) + 20
|
||||
{-1, 21}, // min(1) + 20
|
||||
{-10, 21}, // min(1) + 20
|
||||
{100, 120}, // 100 + 20
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -53,6 +53,7 @@ type GeminiMessagesCompatService struct {
|
||||
httpUpstream HTTPUpstream
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
cfg *config.Config
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
}
|
||||
|
||||
func NewGeminiMessagesCompatService(
|
||||
@@ -76,6 +77,7 @@ func NewGeminiMessagesCompatService(
|
||||
httpUpstream: httpUpstream,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
cfg: cfg,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,6 +231,16 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
|
||||
account *Account,
|
||||
requestedModel, platform string,
|
||||
useMixedScheduling bool,
|
||||
) bool {
|
||||
return s.isAccountUsableForRequestWithPrecheck(ctx, account, requestedModel, platform, useMixedScheduling, nil)
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) isAccountUsableForRequestWithPrecheck(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
requestedModel, platform string,
|
||||
useMixedScheduling bool,
|
||||
precheckResult map[int64]bool,
|
||||
) bool {
|
||||
// 检查模型调度能力
|
||||
// Check model scheduling capability
|
||||
@@ -250,7 +262,7 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
|
||||
|
||||
// 速率限制预检
|
||||
// Rate limit precheck
|
||||
if !s.passesRateLimitPreCheck(ctx, account, requestedModel) {
|
||||
if !s.passesRateLimitPreCheckWithCache(ctx, account, requestedModel, precheckResult) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -272,15 +284,17 @@ func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account
|
||||
return false
|
||||
}
|
||||
|
||||
// passesRateLimitPreCheck 执行速率限制预检。
|
||||
// 返回 true 表示通过预检或无需预检。
|
||||
//
|
||||
// passesRateLimitPreCheck performs rate limit precheck.
|
||||
// Returns true if passed or precheck not required.
|
||||
func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool {
|
||||
func (s *GeminiMessagesCompatService) passesRateLimitPreCheckWithCache(ctx context.Context, account *Account, requestedModel string, precheckResult map[int64]bool) bool {
|
||||
if s.rateLimitService == nil || requestedModel == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
if precheckResult != nil {
|
||||
if ok, exists := precheckResult[account.ID]; exists {
|
||||
return ok
|
||||
}
|
||||
}
|
||||
|
||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
|
||||
@@ -302,6 +316,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
|
||||
useMixedScheduling bool,
|
||||
) *Account {
|
||||
var selected *Account
|
||||
precheckResult := s.buildPreCheckUsageResultMap(ctx, accounts, requestedModel)
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
@@ -312,7 +327,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
|
||||
}
|
||||
|
||||
// 检查账号是否可用于当前请求
|
||||
if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) {
|
||||
if !s.isAccountUsableForRequestWithPrecheck(ctx, acc, requestedModel, platform, useMixedScheduling, precheckResult) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -330,6 +345,23 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
|
||||
return selected
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) buildPreCheckUsageResultMap(ctx context.Context, accounts []Account, requestedModel string) map[int64]bool {
|
||||
if s.rateLimitService == nil || requestedModel == "" || len(accounts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
candidates := make([]*Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
candidates = append(candidates, &accounts[i])
|
||||
}
|
||||
|
||||
result, err := s.rateLimitService.PreCheckUsageBatch(ctx, candidates, requestedModel)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheckBatch] failed: %v", err)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
|
||||
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。
|
||||
//
|
||||
@@ -2390,7 +2422,7 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
||||
}
|
||||
}
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
@@ -2415,8 +2447,8 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
|
||||
}
|
||||
|
||||
if s.cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
if s.responseHeaderFilter != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||
}
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
@@ -2557,7 +2589,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
|
||||
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
|
||||
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.responseHeaderFilter)
|
||||
if wwwAuthenticate != "" {
|
||||
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
|
||||
}
|
||||
|
||||
@@ -32,6 +32,9 @@ type Group struct {
|
||||
SoraVideoPricePerRequest *float64
|
||||
SoraVideoPricePerRequestHD *float64
|
||||
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool
|
||||
FallbackGroupID *int64
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
)
|
||||
|
||||
const modelRateLimitsKey = "model_rate_limits"
|
||||
@@ -73,7 +71,7 @@ func resolveFinalAntigravityModelKey(ctx context.Context, account *Account, requ
|
||||
return ""
|
||||
}
|
||||
// thinking 会影响 Antigravity 最终模型名(例如 claude-sonnet-4-5 -> claude-sonnet-4-5-thinking)
|
||||
if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
|
||||
if enabled, ok := ThinkingEnabledFromContext(ctx); ok {
|
||||
modelKey = applyThinkingModelSuffix(modelKey, enabled)
|
||||
}
|
||||
return modelKey
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
// OpenAIOAuthClient interface for OpenAI OAuth operations
|
||||
type OpenAIOAuthClient interface {
|
||||
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
|
||||
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error)
|
||||
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
|
||||
RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error)
|
||||
}
|
||||
|
||||
@@ -14,10 +14,10 @@ import (
|
||||
// --- mock: ClaudeOAuthClient ---
|
||||
|
||||
type mockClaudeOAuthClient struct {
|
||||
getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error)
|
||||
getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
|
||||
exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error)
|
||||
refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
|
||||
getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error)
|
||||
getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
|
||||
exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error)
|
||||
refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
|
||||
}
|
||||
|
||||
func (m *mockClaudeOAuthClient) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||
@@ -437,9 +437,9 @@ func TestOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) {
|
||||
|
||||
// 无 refresh_token 的账号
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
ID: 1,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "some-token",
|
||||
},
|
||||
@@ -460,9 +460,9 @@ func TestOAuthService_RefreshAccountToken_EmptyRefreshToken(t *testing.T) {
|
||||
defer svc.Stop()
|
||||
|
||||
account := &Account{
|
||||
ID: 2,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
ID: 2,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "some-token",
|
||||
"refresh_token": "",
|
||||
|
||||
909
backend/internal/service/openai_account_scheduler.go
Normal file
909
backend/internal/service/openai_account_scheduler.go
Normal file
@@ -0,0 +1,909 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"errors"
|
||||
"hash/fnv"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
|
||||
openAIAccountScheduleLayerSessionSticky = "session_hash"
|
||||
openAIAccountScheduleLayerLoadBalance = "load_balance"
|
||||
)
|
||||
|
||||
type OpenAIAccountScheduleRequest struct {
|
||||
GroupID *int64
|
||||
SessionHash string
|
||||
StickyAccountID int64
|
||||
PreviousResponseID string
|
||||
RequestedModel string
|
||||
RequiredTransport OpenAIUpstreamTransport
|
||||
ExcludedIDs map[int64]struct{}
|
||||
}
|
||||
|
||||
type OpenAIAccountScheduleDecision struct {
|
||||
Layer string
|
||||
StickyPreviousHit bool
|
||||
StickySessionHit bool
|
||||
CandidateCount int
|
||||
TopK int
|
||||
LatencyMs int64
|
||||
LoadSkew float64
|
||||
SelectedAccountID int64
|
||||
SelectedAccountType string
|
||||
}
|
||||
|
||||
type OpenAIAccountSchedulerMetricsSnapshot struct {
|
||||
SelectTotal int64
|
||||
StickyPreviousHitTotal int64
|
||||
StickySessionHitTotal int64
|
||||
LoadBalanceSelectTotal int64
|
||||
AccountSwitchTotal int64
|
||||
SchedulerLatencyMsTotal int64
|
||||
SchedulerLatencyMsAvg float64
|
||||
StickyHitRatio float64
|
||||
AccountSwitchRate float64
|
||||
LoadSkewAvg float64
|
||||
RuntimeStatsAccountCount int
|
||||
}
|
||||
|
||||
type OpenAIAccountScheduler interface {
|
||||
Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error)
|
||||
ReportResult(accountID int64, success bool, firstTokenMs *int)
|
||||
ReportSwitch()
|
||||
SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot
|
||||
}
|
||||
|
||||
type openAIAccountSchedulerMetrics struct {
|
||||
selectTotal atomic.Int64
|
||||
stickyPreviousHitTotal atomic.Int64
|
||||
stickySessionHitTotal atomic.Int64
|
||||
loadBalanceSelectTotal atomic.Int64
|
||||
accountSwitchTotal atomic.Int64
|
||||
latencyMsTotal atomic.Int64
|
||||
loadSkewMilliTotal atomic.Int64
|
||||
}
|
||||
|
||||
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.selectTotal.Add(1)
|
||||
m.latencyMsTotal.Add(decision.LatencyMs)
|
||||
m.loadSkewMilliTotal.Add(int64(math.Round(decision.LoadSkew * 1000)))
|
||||
if decision.StickyPreviousHit {
|
||||
m.stickyPreviousHitTotal.Add(1)
|
||||
}
|
||||
if decision.StickySessionHit {
|
||||
m.stickySessionHitTotal.Add(1)
|
||||
}
|
||||
if decision.Layer == openAIAccountScheduleLayerLoadBalance {
|
||||
m.loadBalanceSelectTotal.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *openAIAccountSchedulerMetrics) recordSwitch() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.accountSwitchTotal.Add(1)
|
||||
}
|
||||
|
||||
type openAIAccountRuntimeStats struct {
|
||||
accounts sync.Map
|
||||
accountCount atomic.Int64
|
||||
}
|
||||
|
||||
type openAIAccountRuntimeStat struct {
|
||||
errorRateEWMABits atomic.Uint64
|
||||
ttftEWMABits atomic.Uint64
|
||||
}
|
||||
|
||||
func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats {
|
||||
return &openAIAccountRuntimeStats{}
|
||||
}
|
||||
|
||||
func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat {
|
||||
if value, ok := s.accounts.Load(accountID); ok {
|
||||
stat, _ := value.(*openAIAccountRuntimeStat)
|
||||
if stat != nil {
|
||||
return stat
|
||||
}
|
||||
}
|
||||
|
||||
stat := &openAIAccountRuntimeStat{}
|
||||
stat.ttftEWMABits.Store(math.Float64bits(math.NaN()))
|
||||
actual, loaded := s.accounts.LoadOrStore(accountID, stat)
|
||||
if !loaded {
|
||||
s.accountCount.Add(1)
|
||||
return stat
|
||||
}
|
||||
existing, _ := actual.(*openAIAccountRuntimeStat)
|
||||
if existing != nil {
|
||||
return existing
|
||||
}
|
||||
return stat
|
||||
}
|
||||
|
||||
func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) {
|
||||
for {
|
||||
oldBits := target.Load()
|
||||
oldValue := math.Float64frombits(oldBits)
|
||||
newValue := alpha*sample + (1-alpha)*oldValue
|
||||
if target.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) {
|
||||
if s == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
const alpha = 0.2
|
||||
stat := s.loadOrCreate(accountID)
|
||||
|
||||
errorSample := 1.0
|
||||
if success {
|
||||
errorSample = 0.0
|
||||
}
|
||||
updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha)
|
||||
|
||||
if firstTokenMs != nil && *firstTokenMs > 0 {
|
||||
ttft := float64(*firstTokenMs)
|
||||
ttftBits := math.Float64bits(ttft)
|
||||
for {
|
||||
oldBits := stat.ttftEWMABits.Load()
|
||||
oldValue := math.Float64frombits(oldBits)
|
||||
if math.IsNaN(oldValue) {
|
||||
if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
newValue := alpha*ttft + (1-alpha)*oldValue
|
||||
if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) {
|
||||
if s == nil || accountID <= 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
value, ok := s.accounts.Load(accountID)
|
||||
if !ok {
|
||||
return 0, 0, false
|
||||
}
|
||||
stat, _ := value.(*openAIAccountRuntimeStat)
|
||||
if stat == nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load()))
|
||||
ttftValue := math.Float64frombits(stat.ttftEWMABits.Load())
|
||||
if math.IsNaN(ttftValue) {
|
||||
return errorRate, 0, false
|
||||
}
|
||||
return errorRate, ttftValue, true
|
||||
}
|
||||
|
||||
func (s *openAIAccountRuntimeStats) size() int {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
return int(s.accountCount.Load())
|
||||
}
|
||||
|
||||
type defaultOpenAIAccountScheduler struct {
|
||||
service *OpenAIGatewayService
|
||||
metrics openAIAccountSchedulerMetrics
|
||||
stats *openAIAccountRuntimeStats
|
||||
}
|
||||
|
||||
func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler {
|
||||
if stats == nil {
|
||||
stats = newOpenAIAccountRuntimeStats()
|
||||
}
|
||||
return &defaultOpenAIAccountScheduler{
|
||||
service: service,
|
||||
stats: stats,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) Select(
|
||||
ctx context.Context,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||
decision := OpenAIAccountScheduleDecision{}
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
decision.LatencyMs = time.Since(start).Milliseconds()
|
||||
s.metrics.recordSelect(decision)
|
||||
}()
|
||||
|
||||
previousResponseID := strings.TrimSpace(req.PreviousResponseID)
|
||||
if previousResponseID != "" {
|
||||
selection, err := s.service.SelectAccountByPreviousResponseID(
|
||||
ctx,
|
||||
req.GroupID,
|
||||
previousResponseID,
|
||||
req.RequestedModel,
|
||||
req.ExcludedIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, decision, err
|
||||
}
|
||||
if selection != nil && selection.Account != nil {
|
||||
if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) {
|
||||
selection = nil
|
||||
}
|
||||
}
|
||||
if selection != nil && selection.Account != nil {
|
||||
decision.Layer = openAIAccountScheduleLayerPreviousResponse
|
||||
decision.StickyPreviousHit = true
|
||||
decision.SelectedAccountID = selection.Account.ID
|
||||
decision.SelectedAccountType = selection.Account.Type
|
||||
if req.SessionHash != "" {
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, selection.Account.ID)
|
||||
}
|
||||
return selection, decision, nil
|
||||
}
|
||||
}
|
||||
|
||||
selection, err := s.selectBySessionHash(ctx, req)
|
||||
if err != nil {
|
||||
return nil, decision, err
|
||||
}
|
||||
if selection != nil && selection.Account != nil {
|
||||
decision.Layer = openAIAccountScheduleLayerSessionSticky
|
||||
decision.StickySessionHit = true
|
||||
decision.SelectedAccountID = selection.Account.ID
|
||||
decision.SelectedAccountType = selection.Account.Type
|
||||
return selection, decision, nil
|
||||
}
|
||||
|
||||
selection, candidateCount, topK, loadSkew, err := s.selectByLoadBalance(ctx, req)
|
||||
decision.Layer = openAIAccountScheduleLayerLoadBalance
|
||||
decision.CandidateCount = candidateCount
|
||||
decision.TopK = topK
|
||||
decision.LoadSkew = loadSkew
|
||||
if err != nil {
|
||||
return nil, decision, err
|
||||
}
|
||||
if selection != nil && selection.Account != nil {
|
||||
decision.SelectedAccountID = selection.Account.ID
|
||||
decision.SelectedAccountType = selection.Account.Type
|
||||
}
|
||||
return selection, decision, nil
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
||||
ctx context.Context,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
) (*AccountSelectionResult, error) {
|
||||
sessionHash := strings.TrimSpace(req.SessionHash)
|
||||
if sessionHash == "" || s == nil || s.service == nil || s.service.cache == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
accountID := req.StickyAccountID
|
||||
if accountID <= 0 {
|
||||
var err error
|
||||
accountID, err = s.service.getStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
if err != nil || accountID <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
if accountID <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if req.ExcludedIDs != nil {
|
||||
if _, excluded := req.ExcludedIDs[accountID]; excluded {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
account, err := s.service.getSchedulableAccount(ctx, accountID)
|
||||
if err != nil || account == nil {
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() {
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
return nil, nil
|
||||
}
|
||||
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if acquireErr == nil && result.Acquired {
|
||||
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
cfg := s.service.schedulingConfig()
|
||||
if s.service.concurrencyService != nil {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type openAIAccountCandidateScore struct {
|
||||
account *Account
|
||||
loadInfo *AccountLoadInfo
|
||||
score float64
|
||||
errorRate float64
|
||||
ttft float64
|
||||
hasTTFT bool
|
||||
}
|
||||
|
||||
type openAIAccountCandidateHeap []openAIAccountCandidateScore
|
||||
|
||||
func (h openAIAccountCandidateHeap) Len() int {
|
||||
return len(h)
|
||||
}
|
||||
|
||||
func (h openAIAccountCandidateHeap) Less(i, j int) bool {
|
||||
// 最小堆根节点保存“最差”候选,便于 O(log k) 维护 topK。
|
||||
return isOpenAIAccountCandidateBetter(h[j], h[i])
|
||||
}
|
||||
|
||||
func (h openAIAccountCandidateHeap) Swap(i, j int) {
|
||||
h[i], h[j] = h[j], h[i]
|
||||
}
|
||||
|
||||
func (h *openAIAccountCandidateHeap) Push(x any) {
|
||||
candidate, ok := x.(openAIAccountCandidateScore)
|
||||
if !ok {
|
||||
panic("openAIAccountCandidateHeap: invalid element type")
|
||||
}
|
||||
*h = append(*h, candidate)
|
||||
}
|
||||
|
||||
func (h *openAIAccountCandidateHeap) Pop() any {
|
||||
old := *h
|
||||
n := len(old)
|
||||
last := old[n-1]
|
||||
*h = old[:n-1]
|
||||
return last
|
||||
}
|
||||
|
||||
func isOpenAIAccountCandidateBetter(left openAIAccountCandidateScore, right openAIAccountCandidateScore) bool {
|
||||
if left.score != right.score {
|
||||
return left.score > right.score
|
||||
}
|
||||
if left.account.Priority != right.account.Priority {
|
||||
return left.account.Priority < right.account.Priority
|
||||
}
|
||||
if left.loadInfo.LoadRate != right.loadInfo.LoadRate {
|
||||
return left.loadInfo.LoadRate < right.loadInfo.LoadRate
|
||||
}
|
||||
if left.loadInfo.WaitingCount != right.loadInfo.WaitingCount {
|
||||
return left.loadInfo.WaitingCount < right.loadInfo.WaitingCount
|
||||
}
|
||||
return left.account.ID < right.account.ID
|
||||
}
|
||||
|
||||
func selectTopKOpenAICandidates(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
if topK <= 0 {
|
||||
topK = 1
|
||||
}
|
||||
if topK >= len(candidates) {
|
||||
ranked := append([]openAIAccountCandidateScore(nil), candidates...)
|
||||
sort.Slice(ranked, func(i, j int) bool {
|
||||
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
|
||||
})
|
||||
return ranked
|
||||
}
|
||||
|
||||
best := make(openAIAccountCandidateHeap, 0, topK)
|
||||
for _, candidate := range candidates {
|
||||
if len(best) < topK {
|
||||
heap.Push(&best, candidate)
|
||||
continue
|
||||
}
|
||||
if isOpenAIAccountCandidateBetter(candidate, best[0]) {
|
||||
best[0] = candidate
|
||||
heap.Fix(&best, 0)
|
||||
}
|
||||
}
|
||||
|
||||
ranked := make([]openAIAccountCandidateScore, len(best))
|
||||
copy(ranked, best)
|
||||
sort.Slice(ranked, func(i, j int) bool {
|
||||
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
|
||||
})
|
||||
return ranked
|
||||
}
|
||||
|
||||
type openAISelectionRNG struct {
|
||||
state uint64
|
||||
}
|
||||
|
||||
func newOpenAISelectionRNG(seed uint64) openAISelectionRNG {
|
||||
if seed == 0 {
|
||||
seed = 0x9e3779b97f4a7c15
|
||||
}
|
||||
return openAISelectionRNG{state: seed}
|
||||
}
|
||||
|
||||
func (r *openAISelectionRNG) nextUint64() uint64 {
|
||||
// xorshift64*
|
||||
x := r.state
|
||||
x ^= x >> 12
|
||||
x ^= x << 25
|
||||
x ^= x >> 27
|
||||
r.state = x
|
||||
return x * 2685821657736338717
|
||||
}
|
||||
|
||||
func (r *openAISelectionRNG) nextFloat64() float64 {
|
||||
// [0,1)
|
||||
return float64(r.nextUint64()>>11) / (1 << 53)
|
||||
}
|
||||
|
||||
func deriveOpenAISelectionSeed(req OpenAIAccountScheduleRequest) uint64 {
|
||||
hasher := fnv.New64a()
|
||||
writeValue := func(value string) {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return
|
||||
}
|
||||
_, _ = hasher.Write([]byte(trimmed))
|
||||
_, _ = hasher.Write([]byte{0})
|
||||
}
|
||||
|
||||
writeValue(req.SessionHash)
|
||||
writeValue(req.PreviousResponseID)
|
||||
writeValue(req.RequestedModel)
|
||||
if req.GroupID != nil {
|
||||
_, _ = hasher.Write([]byte(strconv.FormatInt(*req.GroupID, 10)))
|
||||
}
|
||||
|
||||
seed := hasher.Sum64()
|
||||
// 对“无会话锚点”的纯负载均衡请求引入时间熵,避免固定命中同一账号。
|
||||
if strings.TrimSpace(req.SessionHash) == "" && strings.TrimSpace(req.PreviousResponseID) == "" {
|
||||
seed ^= uint64(time.Now().UnixNano())
|
||||
}
|
||||
if seed == 0 {
|
||||
seed = uint64(time.Now().UnixNano()) ^ 0x9e3779b97f4a7c15
|
||||
}
|
||||
return seed
|
||||
}
|
||||
|
||||
func buildOpenAIWeightedSelectionOrder(
|
||||
candidates []openAIAccountCandidateScore,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
) []openAIAccountCandidateScore {
|
||||
if len(candidates) <= 1 {
|
||||
return append([]openAIAccountCandidateScore(nil), candidates...)
|
||||
}
|
||||
|
||||
pool := append([]openAIAccountCandidateScore(nil), candidates...)
|
||||
weights := make([]float64, len(pool))
|
||||
minScore := pool[0].score
|
||||
for i := 1; i < len(pool); i++ {
|
||||
if pool[i].score < minScore {
|
||||
minScore = pool[i].score
|
||||
}
|
||||
}
|
||||
for i := range pool {
|
||||
// 将 top-K 分值平移到正区间,避免“单一最高分账号”长期垄断。
|
||||
weight := (pool[i].score - minScore) + 1.0
|
||||
if math.IsNaN(weight) || math.IsInf(weight, 0) || weight <= 0 {
|
||||
weight = 1.0
|
||||
}
|
||||
weights[i] = weight
|
||||
}
|
||||
|
||||
order := make([]openAIAccountCandidateScore, 0, len(pool))
|
||||
rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
|
||||
for len(pool) > 0 {
|
||||
total := 0.0
|
||||
for _, w := range weights {
|
||||
total += w
|
||||
}
|
||||
|
||||
selectedIdx := 0
|
||||
if total > 0 {
|
||||
r := rng.nextFloat64() * total
|
||||
acc := 0.0
|
||||
for i, w := range weights {
|
||||
acc += w
|
||||
if r <= acc {
|
||||
selectedIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
selectedIdx = int(rng.nextUint64() % uint64(len(pool)))
|
||||
}
|
||||
|
||||
order = append(order, pool[selectedIdx])
|
||||
pool = append(pool[:selectedIdx], pool[selectedIdx+1:]...)
|
||||
weights = append(weights[:selectedIdx], weights[selectedIdx+1:]...)
|
||||
}
|
||||
return order
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
ctx context.Context,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
) (*AccountSelectionResult, int, int, float64, error) {
|
||||
accounts, err := s.service.listSchedulableAccounts(ctx, req.GroupID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, 0, err
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
|
||||
}
|
||||
|
||||
filtered := make([]*Account, 0, len(accounts))
|
||||
loadReq := make([]AccountWithConcurrency, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
account := &accounts[i]
|
||||
if req.ExcludedIDs != nil {
|
||||
if _, excluded := req.ExcludedIDs[account.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !account.IsSchedulable() || !account.IsOpenAI() {
|
||||
continue
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, account)
|
||||
loadReq = append(loadReq, AccountWithConcurrency{
|
||||
ID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
})
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
|
||||
}
|
||||
|
||||
loadMap := map[int64]*AccountLoadInfo{}
|
||||
if s.service.concurrencyService != nil {
|
||||
if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil {
|
||||
loadMap = batchLoad
|
||||
}
|
||||
}
|
||||
|
||||
minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
|
||||
maxWaiting := 1
|
||||
loadRateSum := 0.0
|
||||
loadRateSumSquares := 0.0
|
||||
minTTFT, maxTTFT := 0.0, 0.0
|
||||
hasTTFTSample := false
|
||||
candidates := make([]openAIAccountCandidateScore, 0, len(filtered))
|
||||
for _, account := range filtered {
|
||||
loadInfo := loadMap[account.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: account.ID}
|
||||
}
|
||||
if account.Priority < minPriority {
|
||||
minPriority = account.Priority
|
||||
}
|
||||
if account.Priority > maxPriority {
|
||||
maxPriority = account.Priority
|
||||
}
|
||||
if loadInfo.WaitingCount > maxWaiting {
|
||||
maxWaiting = loadInfo.WaitingCount
|
||||
}
|
||||
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
|
||||
if hasTTFT && ttft > 0 {
|
||||
if !hasTTFTSample {
|
||||
minTTFT, maxTTFT = ttft, ttft
|
||||
hasTTFTSample = true
|
||||
} else {
|
||||
if ttft < minTTFT {
|
||||
minTTFT = ttft
|
||||
}
|
||||
if ttft > maxTTFT {
|
||||
maxTTFT = ttft
|
||||
}
|
||||
}
|
||||
}
|
||||
loadRate := float64(loadInfo.LoadRate)
|
||||
loadRateSum += loadRate
|
||||
loadRateSumSquares += loadRate * loadRate
|
||||
candidates = append(candidates, openAIAccountCandidateScore{
|
||||
account: account,
|
||||
loadInfo: loadInfo,
|
||||
errorRate: errorRate,
|
||||
ttft: ttft,
|
||||
hasTTFT: hasTTFT,
|
||||
})
|
||||
}
|
||||
loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
|
||||
|
||||
weights := s.service.openAIWSSchedulerWeights()
|
||||
for i := range candidates {
|
||||
item := &candidates[i]
|
||||
priorityFactor := 1.0
|
||||
if maxPriority > minPriority {
|
||||
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
|
||||
}
|
||||
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
|
||||
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
|
||||
errorFactor := 1 - clamp01(item.errorRate)
|
||||
ttftFactor := 0.5
|
||||
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
|
||||
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
|
||||
}
|
||||
|
||||
item.score = weights.Priority*priorityFactor +
|
||||
weights.Load*loadFactor +
|
||||
weights.Queue*queueFactor +
|
||||
weights.ErrorRate*errorFactor +
|
||||
weights.TTFT*ttftFactor
|
||||
}
|
||||
|
||||
topK := s.service.openAIWSLBTopK()
|
||||
if topK > len(candidates) {
|
||||
topK = len(candidates)
|
||||
}
|
||||
if topK <= 0 {
|
||||
topK = 1
|
||||
}
|
||||
rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
|
||||
selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
|
||||
|
||||
for i := 0; i < len(selectionOrder); i++ {
|
||||
candidate := selectionOrder[i]
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency)
|
||||
if acquireErr != nil {
|
||||
return nil, len(candidates), topK, loadSkew, acquireErr
|
||||
}
|
||||
if result != nil && result.Acquired {
|
||||
if req.SessionHash != "" {
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: candidate.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
}
|
||||
}
|
||||
|
||||
cfg := s.service.schedulingConfig()
|
||||
candidate := selectionOrder[0]
|
||||
return &AccountSelectionResult{
|
||||
Account: candidate.account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: candidate.account.ID,
|
||||
MaxConcurrency: candidate.account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, len(candidates), topK, loadSkew, nil
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
|
||||
// HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。
|
||||
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
|
||||
return true
|
||||
}
|
||||
if s == nil || s.service == nil || account == nil {
|
||||
return false
|
||||
}
|
||||
return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
|
||||
if s == nil || s.stats == nil {
|
||||
return
|
||||
}
|
||||
s.stats.report(accountID, success, firstTokenMs)
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) ReportSwitch() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.metrics.recordSwitch()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot {
|
||||
if s == nil {
|
||||
return OpenAIAccountSchedulerMetricsSnapshot{}
|
||||
}
|
||||
|
||||
selectTotal := s.metrics.selectTotal.Load()
|
||||
prevHit := s.metrics.stickyPreviousHitTotal.Load()
|
||||
sessionHit := s.metrics.stickySessionHitTotal.Load()
|
||||
switchTotal := s.metrics.accountSwitchTotal.Load()
|
||||
latencyTotal := s.metrics.latencyMsTotal.Load()
|
||||
loadSkewTotal := s.metrics.loadSkewMilliTotal.Load()
|
||||
|
||||
snapshot := OpenAIAccountSchedulerMetricsSnapshot{
|
||||
SelectTotal: selectTotal,
|
||||
StickyPreviousHitTotal: prevHit,
|
||||
StickySessionHitTotal: sessionHit,
|
||||
LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(),
|
||||
AccountSwitchTotal: switchTotal,
|
||||
SchedulerLatencyMsTotal: latencyTotal,
|
||||
RuntimeStatsAccountCount: s.stats.size(),
|
||||
}
|
||||
if selectTotal > 0 {
|
||||
snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal)
|
||||
snapshot.StickyHitRatio = float64(prevHit+sessionHit) / float64(selectTotal)
|
||||
snapshot.AccountSwitchRate = float64(switchTotal) / float64(selectTotal)
|
||||
snapshot.LoadSkewAvg = float64(loadSkewTotal) / 1000 / float64(selectTotal)
|
||||
}
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
s.openaiSchedulerOnce.Do(func() {
|
||||
if s.openaiAccountStats == nil {
|
||||
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
|
||||
}
|
||||
if s.openaiScheduler == nil {
|
||||
s.openaiScheduler = newDefaultOpenAIAccountScheduler(s, s.openaiAccountStats)
|
||||
}
|
||||
})
|
||||
return s.openaiScheduler
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
|
||||
ctx context.Context,
|
||||
groupID *int64,
|
||||
previousResponseID string,
|
||||
sessionHash string,
|
||||
requestedModel string,
|
||||
excludedIDs map[int64]struct{},
|
||||
requiredTransport OpenAIUpstreamTransport,
|
||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||
decision := OpenAIAccountScheduleDecision{}
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
if scheduler == nil {
|
||||
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
decision.Layer = openAIAccountScheduleLayerLoadBalance
|
||||
return selection, decision, err
|
||||
}
|
||||
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 {
|
||||
stickyAccountID = accountID
|
||||
}
|
||||
}
|
||||
|
||||
return scheduler.Select(ctx, OpenAIAccountScheduleRequest{
|
||||
GroupID: groupID,
|
||||
SessionHash: sessionHash,
|
||||
StickyAccountID: stickyAccountID,
|
||||
PreviousResponseID: previousResponseID,
|
||||
RequestedModel: requestedModel,
|
||||
RequiredTransport: requiredTransport,
|
||||
ExcludedIDs: excludedIDs,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
if scheduler == nil {
|
||||
return
|
||||
}
|
||||
scheduler.ReportResult(accountID, success, firstTokenMs)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
if scheduler == nil {
|
||||
return
|
||||
}
|
||||
scheduler.ReportSwitch()
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
if scheduler == nil {
|
||||
return OpenAIAccountSchedulerMetricsSnapshot{}
|
||||
}
|
||||
return scheduler.SnapshotMetrics()
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIWSSessionStickyTTL() time.Duration {
|
||||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 {
|
||||
return time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second
|
||||
}
|
||||
return openaiStickySessionTTL
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIWSLBTopK() int {
|
||||
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 {
|
||||
return s.cfg.Gateway.OpenAIWS.LBTopK
|
||||
}
|
||||
return 7
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView {
|
||||
if s != nil && s.cfg != nil {
|
||||
return GatewayOpenAIWSSchedulerScoreWeightsView{
|
||||
Priority: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority,
|
||||
Load: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load,
|
||||
Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue,
|
||||
ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate,
|
||||
TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT,
|
||||
}
|
||||
}
|
||||
return GatewayOpenAIWSSchedulerScoreWeightsView{
|
||||
Priority: 1.0,
|
||||
Load: 1.0,
|
||||
Queue: 0.7,
|
||||
ErrorRate: 0.8,
|
||||
TTFT: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
type GatewayOpenAIWSSchedulerScoreWeightsView struct {
|
||||
Priority float64
|
||||
Load float64
|
||||
Queue float64
|
||||
ErrorRate float64
|
||||
TTFT float64
|
||||
}
|
||||
|
||||
func clamp01(value float64) float64 {
|
||||
switch {
|
||||
case value < 0:
|
||||
return 0
|
||||
case value > 1:
|
||||
return 1
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 {
|
||||
if count <= 1 {
|
||||
return 0
|
||||
}
|
||||
mean := sum / float64(count)
|
||||
variance := sumSquares/float64(count) - mean*mean
|
||||
if variance < 0 {
|
||||
variance = 0
|
||||
}
|
||||
return math.Sqrt(variance)
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func buildOpenAISchedulerBenchmarkCandidates(size int) []openAIAccountCandidateScore {
|
||||
if size <= 0 {
|
||||
return nil
|
||||
}
|
||||
candidates := make([]openAIAccountCandidateScore, 0, size)
|
||||
for i := 0; i < size; i++ {
|
||||
accountID := int64(10_000 + i)
|
||||
candidates = append(candidates, openAIAccountCandidateScore{
|
||||
account: &Account{
|
||||
ID: accountID,
|
||||
Priority: i % 7,
|
||||
},
|
||||
loadInfo: &AccountLoadInfo{
|
||||
AccountID: accountID,
|
||||
LoadRate: (i * 17) % 100,
|
||||
WaitingCount: (i * 11) % 13,
|
||||
},
|
||||
score: float64((i*29)%1000) / 100,
|
||||
errorRate: float64((i * 5) % 100 / 100),
|
||||
ttft: float64(30 + (i*3)%500),
|
||||
hasTTFT: i%3 != 0,
|
||||
})
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func selectTopKOpenAICandidatesBySortBenchmark(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
if topK <= 0 {
|
||||
topK = 1
|
||||
}
|
||||
ranked := append([]openAIAccountCandidateScore(nil), candidates...)
|
||||
sort.Slice(ranked, func(i, j int) bool {
|
||||
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
|
||||
})
|
||||
if topK > len(ranked) {
|
||||
topK = len(ranked)
|
||||
}
|
||||
return ranked[:topK]
|
||||
}
|
||||
|
||||
func BenchmarkOpenAIAccountSchedulerSelectTopK(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
size int
|
||||
topK int
|
||||
}{
|
||||
{name: "n_16_k_3", size: 16, topK: 3},
|
||||
{name: "n_64_k_3", size: 64, topK: 3},
|
||||
{name: "n_256_k_5", size: 256, topK: 5},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
candidates := buildOpenAISchedulerBenchmarkCandidates(tc.size)
|
||||
b.Run(tc.name+"/heap_topk", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
result := selectTopKOpenAICandidates(candidates, tc.topK)
|
||||
if len(result) == 0 {
|
||||
b.Fatal("unexpected empty result")
|
||||
}
|
||||
}
|
||||
})
|
||||
b.Run(tc.name+"/full_sort", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
result := selectTopKOpenAICandidatesBySortBenchmark(candidates, tc.topK)
|
||||
if len(result) == 0 {
|
||||
b.Fatal("unexpected empty result")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
841
backend/internal/service/openai_account_scheduler_test.go
Normal file
841
backend/internal/service/openai_account_scheduler_test.go
Normal file
@@ -0,0 +1,841 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(9)
|
||||
account := Account{
|
||||
ID: 1001,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 2,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
|
||||
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
store := svc.getOpenAIWSStateStore()
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_001", account.ID, time.Hour))
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"resp_prev_001",
|
||||
"session_hash_001",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, account.ID, selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer)
|
||||
require.True(t, decision.StickyPreviousHit)
|
||||
require.Equal(t, account.ID, cache.sessionBindings["openai:session_hash_001"])
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10)
|
||||
account := Account{
|
||||
ID: 2001,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:session_hash_abc": account.ID,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: &config.Config{},
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
"session_hash_abc",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, account.ID, selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
|
||||
require.True(t, decision.StickySessionHit)
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsSticky(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10100)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 21001,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
},
|
||||
{
|
||||
ID: 21002,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 9,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:session_hash_sticky_busy": 21001,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2
|
||||
cfg.Gateway.Scheduling.StickySessionWaitTimeout = 45 * time.Second
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
acquireResults: map[int64]bool{
|
||||
21001: false, // sticky 账号已满
|
||||
21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换)
|
||||
},
|
||||
waitCounts: map[int64]int{
|
||||
21001: 999,
|
||||
},
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
21001: {AccountID: 21001, LoadRate: 90, WaitingCount: 9},
|
||||
21002: {AccountID: 21002, LoadRate: 1, WaitingCount: 0},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
"session_hash_sticky_busy",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(21001), selection.Account.ID, "busy sticky account should remain selected")
|
||||
require.False(t, selection.Acquired)
|
||||
require.NotNil(t, selection.WaitPlan)
|
||||
require.Equal(t, int64(21001), selection.WaitPlan.AccountID)
|
||||
require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
|
||||
require.True(t, decision.StickySessionHit)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(1010)
|
||||
account := Account{
|
||||
ID: 2101,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_ws_force_http": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:session_hash_force_http": account.ID,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: &config.Config{},
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
"session_hash_force_http",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, account.ID, selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
|
||||
require.True(t, decision.StickySessionHit)
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStickyHTTPAccount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(1011)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 2201,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
},
|
||||
{
|
||||
ID: 2202,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 5,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:session_hash_ws_only": 2201,
|
||||
},
|
||||
}
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
|
||||
// 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0},
|
||||
2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
"session_hash_ws_only",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(2202), selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
require.False(t, decision.StickySessionHit)
|
||||
require.Equal(t, 1, decision.CandidateCount)
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailableAccount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(1012)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 2301,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &stubGatewayCache{},
|
||||
cfg: newOpenAIWSV2TestConfig(),
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
"",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, selection)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
require.Equal(t, 0, decision.CandidateCount)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(11)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 3001,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
},
|
||||
{
|
||||
ID: 3002,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
},
|
||||
{
|
||||
ID: 3003,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.LBTopK = 2
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.4
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1
|
||||
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8},
|
||||
3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1},
|
||||
3003: {AccountID: 3003, LoadRate: 10, WaitingCount: 0},
|
||||
},
|
||||
acquireResults: map[int64]bool{
|
||||
3003: false, // top1 失败,必须回退到 top-K 的下一候选
|
||||
3002: true,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &stubGatewayCache{},
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
"",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(3002), selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
require.Equal(t, 3, decision.CandidateCount)
|
||||
require.Equal(t, 2, decision.TopK)
|
||||
require.Greater(t, decision.LoadSkew, 0.0)
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(12)
|
||||
account := Account{
|
||||
ID: 4001,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:session_hash_metrics": account.ID,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: &config.Config{},
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120))
|
||||
svc.RecordOpenAIAccountSwitch()
|
||||
|
||||
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
|
||||
require.GreaterOrEqual(t, snapshot.SelectTotal, int64(1))
|
||||
require.GreaterOrEqual(t, snapshot.StickySessionHitTotal, int64(1))
|
||||
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
|
||||
require.GreaterOrEqual(t, snapshot.SchedulerLatencyMsAvg, float64(0))
|
||||
require.GreaterOrEqual(t, snapshot.StickyHitRatio, 0.0)
|
||||
require.GreaterOrEqual(t, snapshot.RuntimeStatsAccountCount, 1)
|
||||
}
|
||||
|
||||
func intPtrForTest(v int) *int {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestOpenAIAccountRuntimeStats_ReportAndSnapshot(t *testing.T) {
|
||||
stats := newOpenAIAccountRuntimeStats()
|
||||
stats.report(1001, true, nil)
|
||||
firstTTFT := 100
|
||||
stats.report(1001, false, &firstTTFT)
|
||||
secondTTFT := 200
|
||||
stats.report(1001, false, &secondTTFT)
|
||||
|
||||
errorRate, ttft, hasTTFT := stats.snapshot(1001)
|
||||
require.True(t, hasTTFT)
|
||||
require.InDelta(t, 0.36, errorRate, 1e-9)
|
||||
require.InDelta(t, 120.0, ttft, 1e-9)
|
||||
require.Equal(t, 1, stats.size())
|
||||
}
|
||||
|
||||
func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) {
|
||||
stats := newOpenAIAccountRuntimeStats()
|
||||
|
||||
const (
|
||||
accountCount = 4
|
||||
workers = 16
|
||||
iterations = 800
|
||||
)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(workers)
|
||||
for worker := 0; worker < workers; worker++ {
|
||||
worker := worker
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < iterations; i++ {
|
||||
accountID := int64(i%accountCount + 1)
|
||||
success := (i+worker)%3 != 0
|
||||
ttft := 80 + (i+worker)%40
|
||||
stats.report(accountID, success, &ttft)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
require.Equal(t, accountCount, stats.size())
|
||||
for accountID := int64(1); accountID <= accountCount; accountID++ {
|
||||
errorRate, ttft, hasTTFT := stats.snapshot(accountID)
|
||||
require.GreaterOrEqual(t, errorRate, 0.0)
|
||||
require.LessOrEqual(t, errorRate, 1.0)
|
||||
require.True(t, hasTTFT)
|
||||
require.Greater(t, ttft, 0.0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectTopKOpenAICandidates(t *testing.T) {
|
||||
candidates := []openAIAccountCandidateScore{
|
||||
{
|
||||
account: &Account{ID: 11, Priority: 2},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 1},
|
||||
score: 10.0,
|
||||
},
|
||||
{
|
||||
account: &Account{ID: 12, Priority: 1},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 20, WaitingCount: 1},
|
||||
score: 9.5,
|
||||
},
|
||||
{
|
||||
account: &Account{ID: 13, Priority: 1},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 0},
|
||||
score: 10.0,
|
||||
},
|
||||
{
|
||||
account: &Account{ID: 14, Priority: 0},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 40, WaitingCount: 0},
|
||||
score: 8.0,
|
||||
},
|
||||
}
|
||||
|
||||
top2 := selectTopKOpenAICandidates(candidates, 2)
|
||||
require.Len(t, top2, 2)
|
||||
require.Equal(t, int64(13), top2[0].account.ID)
|
||||
require.Equal(t, int64(11), top2[1].account.ID)
|
||||
|
||||
topAll := selectTopKOpenAICandidates(candidates, 8)
|
||||
require.Len(t, topAll, len(candidates))
|
||||
require.Equal(t, int64(13), topAll[0].account.ID)
|
||||
require.Equal(t, int64(11), topAll[1].account.ID)
|
||||
require.Equal(t, int64(12), topAll[2].account.ID)
|
||||
require.Equal(t, int64(14), topAll[3].account.ID)
|
||||
}
|
||||
|
||||
func TestBuildOpenAIWeightedSelectionOrder_DeterministicBySessionSeed(t *testing.T) {
|
||||
candidates := []openAIAccountCandidateScore{
|
||||
{
|
||||
account: &Account{ID: 101},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 0},
|
||||
score: 4.2,
|
||||
},
|
||||
{
|
||||
account: &Account{ID: 102},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 1},
|
||||
score: 3.5,
|
||||
},
|
||||
{
|
||||
account: &Account{ID: 103},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 2},
|
||||
score: 2.1,
|
||||
},
|
||||
}
|
||||
req := OpenAIAccountScheduleRequest{
|
||||
GroupID: int64PtrForTest(99),
|
||||
SessionHash: "session_seed_fixed",
|
||||
RequestedModel: "gpt-5.1",
|
||||
}
|
||||
|
||||
first := buildOpenAIWeightedSelectionOrder(candidates, req)
|
||||
second := buildOpenAIWeightedSelectionOrder(candidates, req)
|
||||
require.Len(t, first, len(candidates))
|
||||
require.Len(t, second, len(candidates))
|
||||
for i := range first {
|
||||
require.Equal(t, first[i].account.ID, second[i].account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesAcrossSessions(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(15)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 5101,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 3,
|
||||
Priority: 0,
|
||||
},
|
||||
{
|
||||
ID: 5102,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 3,
|
||||
Priority: 0,
|
||||
},
|
||||
{
|
||||
ID: 5103,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 3,
|
||||
Priority: 0,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.LBTopK = 3
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
|
||||
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1},
|
||||
5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1},
|
||||
5103: {AccountID: 5103, LoadRate: 20, WaitingCount: 1},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selected := make(map[int64]int, len(accounts))
|
||||
for i := 0; i < 60; i++ {
|
||||
sessionHash := fmt.Sprintf("session_hash_lb_%d", i)
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"",
|
||||
sessionHash,
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
selected[selection.Account.ID]++
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
// 多 session 应该能打散到多个账号,避免“恒定单账号命中”。
|
||||
require.GreaterOrEqual(t, len(selected), 2)
|
||||
}
|
||||
|
||||
func TestDeriveOpenAISelectionSeed_NoAffinityAddsEntropy(t *testing.T) {
|
||||
req := OpenAIAccountScheduleRequest{
|
||||
RequestedModel: "gpt-5.1",
|
||||
}
|
||||
seed1 := deriveOpenAISelectionSeed(req)
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
seed2 := deriveOpenAISelectionSeed(req)
|
||||
require.NotZero(t, seed1)
|
||||
require.NotZero(t, seed2)
|
||||
require.NotEqual(t, seed1, seed2)
|
||||
}
|
||||
|
||||
func TestBuildOpenAIWeightedSelectionOrder_HandlesInvalidScores(t *testing.T) {
|
||||
candidates := []openAIAccountCandidateScore{
|
||||
{
|
||||
account: &Account{ID: 901},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
|
||||
score: math.NaN(),
|
||||
},
|
||||
{
|
||||
account: &Account{ID: 902},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
|
||||
score: math.Inf(1),
|
||||
},
|
||||
{
|
||||
account: &Account{ID: 903},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
|
||||
score: -1,
|
||||
},
|
||||
}
|
||||
req := OpenAIAccountScheduleRequest{
|
||||
SessionHash: "seed_invalid_scores",
|
||||
}
|
||||
|
||||
order := buildOpenAIWeightedSelectionOrder(candidates, req)
|
||||
require.Len(t, order, len(candidates))
|
||||
seen := map[int64]struct{}{}
|
||||
for _, item := range order {
|
||||
seen[item.account.ID] = struct{}{}
|
||||
}
|
||||
require.Len(t, seen, len(candidates))
|
||||
}
|
||||
|
||||
func TestOpenAISelectionRNG_SeedZeroStillWorks(t *testing.T) {
|
||||
rng := newOpenAISelectionRNG(0)
|
||||
v1 := rng.nextUint64()
|
||||
v2 := rng.nextUint64()
|
||||
require.NotEqual(t, v1, v2)
|
||||
require.GreaterOrEqual(t, rng.nextFloat64(), 0.0)
|
||||
require.Less(t, rng.nextFloat64(), 1.0)
|
||||
}
|
||||
|
||||
func TestOpenAIAccountCandidateHeap_PushPopAndInvalidType(t *testing.T) {
|
||||
h := openAIAccountCandidateHeap{}
|
||||
h.Push(openAIAccountCandidateScore{
|
||||
account: &Account{ID: 7001},
|
||||
loadInfo: &AccountLoadInfo{LoadRate: 0, WaitingCount: 0},
|
||||
score: 1.0,
|
||||
})
|
||||
require.Equal(t, 1, h.Len())
|
||||
popped, ok := h.Pop().(openAIAccountCandidateScore)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(7001), popped.account.ID)
|
||||
require.Equal(t, 0, h.Len())
|
||||
|
||||
require.Panics(t, func() {
|
||||
h.Push("bad_element_type")
|
||||
})
|
||||
}
|
||||
|
||||
func TestClamp01_AllBranches(t *testing.T) {
|
||||
require.Equal(t, 0.0, clamp01(-0.2))
|
||||
require.Equal(t, 1.0, clamp01(1.3))
|
||||
require.Equal(t, 0.5, clamp01(0.5))
|
||||
}
|
||||
|
||||
func TestCalcLoadSkewByMoments_Branches(t *testing.T) {
|
||||
require.Equal(t, 0.0, calcLoadSkewByMoments(1, 1, 1))
|
||||
// variance < 0 分支:sumSquares/count - mean^2 为负值时应钳制为 0。
|
||||
require.Equal(t, 0.0, calcLoadSkewByMoments(1, 0, 2))
|
||||
require.GreaterOrEqual(t, calcLoadSkewByMoments(6, 20, 3), 0.0)
|
||||
}
|
||||
|
||||
func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
|
||||
schedulerAny := newDefaultOpenAIAccountScheduler(&OpenAIGatewayService{}, nil)
|
||||
scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
|
||||
require.True(t, ok)
|
||||
|
||||
ttft := 100
|
||||
scheduler.ReportResult(1001, true, &ttft)
|
||||
scheduler.ReportSwitch()
|
||||
scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{
|
||||
Layer: openAIAccountScheduleLayerLoadBalance,
|
||||
LatencyMs: 8,
|
||||
LoadSkew: 0.5,
|
||||
StickyPreviousHit: true,
|
||||
})
|
||||
scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{
|
||||
Layer: openAIAccountScheduleLayerSessionSticky,
|
||||
LatencyMs: 6,
|
||||
LoadSkew: 0.2,
|
||||
StickySessionHit: true,
|
||||
})
|
||||
|
||||
snapshot := scheduler.SnapshotMetrics()
|
||||
require.Equal(t, int64(2), snapshot.SelectTotal)
|
||||
require.Equal(t, int64(1), snapshot.StickyPreviousHitTotal)
|
||||
require.Equal(t, int64(1), snapshot.StickySessionHitTotal)
|
||||
require.Equal(t, int64(1), snapshot.LoadBalanceSelectTotal)
|
||||
require.Equal(t, int64(1), snapshot.AccountSwitchTotal)
|
||||
require.Greater(t, snapshot.SchedulerLatencyMsAvg, 0.0)
|
||||
require.Greater(t, snapshot.StickyHitRatio, 0.0)
|
||||
require.Greater(t, snapshot.LoadSkewAvg, 0.0)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
ttft := 120
|
||||
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
|
||||
svc.RecordOpenAIAccountSwitch()
|
||||
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
|
||||
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
|
||||
require.Equal(t, 7, svc.openAIWSLBTopK())
|
||||
require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
|
||||
|
||||
defaultWeights := svc.openAIWSSchedulerWeights()
|
||||
require.Equal(t, 1.0, defaultWeights.Priority)
|
||||
require.Equal(t, 1.0, defaultWeights.Load)
|
||||
require.Equal(t, 0.7, defaultWeights.Queue)
|
||||
require.Equal(t, 0.8, defaultWeights.ErrorRate)
|
||||
require.Equal(t, 0.5, defaultWeights.TTFT)
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.LBTopK = 9
|
||||
cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 180
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.2
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0.3
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.4
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.5
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.6
|
||||
svcWithCfg := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
require.Equal(t, 9, svcWithCfg.openAIWSLBTopK())
|
||||
require.Equal(t, 180*time.Second, svcWithCfg.openAIWSSessionStickyTTL())
|
||||
customWeights := svcWithCfg.openAIWSSchedulerWeights()
|
||||
require.Equal(t, 0.2, customWeights.Priority)
|
||||
require.Equal(t, 0.3, customWeights.Load)
|
||||
require.Equal(t, 0.4, customWeights.Queue)
|
||||
require.Equal(t, 0.5, customWeights.ErrorRate)
|
||||
require.Equal(t, 0.6, customWeights.TTFT)
|
||||
}
|
||||
|
||||
func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *testing.T) {
|
||||
scheduler := &defaultOpenAIAccountScheduler{}
|
||||
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportAny))
|
||||
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE))
|
||||
require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2))
|
||||
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
scheduler.service = &OpenAIGatewayService{cfg: cfg}
|
||||
account := &Account{
|
||||
ID: 8801,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.True(t, scheduler.isAccountTransportCompatible(account, OpenAIUpstreamTransportResponsesWebsocketV2))
|
||||
}
|
||||
|
||||
func int64PtrForTest(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
71
backend/internal/service/openai_client_transport.go
Normal file
71
backend/internal/service/openai_client_transport.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpenAIClientTransport 表示客户端入站协议类型。
|
||||
type OpenAIClientTransport string
|
||||
|
||||
const (
|
||||
OpenAIClientTransportUnknown OpenAIClientTransport = ""
|
||||
OpenAIClientTransportHTTP OpenAIClientTransport = "http"
|
||||
OpenAIClientTransportWS OpenAIClientTransport = "ws"
|
||||
)
|
||||
|
||||
const openAIClientTransportContextKey = "openai_client_transport"
|
||||
|
||||
// SetOpenAIClientTransport 标记当前请求的客户端入站协议。
|
||||
func SetOpenAIClientTransport(c *gin.Context, transport OpenAIClientTransport) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
normalized := normalizeOpenAIClientTransport(transport)
|
||||
if normalized == OpenAIClientTransportUnknown {
|
||||
return
|
||||
}
|
||||
c.Set(openAIClientTransportContextKey, string(normalized))
|
||||
}
|
||||
|
||||
// GetOpenAIClientTransport 读取当前请求的客户端入站协议。
|
||||
func GetOpenAIClientTransport(c *gin.Context) OpenAIClientTransport {
|
||||
if c == nil {
|
||||
return OpenAIClientTransportUnknown
|
||||
}
|
||||
raw, ok := c.Get(openAIClientTransportContextKey)
|
||||
if !ok || raw == nil {
|
||||
return OpenAIClientTransportUnknown
|
||||
}
|
||||
|
||||
switch v := raw.(type) {
|
||||
case OpenAIClientTransport:
|
||||
return normalizeOpenAIClientTransport(v)
|
||||
case string:
|
||||
return normalizeOpenAIClientTransport(OpenAIClientTransport(v))
|
||||
default:
|
||||
return OpenAIClientTransportUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpenAIClientTransport(transport OpenAIClientTransport) OpenAIClientTransport {
|
||||
switch strings.ToLower(strings.TrimSpace(string(transport))) {
|
||||
case string(OpenAIClientTransportHTTP), "http_sse", "sse":
|
||||
return OpenAIClientTransportHTTP
|
||||
case string(OpenAIClientTransportWS), "websocket":
|
||||
return OpenAIClientTransportWS
|
||||
default:
|
||||
return OpenAIClientTransportUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func resolveOpenAIWSDecisionByClientTransport(
|
||||
decision OpenAIWSProtocolDecision,
|
||||
clientTransport OpenAIClientTransport,
|
||||
) OpenAIWSProtocolDecision {
|
||||
if clientTransport == OpenAIClientTransportHTTP {
|
||||
return openAIWSHTTPDecision("client_protocol_http")
|
||||
}
|
||||
return decision
|
||||
}
|
||||
107
backend/internal/service/openai_client_transport_test.go
Normal file
107
backend/internal/service/openai_client_transport_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpenAIClientTransport_SetAndGet(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(c))
|
||||
|
||||
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
|
||||
require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c))
|
||||
|
||||
SetOpenAIClientTransport(c, OpenAIClientTransportWS)
|
||||
require.Equal(t, OpenAIClientTransportWS, GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
func TestOpenAIClientTransport_GetNormalizesRawContextValue(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rawValue any
|
||||
want OpenAIClientTransport
|
||||
}{
|
||||
{
|
||||
name: "type_value_ws",
|
||||
rawValue: OpenAIClientTransportWS,
|
||||
want: OpenAIClientTransportWS,
|
||||
},
|
||||
{
|
||||
name: "http_sse_alias",
|
||||
rawValue: "http_sse",
|
||||
want: OpenAIClientTransportHTTP,
|
||||
},
|
||||
{
|
||||
name: "sse_alias",
|
||||
rawValue: "sSe",
|
||||
want: OpenAIClientTransportHTTP,
|
||||
},
|
||||
{
|
||||
name: "websocket_alias",
|
||||
rawValue: "WebSocket",
|
||||
want: OpenAIClientTransportWS,
|
||||
},
|
||||
{
|
||||
name: "invalid_string",
|
||||
rawValue: "tcp",
|
||||
want: OpenAIClientTransportUnknown,
|
||||
},
|
||||
{
|
||||
name: "invalid_type",
|
||||
rawValue: 123,
|
||||
want: OpenAIClientTransportUnknown,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Set(openAIClientTransportContextKey, tt.rawValue)
|
||||
require.Equal(t, tt.want, GetOpenAIClientTransport(c))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIClientTransport_NilAndUnknownInput(t *testing.T) {
|
||||
SetOpenAIClientTransport(nil, OpenAIClientTransportHTTP)
|
||||
require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(nil))
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
SetOpenAIClientTransport(c, OpenAIClientTransportUnknown)
|
||||
_, exists := c.Get(openAIClientTransportContextKey)
|
||||
require.False(t, exists)
|
||||
|
||||
SetOpenAIClientTransport(c, OpenAIClientTransport(" "))
|
||||
_, exists = c.Get(openAIClientTransportContextKey)
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func TestResolveOpenAIWSDecisionByClientTransport(t *testing.T) {
|
||||
base := OpenAIWSProtocolDecision{
|
||||
Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
Reason: "ws_v2_enabled",
|
||||
}
|
||||
|
||||
httpDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportHTTP)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, httpDecision.Transport)
|
||||
require.Equal(t, "client_protocol_http", httpDecision.Reason)
|
||||
|
||||
wsDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportWS)
|
||||
require.Equal(t, base, wsDecision)
|
||||
|
||||
unknownDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportUnknown)
|
||||
require.Equal(t, base, unknownDecision)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -123,3 +123,19 @@ func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "parse request")
|
||||
}
|
||||
|
||||
func TestGetOpenAIRequestBodyMap_WriteBackContextCache(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
got, err := getOpenAIRequestBodyMap(c, []byte(`{"model":"gpt-5","stream":true}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "gpt-5", got["model"])
|
||||
|
||||
cached, ok := c.Get(OpenAIParsedRequestBodyKey)
|
||||
require.True(t, ok)
|
||||
cachedMap, ok := cached.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, got, cachedMap)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -166,6 +168,54 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_GenerateSessionHash_UsesXXHash64(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
|
||||
c.Request.Header.Set("session_id", "sess-fixed-value")
|
||||
svc := &OpenAIGatewayService{}
|
||||
|
||||
got := svc.GenerateSessionHash(c, nil)
|
||||
want := fmt.Sprintf("%016x", xxhash.Sum64String("sess-fixed-value"))
|
||||
require.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
|
||||
c.Request.Header.Set("session_id", "sess-legacy-check")
|
||||
svc := &OpenAIGatewayService{}
|
||||
|
||||
sessionHash := svc.GenerateSessionHash(c, nil)
|
||||
require.NotEmpty(t, sessionHash)
|
||||
require.NotNil(t, c.Request)
|
||||
require.NotNil(t, c.Request.Context())
|
||||
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
seed := "openai_ws_ingress:9:100:200"
|
||||
|
||||
got := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), seed)
|
||||
want := fmt.Sprintf("%016x", xxhash.Sum64String(seed))
|
||||
require.Equal(t, want, got)
|
||||
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
|
||||
|
||||
empty := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), " ")
|
||||
require.Equal(t, "", empty)
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
if c.waitCounts != nil {
|
||||
if count, ok := c.waitCounts[accountID]; ok {
|
||||
|
||||
@@ -0,0 +1,357 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
var (
|
||||
benchmarkToolContinuationBoolSink bool
|
||||
benchmarkWSParseStringSink string
|
||||
benchmarkWSParseMapSink map[string]any
|
||||
benchmarkUsageSink OpenAIUsage
|
||||
)
|
||||
|
||||
func BenchmarkToolContinuationValidationLegacy(b *testing.B) {
|
||||
reqBody := benchmarkToolContinuationRequestBody()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkToolContinuationBoolSink = legacyValidateFunctionCallOutputContext(reqBody)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkToolContinuationValidationOptimized(b *testing.B) {
|
||||
reqBody := benchmarkToolContinuationRequestBody()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkToolContinuationBoolSink = optimizedValidateFunctionCallOutputContext(reqBody)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWSIngressPayloadParseLegacy(b *testing.B) {
|
||||
raw := benchmarkWSIngressPayloadBytes()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
eventType, model, promptCacheKey, previousResponseID, payload, err := legacyParseWSIngressPayload(raw)
|
||||
if err == nil {
|
||||
benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID
|
||||
benchmarkWSParseMapSink = payload
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWSIngressPayloadParseOptimized(b *testing.B) {
|
||||
raw := benchmarkWSIngressPayloadBytes()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
eventType, model, promptCacheKey, previousResponseID, payload, err := optimizedParseWSIngressPayload(raw)
|
||||
if err == nil {
|
||||
benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID
|
||||
benchmarkWSParseMapSink = payload
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOpenAIUsageExtractLegacy(b *testing.B) {
|
||||
body := benchmarkOpenAIUsageJSONBytes()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
usage, ok := legacyExtractOpenAIUsageFromJSONBytes(body)
|
||||
if ok {
|
||||
benchmarkUsageSink = usage
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOpenAIUsageExtractOptimized(b *testing.B) {
|
||||
body := benchmarkOpenAIUsageJSONBytes()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
usage, ok := extractOpenAIUsageFromJSONBytes(body)
|
||||
if ok {
|
||||
benchmarkUsageSink = usage
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkToolContinuationRequestBody() map[string]any {
|
||||
input := make([]any, 0, 64)
|
||||
for i := 0; i < 24; i++ {
|
||||
input = append(input, map[string]any{
|
||||
"type": "text",
|
||||
"text": "benchmark text",
|
||||
})
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
callID := "call_" + strconv.Itoa(i)
|
||||
input = append(input, map[string]any{
|
||||
"type": "tool_call",
|
||||
"call_id": callID,
|
||||
})
|
||||
input = append(input, map[string]any{
|
||||
"type": "function_call_output",
|
||||
"call_id": callID,
|
||||
})
|
||||
input = append(input, map[string]any{
|
||||
"type": "item_reference",
|
||||
"id": callID,
|
||||
})
|
||||
}
|
||||
return map[string]any{
|
||||
"model": "gpt-5.3-codex",
|
||||
"input": input,
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkWSIngressPayloadBytes() []byte {
|
||||
return []byte(`{"type":"response.create","model":"gpt-5.3-codex","prompt_cache_key":"cache_bench","previous_response_id":"resp_prev_bench","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`)
|
||||
}
|
||||
|
||||
func benchmarkOpenAIUsageJSONBytes() []byte {
|
||||
return []byte(`{"id":"resp_bench","object":"response","model":"gpt-5.3-codex","usage":{"input_tokens":3210,"output_tokens":987,"input_tokens_details":{"cached_tokens":456}}}`)
|
||||
}
|
||||
|
||||
func legacyValidateFunctionCallOutputContext(reqBody map[string]any) bool {
|
||||
if !legacyHasFunctionCallOutput(reqBody) {
|
||||
return true
|
||||
}
|
||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||
if strings.TrimSpace(previousResponseID) != "" {
|
||||
return true
|
||||
}
|
||||
if legacyHasToolCallContext(reqBody) {
|
||||
return true
|
||||
}
|
||||
if legacyHasFunctionCallOutputMissingCallID(reqBody) {
|
||||
return false
|
||||
}
|
||||
callIDs := legacyFunctionCallOutputCallIDs(reqBody)
|
||||
return legacyHasItemReferenceForCallIDs(reqBody, callIDs)
|
||||
}
|
||||
|
||||
func optimizedValidateFunctionCallOutputContext(reqBody map[string]any) bool {
|
||||
validation := ValidateFunctionCallOutputContext(reqBody)
|
||||
if !validation.HasFunctionCallOutput {
|
||||
return true
|
||||
}
|
||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||
if strings.TrimSpace(previousResponseID) != "" {
|
||||
return true
|
||||
}
|
||||
if validation.HasToolCallContext {
|
||||
return true
|
||||
}
|
||||
if validation.HasFunctionCallOutputMissingCallID {
|
||||
return false
|
||||
}
|
||||
return validation.HasItemReferenceForAllCallIDs
|
||||
}
|
||||
|
||||
func legacyHasFunctionCallOutput(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType == "function_call_output" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func legacyHasToolCallContext(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "tool_call" && itemType != "function_call" {
|
||||
continue
|
||||
}
|
||||
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func legacyFunctionCallOutputCallIDs(reqBody map[string]any) []string {
|
||||
if reqBody == nil {
|
||||
return nil
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
ids := make(map[string]struct{})
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "function_call_output" {
|
||||
continue
|
||||
}
|
||||
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||
ids[callID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
callIDs := make([]string, 0, len(ids))
|
||||
for id := range ids {
|
||||
callIDs = append(callIDs, id)
|
||||
}
|
||||
return callIDs
|
||||
}
|
||||
|
||||
func legacyHasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "function_call_output" {
|
||||
continue
|
||||
}
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
if strings.TrimSpace(callID) == "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func legacyHasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
|
||||
if reqBody == nil || len(callIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
referenceIDs := make(map[string]struct{})
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "item_reference" {
|
||||
continue
|
||||
}
|
||||
idValue, _ := itemMap["id"].(string)
|
||||
idValue = strings.TrimSpace(idValue)
|
||||
if idValue == "" {
|
||||
continue
|
||||
}
|
||||
referenceIDs[idValue] = struct{}{}
|
||||
}
|
||||
if len(referenceIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, callID := range callIDs {
|
||||
if _, ok := referenceIDs[callID]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func legacyParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) {
|
||||
values := gjson.GetManyBytes(raw, "type", "model", "prompt_cache_key", "previous_response_id")
|
||||
eventType = strings.TrimSpace(values[0].String())
|
||||
if eventType == "" {
|
||||
eventType = "response.create"
|
||||
}
|
||||
model = strings.TrimSpace(values[1].String())
|
||||
promptCacheKey = strings.TrimSpace(values[2].String())
|
||||
previousResponseID = strings.TrimSpace(values[3].String())
|
||||
payload = make(map[string]any)
|
||||
if err = json.Unmarshal(raw, &payload); err != nil {
|
||||
return "", "", "", "", nil, err
|
||||
}
|
||||
if _, exists := payload["type"]; !exists {
|
||||
payload["type"] = "response.create"
|
||||
}
|
||||
return eventType, model, promptCacheKey, previousResponseID, payload, nil
|
||||
}
|
||||
|
||||
func optimizedParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) {
|
||||
payload = make(map[string]any)
|
||||
if err = json.Unmarshal(raw, &payload); err != nil {
|
||||
return "", "", "", "", nil, err
|
||||
}
|
||||
eventType = openAIWSPayloadString(payload, "type")
|
||||
if eventType == "" {
|
||||
eventType = "response.create"
|
||||
payload["type"] = eventType
|
||||
}
|
||||
model = openAIWSPayloadString(payload, "model")
|
||||
promptCacheKey = openAIWSPayloadString(payload, "prompt_cache_key")
|
||||
previousResponseID = openAIWSPayloadString(payload, "previous_response_id")
|
||||
return eventType, model, promptCacheKey, previousResponseID, payload, nil
|
||||
}
|
||||
|
||||
func legacyExtractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
|
||||
var response struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
return OpenAIUsage{}, false
|
||||
}
|
||||
return OpenAIUsage{
|
||||
InputTokens: response.Usage.InputTokens,
|
||||
OutputTokens: response.Usage.OutputTokens,
|
||||
CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens,
|
||||
}, true
|
||||
}
|
||||
@@ -515,7 +515,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
|
||||
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
|
||||
require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent"))
|
||||
require.Equal(t, "codex_cli_rs/0.104.0", upstream.lastReq.Header.Get("User-Agent"))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) {
|
||||
|
||||
@@ -5,8 +5,12 @@ import (
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -16,6 +20,13 @@ import (
|
||||
|
||||
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
|
||||
|
||||
var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`)
|
||||
|
||||
type soraSessionChunk struct {
|
||||
index int
|
||||
value string
|
||||
}
|
||||
|
||||
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
||||
type OpenAIOAuthService struct {
|
||||
sessionStore *openai.SessionStore
|
||||
@@ -39,7 +50,7 @@ type OpenAIAuthURLResult struct {
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates an OpenAI OAuth authorization URL
|
||||
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*OpenAIAuthURLResult, error) {
|
||||
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, platform string) (*OpenAIAuthURLResult, error) {
|
||||
// Generate PKCE values
|
||||
state, err := openai.GenerateState()
|
||||
if err != nil {
|
||||
@@ -75,11 +86,14 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
if redirectURI == "" {
|
||||
redirectURI = openai.DefaultRedirectURI
|
||||
}
|
||||
normalizedPlatform := normalizeOpenAIOAuthPlatform(platform)
|
||||
clientID, _ := openai.OAuthClientConfigByPlatform(normalizedPlatform)
|
||||
|
||||
// Store session
|
||||
session := &openai.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
@@ -87,7 +101,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
// Build authorization URL
|
||||
authURL := openai.BuildAuthorizationURL(state, codeChallenge, redirectURI)
|
||||
authURL := openai.BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, normalizedPlatform)
|
||||
|
||||
return &OpenAIAuthURLResult{
|
||||
AuthURL: authURL,
|
||||
@@ -111,6 +125,7 @@ type OpenAITokenInfo struct {
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
|
||||
@@ -148,9 +163,13 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
if input.RedirectURI != "" {
|
||||
redirectURI = input.RedirectURI
|
||||
}
|
||||
clientID := strings.TrimSpace(session.ClientID)
|
||||
if clientID == "" {
|
||||
clientID = openai.ClientID
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -158,8 +177,10 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
// Parse ID token to get user info
|
||||
var userInfo *openai.UserInfo
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if parseErr != nil {
|
||||
slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
|
||||
} else {
|
||||
userInfo = claims.GetUserInfo()
|
||||
}
|
||||
}
|
||||
@@ -173,6 +194,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
|
||||
IDToken: tokenResp.IDToken,
|
||||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
ClientID: clientID,
|
||||
}
|
||||
|
||||
if userInfo != nil {
|
||||
@@ -200,8 +222,10 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
|
||||
// Parse ID token to get user info
|
||||
var userInfo *openai.UserInfo
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if parseErr != nil {
|
||||
slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
|
||||
} else {
|
||||
userInfo = claims.GetUserInfo()
|
||||
}
|
||||
}
|
||||
@@ -213,6 +237,9 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
|
||||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
}
|
||||
if trimmed := strings.TrimSpace(clientID); trimmed != "" {
|
||||
tokenInfo.ClientID = trimmed
|
||||
}
|
||||
|
||||
if userInfo != nil {
|
||||
tokenInfo.Email = userInfo.Email
|
||||
@@ -226,6 +253,7 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
|
||||
|
||||
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
|
||||
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
|
||||
sessionToken = normalizeSoraSessionTokenInput(sessionToken)
|
||||
if strings.TrimSpace(sessionToken) == "" {
|
||||
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
|
||||
}
|
||||
@@ -287,10 +315,141 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi
|
||||
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
|
||||
ExpiresIn: expiresIn,
|
||||
ExpiresAt: expiresAt,
|
||||
ClientID: openai.SoraClientID,
|
||||
Email: strings.TrimSpace(sessionResp.User.Email),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeSoraSessionTokenInput(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1)
|
||||
if len(matches) == 0 {
|
||||
return sanitizeSessionToken(trimmed)
|
||||
}
|
||||
|
||||
chunkMatches := make([]soraSessionChunk, 0, len(matches))
|
||||
singleValues := make([]string, 0, len(matches))
|
||||
|
||||
for _, match := range matches {
|
||||
if len(match) < 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
value := sanitizeSessionToken(match[2])
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.TrimSpace(match[1]) == "" {
|
||||
singleValues = append(singleValues, value)
|
||||
continue
|
||||
}
|
||||
|
||||
idx, err := strconv.Atoi(strings.TrimSpace(match[1]))
|
||||
if err != nil || idx < 0 {
|
||||
continue
|
||||
}
|
||||
chunkMatches = append(chunkMatches, soraSessionChunk{
|
||||
index: idx,
|
||||
value: value,
|
||||
})
|
||||
}
|
||||
|
||||
if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" {
|
||||
return merged
|
||||
}
|
||||
|
||||
if len(singleValues) > 0 {
|
||||
return singleValues[len(singleValues)-1]
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string {
|
||||
if len(chunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
byIndex := make(map[int]string, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
byIndex[chunk.index] = chunk.value
|
||||
}
|
||||
|
||||
if _, ok := byIndex[0]; !ok {
|
||||
return ""
|
||||
}
|
||||
if requireComplete {
|
||||
for idx := 0; idx <= requiredMaxIndex; idx++ {
|
||||
if _, ok := byIndex[idx]; !ok {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
orderedIndexes := make([]int, 0, len(byIndex))
|
||||
for idx := range byIndex {
|
||||
orderedIndexes = append(orderedIndexes, idx)
|
||||
}
|
||||
sort.Ints(orderedIndexes)
|
||||
|
||||
var builder strings.Builder
|
||||
for _, idx := range orderedIndexes {
|
||||
if _, err := builder.WriteString(byIndex[idx]); err != nil {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return sanitizeSessionToken(builder.String())
|
||||
}
|
||||
|
||||
func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string {
|
||||
if len(chunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
requiredMaxIndex := 0
|
||||
for _, chunk := range chunks {
|
||||
if chunk.index > requiredMaxIndex {
|
||||
requiredMaxIndex = chunk.index
|
||||
}
|
||||
}
|
||||
|
||||
groupStarts := make([]int, 0, len(chunks))
|
||||
for idx, chunk := range chunks {
|
||||
if chunk.index == 0 {
|
||||
groupStarts = append(groupStarts, idx)
|
||||
}
|
||||
}
|
||||
|
||||
if len(groupStarts) == 0 {
|
||||
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
|
||||
}
|
||||
|
||||
for i := len(groupStarts) - 1; i >= 0; i-- {
|
||||
start := groupStarts[i]
|
||||
end := len(chunks)
|
||||
if i+1 < len(groupStarts) {
|
||||
end = groupStarts[i+1]
|
||||
}
|
||||
if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" {
|
||||
return merged
|
||||
}
|
||||
}
|
||||
|
||||
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
|
||||
}
|
||||
|
||||
func sanitizeSessionToken(raw string) string {
|
||||
token := strings.TrimSpace(raw)
|
||||
token = strings.Trim(token, "\"'`")
|
||||
token = strings.TrimSuffix(token, ";")
|
||||
return strings.TrimSpace(token)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||
if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
|
||||
@@ -322,9 +481,12 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
|
||||
expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339)
|
||||
|
||||
creds := map[string]any{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"refresh_token": tokenInfo.RefreshToken,
|
||||
"expires_at": expiresAt,
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"expires_at": expiresAt,
|
||||
}
|
||||
// 仅在刷新响应返回了新的 refresh_token 时才更新,防止用空值覆盖已有令牌
|
||||
if strings.TrimSpace(tokenInfo.RefreshToken) != "" {
|
||||
creds["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
|
||||
if tokenInfo.IDToken != "" {
|
||||
@@ -342,6 +504,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
|
||||
if tokenInfo.OrganizationID != "" {
|
||||
creds["organization_id"] = tokenInfo.OrganizationID
|
||||
}
|
||||
if strings.TrimSpace(tokenInfo.ClientID) != "" {
|
||||
creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID)
|
||||
}
|
||||
|
||||
return creds
|
||||
}
|
||||
@@ -377,3 +542,12 @@ func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpenAIOAuthPlatform(platform string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(platform)) {
|
||||
case PlatformSora:
|
||||
return openai.OAuthPlatformSora
|
||||
default:
|
||||
return openai.OAuthPlatformOpenAI
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openaiOAuthClientAuthURLStub struct{}
|
||||
|
||||
func (s *openaiOAuthClientAuthURLStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientAuthURLStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientAuthURLStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) {
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformOpenAI)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, result.AuthURL)
|
||||
require.NotEmpty(t, result.SessionID)
|
||||
|
||||
parsed, err := url.Parse(result.AuthURL)
|
||||
require.NoError(t, err)
|
||||
q := parsed.Query()
|
||||
require.Equal(t, openai.ClientID, q.Get("client_id"))
|
||||
require.Equal(t, "true", q.Get("codex_cli_simplified_flow"))
|
||||
|
||||
session, ok := svc.sessionStore.Get(result.SessionID)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, openai.ClientID, session.ClientID)
|
||||
}
|
||||
|
||||
// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的
|
||||
// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。
|
||||
func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) {
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, result.AuthURL)
|
||||
require.NotEmpty(t, result.SessionID)
|
||||
|
||||
parsed, err := url.Parse(result.AuthURL)
|
||||
require.NoError(t, err)
|
||||
q := parsed.Query()
|
||||
require.Equal(t, openai.ClientID, q.Get("client_id"))
|
||||
require.Empty(t, q.Get("codex_cli_simplified_flow"))
|
||||
|
||||
session, ok := svc.sessionStore.Get(result.SessionID)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, openai.ClientID, session.ClientID)
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
@@ -13,7 +14,7 @@ import (
|
||||
|
||||
type openaiOAuthClientNoopStub struct{}
|
||||
|
||||
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -67,3 +68,106 @@ func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testi
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing access token")
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax"
|
||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "at-token", info.AccessToken)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
raw := strings.Join([]string{
|
||||
"Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly",
|
||||
"Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly",
|
||||
}, "\n")
|
||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "at-token", info.AccessToken)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
raw := strings.Join([]string{
|
||||
"Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly",
|
||||
"Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly",
|
||||
"Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly",
|
||||
"Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly",
|
||||
}, "\n")
|
||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "at-token", info.AccessToken)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
origin := openAISoraSessionAuthURL
|
||||
openAISoraSessionAuthURL = server.URL
|
||||
defer func() { openAISoraSessionAuthURL = origin }()
|
||||
|
||||
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
|
||||
defer svc.Stop()
|
||||
|
||||
raw := strings.Join([]string{
|
||||
"set-cookie",
|
||||
"__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/",
|
||||
"set-cookie",
|
||||
"__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/",
|
||||
"set-cookie",
|
||||
"__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/",
|
||||
}, "\n")
|
||||
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "at-token", info.AccessToken)
|
||||
}
|
||||
|
||||
@@ -13,10 +13,12 @@ import (
|
||||
|
||||
type openaiOAuthClientStateStub struct {
|
||||
exchangeCalled int32
|
||||
lastClientID string
|
||||
}
|
||||
|
||||
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||
atomic.AddInt32(&s.exchangeCalled, 1)
|
||||
s.lastClientID = clientID
|
||||
return &openai.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rt",
|
||||
@@ -95,6 +97,8 @@ func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, "at", info.AccessToken)
|
||||
require.Equal(t, openai.ClientID, info.ClientID)
|
||||
require.Equal(t, openai.ClientID, client.lastClientID)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled))
|
||||
|
||||
_, ok := svc.sessionStore.Get("sid")
|
||||
|
||||
37
backend/internal/service/openai_previous_response_id.go
Normal file
37
backend/internal/service/openai_previous_response_id.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
OpenAIPreviousResponseIDKindEmpty = "empty"
|
||||
OpenAIPreviousResponseIDKindResponseID = "response_id"
|
||||
OpenAIPreviousResponseIDKindMessageID = "message_id"
|
||||
OpenAIPreviousResponseIDKindUnknown = "unknown"
|
||||
)
|
||||
|
||||
var (
|
||||
openAIResponseIDPattern = regexp.MustCompile(`^resp_[A-Za-z0-9_-]{1,256}$`)
|
||||
openAIMessageIDPattern = regexp.MustCompile(`^(msg|message|item|chatcmpl)_[A-Za-z0-9_-]{1,256}$`)
|
||||
)
|
||||
|
||||
// ClassifyOpenAIPreviousResponseIDKind classifies previous_response_id to improve diagnostics.
|
||||
func ClassifyOpenAIPreviousResponseIDKind(id string) string {
|
||||
trimmed := strings.TrimSpace(id)
|
||||
if trimmed == "" {
|
||||
return OpenAIPreviousResponseIDKindEmpty
|
||||
}
|
||||
if openAIResponseIDPattern.MatchString(trimmed) {
|
||||
return OpenAIPreviousResponseIDKindResponseID
|
||||
}
|
||||
if openAIMessageIDPattern.MatchString(strings.ToLower(trimmed)) {
|
||||
return OpenAIPreviousResponseIDKindMessageID
|
||||
}
|
||||
return OpenAIPreviousResponseIDKindUnknown
|
||||
}
|
||||
|
||||
func IsOpenAIPreviousResponseIDLikelyMessageID(id string) bool {
|
||||
return ClassifyOpenAIPreviousResponseIDKind(id) == OpenAIPreviousResponseIDKindMessageID
|
||||
}
|
||||
34
backend/internal/service/openai_previous_response_id_test.go
Normal file
34
backend/internal/service/openai_previous_response_id_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestClassifyOpenAIPreviousResponseIDKind(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id string
|
||||
want string
|
||||
}{
|
||||
{name: "empty", id: " ", want: OpenAIPreviousResponseIDKindEmpty},
|
||||
{name: "response_id", id: "resp_0906a621bc423a8d0169a108637ef88197b74b0e2f37ba358f", want: OpenAIPreviousResponseIDKindResponseID},
|
||||
{name: "message_id", id: "msg_123456", want: OpenAIPreviousResponseIDKindMessageID},
|
||||
{name: "item_id", id: "item_abcdef", want: OpenAIPreviousResponseIDKindMessageID},
|
||||
{name: "unknown", id: "foo_123456", want: OpenAIPreviousResponseIDKindUnknown},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := ClassifyOpenAIPreviousResponseIDKind(tc.id); got != tc.want {
|
||||
t.Fatalf("ClassifyOpenAIPreviousResponseIDKind(%q)=%q want=%q", tc.id, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOpenAIPreviousResponseIDLikelyMessageID(t *testing.T) {
|
||||
if !IsOpenAIPreviousResponseIDLikelyMessageID("msg_123") {
|
||||
t.Fatal("expected msg_123 to be identified as message id")
|
||||
}
|
||||
if IsOpenAIPreviousResponseIDLikelyMessageID("resp_123") {
|
||||
t.Fatal("expected resp_123 not to be identified as message id")
|
||||
}
|
||||
}
|
||||
214
backend/internal/service/openai_sticky_compat.go
Normal file
214
backend/internal/service/openai_sticky_compat.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type openAILegacySessionHashContextKey struct{}
|
||||
|
||||
var openAILegacySessionHashKey = openAILegacySessionHashContextKey{}
|
||||
|
||||
var (
|
||||
openAIStickyLegacyReadFallbackTotal atomic.Int64
|
||||
openAIStickyLegacyReadFallbackHit atomic.Int64
|
||||
openAIStickyLegacyDualWriteTotal atomic.Int64
|
||||
)
|
||||
|
||||
func openAIStickyCompatStats() (legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal int64) {
|
||||
return openAIStickyLegacyReadFallbackTotal.Load(),
|
||||
openAIStickyLegacyReadFallbackHit.Load(),
|
||||
openAIStickyLegacyDualWriteTotal.Load()
|
||||
}
|
||||
|
||||
func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash string) {
|
||||
normalized := strings.TrimSpace(sessionID)
|
||||
if normalized == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
currentHash = fmt.Sprintf("%016x", xxhash.Sum64String(normalized))
|
||||
sum := sha256.Sum256([]byte(normalized))
|
||||
legacyHash = hex.EncodeToString(sum[:])
|
||||
return currentHash, legacyHash
|
||||
}
|
||||
|
||||
func withOpenAILegacySessionHash(ctx context.Context, legacyHash string) context.Context {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
trimmed := strings.TrimSpace(legacyHash)
|
||||
if trimmed == "" {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, openAILegacySessionHashKey, trimmed)
|
||||
}
|
||||
|
||||
func openAILegacySessionHashFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
value, _ := ctx.Value(openAILegacySessionHashKey).(string)
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
func attachOpenAILegacySessionHashToGin(c *gin.Context, legacyHash string) {
|
||||
if c == nil || c.Request == nil {
|
||||
return
|
||||
}
|
||||
c.Request = c.Request.WithContext(withOpenAILegacySessionHash(c.Request.Context(), legacyHash))
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAISessionHashReadOldFallbackEnabled() bool {
|
||||
if s == nil || s.cfg == nil {
|
||||
return true
|
||||
}
|
||||
return s.cfg.Gateway.OpenAIWS.SessionHashReadOldFallback
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAISessionHashDualWriteOldEnabled() bool {
|
||||
if s == nil || s.cfg == nil {
|
||||
return true
|
||||
}
|
||||
return s.cfg.Gateway.OpenAIWS.SessionHashDualWriteOld
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAISessionCacheKey(sessionHash string) string {
|
||||
normalized := strings.TrimSpace(sessionHash)
|
||||
if normalized == "" {
|
||||
return ""
|
||||
}
|
||||
return "openai:" + normalized
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAILegacySessionCacheKey(ctx context.Context, sessionHash string) string {
|
||||
legacyHash := openAILegacySessionHashFromContext(ctx)
|
||||
if legacyHash == "" {
|
||||
return ""
|
||||
}
|
||||
legacyKey := "openai:" + legacyHash
|
||||
if legacyKey == s.openAISessionCacheKey(sessionHash) {
|
||||
return ""
|
||||
}
|
||||
return legacyKey
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) openAIStickyLegacyTTL(ttl time.Duration) time.Duration {
|
||||
legacyTTL := ttl
|
||||
if legacyTTL <= 0 {
|
||||
legacyTTL = openaiStickySessionTTL
|
||||
}
|
||||
if legacyTTL > 10*time.Minute {
|
||||
return 10 * time.Minute
|
||||
}
|
||||
return legacyTTL
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) {
|
||||
if s == nil || s.cache == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
primaryKey := s.openAISessionCacheKey(sessionHash)
|
||||
if primaryKey == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), primaryKey)
|
||||
if err == nil && accountID > 0 {
|
||||
return accountID, nil
|
||||
}
|
||||
if !s.openAISessionHashReadOldFallbackEnabled() {
|
||||
return accountID, err
|
||||
}
|
||||
|
||||
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
|
||||
if legacyKey == "" {
|
||||
return accountID, err
|
||||
}
|
||||
|
||||
openAIStickyLegacyReadFallbackTotal.Add(1)
|
||||
legacyAccountID, legacyErr := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), legacyKey)
|
||||
if legacyErr == nil && legacyAccountID > 0 {
|
||||
openAIStickyLegacyReadFallbackHit.Add(1)
|
||||
return legacyAccountID, nil
|
||||
}
|
||||
return accountID, err
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) setStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
if s == nil || s.cache == nil || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
primaryKey := s.openAISessionCacheKey(sessionHash)
|
||||
if primaryKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), primaryKey, accountID, ttl); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !s.openAISessionHashDualWriteOldEnabled() {
|
||||
return nil
|
||||
}
|
||||
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
|
||||
if legacyKey == "" {
|
||||
return nil
|
||||
}
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), legacyKey, accountID, s.openAIStickyLegacyTTL(ttl)); err != nil {
|
||||
return err
|
||||
}
|
||||
openAIStickyLegacyDualWriteTotal.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) refreshStickySessionTTL(ctx context.Context, groupID *int64, sessionHash string, ttl time.Duration) error {
|
||||
if s == nil || s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
primaryKey := s.openAISessionCacheKey(sessionHash)
|
||||
if primaryKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), primaryKey, ttl)
|
||||
if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() {
|
||||
return err
|
||||
}
|
||||
|
||||
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
|
||||
if legacyKey != "" {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), legacyKey, s.openAIStickyLegacyTTL(ttl))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) deleteStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) error {
|
||||
if s == nil || s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
primaryKey := s.openAISessionCacheKey(sessionHash)
|
||||
if primaryKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), primaryKey)
|
||||
if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() {
|
||||
return err
|
||||
}
|
||||
|
||||
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
|
||||
if legacyKey != "" {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), legacyKey)
|
||||
}
|
||||
return err
|
||||
}
|
||||
96
backend/internal/service/openai_sticky_compat_test.go
Normal file
96
backend/internal/service/openai_sticky_compat_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetStickySessionAccountID_FallbackToLegacyKey(t *testing.T) {
|
||||
beforeFallbackTotal, beforeFallbackHit, _ := openAIStickyCompatStats()
|
||||
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:legacy-hash": 42,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{
|
||||
cache: cache,
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
OpenAIWS: config.GatewayOpenAIWSConfig{
|
||||
SessionHashReadOldFallback: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash")
|
||||
accountID, err := svc.getStickySessionAccountID(ctx, nil, "new-hash")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(42), accountID)
|
||||
|
||||
afterFallbackTotal, afterFallbackHit, _ := openAIStickyCompatStats()
|
||||
require.Equal(t, beforeFallbackTotal+1, afterFallbackTotal)
|
||||
require.Equal(t, beforeFallbackHit+1, afterFallbackHit)
|
||||
}
|
||||
|
||||
func TestSetStickySessionAccountID_DualWriteOldEnabled(t *testing.T) {
|
||||
_, _, beforeDualWriteTotal := openAIStickyCompatStats()
|
||||
|
||||
cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
|
||||
svc := &OpenAIGatewayService{
|
||||
cache: cache,
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
OpenAIWS: config.GatewayOpenAIWSConfig{
|
||||
SessionHashDualWriteOld: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash")
|
||||
err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"])
|
||||
require.Equal(t, int64(9), cache.sessionBindings["openai:legacy-hash"])
|
||||
|
||||
_, _, afterDualWriteTotal := openAIStickyCompatStats()
|
||||
require.Equal(t, beforeDualWriteTotal+1, afterDualWriteTotal)
|
||||
}
|
||||
|
||||
func TestSetStickySessionAccountID_DualWriteOldDisabled(t *testing.T) {
|
||||
cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
|
||||
svc := &OpenAIGatewayService{
|
||||
cache: cache,
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
OpenAIWS: config.GatewayOpenAIWSConfig{
|
||||
SessionHashDualWriteOld: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash")
|
||||
err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"])
|
||||
_, exists := cache.sessionBindings["openai:legacy-hash"]
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func TestSnapshotOpenAICompatibilityFallbackMetrics(t *testing.T) {
|
||||
before := SnapshotOpenAICompatibilityFallbackMetrics()
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true)
|
||||
_, _ = ThinkingEnabledFromContext(ctx)
|
||||
|
||||
after := SnapshotOpenAICompatibilityFallbackMetrics()
|
||||
require.GreaterOrEqual(t, after.MetadataLegacyFallbackTotal, before.MetadataLegacyFallbackTotal+1)
|
||||
require.GreaterOrEqual(t, after.MetadataLegacyFallbackThinkingEnabledTotal, before.MetadataLegacyFallbackThinkingEnabledTotal+1)
|
||||
}
|
||||
@@ -2,6 +2,24 @@ package service
|
||||
|
||||
import "strings"
|
||||
|
||||
// ToolContinuationSignals 聚合工具续链相关信号,避免重复遍历 input。
|
||||
type ToolContinuationSignals struct {
|
||||
HasFunctionCallOutput bool
|
||||
HasFunctionCallOutputMissingCallID bool
|
||||
HasToolCallContext bool
|
||||
HasItemReference bool
|
||||
HasItemReferenceForAllCallIDs bool
|
||||
FunctionCallOutputCallIDs []string
|
||||
}
|
||||
|
||||
// FunctionCallOutputValidation 汇总 function_call_output 关联性校验结果。
|
||||
type FunctionCallOutputValidation struct {
|
||||
HasFunctionCallOutput bool
|
||||
HasToolCallContext bool
|
||||
HasFunctionCallOutputMissingCallID bool
|
||||
HasItemReferenceForAllCallIDs bool
|
||||
}
|
||||
|
||||
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
|
||||
// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、
|
||||
// 或显式声明 tools/tool_choice。
|
||||
@@ -18,107 +36,191 @@ func NeedsToolContinuation(reqBody map[string]any) bool {
|
||||
if hasToolChoiceSignal(reqBody) {
|
||||
return true
|
||||
}
|
||||
if inputHasType(reqBody, "function_call_output") {
|
||||
return true
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if inputHasType(reqBody, "item_reference") {
|
||||
return true
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType == "function_call_output" || itemType == "item_reference" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AnalyzeToolContinuationSignals 单次遍历 input,提取 function_call_output/tool_call/item_reference 相关信号。
|
||||
func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSignals {
|
||||
signals := ToolContinuationSignals{}
|
||||
if reqBody == nil {
|
||||
return signals
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return signals
|
||||
}
|
||||
|
||||
var callIDs map[string]struct{}
|
||||
var referenceIDs map[string]struct{}
|
||||
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
switch itemType {
|
||||
case "tool_call", "function_call":
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
if strings.TrimSpace(callID) != "" {
|
||||
signals.HasToolCallContext = true
|
||||
}
|
||||
case "function_call_output":
|
||||
signals.HasFunctionCallOutput = true
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
callID = strings.TrimSpace(callID)
|
||||
if callID == "" {
|
||||
signals.HasFunctionCallOutputMissingCallID = true
|
||||
continue
|
||||
}
|
||||
if callIDs == nil {
|
||||
callIDs = make(map[string]struct{})
|
||||
}
|
||||
callIDs[callID] = struct{}{}
|
||||
case "item_reference":
|
||||
signals.HasItemReference = true
|
||||
idValue, _ := itemMap["id"].(string)
|
||||
idValue = strings.TrimSpace(idValue)
|
||||
if idValue == "" {
|
||||
continue
|
||||
}
|
||||
if referenceIDs == nil {
|
||||
referenceIDs = make(map[string]struct{})
|
||||
}
|
||||
referenceIDs[idValue] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if len(callIDs) == 0 {
|
||||
return signals
|
||||
}
|
||||
signals.FunctionCallOutputCallIDs = make([]string, 0, len(callIDs))
|
||||
allReferenced := len(referenceIDs) > 0
|
||||
for callID := range callIDs {
|
||||
signals.FunctionCallOutputCallIDs = append(signals.FunctionCallOutputCallIDs, callID)
|
||||
if allReferenced {
|
||||
if _, ok := referenceIDs[callID]; !ok {
|
||||
allReferenced = false
|
||||
}
|
||||
}
|
||||
}
|
||||
signals.HasItemReferenceForAllCallIDs = allReferenced
|
||||
return signals
|
||||
}
|
||||
|
||||
// ValidateFunctionCallOutputContext 为 handler 提供低开销校验结果:
|
||||
// 1) 无 function_call_output 直接返回
|
||||
// 2) 若已存在 tool_call/function_call 上下文则提前返回
|
||||
// 3) 仅在无工具上下文时才构建 call_id / item_reference 集合
|
||||
func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutputValidation {
|
||||
result := FunctionCallOutputValidation{}
|
||||
if reqBody == nil {
|
||||
return result
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return result
|
||||
}
|
||||
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
switch itemType {
|
||||
case "function_call_output":
|
||||
result.HasFunctionCallOutput = true
|
||||
case "tool_call", "function_call":
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
if strings.TrimSpace(callID) != "" {
|
||||
result.HasToolCallContext = true
|
||||
}
|
||||
}
|
||||
if result.HasFunctionCallOutput && result.HasToolCallContext {
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
if !result.HasFunctionCallOutput || result.HasToolCallContext {
|
||||
return result
|
||||
}
|
||||
|
||||
callIDs := make(map[string]struct{})
|
||||
referenceIDs := make(map[string]struct{})
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
switch itemType {
|
||||
case "function_call_output":
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
callID = strings.TrimSpace(callID)
|
||||
if callID == "" {
|
||||
result.HasFunctionCallOutputMissingCallID = true
|
||||
continue
|
||||
}
|
||||
callIDs[callID] = struct{}{}
|
||||
case "item_reference":
|
||||
idValue, _ := itemMap["id"].(string)
|
||||
idValue = strings.TrimSpace(idValue)
|
||||
if idValue == "" {
|
||||
continue
|
||||
}
|
||||
referenceIDs[idValue] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if len(callIDs) == 0 || len(referenceIDs) == 0 {
|
||||
return result
|
||||
}
|
||||
allReferenced := true
|
||||
for callID := range callIDs {
|
||||
if _, ok := referenceIDs[callID]; !ok {
|
||||
allReferenced = false
|
||||
break
|
||||
}
|
||||
}
|
||||
result.HasItemReferenceForAllCallIDs = allReferenced
|
||||
return result
|
||||
}
|
||||
|
||||
// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。
|
||||
func HasFunctionCallOutput(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
return inputHasType(reqBody, "function_call_output")
|
||||
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutput
|
||||
}
|
||||
|
||||
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call,
|
||||
// 用于判断 function_call_output 是否具备可关联的上下文。
|
||||
func HasToolCallContext(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "tool_call" && itemType != "function_call" {
|
||||
continue
|
||||
}
|
||||
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return AnalyzeToolContinuationSignals(reqBody).HasToolCallContext
|
||||
}
|
||||
|
||||
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
|
||||
// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。
|
||||
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
|
||||
if reqBody == nil {
|
||||
return nil
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
ids := make(map[string]struct{})
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "function_call_output" {
|
||||
continue
|
||||
}
|
||||
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||
ids[callID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(ids))
|
||||
for id := range ids {
|
||||
result = append(result, id)
|
||||
}
|
||||
return result
|
||||
return AnalyzeToolContinuationSignals(reqBody).FunctionCallOutputCallIDs
|
||||
}
|
||||
|
||||
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
|
||||
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "function_call_output" {
|
||||
continue
|
||||
}
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
if strings.TrimSpace(callID) == "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutputMissingCallID
|
||||
}
|
||||
|
||||
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
|
||||
@@ -152,32 +254,13 @@ func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
|
||||
return false
|
||||
}
|
||||
for _, callID := range callIDs {
|
||||
if _, ok := referenceIDs[callID]; !ok {
|
||||
if _, ok := referenceIDs[strings.TrimSpace(callID)]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// inputHasType 判断 input 中是否存在指定类型的 item。
|
||||
func inputHasType(reqBody map[string]any, want string) bool {
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasNonEmptyString 判断字段是否为非空字符串。
|
||||
func hasNonEmptyString(value any) bool {
|
||||
stringValue, ok := value.(string)
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
|
||||
@@ -62,169 +66,201 @@ func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, boo
|
||||
if data == "" || data == "\n" {
|
||||
return data, false
|
||||
}
|
||||
correctedBytes, corrected := c.CorrectToolCallsInSSEBytes([]byte(data))
|
||||
if !corrected {
|
||||
return data, false
|
||||
}
|
||||
return string(correctedBytes), true
|
||||
}
|
||||
|
||||
// 尝试解析 JSON
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||
// 不是有效的 JSON,直接返回原数据
|
||||
// CorrectToolCallsInSSEBytes 修正 SSE JSON 数据中的工具调用(字节路径)。
|
||||
// 返回修正后的数据和是否进行了修正。
|
||||
func (c *CodexToolCorrector) CorrectToolCallsInSSEBytes(data []byte) ([]byte, bool) {
|
||||
if len(bytes.TrimSpace(data)) == 0 {
|
||||
return data, false
|
||||
}
|
||||
if !mayContainToolCallPayload(data) {
|
||||
return data, false
|
||||
}
|
||||
if !gjson.ValidBytes(data) {
|
||||
// 不是有效 JSON,直接返回原数据
|
||||
return data, false
|
||||
}
|
||||
|
||||
updated := data
|
||||
corrected := false
|
||||
|
||||
// 处理 tool_calls 数组
|
||||
if toolCalls, ok := payload["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
collect := func(changed bool, next []byte) {
|
||||
if changed {
|
||||
corrected = true
|
||||
updated = next
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 function_call 对象
|
||||
if functionCall, ok := payload["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
if next, changed := c.correctToolCallsArrayAtPath(updated, "tool_calls"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
if next, changed := c.correctFunctionAtPath(updated, "function_call"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
if next, changed := c.correctToolCallsArrayAtPath(updated, "delta.tool_calls"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
if next, changed := c.correctFunctionAtPath(updated, "delta.function_call"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
|
||||
// 处理 delta.tool_calls
|
||||
if delta, ok := payload["delta"].(map[string]any); ok {
|
||||
if toolCalls, ok := delta["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
choicesCount := int(gjson.GetBytes(updated, "choices.#").Int())
|
||||
for i := 0; i < choicesCount; i++ {
|
||||
prefix := "choices." + strconv.Itoa(i)
|
||||
if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".message.tool_calls"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
if functionCall, ok := delta["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
if next, changed := c.correctFunctionAtPath(updated, prefix+".message.function_call"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
|
||||
if choices, ok := payload["choices"].([]any); ok {
|
||||
for _, choice := range choices {
|
||||
if choiceMap, ok := choice.(map[string]any); ok {
|
||||
// 处理 message 中的工具调用
|
||||
if message, ok := choiceMap["message"].(map[string]any); ok {
|
||||
if toolCalls, ok := message["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
if functionCall, ok := message["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// 处理 delta 中的工具调用
|
||||
if delta, ok := choiceMap["delta"].(map[string]any); ok {
|
||||
if toolCalls, ok := delta["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
if functionCall, ok := delta["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".delta.tool_calls"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
if next, changed := c.correctFunctionAtPath(updated, prefix+".delta.function_call"); changed {
|
||||
collect(changed, next)
|
||||
}
|
||||
}
|
||||
|
||||
if !corrected {
|
||||
return data, false
|
||||
}
|
||||
return updated, true
|
||||
}
|
||||
|
||||
// 序列化回 JSON
|
||||
correctedBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Failed to marshal corrected data: %v", err)
|
||||
func mayContainToolCallPayload(data []byte) bool {
|
||||
// 快速路径:多数 token / 文本事件不包含工具字段,避免进入 JSON 解析热路径。
|
||||
return bytes.Contains(data, []byte(`"tool_calls"`)) ||
|
||||
bytes.Contains(data, []byte(`"function_call"`)) ||
|
||||
bytes.Contains(data, []byte(`"function":{"name"`))
|
||||
}
|
||||
|
||||
// correctToolCallsArrayAtPath 修正指定路径下 tool_calls 数组中的工具名称。
|
||||
func (c *CodexToolCorrector) correctToolCallsArrayAtPath(data []byte, toolCallsPath string) ([]byte, bool) {
|
||||
count := int(gjson.GetBytes(data, toolCallsPath+".#").Int())
|
||||
if count <= 0 {
|
||||
return data, false
|
||||
}
|
||||
|
||||
return string(correctedBytes), true
|
||||
}
|
||||
|
||||
// correctToolCallsArray 修正工具调用数组中的工具名称
|
||||
func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool {
|
||||
updated := data
|
||||
corrected := false
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCallMap, ok := toolCall.(map[string]any); ok {
|
||||
if function, ok := toolCallMap["function"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(function) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
for i := 0; i < count; i++ {
|
||||
functionPath := toolCallsPath + "." + strconv.Itoa(i) + ".function"
|
||||
if next, changed := c.correctFunctionAtPath(updated, functionPath); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
return corrected
|
||||
return updated, corrected
|
||||
}
|
||||
|
||||
// correctFunctionCall 修正单个函数调用的工具名称和参数
|
||||
func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool {
|
||||
name, ok := functionCall["name"].(string)
|
||||
if !ok || name == "" {
|
||||
return false
|
||||
// correctFunctionAtPath 修正指定路径下单个函数调用的工具名称和参数。
|
||||
func (c *CodexToolCorrector) correctFunctionAtPath(data []byte, functionPath string) ([]byte, bool) {
|
||||
namePath := functionPath + ".name"
|
||||
nameResult := gjson.GetBytes(data, namePath)
|
||||
if !nameResult.Exists() || nameResult.Type != gjson.String {
|
||||
return data, false
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(nameResult.Str)
|
||||
if name == "" {
|
||||
return data, false
|
||||
}
|
||||
updated := data
|
||||
corrected := false
|
||||
|
||||
// 查找并修正工具名称
|
||||
if correctName, found := codexToolNameMapping[name]; found {
|
||||
functionCall["name"] = correctName
|
||||
c.recordCorrection(name, correctName)
|
||||
corrected = true
|
||||
name = correctName // 使用修正后的名称进行参数修正
|
||||
if next, err := sjson.SetBytes(updated, namePath, correctName); err == nil {
|
||||
updated = next
|
||||
c.recordCorrection(name, correctName)
|
||||
corrected = true
|
||||
name = correctName // 使用修正后的名称进行参数修正
|
||||
}
|
||||
}
|
||||
|
||||
// 修正工具参数(基于工具名称)
|
||||
if c.correctToolParameters(name, functionCall) {
|
||||
if next, changed := c.correctToolParametersAtPath(updated, functionPath+".arguments", name); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
}
|
||||
|
||||
return corrected
|
||||
return updated, corrected
|
||||
}
|
||||
|
||||
// correctToolParameters 修正工具参数以符合 OpenCode 规范
|
||||
func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool {
|
||||
arguments, ok := functionCall["arguments"]
|
||||
if !ok {
|
||||
return false
|
||||
// correctToolParametersAtPath 修正指定路径下 arguments 参数。
|
||||
func (c *CodexToolCorrector) correctToolParametersAtPath(data []byte, argumentsPath, toolName string) ([]byte, bool) {
|
||||
if toolName != "bash" && toolName != "edit" {
|
||||
return data, false
|
||||
}
|
||||
|
||||
// arguments 可能是字符串(JSON)或已解析的 map
|
||||
var argsMap map[string]any
|
||||
switch v := arguments.(type) {
|
||||
case string:
|
||||
// 解析 JSON 字符串
|
||||
if err := json.Unmarshal([]byte(v), &argsMap); err != nil {
|
||||
return false
|
||||
args := gjson.GetBytes(data, argumentsPath)
|
||||
if !args.Exists() {
|
||||
return data, false
|
||||
}
|
||||
|
||||
switch args.Type {
|
||||
case gjson.String:
|
||||
argsJSON := strings.TrimSpace(args.Str)
|
||||
if !gjson.Valid(argsJSON) {
|
||||
return data, false
|
||||
}
|
||||
case map[string]any:
|
||||
argsMap = v
|
||||
if !gjson.Parse(argsJSON).IsObject() {
|
||||
return data, false
|
||||
}
|
||||
nextArgsJSON, corrected := c.correctToolArgumentsJSON(argsJSON, toolName)
|
||||
if !corrected {
|
||||
return data, false
|
||||
}
|
||||
next, err := sjson.SetBytes(data, argumentsPath, nextArgsJSON)
|
||||
if err != nil {
|
||||
return data, false
|
||||
}
|
||||
return next, true
|
||||
case gjson.JSON:
|
||||
if !args.IsObject() || !gjson.Valid(args.Raw) {
|
||||
return data, false
|
||||
}
|
||||
nextArgsJSON, corrected := c.correctToolArgumentsJSON(args.Raw, toolName)
|
||||
if !corrected {
|
||||
return data, false
|
||||
}
|
||||
next, err := sjson.SetRawBytes(data, argumentsPath, []byte(nextArgsJSON))
|
||||
if err != nil {
|
||||
return data, false
|
||||
}
|
||||
return next, true
|
||||
default:
|
||||
return false
|
||||
return data, false
|
||||
}
|
||||
}
|
||||
|
||||
// correctToolArgumentsJSON 修正工具参数 JSON(对象字符串),返回修正后的 JSON 与是否变更。
|
||||
func (c *CodexToolCorrector) correctToolArgumentsJSON(argsJSON, toolName string) (string, bool) {
|
||||
if !gjson.Valid(argsJSON) {
|
||||
return argsJSON, false
|
||||
}
|
||||
if !gjson.Parse(argsJSON).IsObject() {
|
||||
return argsJSON, false
|
||||
}
|
||||
|
||||
updated := argsJSON
|
||||
corrected := false
|
||||
|
||||
// 根据工具名称应用特定的参数修正规则
|
||||
switch toolName {
|
||||
case "bash":
|
||||
// OpenCode bash 支持 workdir;有些来源会输出 work_dir。
|
||||
if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir {
|
||||
if workDir, exists := argsMap["work_dir"]; exists {
|
||||
argsMap["workdir"] = workDir
|
||||
delete(argsMap, "work_dir")
|
||||
if !gjson.Get(updated, "workdir").Exists() {
|
||||
if next, changed := moveJSONField(updated, "work_dir", "workdir"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool")
|
||||
}
|
||||
} else {
|
||||
if _, exists := argsMap["work_dir"]; exists {
|
||||
delete(argsMap, "work_dir")
|
||||
if next, changed := deleteJSONField(updated, "work_dir"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool")
|
||||
}
|
||||
@@ -232,67 +268,71 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
|
||||
|
||||
case "edit":
|
||||
// OpenCode edit 参数为 filePath/oldString/newString(camelCase)。
|
||||
if _, exists := argsMap["filePath"]; !exists {
|
||||
if filePath, exists := argsMap["file_path"]; exists {
|
||||
argsMap["filePath"] = filePath
|
||||
delete(argsMap, "file_path")
|
||||
if !gjson.Get(updated, "filePath").Exists() {
|
||||
if next, changed := moveJSONField(updated, "file_path", "filePath"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool")
|
||||
} else if filePath, exists := argsMap["path"]; exists {
|
||||
argsMap["filePath"] = filePath
|
||||
delete(argsMap, "path")
|
||||
} else if next, changed := moveJSONField(updated, "path", "filePath"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool")
|
||||
} else if filePath, exists := argsMap["file"]; exists {
|
||||
argsMap["filePath"] = filePath
|
||||
delete(argsMap, "file")
|
||||
} else if next, changed := moveJSONField(updated, "file", "filePath"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool")
|
||||
}
|
||||
}
|
||||
|
||||
if _, exists := argsMap["oldString"]; !exists {
|
||||
if oldString, exists := argsMap["old_string"]; exists {
|
||||
argsMap["oldString"] = oldString
|
||||
delete(argsMap, "old_string")
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
|
||||
}
|
||||
if next, changed := moveJSONField(updated, "old_string", "oldString"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
|
||||
}
|
||||
|
||||
if _, exists := argsMap["newString"]; !exists {
|
||||
if newString, exists := argsMap["new_string"]; exists {
|
||||
argsMap["newString"] = newString
|
||||
delete(argsMap, "new_string")
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
|
||||
}
|
||||
if next, changed := moveJSONField(updated, "new_string", "newString"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
|
||||
}
|
||||
|
||||
if _, exists := argsMap["replaceAll"]; !exists {
|
||||
if replaceAll, exists := argsMap["replace_all"]; exists {
|
||||
argsMap["replaceAll"] = replaceAll
|
||||
delete(argsMap, "replace_all")
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
|
||||
}
|
||||
if next, changed := moveJSONField(updated, "replace_all", "replaceAll"); changed {
|
||||
updated = next
|
||||
corrected = true
|
||||
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
|
||||
}
|
||||
}
|
||||
return updated, corrected
|
||||
}
|
||||
|
||||
// 如果修正了参数,需要重新序列化
|
||||
if corrected {
|
||||
if _, wasString := arguments.(string); wasString {
|
||||
// 原本是字符串,序列化回字符串
|
||||
if newArgsJSON, err := json.Marshal(argsMap); err == nil {
|
||||
functionCall["arguments"] = string(newArgsJSON)
|
||||
}
|
||||
} else {
|
||||
// 原本是 map,直接赋值
|
||||
functionCall["arguments"] = argsMap
|
||||
}
|
||||
func moveJSONField(input, from, to string) (string, bool) {
|
||||
if gjson.Get(input, to).Exists() {
|
||||
return input, false
|
||||
}
|
||||
src := gjson.Get(input, from)
|
||||
if !src.Exists() {
|
||||
return input, false
|
||||
}
|
||||
next, err := sjson.SetRaw(input, to, src.Raw)
|
||||
if err != nil {
|
||||
return input, false
|
||||
}
|
||||
next, err = sjson.Delete(next, from)
|
||||
if err != nil {
|
||||
return input, false
|
||||
}
|
||||
return next, true
|
||||
}
|
||||
|
||||
return corrected
|
||||
func deleteJSONField(input, path string) (string, bool) {
|
||||
if !gjson.Get(input, path).Exists() {
|
||||
return input, false
|
||||
}
|
||||
next, err := sjson.Delete(input, path)
|
||||
if err != nil {
|
||||
return input, false
|
||||
}
|
||||
return next, true
|
||||
}
|
||||
|
||||
// recordCorrection 记录一次工具名称修正
|
||||
|
||||
@@ -5,6 +5,15 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMayContainToolCallPayload(t *testing.T) {
|
||||
if mayContainToolCallPayload([]byte(`{"type":"response.output_text.delta","delta":"hello"}`)) {
|
||||
t.Fatalf("plain text event should not trigger tool-call parsing")
|
||||
}
|
||||
if !mayContainToolCallPayload([]byte(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)) {
|
||||
t.Fatalf("tool_calls event should trigger tool-call parsing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorrectToolCallsInSSEData(t *testing.T) {
|
||||
corrector := NewCodexToolCorrector()
|
||||
|
||||
|
||||
190
backend/internal/service/openai_ws_account_sticky_test.go
Normal file
190
backend/internal/service/openai_ws_account_sticky_test.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(23)
|
||||
account := Account{
|
||||
ID: 2,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 2,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
openaiWSStateStore: store,
|
||||
}
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_1", account.ID, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, account.ID, selection.Account.ID)
|
||||
require.True(t, selection.Acquired)
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(23)
|
||||
account := Account{
|
||||
ID: 8,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
openaiWSStateStore: store,
|
||||
}
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_2", account.ID, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, selection)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_ForceHTTPIgnored(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(23)
|
||||
account := Account{
|
||||
ID: 11,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_ws_force_http": true,
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
openaiWSStateStore: store,
|
||||
}
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_force_http", account.ID, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, selection, "force_http 场景应忽略 previous_response_id 粘连")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(23)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 21,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 22,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 9,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cache := &stubGatewayCache{}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2
|
||||
cfg.Gateway.Scheduling.StickySessionWaitTimeout = 30 * time.Second
|
||||
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
acquireResults: map[int64]bool{
|
||||
21: false, // previous_response 命中的账号繁忙
|
||||
22: true, // 次优账号可用(若回退会命中)
|
||||
},
|
||||
waitCounts: map[int64]int{
|
||||
21: 999,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
openaiWSStateStore: store,
|
||||
}
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_busy", 21, time.Hour))
|
||||
|
||||
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(21), selection.Account.ID, "busy previous_response sticky account should remain selected")
|
||||
require.False(t, selection.Acquired)
|
||||
require.NotNil(t, selection.WaitPlan)
|
||||
require.Equal(t, int64(21), selection.WaitPlan.AccountID)
|
||||
}
|
||||
|
||||
func newOpenAIWSV2TestConfig() *config.Config {
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
|
||||
return cfg
|
||||
}
|
||||
285
backend/internal/service/openai_ws_client.go
Normal file
285
backend/internal/service/openai_ws_client.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
)
|
||||
|
||||
const openAIWSMessageReadLimitBytes int64 = 16 * 1024 * 1024
|
||||
const (
|
||||
openAIWSProxyTransportMaxIdleConns = 128
|
||||
openAIWSProxyTransportMaxIdleConnsPerHost = 64
|
||||
openAIWSProxyTransportIdleConnTimeout = 90 * time.Second
|
||||
openAIWSProxyClientCacheMaxEntries = 256
|
||||
openAIWSProxyClientCacheIdleTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
type OpenAIWSTransportMetricsSnapshot struct {
|
||||
ProxyClientCacheHits int64 `json:"proxy_client_cache_hits"`
|
||||
ProxyClientCacheMisses int64 `json:"proxy_client_cache_misses"`
|
||||
TransportReuseRatio float64 `json:"transport_reuse_ratio"`
|
||||
}
|
||||
|
||||
// openAIWSClientConn 抽象 WS 客户端连接,便于替换底层实现。
|
||||
type openAIWSClientConn interface {
|
||||
WriteJSON(ctx context.Context, value any) error
|
||||
ReadMessage(ctx context.Context) ([]byte, error)
|
||||
Ping(ctx context.Context) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// openAIWSClientDialer 抽象 WS 建连器。
|
||||
type openAIWSClientDialer interface {
|
||||
Dial(ctx context.Context, wsURL string, headers http.Header, proxyURL string) (openAIWSClientConn, int, http.Header, error)
|
||||
}
|
||||
|
||||
type openAIWSTransportMetricsDialer interface {
|
||||
SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot
|
||||
}
|
||||
|
||||
func newDefaultOpenAIWSClientDialer() openAIWSClientDialer {
|
||||
return &coderOpenAIWSClientDialer{
|
||||
proxyClients: make(map[string]*openAIWSProxyClientEntry),
|
||||
}
|
||||
}
|
||||
|
||||
type coderOpenAIWSClientDialer struct {
|
||||
proxyMu sync.Mutex
|
||||
proxyClients map[string]*openAIWSProxyClientEntry
|
||||
proxyHits atomic.Int64
|
||||
proxyMisses atomic.Int64
|
||||
}
|
||||
|
||||
type openAIWSProxyClientEntry struct {
|
||||
client *http.Client
|
||||
lastUsedUnixNano int64
|
||||
}
|
||||
|
||||
func (d *coderOpenAIWSClientDialer) Dial(
|
||||
ctx context.Context,
|
||||
wsURL string,
|
||||
headers http.Header,
|
||||
proxyURL string,
|
||||
) (openAIWSClientConn, int, http.Header, error) {
|
||||
targetURL := strings.TrimSpace(wsURL)
|
||||
if targetURL == "" {
|
||||
return nil, 0, nil, errors.New("ws url is empty")
|
||||
}
|
||||
|
||||
opts := &coderws.DialOptions{
|
||||
HTTPHeader: cloneHeader(headers),
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
}
|
||||
if proxy := strings.TrimSpace(proxyURL); proxy != "" {
|
||||
proxyClient, err := d.proxyHTTPClient(proxy)
|
||||
if err != nil {
|
||||
return nil, 0, nil, err
|
||||
}
|
||||
opts.HTTPClient = proxyClient
|
||||
}
|
||||
|
||||
conn, resp, err := coderws.Dial(ctx, targetURL, opts)
|
||||
if err != nil {
|
||||
status := 0
|
||||
respHeaders := http.Header(nil)
|
||||
if resp != nil {
|
||||
status = resp.StatusCode
|
||||
respHeaders = cloneHeader(resp.Header)
|
||||
}
|
||||
return nil, status, respHeaders, err
|
||||
}
|
||||
// coder/websocket 默认单消息读取上限为 32KB,Codex WS 事件(如 rate_limits/大 delta)
|
||||
// 可能超过该阈值,需显式提高上限,避免本地 read_fail(message too big)。
|
||||
conn.SetReadLimit(openAIWSMessageReadLimitBytes)
|
||||
respHeaders := http.Header(nil)
|
||||
if resp != nil {
|
||||
respHeaders = cloneHeader(resp.Header)
|
||||
}
|
||||
return &coderOpenAIWSClientConn{conn: conn}, 0, respHeaders, nil
|
||||
}
|
||||
|
||||
func (d *coderOpenAIWSClientDialer) proxyHTTPClient(proxy string) (*http.Client, error) {
|
||||
if d == nil {
|
||||
return nil, errors.New("openai ws dialer is nil")
|
||||
}
|
||||
normalizedProxy := strings.TrimSpace(proxy)
|
||||
if normalizedProxy == "" {
|
||||
return nil, errors.New("proxy url is empty")
|
||||
}
|
||||
parsedProxyURL, err := url.Parse(normalizedProxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy url: %w", err)
|
||||
}
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
d.proxyMu.Lock()
|
||||
defer d.proxyMu.Unlock()
|
||||
if entry, ok := d.proxyClients[normalizedProxy]; ok && entry != nil && entry.client != nil {
|
||||
entry.lastUsedUnixNano = now
|
||||
d.proxyHits.Add(1)
|
||||
return entry.client, nil
|
||||
}
|
||||
d.cleanupProxyClientsLocked(now)
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyURL(parsedProxyURL),
|
||||
MaxIdleConns: openAIWSProxyTransportMaxIdleConns,
|
||||
MaxIdleConnsPerHost: openAIWSProxyTransportMaxIdleConnsPerHost,
|
||||
IdleConnTimeout: openAIWSProxyTransportIdleConnTimeout,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
d.proxyClients[normalizedProxy] = &openAIWSProxyClientEntry{
|
||||
client: client,
|
||||
lastUsedUnixNano: now,
|
||||
}
|
||||
d.ensureProxyClientCapacityLocked()
|
||||
d.proxyMisses.Add(1)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (d *coderOpenAIWSClientDialer) cleanupProxyClientsLocked(nowUnixNano int64) {
|
||||
if d == nil || len(d.proxyClients) == 0 {
|
||||
return
|
||||
}
|
||||
idleTTL := openAIWSProxyClientCacheIdleTTL
|
||||
if idleTTL <= 0 {
|
||||
return
|
||||
}
|
||||
now := time.Unix(0, nowUnixNano)
|
||||
for key, entry := range d.proxyClients {
|
||||
if entry == nil || entry.client == nil {
|
||||
delete(d.proxyClients, key)
|
||||
continue
|
||||
}
|
||||
lastUsed := time.Unix(0, entry.lastUsedUnixNano)
|
||||
if now.Sub(lastUsed) > idleTTL {
|
||||
closeOpenAIWSProxyClient(entry.client)
|
||||
delete(d.proxyClients, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *coderOpenAIWSClientDialer) ensureProxyClientCapacityLocked() {
|
||||
if d == nil {
|
||||
return
|
||||
}
|
||||
maxEntries := openAIWSProxyClientCacheMaxEntries
|
||||
if maxEntries <= 0 {
|
||||
return
|
||||
}
|
||||
for len(d.proxyClients) > maxEntries {
|
||||
var oldestKey string
|
||||
var oldestLastUsed int64
|
||||
hasOldest := false
|
||||
for key, entry := range d.proxyClients {
|
||||
lastUsed := int64(0)
|
||||
if entry != nil {
|
||||
lastUsed = entry.lastUsedUnixNano
|
||||
}
|
||||
if !hasOldest || lastUsed < oldestLastUsed {
|
||||
hasOldest = true
|
||||
oldestKey = key
|
||||
oldestLastUsed = lastUsed
|
||||
}
|
||||
}
|
||||
if !hasOldest {
|
||||
return
|
||||
}
|
||||
if entry := d.proxyClients[oldestKey]; entry != nil {
|
||||
closeOpenAIWSProxyClient(entry.client)
|
||||
}
|
||||
delete(d.proxyClients, oldestKey)
|
||||
}
|
||||
}
|
||||
|
||||
func closeOpenAIWSProxyClient(client *http.Client) {
|
||||
if client == nil || client.Transport == nil {
|
||||
return
|
||||
}
|
||||
if transport, ok := client.Transport.(*http.Transport); ok && transport != nil {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *coderOpenAIWSClientDialer) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
|
||||
if d == nil {
|
||||
return OpenAIWSTransportMetricsSnapshot{}
|
||||
}
|
||||
hits := d.proxyHits.Load()
|
||||
misses := d.proxyMisses.Load()
|
||||
total := hits + misses
|
||||
reuseRatio := 0.0
|
||||
if total > 0 {
|
||||
reuseRatio = float64(hits) / float64(total)
|
||||
}
|
||||
return OpenAIWSTransportMetricsSnapshot{
|
||||
ProxyClientCacheHits: hits,
|
||||
ProxyClientCacheMisses: misses,
|
||||
TransportReuseRatio: reuseRatio,
|
||||
}
|
||||
}
|
||||
|
||||
type coderOpenAIWSClientConn struct {
|
||||
conn *coderws.Conn
|
||||
}
|
||||
|
||||
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
|
||||
if c == nil || c.conn == nil {
|
||||
return errOpenAIWSConnClosed
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return wsjson.Write(ctx, c.conn, value)
|
||||
}
|
||||
|
||||
func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, error) {
|
||||
if c == nil || c.conn == nil {
|
||||
return nil, errOpenAIWSConnClosed
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
msgType, payload, err := c.conn.Read(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch msgType {
|
||||
case coderws.MessageText, coderws.MessageBinary:
|
||||
return payload, nil
|
||||
default:
|
||||
return nil, errOpenAIWSConnClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
|
||||
if c == nil || c.conn == nil {
|
||||
return errOpenAIWSConnClosed
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return c.conn.Ping(ctx)
|
||||
}
|
||||
|
||||
func (c *coderOpenAIWSClientConn) Close() error {
|
||||
if c == nil || c.conn == nil {
|
||||
return nil
|
||||
}
|
||||
// Close 为幂等,忽略重复关闭错误。
|
||||
_ = c.conn.Close(coderws.StatusNormalClosure, "")
|
||||
_ = c.conn.CloseNow()
|
||||
return nil
|
||||
}
|
||||
112
backend/internal/service/openai_ws_client_test.go
Normal file
112
backend/internal/service/openai_ws_client_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCoderOpenAIWSClientDialer_ProxyHTTPClientReuse(t *testing.T) {
|
||||
dialer := newDefaultOpenAIWSClientDialer()
|
||||
impl, ok := dialer.(*coderOpenAIWSClientDialer)
|
||||
require.True(t, ok)
|
||||
|
||||
c1, err := impl.proxyHTTPClient("http://127.0.0.1:8080")
|
||||
require.NoError(t, err)
|
||||
c2, err := impl.proxyHTTPClient("http://127.0.0.1:8080")
|
||||
require.NoError(t, err)
|
||||
require.Same(t, c1, c2, "同一代理地址应复用同一个 HTTP 客户端")
|
||||
|
||||
c3, err := impl.proxyHTTPClient("http://127.0.0.1:8081")
|
||||
require.NoError(t, err)
|
||||
require.NotSame(t, c1, c3, "不同代理地址应分离客户端")
|
||||
}
|
||||
|
||||
func TestCoderOpenAIWSClientDialer_ProxyHTTPClientInvalidURL(t *testing.T) {
|
||||
dialer := newDefaultOpenAIWSClientDialer()
|
||||
impl, ok := dialer.(*coderOpenAIWSClientDialer)
|
||||
require.True(t, ok)
|
||||
|
||||
_, err := impl.proxyHTTPClient("://bad")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCoderOpenAIWSClientDialer_TransportMetricsSnapshot(t *testing.T) {
|
||||
dialer := newDefaultOpenAIWSClientDialer()
|
||||
impl, ok := dialer.(*coderOpenAIWSClientDialer)
|
||||
require.True(t, ok)
|
||||
|
||||
_, err := impl.proxyHTTPClient("http://127.0.0.1:18080")
|
||||
require.NoError(t, err)
|
||||
_, err = impl.proxyHTTPClient("http://127.0.0.1:18080")
|
||||
require.NoError(t, err)
|
||||
_, err = impl.proxyHTTPClient("http://127.0.0.1:18081")
|
||||
require.NoError(t, err)
|
||||
|
||||
snapshot := impl.SnapshotTransportMetrics()
|
||||
require.Equal(t, int64(1), snapshot.ProxyClientCacheHits)
|
||||
require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses)
|
||||
require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001)
|
||||
}
|
||||
|
||||
func TestCoderOpenAIWSClientDialer_ProxyClientCacheCapacity(t *testing.T) {
|
||||
dialer := newDefaultOpenAIWSClientDialer()
|
||||
impl, ok := dialer.(*coderOpenAIWSClientDialer)
|
||||
require.True(t, ok)
|
||||
|
||||
total := openAIWSProxyClientCacheMaxEntries + 32
|
||||
for i := 0; i < total; i++ {
|
||||
_, err := impl.proxyHTTPClient(fmt.Sprintf("http://127.0.0.1:%d", 20000+i))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
impl.proxyMu.Lock()
|
||||
cacheSize := len(impl.proxyClients)
|
||||
impl.proxyMu.Unlock()
|
||||
|
||||
require.LessOrEqual(t, cacheSize, openAIWSProxyClientCacheMaxEntries, "代理客户端缓存应受容量上限约束")
|
||||
}
|
||||
|
||||
func TestCoderOpenAIWSClientDialer_ProxyClientCacheIdleTTL(t *testing.T) {
|
||||
dialer := newDefaultOpenAIWSClientDialer()
|
||||
impl, ok := dialer.(*coderOpenAIWSClientDialer)
|
||||
require.True(t, ok)
|
||||
|
||||
oldProxy := "http://127.0.0.1:28080"
|
||||
_, err := impl.proxyHTTPClient(oldProxy)
|
||||
require.NoError(t, err)
|
||||
|
||||
impl.proxyMu.Lock()
|
||||
oldEntry := impl.proxyClients[oldProxy]
|
||||
require.NotNil(t, oldEntry)
|
||||
oldEntry.lastUsedUnixNano = time.Now().Add(-openAIWSProxyClientCacheIdleTTL - time.Minute).UnixNano()
|
||||
impl.proxyMu.Unlock()
|
||||
|
||||
// 触发一次新的代理获取,驱动 TTL 清理。
|
||||
_, err = impl.proxyHTTPClient("http://127.0.0.1:28081")
|
||||
require.NoError(t, err)
|
||||
|
||||
impl.proxyMu.Lock()
|
||||
_, exists := impl.proxyClients[oldProxy]
|
||||
impl.proxyMu.Unlock()
|
||||
|
||||
require.False(t, exists, "超过空闲 TTL 的代理客户端应被回收")
|
||||
}
|
||||
|
||||
func TestCoderOpenAIWSClientDialer_ProxyTransportTLSHandshakeTimeout(t *testing.T) {
|
||||
dialer := newDefaultOpenAIWSClientDialer()
|
||||
impl, ok := dialer.(*coderOpenAIWSClientDialer)
|
||||
require.True(t, ok)
|
||||
|
||||
client, err := impl.proxyHTTPClient("http://127.0.0.1:38080")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, transport)
|
||||
require.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout)
|
||||
}
|
||||
251
backend/internal/service/openai_ws_fallback_test.go
Normal file
251
backend/internal/service/openai_ws_fallback_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClassifyOpenAIWSAcquireError(t *testing.T) {
|
||||
t.Run("dial_426_upgrade_required", func(t *testing.T) {
|
||||
err := &openAIWSDialError{StatusCode: 426, Err: errors.New("upgrade required")}
|
||||
require.Equal(t, "upgrade_required", classifyOpenAIWSAcquireError(err))
|
||||
})
|
||||
|
||||
t.Run("queue_full", func(t *testing.T) {
|
||||
require.Equal(t, "conn_queue_full", classifyOpenAIWSAcquireError(errOpenAIWSConnQueueFull))
|
||||
})
|
||||
|
||||
t.Run("preferred_conn_unavailable", func(t *testing.T) {
|
||||
require.Equal(t, "preferred_conn_unavailable", classifyOpenAIWSAcquireError(errOpenAIWSPreferredConnUnavailable))
|
||||
})
|
||||
|
||||
t.Run("acquire_timeout", func(t *testing.T) {
|
||||
require.Equal(t, "acquire_timeout", classifyOpenAIWSAcquireError(context.DeadlineExceeded))
|
||||
})
|
||||
|
||||
t.Run("auth_failed_401", func(t *testing.T) {
|
||||
err := &openAIWSDialError{StatusCode: 401, Err: errors.New("unauthorized")}
|
||||
require.Equal(t, "auth_failed", classifyOpenAIWSAcquireError(err))
|
||||
})
|
||||
|
||||
t.Run("upstream_rate_limited", func(t *testing.T) {
|
||||
err := &openAIWSDialError{StatusCode: 429, Err: errors.New("rate limited")}
|
||||
require.Equal(t, "upstream_rate_limited", classifyOpenAIWSAcquireError(err))
|
||||
})
|
||||
|
||||
t.Run("upstream_5xx", func(t *testing.T) {
|
||||
err := &openAIWSDialError{StatusCode: 502, Err: errors.New("bad gateway")}
|
||||
require.Equal(t, "upstream_5xx", classifyOpenAIWSAcquireError(err))
|
||||
})
|
||||
|
||||
t.Run("dial_failed_other_status", func(t *testing.T) {
|
||||
err := &openAIWSDialError{StatusCode: 418, Err: errors.New("teapot")}
|
||||
require.Equal(t, "dial_failed", classifyOpenAIWSAcquireError(err))
|
||||
})
|
||||
|
||||
t.Run("other", func(t *testing.T) {
|
||||
require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(errors.New("x")))
|
||||
})
|
||||
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(nil))
|
||||
})
|
||||
}
|
||||
|
||||
func TestClassifyOpenAIWSDialError(t *testing.T) {
|
||||
t.Run("handshake_not_finished", func(t *testing.T) {
|
||||
err := &openAIWSDialError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Err: errors.New("WebSocket protocol error: Handshake not finished"),
|
||||
}
|
||||
require.Equal(t, "handshake_not_finished", classifyOpenAIWSDialError(err))
|
||||
})
|
||||
|
||||
t.Run("context_deadline", func(t *testing.T) {
|
||||
err := &openAIWSDialError{
|
||||
StatusCode: 0,
|
||||
Err: context.DeadlineExceeded,
|
||||
}
|
||||
require.Equal(t, "ctx_deadline_exceeded", classifyOpenAIWSDialError(err))
|
||||
})
|
||||
}
|
||||
|
||||
func TestSummarizeOpenAIWSDialError(t *testing.T) {
|
||||
err := &openAIWSDialError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
ResponseHeaders: http.Header{
|
||||
"Server": []string{"cloudflare"},
|
||||
"Via": []string{"1.1 example"},
|
||||
"Cf-Ray": []string{"abcd1234"},
|
||||
"X-Request-Id": []string{"req_123"},
|
||||
},
|
||||
Err: errors.New("WebSocket protocol error: Handshake not finished"),
|
||||
}
|
||||
|
||||
status, class, closeStatus, closeReason, server, via, cfRay, reqID := summarizeOpenAIWSDialError(err)
|
||||
require.Equal(t, http.StatusBadGateway, status)
|
||||
require.Equal(t, "handshake_not_finished", class)
|
||||
require.Equal(t, "-", closeStatus)
|
||||
require.Equal(t, "-", closeReason)
|
||||
require.Equal(t, "cloudflare", server)
|
||||
require.Equal(t, "1.1 example", via)
|
||||
require.Equal(t, "abcd1234", cfRay)
|
||||
require.Equal(t, "req_123", reqID)
|
||||
}
|
||||
|
||||
func TestClassifyOpenAIWSErrorEvent(t *testing.T) {
|
||||
reason, recoverable := classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"upgrade_required","message":"Upgrade required"}}`))
|
||||
require.Equal(t, "upgrade_required", reason)
|
||||
require.True(t, recoverable)
|
||||
|
||||
reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"previous_response_not_found","message":"not found"}}`))
|
||||
require.Equal(t, "previous_response_not_found", reason)
|
||||
require.True(t, recoverable)
|
||||
}
|
||||
|
||||
func TestClassifyOpenAIWSReconnectReason(t *testing.T) {
|
||||
reason, retryable := classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("policy_violation", errors.New("policy")))
|
||||
require.Equal(t, "policy_violation", reason)
|
||||
require.False(t, retryable)
|
||||
|
||||
reason, retryable = classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("read_event", errors.New("io")))
|
||||
require.Equal(t, "read_event", reason)
|
||||
require.True(t, retryable)
|
||||
}
|
||||
|
||||
func TestOpenAIWSErrorHTTPStatus(t *testing.T) {
|
||||
require.Equal(t, http.StatusBadRequest, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)))
|
||||
require.Equal(t, http.StatusUnauthorized, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"authentication_error","code":"invalid_api_key","message":"auth failed"}}`)))
|
||||
require.Equal(t, http.StatusForbidden, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"permission_error","code":"forbidden","message":"forbidden"}}`)))
|
||||
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"rate_limit_error","code":"rate_limit_exceeded","message":"rate limited"}}`)))
|
||||
require.Equal(t, http.StatusBadGateway, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"server_error","code":"server_error","message":"server"}}`)))
|
||||
}
|
||||
|
||||
func TestResolveOpenAIWSFallbackErrorResponse(t *testing.T) {
|
||||
t.Run("previous_response_not_found", func(t *testing.T) {
|
||||
statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(
|
||||
wrapOpenAIWSFallback("previous_response_not_found", errors.New("previous response not found")),
|
||||
)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, http.StatusBadRequest, statusCode)
|
||||
require.Equal(t, "invalid_request_error", errType)
|
||||
require.Equal(t, "previous response not found", clientMessage)
|
||||
require.Equal(t, "previous response not found", upstreamMessage)
|
||||
})
|
||||
|
||||
t.Run("auth_failed_uses_dial_status", func(t *testing.T) {
|
||||
statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(
|
||||
wrapOpenAIWSFallback("auth_failed", &openAIWSDialError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Err: errors.New("forbidden"),
|
||||
}),
|
||||
)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, http.StatusForbidden, statusCode)
|
||||
require.Equal(t, "upstream_error", errType)
|
||||
require.Equal(t, "forbidden", clientMessage)
|
||||
require.Equal(t, "forbidden", upstreamMessage)
|
||||
})
|
||||
|
||||
t.Run("non_fallback_error_not_resolved", func(t *testing.T) {
|
||||
_, _, _, _, ok := resolveOpenAIWSFallbackErrorResponse(errors.New("plain error"))
|
||||
require.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIWSFallbackCooling(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
svc.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
|
||||
|
||||
require.False(t, svc.isOpenAIWSFallbackCooling(1))
|
||||
svc.markOpenAIWSFallbackCooling(1, "upgrade_required")
|
||||
require.True(t, svc.isOpenAIWSFallbackCooling(1))
|
||||
|
||||
svc.clearOpenAIWSFallbackCooling(1)
|
||||
require.False(t, svc.isOpenAIWSFallbackCooling(1))
|
||||
|
||||
svc.markOpenAIWSFallbackCooling(2, "x")
|
||||
time.Sleep(1200 * time.Millisecond)
|
||||
require.False(t, svc.isOpenAIWSFallbackCooling(2))
|
||||
}
|
||||
|
||||
func TestOpenAIWSRetryBackoff(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
svc.cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 100
|
||||
svc.cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 400
|
||||
svc.cfg.Gateway.OpenAIWS.RetryJitterRatio = 0
|
||||
|
||||
require.Equal(t, time.Duration(100)*time.Millisecond, svc.openAIWSRetryBackoff(1))
|
||||
require.Equal(t, time.Duration(200)*time.Millisecond, svc.openAIWSRetryBackoff(2))
|
||||
require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(3))
|
||||
require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(4))
|
||||
}
|
||||
|
||||
func TestOpenAIWSRetryTotalBudget(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 1200
|
||||
require.Equal(t, 1200*time.Millisecond, svc.openAIWSRetryTotalBudget())
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 0
|
||||
require.Equal(t, time.Duration(0), svc.openAIWSRetryTotalBudget())
|
||||
}
|
||||
|
||||
func TestClassifyOpenAIWSReadFallbackReason(t *testing.T) {
|
||||
require.Equal(t, "policy_violation", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusPolicyViolation}))
|
||||
require.Equal(t, "message_too_big", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusMessageTooBig}))
|
||||
require.Equal(t, "read_event", classifyOpenAIWSReadFallbackReason(errors.New("io")))
|
||||
}
|
||||
|
||||
func TestOpenAIWSStoreDisabledConnMode(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true
|
||||
require.Equal(t, openAIWSStoreDisabledConnModeStrict, svc.openAIWSStoreDisabledConnMode())
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = "adaptive"
|
||||
require.Equal(t, openAIWSStoreDisabledConnModeAdaptive, svc.openAIWSStoreDisabledConnMode())
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = ""
|
||||
svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false
|
||||
require.Equal(t, openAIWSStoreDisabledConnModeOff, svc.openAIWSStoreDisabledConnMode())
|
||||
}
|
||||
|
||||
func TestShouldForceNewConnOnStoreDisabled(t *testing.T) {
|
||||
require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeStrict, ""))
|
||||
require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeOff, "policy_violation"))
|
||||
|
||||
require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "policy_violation"))
|
||||
require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "prewarm_message_too_big"))
|
||||
require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "read_event"))
|
||||
}
|
||||
|
||||
func TestOpenAIWSRetryMetricsSnapshot(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
svc.recordOpenAIWSRetryAttempt(150 * time.Millisecond)
|
||||
svc.recordOpenAIWSRetryAttempt(0)
|
||||
svc.recordOpenAIWSRetryExhausted()
|
||||
svc.recordOpenAIWSNonRetryableFastFallback()
|
||||
|
||||
snapshot := svc.SnapshotOpenAIWSRetryMetrics()
|
||||
require.Equal(t, int64(2), snapshot.RetryAttemptsTotal)
|
||||
require.Equal(t, int64(150), snapshot.RetryBackoffMsTotal)
|
||||
require.Equal(t, int64(1), snapshot.RetryExhaustedTotal)
|
||||
require.Equal(t, int64(1), snapshot.NonRetryableFastFallbackTotal)
|
||||
}
|
||||
|
||||
func TestShouldLogOpenAIWSPayloadSchema(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 0
|
||||
require.True(t, svc.shouldLogOpenAIWSPayloadSchema(1), "首次尝试应始终记录 payload_schema")
|
||||
require.False(t, svc.shouldLogOpenAIWSPayloadSchema(2))
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 1
|
||||
require.True(t, svc.shouldLogOpenAIWSPayloadSchema(2))
|
||||
}
|
||||
3955
backend/internal/service/openai_ws_forwarder.go
Normal file
3955
backend/internal/service/openai_ws_forwarder.go
Normal file
File diff suppressed because it is too large
Load Diff
127
backend/internal/service/openai_ws_forwarder_benchmark_test.go
Normal file
127
backend/internal/service/openai_ws_forwarder_benchmark_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
var (
|
||||
benchmarkOpenAIWSPayloadJSONSink string
|
||||
benchmarkOpenAIWSStringSink string
|
||||
benchmarkOpenAIWSBoolSink bool
|
||||
benchmarkOpenAIWSBytesSink []byte
|
||||
)
|
||||
|
||||
func BenchmarkOpenAIWSForwarderHotPath(b *testing.B) {
|
||||
cfg := &config.Config{}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
account := &Account{ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
reqBody := benchmarkOpenAIWSHotPathRequest()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
payload := svc.buildOpenAIWSCreatePayload(reqBody, account)
|
||||
_, _ = applyOpenAIWSRetryPayloadStrategy(payload, 2)
|
||||
setOpenAIWSTurnMetadata(payload, `{"trace":"bench","turn":"1"}`)
|
||||
|
||||
benchmarkOpenAIWSStringSink = openAIWSPayloadString(payload, "previous_response_id")
|
||||
benchmarkOpenAIWSBoolSink = payload["tools"] != nil
|
||||
benchmarkOpenAIWSStringSink = summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN)
|
||||
benchmarkOpenAIWSStringSink = summarizeOpenAIWSInput(payload["input"])
|
||||
benchmarkOpenAIWSPayloadJSONSink = payloadAsJSON(payload)
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkOpenAIWSHotPathRequest() map[string]any {
|
||||
tools := make([]map[string]any, 0, 24)
|
||||
for i := 0; i < 24; i++ {
|
||||
tools = append(tools, map[string]any{
|
||||
"type": "function",
|
||||
"name": fmt.Sprintf("tool_%02d", i),
|
||||
"description": "benchmark tool schema",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{"type": "string"},
|
||||
"limit": map[string]any{"type": "number"},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
input := make([]map[string]any, 0, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
input = append(input, map[string]any{
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": fmt.Sprintf("benchmark message %d", i),
|
||||
})
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"type": "response.create",
|
||||
"model": "gpt-5.3-codex",
|
||||
"input": input,
|
||||
"tools": tools,
|
||||
"parallel_tool_calls": true,
|
||||
"previous_response_id": "resp_benchmark_prev",
|
||||
"prompt_cache_key": "bench-cache-key",
|
||||
"reasoning": map[string]any{"effort": "medium"},
|
||||
"instructions": "benchmark instructions",
|
||||
"store": false,
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOpenAIWSEventEnvelopeParse(b *testing.B) {
|
||||
event := []byte(`{"type":"response.completed","response":{"id":"resp_bench_1","model":"gpt-5.1","usage":{"input_tokens":12,"output_tokens":8}}}`)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
eventType, responseID, response := parseOpenAIWSEventEnvelope(event)
|
||||
benchmarkOpenAIWSStringSink = eventType
|
||||
benchmarkOpenAIWSStringSink = responseID
|
||||
benchmarkOpenAIWSBoolSink = response.Exists()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOpenAIWSErrorEventFieldReuse(b *testing.B) {
|
||||
event := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(event)
|
||||
benchmarkOpenAIWSStringSink, benchmarkOpenAIWSBoolSink = classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw)
|
||||
code, errType, errMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw)
|
||||
benchmarkOpenAIWSStringSink = code
|
||||
benchmarkOpenAIWSStringSink = errType
|
||||
benchmarkOpenAIWSStringSink = errMsg
|
||||
benchmarkOpenAIWSBoolSink = openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) > 0
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReplaceOpenAIWSMessageModel_NoMatchFastPath(b *testing.B) {
|
||||
event := []byte(`{"type":"response.output_text.delta","delta":"hello world"}`)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReplaceOpenAIWSMessageModel_DualReplace(b *testing.B) {
|
||||
event := []byte(`{"type":"response.completed","model":"gpt-5.1","response":{"id":"resp_1","model":"gpt-5.1"}}`)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseOpenAIWSEventEnvelope(t *testing.T) {
|
||||
eventType, responseID, response := parseOpenAIWSEventEnvelope([]byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.1"}}`))
|
||||
require.Equal(t, "response.completed", eventType)
|
||||
require.Equal(t, "resp_1", responseID)
|
||||
require.True(t, response.Exists())
|
||||
require.Equal(t, `{"id":"resp_1","model":"gpt-5.1"}`, response.Raw)
|
||||
|
||||
eventType, responseID, response = parseOpenAIWSEventEnvelope([]byte(`{"type":"response.delta","id":"evt_1"}`))
|
||||
require.Equal(t, "response.delta", eventType)
|
||||
require.Equal(t, "evt_1", responseID)
|
||||
require.False(t, response.Exists())
|
||||
}
|
||||
|
||||
func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) {
|
||||
usage := &OpenAIUsage{}
|
||||
parseOpenAIWSResponseUsageFromCompletedEvent(
|
||||
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`),
|
||||
usage,
|
||||
)
|
||||
require.Equal(t, 11, usage.InputTokens)
|
||||
require.Equal(t, 7, usage.OutputTokens)
|
||||
require.Equal(t, 3, usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) {
|
||||
message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
|
||||
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
|
||||
|
||||
wrappedReason, wrappedRecoverable := classifyOpenAIWSErrorEvent(message)
|
||||
rawReason, rawRecoverable := classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw)
|
||||
require.Equal(t, wrappedReason, rawReason)
|
||||
require.Equal(t, wrappedRecoverable, rawRecoverable)
|
||||
|
||||
wrappedStatus := openAIWSErrorHTTPStatus(message)
|
||||
rawStatus := openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw)
|
||||
require.Equal(t, wrappedStatus, rawStatus)
|
||||
require.Equal(t, http.StatusBadRequest, rawStatus)
|
||||
|
||||
wrappedCode, wrappedType, wrappedMsg := summarizeOpenAIWSErrorEventFields(message)
|
||||
rawCode, rawType, rawMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw)
|
||||
require.Equal(t, wrappedCode, rawCode)
|
||||
require.Equal(t, wrappedType, rawType)
|
||||
require.Equal(t, wrappedMsg, rawMsg)
|
||||
}
|
||||
|
||||
func TestOpenAIWSMessageLikelyContainsToolCalls(t *testing.T) {
|
||||
require.False(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_text.delta","delta":"hello"}`)))
|
||||
require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"tool_calls":[{"id":"tc1"}]}}`)))
|
||||
require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"type":"function_call"}}`)))
|
||||
}
|
||||
|
||||
func TestReplaceOpenAIWSMessageModel_OptimizedStillCorrect(t *testing.T) {
|
||||
noModel := []byte(`{"type":"response.output_text.delta","delta":"hello"}`)
|
||||
require.Equal(t, string(noModel), string(replaceOpenAIWSMessageModel(noModel, "gpt-5.1", "custom-model")))
|
||||
|
||||
rootOnly := []byte(`{"type":"response.created","model":"gpt-5.1"}`)
|
||||
require.Equal(t, `{"type":"response.created","model":"custom-model"}`, string(replaceOpenAIWSMessageModel(rootOnly, "gpt-5.1", "custom-model")))
|
||||
|
||||
responseOnly := []byte(`{"type":"response.completed","response":{"model":"gpt-5.1"}}`)
|
||||
require.Equal(t, `{"type":"response.completed","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(responseOnly, "gpt-5.1", "custom-model")))
|
||||
|
||||
both := []byte(`{"model":"gpt-5.1","response":{"model":"gpt-5.1"}}`)
|
||||
require.Equal(t, `{"model":"custom-model","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(both, "gpt-5.1", "custom-model")))
|
||||
}
|
||||
2483
backend/internal/service/openai_ws_forwarder_ingress_session_test.go
Normal file
2483
backend/internal/service/openai_ws_forwarder_ingress_session_test.go
Normal file
File diff suppressed because it is too large
Load Diff
714
backend/internal/service/openai_ws_forwarder_ingress_test.go
Normal file
714
backend/internal/service/openai_ws_forwarder_ingress_test.go
Normal file
@@ -0,0 +1,714 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestIsOpenAIWSClientDisconnectError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{name: "nil", err: nil, want: false},
|
||||
{name: "io_eof", err: io.EOF, want: true},
|
||||
{name: "net_closed", err: net.ErrClosed, want: true},
|
||||
{name: "context_canceled", err: context.Canceled, want: true},
|
||||
{name: "ws_normal_closure", err: coderws.CloseError{Code: coderws.StatusNormalClosure}, want: true},
|
||||
{name: "ws_going_away", err: coderws.CloseError{Code: coderws.StatusGoingAway}, want: true},
|
||||
{name: "ws_no_status", err: coderws.CloseError{Code: coderws.StatusNoStatusRcvd}, want: true},
|
||||
{name: "ws_abnormal_1006", err: coderws.CloseError{Code: coderws.StatusAbnormalClosure}, want: true},
|
||||
{name: "ws_policy_violation", err: coderws.CloseError{Code: coderws.StatusPolicyViolation}, want: false},
|
||||
{name: "wrapped_eof_message", err: errors.New("failed to get reader: failed to read frame header: EOF"), want: true},
|
||||
{name: "connection_reset_by_peer", err: errors.New("failed to read frame header: read tcp 127.0.0.1:1234->127.0.0.1:5678: read: connection reset by peer"), want: true},
|
||||
{name: "broken_pipe", err: errors.New("write tcp 127.0.0.1:1234->127.0.0.1:5678: write: broken pipe"), want: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, tt.want, isOpenAIWSClientDisconnectError(tt.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOpenAIWSIngressPreviousResponseNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(nil))
|
||||
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(errors.New("plain error")))
|
||||
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(
|
||||
wrapOpenAIWSIngressTurnError("read_upstream", errors.New("upstream read failed"), false),
|
||||
))
|
||||
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(
|
||||
wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), true),
|
||||
))
|
||||
require.True(t, isOpenAIWSIngressPreviousResponseNotFound(
|
||||
wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), false),
|
||||
))
|
||||
}
|
||||
|
||||
func TestOpenAIWSIngressPreviousResponseRecoveryEnabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var nilService *OpenAIGatewayService
|
||||
require.True(t, nilService.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil service should default to enabled")
|
||||
|
||||
svcWithNilCfg := &OpenAIGatewayService{}
|
||||
require.True(t, svcWithNilCfg.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil config should default to enabled")
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{},
|
||||
}
|
||||
require.False(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled(), "explicit config default should be false")
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true
|
||||
require.True(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled())
|
||||
}
|
||||
|
||||
func TestDropPreviousResponseIDFromRawPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("empty_payload", func(t *testing.T) {
|
||||
updated, removed, err := dropPreviousResponseIDFromRawPayload(nil)
|
||||
require.NoError(t, err)
|
||||
require.False(t, removed)
|
||||
require.Empty(t, updated)
|
||||
})
|
||||
|
||||
t.Run("payload_without_previous_response_id", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
|
||||
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
|
||||
require.NoError(t, err)
|
||||
require.False(t, removed)
|
||||
require.Equal(t, string(payload), string(updated))
|
||||
})
|
||||
|
||||
t.Run("normal_delete_success", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`)
|
||||
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
|
||||
require.NoError(t, err)
|
||||
require.True(t, removed)
|
||||
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
|
||||
})
|
||||
|
||||
t.Run("duplicate_keys_are_removed", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","previous_response_id":"resp_a","input":[],"previous_response_id":"resp_b"}`)
|
||||
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
|
||||
require.NoError(t, err)
|
||||
require.True(t, removed)
|
||||
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
|
||||
})
|
||||
|
||||
t.Run("nil_delete_fn_uses_default_delete_logic", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`)
|
||||
updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, nil)
|
||||
require.NoError(t, err)
|
||||
require.True(t, removed)
|
||||
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
|
||||
})
|
||||
|
||||
t.Run("delete_error", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`)
|
||||
updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, func(_ []byte, _ string) ([]byte, error) {
|
||||
return nil, errors.New("delete failed")
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.False(t, removed)
|
||||
require.Equal(t, string(payload), string(updated))
|
||||
})
|
||||
|
||||
t.Run("malformed_json_is_still_best_effort_deleted", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","previous_response_id":"resp_abc"`)
|
||||
require.True(t, gjson.GetBytes(payload, "previous_response_id").Exists())
|
||||
|
||||
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
|
||||
require.NoError(t, err)
|
||||
require.True(t, removed)
|
||||
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAlignStoreDisabledPreviousResponseID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("empty_payload", func(t *testing.T) {
|
||||
updated, changed, err := alignStoreDisabledPreviousResponseID(nil, "resp_target")
|
||||
require.NoError(t, err)
|
||||
require.False(t, changed)
|
||||
require.Empty(t, updated)
|
||||
})
|
||||
|
||||
t.Run("empty_expected", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","previous_response_id":"resp_old"}`)
|
||||
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "")
|
||||
require.NoError(t, err)
|
||||
require.False(t, changed)
|
||||
require.Equal(t, string(payload), string(updated))
|
||||
})
|
||||
|
||||
t.Run("missing_previous_response_id", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
|
||||
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
|
||||
require.NoError(t, err)
|
||||
require.False(t, changed)
|
||||
require.Equal(t, string(payload), string(updated))
|
||||
})
|
||||
|
||||
t.Run("already_aligned", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","previous_response_id":"resp_target"}`)
|
||||
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
|
||||
require.NoError(t, err)
|
||||
require.False(t, changed)
|
||||
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
|
||||
})
|
||||
|
||||
t.Run("mismatch_rewrites_to_expected", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","previous_response_id":"resp_old","input":[]}`)
|
||||
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
|
||||
})
|
||||
|
||||
t.Run("duplicate_keys_rewrites_to_single_expected", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","previous_response_id":"resp_old_1","input":[],"previous_response_id":"resp_old_2"}`)
|
||||
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetPreviousResponseIDToRawPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("empty_payload", func(t *testing.T) {
|
||||
updated, err := setPreviousResponseIDToRawPayload(nil, "resp_target")
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, updated)
|
||||
})
|
||||
|
||||
t.Run("empty_previous_response_id", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
|
||||
updated, err := setPreviousResponseIDToRawPayload(payload, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(payload), string(updated))
|
||||
})
|
||||
|
||||
t.Run("set_previous_response_id_when_missing", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
|
||||
updated, err := setPreviousResponseIDToRawPayload(payload, "resp_target")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
|
||||
require.Equal(t, "gpt-5.1", gjson.GetBytes(updated, "model").String())
|
||||
})
|
||||
|
||||
t.Run("overwrite_existing_previous_response_id", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_old"}`)
|
||||
updated, err := setPreviousResponseIDToRawPayload(payload, "resp_new")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "resp_new", gjson.GetBytes(updated, "previous_response_id").String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
storeDisabled bool
|
||||
turn int
|
||||
hasFunctionCallOutput bool
|
||||
currentPreviousResponse string
|
||||
expectedPrevious string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "infer_when_all_conditions_match",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "resp_1",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "skip_when_store_enabled",
|
||||
storeDisabled: false,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_on_first_turn",
|
||||
storeDisabled: true,
|
||||
turn: 1,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_without_function_call_output",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: false,
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_when_request_already_has_previous_response_id",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
currentPreviousResponse: "resp_client",
|
||||
expectedPrevious: "resp_1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "skip_when_last_turn_response_id_missing",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "trim_whitespace_before_judgement",
|
||||
storeDisabled: true,
|
||||
turn: 2,
|
||||
hasFunctionCallOutput: true,
|
||||
expectedPrevious: " resp_2 ",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := shouldInferIngressFunctionCallOutputPreviousResponseID(
|
||||
tt.storeDisabled,
|
||||
tt.turn,
|
||||
tt.hasFunctionCallOutput,
|
||||
tt.currentPreviousResponse,
|
||||
tt.expectedPrevious,
|
||||
)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIWSInputIsPrefixExtended(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
previous []byte
|
||||
current []byte
|
||||
want bool
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "both_missing_input",
|
||||
previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`),
|
||||
current: []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_1"}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "previous_missing_current_empty_array",
|
||||
previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`),
|
||||
current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "previous_missing_current_non_empty_array",
|
||||
previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`),
|
||||
current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[{"type":"input_text","text":"hello"}]}`),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "array_prefix_match",
|
||||
previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`),
|
||||
current: []byte(`{"input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}]}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "array_prefix_mismatch",
|
||||
previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`),
|
||||
current: []byte(`{"input":[{"type":"input_text","text":"different"}]}`),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "current_shorter_than_previous",
|
||||
previous: []byte(`{"input":[{"type":"input_text","text":"a"},{"type":"input_text","text":"b"}]}`),
|
||||
current: []byte(`{"input":[{"type":"input_text","text":"a"}]}`),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "previous_has_input_current_missing",
|
||||
previous: []byte(`{"input":[{"type":"input_text","text":"a"}]}`),
|
||||
current: []byte(`{"model":"gpt-5.1"}`),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "input_string_treated_as_single_item",
|
||||
previous: []byte(`{"input":"hello"}`),
|
||||
current: []byte(`{"input":"hello"}`),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "current_invalid_input_json",
|
||||
previous: []byte(`{"input":[]}`),
|
||||
current: []byte(`{"input":[}`),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid_input_json",
|
||||
previous: []byte(`{"input":[}`),
|
||||
current: []byte(`{"input":[]}`),
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := openAIWSInputIsPrefixExtended(tt.previous, tt.current)
|
||||
if tt.expectErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIWSJSONForCompare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
normalized, err := normalizeOpenAIWSJSONForCompare([]byte(`{"b":2,"a":1}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, `{"a":1,"b":2}`, string(normalized))
|
||||
|
||||
_, err = normalizeOpenAIWSJSONForCompare([]byte(" "))
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = normalizeOpenAIWSJSONForCompare([]byte(`{"a":`))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIWSJSONForCompareOrRaw(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, `{"a":1,"b":2}`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"b":2,"a":1}`))))
|
||||
require.Equal(t, `{"a":`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"a":`))))
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
normalized, err := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(
|
||||
[]byte(`{"model":"gpt-5.1","input":[1],"previous_response_id":"resp_x","metadata":{"b":2,"a":1}}`),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.False(t, gjson.GetBytes(normalized, "input").Exists())
|
||||
require.False(t, gjson.GetBytes(normalized, "previous_response_id").Exists())
|
||||
require.Equal(t, float64(1), gjson.GetBytes(normalized, "metadata.a").Float())
|
||||
|
||||
_, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(nil)
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID([]byte(`[]`))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOpenAIWSExtractNormalizedInputSequence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("empty_payload", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence(nil)
|
||||
require.NoError(t, err)
|
||||
require.False(t, exists)
|
||||
require.Nil(t, items)
|
||||
})
|
||||
|
||||
t.Run("input_missing", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"type":"response.create"}`))
|
||||
require.NoError(t, err)
|
||||
require.False(t, exists)
|
||||
require.Nil(t, items)
|
||||
})
|
||||
|
||||
t.Run("input_array", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[{"type":"input_text","text":"hello"}]}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 1)
|
||||
})
|
||||
|
||||
t.Run("input_object", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":{"type":"input_text","text":"hello"}}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 1)
|
||||
})
|
||||
|
||||
t.Run("input_string", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":"hello"}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 1)
|
||||
require.Equal(t, `"hello"`, string(items[0]))
|
||||
})
|
||||
|
||||
t.Run("input_number", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":42}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 1)
|
||||
require.Equal(t, "42", string(items[0]))
|
||||
})
|
||||
|
||||
t.Run("input_bool", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":true}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 1)
|
||||
require.Equal(t, "true", string(items[0]))
|
||||
})
|
||||
|
||||
t.Run("input_null", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":null}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 1)
|
||||
require.Equal(t, "null", string(items[0]))
|
||||
})
|
||||
|
||||
t.Run("input_invalid_array_json", func(t *testing.T) {
|
||||
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[}`))
|
||||
require.Error(t, err)
|
||||
require.True(t, exists)
|
||||
require.Nil(t, items)
|
||||
})
|
||||
}
|
||||
|
||||
func TestShouldKeepIngressPreviousResponseID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
previousPayload := []byte(`{
|
||||
"type":"response.create",
|
||||
"model":"gpt-5.1",
|
||||
"store":false,
|
||||
"tools":[{"type":"function","name":"tool_a"}],
|
||||
"input":[{"type":"input_text","text":"hello"}]
|
||||
}`)
|
||||
currentStrictPayload := []byte(`{
|
||||
"type":"response.create",
|
||||
"model":"gpt-5.1",
|
||||
"store":false,
|
||||
"tools":[{"name":"tool_a","type":"function"}],
|
||||
"previous_response_id":"resp_turn_1",
|
||||
"input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}]
|
||||
}`)
|
||||
|
||||
t.Run("strict_incremental_keep", func(t *testing.T) {
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_1", false)
|
||||
require.NoError(t, err)
|
||||
require.True(t, keep)
|
||||
require.Equal(t, "strict_incremental_ok", reason)
|
||||
})
|
||||
|
||||
t.Run("missing_previous_response_id", func(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`)
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, keep)
|
||||
require.Equal(t, "missing_previous_response_id", reason)
|
||||
})
|
||||
|
||||
t.Run("missing_last_turn_response_id", func(t *testing.T) {
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "", false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, keep)
|
||||
require.Equal(t, "missing_last_turn_response_id", reason)
|
||||
})
|
||||
|
||||
t.Run("previous_response_id_mismatch", func(t *testing.T) {
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_other", false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, keep)
|
||||
require.Equal(t, "previous_response_id_mismatch", reason)
|
||||
})
|
||||
|
||||
t.Run("missing_previous_turn_payload", func(t *testing.T) {
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(nil, currentStrictPayload, "resp_turn_1", false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, keep)
|
||||
require.Equal(t, "missing_previous_turn_payload", reason)
|
||||
})
|
||||
|
||||
t.Run("non_input_changed", func(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"type":"response.create",
|
||||
"model":"gpt-5.1-mini",
|
||||
"store":false,
|
||||
"tools":[{"type":"function","name":"tool_a"}],
|
||||
"previous_response_id":"resp_turn_1",
|
||||
"input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]
|
||||
}`)
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, keep)
|
||||
require.Equal(t, "non_input_changed", reason)
|
||||
})
|
||||
|
||||
t.Run("delta_input_keeps_previous_response_id", func(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"type":"response.create",
|
||||
"model":"gpt-5.1",
|
||||
"store":false,
|
||||
"tools":[{"type":"function","name":"tool_a"}],
|
||||
"previous_response_id":"resp_turn_1",
|
||||
"input":[{"type":"input_text","text":"different"}]
|
||||
}`)
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
|
||||
require.NoError(t, err)
|
||||
require.True(t, keep)
|
||||
require.Equal(t, "strict_incremental_ok", reason)
|
||||
})
|
||||
|
||||
t.Run("function_call_output_keeps_previous_response_id", func(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"type":"response.create",
|
||||
"model":"gpt-5.1",
|
||||
"store":false,
|
||||
"previous_response_id":"resp_external",
|
||||
"input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
|
||||
}`)
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, keep)
|
||||
require.Equal(t, "has_function_call_output", reason)
|
||||
})
|
||||
|
||||
t.Run("non_input_compare_error", func(t *testing.T) {
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID([]byte(`[]`), currentStrictPayload, "resp_turn_1", false)
|
||||
require.Error(t, err)
|
||||
require.False(t, keep)
|
||||
require.Equal(t, "non_input_compare_error", reason)
|
||||
})
|
||||
|
||||
t.Run("current_payload_compare_error", func(t *testing.T) {
|
||||
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, []byte(`{"previous_response_id":"resp_turn_1","input":[}`), "resp_turn_1", false)
|
||||
require.Error(t, err)
|
||||
require.False(t, keep)
|
||||
require.Equal(t, "non_input_compare_error", reason)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildOpenAIWSReplayInputSequence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lastFull := []json.RawMessage{
|
||||
json.RawMessage(`{"type":"input_text","text":"hello"}`),
|
||||
}
|
||||
|
||||
t.Run("no_previous_response_id_use_current", func(t *testing.T) {
|
||||
items, exists, err := buildOpenAIWSReplayInputSequence(
|
||||
lastFull,
|
||||
true,
|
||||
[]byte(`{"input":[{"type":"input_text","text":"new"}]}`),
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 1)
|
||||
require.Equal(t, "new", gjson.GetBytes(items[0], "text").String())
|
||||
})
|
||||
|
||||
t.Run("previous_response_id_delta_append", func(t *testing.T) {
|
||||
items, exists, err := buildOpenAIWSReplayInputSequence(
|
||||
lastFull,
|
||||
true,
|
||||
[]byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"world"}]}`),
|
||||
true,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 2)
|
||||
require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String())
|
||||
require.Equal(t, "world", gjson.GetBytes(items[1], "text").String())
|
||||
})
|
||||
|
||||
t.Run("previous_response_id_full_input_replace", func(t *testing.T) {
|
||||
items, exists, err := buildOpenAIWSReplayInputSequence(
|
||||
lastFull,
|
||||
true,
|
||||
[]byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]}`),
|
||||
true,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
require.Len(t, items, 2)
|
||||
require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String())
|
||||
require.Equal(t, "world", gjson.GetBytes(items[1], "text").String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetOpenAIWSPayloadInputSequence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("set_items", func(t *testing.T) {
|
||||
original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`)
|
||||
items := []json.RawMessage{
|
||||
json.RawMessage(`{"type":"input_text","text":"hello"}`),
|
||||
json.RawMessage(`{"type":"input_text","text":"world"}`),
|
||||
}
|
||||
updated, err := setOpenAIWSPayloadInputSequence(original, items, true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hello", gjson.GetBytes(updated, "input.0.text").String())
|
||||
require.Equal(t, "world", gjson.GetBytes(updated, "input.1.text").String())
|
||||
})
|
||||
|
||||
t.Run("preserve_empty_array_not_null", func(t *testing.T) {
|
||||
original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`)
|
||||
updated, err := setOpenAIWSPayloadInputSequence(original, nil, true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, gjson.GetBytes(updated, "input").IsArray())
|
||||
require.Len(t, gjson.GetBytes(updated, "input").Array(), 0)
|
||||
require.False(t, gjson.GetBytes(updated, "input").Type == gjson.Null)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCloneOpenAIWSRawMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("nil_slice", func(t *testing.T) {
|
||||
cloned := cloneOpenAIWSRawMessages(nil)
|
||||
require.Nil(t, cloned)
|
||||
})
|
||||
|
||||
t.Run("empty_slice", func(t *testing.T) {
|
||||
items := make([]json.RawMessage, 0)
|
||||
cloned := cloneOpenAIWSRawMessages(items)
|
||||
require.NotNil(t, cloned)
|
||||
require.Len(t, cloned, 0)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplyOpenAIWSRetryPayloadStrategy_KeepPromptCacheKey(t *testing.T) {
|
||||
payload := map[string]any{
|
||||
"model": "gpt-5.3-codex",
|
||||
"prompt_cache_key": "pcache_123",
|
||||
"include": []any{"reasoning.encrypted_content"},
|
||||
"text": map[string]any{
|
||||
"verbosity": "low",
|
||||
},
|
||||
"tools": []any{map[string]any{"type": "function"}},
|
||||
}
|
||||
|
||||
strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 3)
|
||||
require.Equal(t, "trim_optional_fields", strategy)
|
||||
require.Contains(t, removed, "include")
|
||||
require.NotContains(t, removed, "prompt_cache_key")
|
||||
require.Equal(t, "pcache_123", payload["prompt_cache_key"])
|
||||
require.NotContains(t, payload, "include")
|
||||
require.Contains(t, payload, "text")
|
||||
}
|
||||
|
||||
func TestApplyOpenAIWSRetryPayloadStrategy_AttemptSixKeepsSemanticFields(t *testing.T) {
|
||||
payload := map[string]any{
|
||||
"prompt_cache_key": "pcache_456",
|
||||
"instructions": "long instructions",
|
||||
"tools": []any{map[string]any{"type": "function"}},
|
||||
"parallel_tool_calls": true,
|
||||
"tool_choice": "auto",
|
||||
"include": []any{"reasoning.encrypted_content"},
|
||||
"text": map[string]any{"verbosity": "high"},
|
||||
}
|
||||
|
||||
strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 6)
|
||||
require.Equal(t, "trim_optional_fields", strategy)
|
||||
require.Contains(t, removed, "include")
|
||||
require.NotContains(t, removed, "prompt_cache_key")
|
||||
require.Equal(t, "pcache_456", payload["prompt_cache_key"])
|
||||
require.Contains(t, payload, "instructions")
|
||||
require.Contains(t, payload, "tools")
|
||||
require.Contains(t, payload, "tool_choice")
|
||||
require.Contains(t, payload, "parallel_tool_calls")
|
||||
require.Contains(t, payload, "text")
|
||||
}
|
||||
1306
backend/internal/service/openai_ws_forwarder_success_test.go
Normal file
1306
backend/internal/service/openai_ws_forwarder_success_test.go
Normal file
File diff suppressed because it is too large
Load Diff
1706
backend/internal/service/openai_ws_pool.go
Normal file
1706
backend/internal/service/openai_ws_pool.go
Normal file
File diff suppressed because it is too large
Load Diff
58
backend/internal/service/openai_ws_pool_benchmark_test.go
Normal file
58
backend/internal/service/openai_ws_pool_benchmark_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
func BenchmarkOpenAIWSPoolAcquire(b *testing.B) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4
|
||||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 256
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
|
||||
|
||||
pool := newOpenAIWSConnPool(cfg)
|
||||
pool.setClientDialerForTest(&openAIWSCountingDialer{})
|
||||
|
||||
account := &Account{ID: 1001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
req := openAIWSAcquireRequest{
|
||||
Account: account,
|
||||
WSURL: "wss://example.com/v1/responses",
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
lease, err := pool.Acquire(ctx, req)
|
||||
if err != nil {
|
||||
b.Fatalf("warm acquire failed: %v", err)
|
||||
}
|
||||
lease.Release()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
var (
|
||||
got *openAIWSConnLease
|
||||
acquireErr error
|
||||
)
|
||||
for retry := 0; retry < 3; retry++ {
|
||||
got, acquireErr = pool.Acquire(ctx, req)
|
||||
if acquireErr == nil {
|
||||
break
|
||||
}
|
||||
if !errors.Is(acquireErr, errOpenAIWSConnClosed) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if acquireErr != nil {
|
||||
b.Fatalf("acquire failed: %v", acquireErr)
|
||||
}
|
||||
got.Release()
|
||||
}
|
||||
})
|
||||
}
|
||||
1709
backend/internal/service/openai_ws_pool_test.go
Normal file
1709
backend/internal/service/openai_ws_pool_test.go
Normal file
File diff suppressed because it is too large
Load Diff
1218
backend/internal/service/openai_ws_protocol_forward_test.go
Normal file
1218
backend/internal/service/openai_ws_protocol_forward_test.go
Normal file
File diff suppressed because it is too large
Load Diff
117
backend/internal/service/openai_ws_protocol_resolver.go
Normal file
117
backend/internal/service/openai_ws_protocol_resolver.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package service
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/config"
|
||||
|
||||
// OpenAIUpstreamTransport 表示 OpenAI 上游传输协议。
|
||||
type OpenAIUpstreamTransport string
|
||||
|
||||
const (
|
||||
OpenAIUpstreamTransportAny OpenAIUpstreamTransport = ""
|
||||
OpenAIUpstreamTransportHTTPSSE OpenAIUpstreamTransport = "http_sse"
|
||||
OpenAIUpstreamTransportResponsesWebsocket OpenAIUpstreamTransport = "responses_websockets"
|
||||
OpenAIUpstreamTransportResponsesWebsocketV2 OpenAIUpstreamTransport = "responses_websockets_v2"
|
||||
)
|
||||
|
||||
// OpenAIWSProtocolDecision 表示协议决策结果。
|
||||
type OpenAIWSProtocolDecision struct {
|
||||
Transport OpenAIUpstreamTransport
|
||||
Reason string
|
||||
}
|
||||
|
||||
// OpenAIWSProtocolResolver 定义 OpenAI 上游协议决策。
|
||||
type OpenAIWSProtocolResolver interface {
|
||||
Resolve(account *Account) OpenAIWSProtocolDecision
|
||||
}
|
||||
|
||||
type defaultOpenAIWSProtocolResolver struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewOpenAIWSProtocolResolver 创建默认协议决策器。
|
||||
func NewOpenAIWSProtocolResolver(cfg *config.Config) OpenAIWSProtocolResolver {
|
||||
return &defaultOpenAIWSProtocolResolver{cfg: cfg}
|
||||
}
|
||||
|
||||
func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProtocolDecision {
|
||||
if account == nil {
|
||||
return openAIWSHTTPDecision("account_missing")
|
||||
}
|
||||
if !account.IsOpenAI() {
|
||||
return openAIWSHTTPDecision("platform_not_openai")
|
||||
}
|
||||
if account.IsOpenAIWSForceHTTPEnabled() {
|
||||
return openAIWSHTTPDecision("account_force_http")
|
||||
}
|
||||
if r == nil || r.cfg == nil {
|
||||
return openAIWSHTTPDecision("config_missing")
|
||||
}
|
||||
|
||||
wsCfg := r.cfg.Gateway.OpenAIWS
|
||||
if wsCfg.ForceHTTP {
|
||||
return openAIWSHTTPDecision("global_force_http")
|
||||
}
|
||||
if !wsCfg.Enabled {
|
||||
return openAIWSHTTPDecision("global_disabled")
|
||||
}
|
||||
if account.IsOpenAIOAuth() {
|
||||
if !wsCfg.OAuthEnabled {
|
||||
return openAIWSHTTPDecision("oauth_disabled")
|
||||
}
|
||||
} else if account.IsOpenAIApiKey() {
|
||||
if !wsCfg.APIKeyEnabled {
|
||||
return openAIWSHTTPDecision("apikey_disabled")
|
||||
}
|
||||
} else {
|
||||
return openAIWSHTTPDecision("unknown_auth_type")
|
||||
}
|
||||
if wsCfg.ModeRouterV2Enabled {
|
||||
mode := account.ResolveOpenAIResponsesWebSocketV2Mode(wsCfg.IngressModeDefault)
|
||||
switch mode {
|
||||
case OpenAIWSIngressModeOff:
|
||||
return openAIWSHTTPDecision("account_mode_off")
|
||||
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
|
||||
// continue
|
||||
default:
|
||||
return openAIWSHTTPDecision("account_mode_off")
|
||||
}
|
||||
if account.Concurrency <= 0 {
|
||||
return openAIWSHTTPDecision("account_concurrency_invalid")
|
||||
}
|
||||
if wsCfg.ResponsesWebsocketsV2 {
|
||||
return OpenAIWSProtocolDecision{
|
||||
Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
Reason: "ws_v2_mode_" + mode,
|
||||
}
|
||||
}
|
||||
if wsCfg.ResponsesWebsockets {
|
||||
return OpenAIWSProtocolDecision{
|
||||
Transport: OpenAIUpstreamTransportResponsesWebsocket,
|
||||
Reason: "ws_v1_mode_" + mode,
|
||||
}
|
||||
}
|
||||
return openAIWSHTTPDecision("feature_disabled")
|
||||
}
|
||||
if !account.IsOpenAIResponsesWebSocketV2Enabled() {
|
||||
return openAIWSHTTPDecision("account_disabled")
|
||||
}
|
||||
if wsCfg.ResponsesWebsocketsV2 {
|
||||
return OpenAIWSProtocolDecision{
|
||||
Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
Reason: "ws_v2_enabled",
|
||||
}
|
||||
}
|
||||
if wsCfg.ResponsesWebsockets {
|
||||
return OpenAIWSProtocolDecision{
|
||||
Transport: OpenAIUpstreamTransportResponsesWebsocket,
|
||||
Reason: "ws_v1_enabled",
|
||||
}
|
||||
}
|
||||
return openAIWSHTTPDecision("feature_disabled")
|
||||
}
|
||||
|
||||
func openAIWSHTTPDecision(reason string) OpenAIWSProtocolDecision {
|
||||
return OpenAIWSProtocolDecision{
|
||||
Transport: OpenAIUpstreamTransportHTTPSSE,
|
||||
Reason: reason,
|
||||
}
|
||||
}
|
||||
203
backend/internal/service/openai_ws_protocol_resolver_test.go
Normal file
203
backend/internal/service/openai_ws_protocol_resolver_test.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpenAIWSProtocolResolver_Resolve(t *testing.T) {
|
||||
baseCfg := &config.Config{}
|
||||
baseCfg.Gateway.OpenAIWS.Enabled = true
|
||||
baseCfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
baseCfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
baseCfg.Gateway.OpenAIWS.ResponsesWebsockets = false
|
||||
baseCfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
|
||||
openAIOAuthEnabled := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("v2优先", func(t *testing.T) {
|
||||
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(openAIOAuthEnabled)
|
||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||
require.Equal(t, "ws_v2_enabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("v2关闭时回退v1", func(t *testing.T) {
|
||||
cfg := *baseCfg
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
|
||||
|
||||
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
|
||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocket, decision.Transport)
|
||||
require.Equal(t, "ws_v1_enabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("透传开关不影响WS协议判定", func(t *testing.T) {
|
||||
account := *openAIOAuthEnabled
|
||||
account.Extra = map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||
"openai_passthrough": true,
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
|
||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||
require.Equal(t, "ws_v2_enabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("账号级强制HTTP", func(t *testing.T) {
|
||||
account := *openAIOAuthEnabled
|
||||
account.Extra = map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||
"openai_ws_force_http": true,
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "account_force_http", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("全局关闭保持HTTP", func(t *testing.T) {
|
||||
cfg := *baseCfg
|
||||
cfg.Gateway.OpenAIWS.Enabled = false
|
||||
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "global_disabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("账号开关关闭保持HTTP", func(t *testing.T) {
|
||||
account := *openAIOAuthEnabled
|
||||
account.Extra = map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": false,
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "account_disabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
|
||||
account := *openAIOAuthEnabled
|
||||
account.Extra = map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "account_disabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("兼容旧键openai_ws_enabled", func(t *testing.T) {
|
||||
account := *openAIOAuthEnabled
|
||||
account.Extra = map[string]any{
|
||||
"openai_ws_enabled": true,
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
|
||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||
require.Equal(t, "ws_v2_enabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("按账号类型开关控制", func(t *testing.T) {
|
||||
cfg := *baseCfg
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = false
|
||||
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "oauth_disabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("API Key 账号关闭开关时回退HTTP", func(t *testing.T) {
|
||||
cfg := *baseCfg
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = false
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(account)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "apikey_disabled", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("未知认证类型回退HTTP", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: "unknown_type",
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(account)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "unknown_auth_type", decision.Reason)
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
|
||||
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
|
||||
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("dedicated mode routes to ws v2", func(t *testing.T) {
|
||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
|
||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||
require.Equal(t, "ws_v2_mode_dedicated", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("off mode routes to http", func(t *testing.T) {
|
||||
offAccount := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff,
|
||||
},
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(offAccount)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "account_mode_off", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) {
|
||||
legacyAccount := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
|
||||
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
|
||||
require.Equal(t, "ws_v2_mode_shared", decision.Reason)
|
||||
})
|
||||
|
||||
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
|
||||
invalidConcurrency := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
|
||||
},
|
||||
}
|
||||
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)
|
||||
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
|
||||
require.Equal(t, "account_concurrency_invalid", decision.Reason)
|
||||
})
|
||||
}
|
||||
440
backend/internal/service/openai_ws_state_store.go
Normal file
440
backend/internal/service/openai_ws_state_store.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIWSResponseAccountCachePrefix = "openai:response:"
|
||||
openAIWSStateStoreCleanupInterval = time.Minute
|
||||
openAIWSStateStoreCleanupMaxPerMap = 512
|
||||
openAIWSStateStoreMaxEntriesPerMap = 65536
|
||||
openAIWSStateStoreRedisTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
type openAIWSAccountBinding struct {
|
||||
accountID int64
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type openAIWSConnBinding struct {
|
||||
connID string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type openAIWSTurnStateBinding struct {
|
||||
turnState string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type openAIWSSessionConnBinding struct {
|
||||
connID string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// OpenAIWSStateStore 管理 WSv2 的粘连状态。
|
||||
// - response_id -> account_id 用于续链路由
|
||||
// - response_id -> conn_id 用于连接内上下文复用
|
||||
//
|
||||
// response_id -> account_id 优先走 GatewayCache(Redis),同时维护本地热缓存。
|
||||
// response_id -> conn_id 仅在本进程内有效。
|
||||
type OpenAIWSStateStore interface {
|
||||
BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error
|
||||
GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error)
|
||||
DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error
|
||||
|
||||
BindResponseConn(responseID, connID string, ttl time.Duration)
|
||||
GetResponseConn(responseID string) (string, bool)
|
||||
DeleteResponseConn(responseID string)
|
||||
|
||||
BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration)
|
||||
GetSessionTurnState(groupID int64, sessionHash string) (string, bool)
|
||||
DeleteSessionTurnState(groupID int64, sessionHash string)
|
||||
|
||||
BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration)
|
||||
GetSessionConn(groupID int64, sessionHash string) (string, bool)
|
||||
DeleteSessionConn(groupID int64, sessionHash string)
|
||||
}
|
||||
|
||||
type defaultOpenAIWSStateStore struct {
|
||||
cache GatewayCache
|
||||
|
||||
responseToAccountMu sync.RWMutex
|
||||
responseToAccount map[string]openAIWSAccountBinding
|
||||
responseToConnMu sync.RWMutex
|
||||
responseToConn map[string]openAIWSConnBinding
|
||||
sessionToTurnStateMu sync.RWMutex
|
||||
sessionToTurnState map[string]openAIWSTurnStateBinding
|
||||
sessionToConnMu sync.RWMutex
|
||||
sessionToConn map[string]openAIWSSessionConnBinding
|
||||
|
||||
lastCleanupUnixNano atomic.Int64
|
||||
}
|
||||
|
||||
// NewOpenAIWSStateStore 创建默认 WS 状态存储。
|
||||
func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore {
|
||||
store := &defaultOpenAIWSStateStore{
|
||||
cache: cache,
|
||||
responseToAccount: make(map[string]openAIWSAccountBinding, 256),
|
||||
responseToConn: make(map[string]openAIWSConnBinding, 256),
|
||||
sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256),
|
||||
sessionToConn: make(map[string]openAIWSSessionConnBinding, 256),
|
||||
}
|
||||
store.lastCleanupUnixNano.Store(time.Now().UnixNano())
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
if id == "" || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
ttl = normalizeOpenAIWSTTL(ttl)
|
||||
s.maybeCleanup()
|
||||
|
||||
expiresAt := time.Now().Add(ttl)
|
||||
s.responseToAccountMu.Lock()
|
||||
ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap)
|
||||
s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt}
|
||||
s.responseToAccountMu.Unlock()
|
||||
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
cacheKey := openAIWSResponseAccountCacheKey(id)
|
||||
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl)
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
if id == "" {
|
||||
return 0, nil
|
||||
}
|
||||
s.maybeCleanup()
|
||||
|
||||
now := time.Now()
|
||||
s.responseToAccountMu.RLock()
|
||||
if binding, ok := s.responseToAccount[id]; ok {
|
||||
if now.Before(binding.expiresAt) {
|
||||
accountID := binding.accountID
|
||||
s.responseToAccountMu.RUnlock()
|
||||
return accountID, nil
|
||||
}
|
||||
}
|
||||
s.responseToAccountMu.RUnlock()
|
||||
|
||||
if s.cache == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
cacheKey := openAIWSResponseAccountCacheKey(id)
|
||||
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey)
|
||||
if err != nil || accountID <= 0 {
|
||||
// 缓存读取失败不阻断主流程,按未命中降级。
|
||||
return 0, nil
|
||||
}
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
if id == "" {
|
||||
return nil
|
||||
}
|
||||
s.responseToAccountMu.Lock()
|
||||
delete(s.responseToAccount, id)
|
||||
s.responseToAccountMu.Unlock()
|
||||
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
|
||||
defer cancel()
|
||||
return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id))
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
conn := strings.TrimSpace(connID)
|
||||
if id == "" || conn == "" {
|
||||
return
|
||||
}
|
||||
ttl = normalizeOpenAIWSTTL(ttl)
|
||||
s.maybeCleanup()
|
||||
|
||||
s.responseToConnMu.Lock()
|
||||
ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap)
|
||||
s.responseToConn[id] = openAIWSConnBinding{
|
||||
connID: conn,
|
||||
expiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
s.responseToConnMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
if id == "" {
|
||||
return "", false
|
||||
}
|
||||
s.maybeCleanup()
|
||||
|
||||
now := time.Now()
|
||||
s.responseToConnMu.RLock()
|
||||
binding, ok := s.responseToConn[id]
|
||||
s.responseToConnMu.RUnlock()
|
||||
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
|
||||
return "", false
|
||||
}
|
||||
return binding.connID, true
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) {
|
||||
id := normalizeOpenAIWSResponseID(responseID)
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
s.responseToConnMu.Lock()
|
||||
delete(s.responseToConn, id)
|
||||
s.responseToConnMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
state := strings.TrimSpace(turnState)
|
||||
if key == "" || state == "" {
|
||||
return
|
||||
}
|
||||
ttl = normalizeOpenAIWSTTL(ttl)
|
||||
s.maybeCleanup()
|
||||
|
||||
s.sessionToTurnStateMu.Lock()
|
||||
ensureBindingCapacity(s.sessionToTurnState, key, openAIWSStateStoreMaxEntriesPerMap)
|
||||
s.sessionToTurnState[key] = openAIWSTurnStateBinding{
|
||||
turnState: state,
|
||||
expiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
s.sessionToTurnStateMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
if key == "" {
|
||||
return "", false
|
||||
}
|
||||
s.maybeCleanup()
|
||||
|
||||
now := time.Now()
|
||||
s.sessionToTurnStateMu.RLock()
|
||||
binding, ok := s.sessionToTurnState[key]
|
||||
s.sessionToTurnStateMu.RUnlock()
|
||||
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.turnState) == "" {
|
||||
return "", false
|
||||
}
|
||||
return binding.turnState, true
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessionHash string) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
s.sessionToTurnStateMu.Lock()
|
||||
delete(s.sessionToTurnState, key)
|
||||
s.sessionToTurnStateMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
conn := strings.TrimSpace(connID)
|
||||
if key == "" || conn == "" {
|
||||
return
|
||||
}
|
||||
ttl = normalizeOpenAIWSTTL(ttl)
|
||||
s.maybeCleanup()
|
||||
|
||||
s.sessionToConnMu.Lock()
|
||||
ensureBindingCapacity(s.sessionToConn, key, openAIWSStateStoreMaxEntriesPerMap)
|
||||
s.sessionToConn[key] = openAIWSSessionConnBinding{
|
||||
connID: conn,
|
||||
expiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
s.sessionToConnMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash string) (string, bool) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
if key == "" {
|
||||
return "", false
|
||||
}
|
||||
s.maybeCleanup()
|
||||
|
||||
now := time.Now()
|
||||
s.sessionToConnMu.RLock()
|
||||
binding, ok := s.sessionToConn[key]
|
||||
s.sessionToConnMu.RUnlock()
|
||||
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
|
||||
return "", false
|
||||
}
|
||||
return binding.connID, true
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) DeleteSessionConn(groupID int64, sessionHash string) {
|
||||
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
s.sessionToConnMu.Lock()
|
||||
delete(s.sessionToConn, key)
|
||||
s.sessionToConnMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIWSStateStore) maybeCleanup() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
last := time.Unix(0, s.lastCleanupUnixNano.Load())
|
||||
if now.Sub(last) < openAIWSStateStoreCleanupInterval {
|
||||
return
|
||||
}
|
||||
if !s.lastCleanupUnixNano.CompareAndSwap(last.UnixNano(), now.UnixNano()) {
|
||||
return
|
||||
}
|
||||
|
||||
// 增量限额清理,避免高规模下一次性全量扫描导致长时间阻塞。
|
||||
s.responseToAccountMu.Lock()
|
||||
cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap)
|
||||
s.responseToAccountMu.Unlock()
|
||||
|
||||
s.responseToConnMu.Lock()
|
||||
cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap)
|
||||
s.responseToConnMu.Unlock()
|
||||
|
||||
s.sessionToTurnStateMu.Lock()
|
||||
cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap)
|
||||
s.sessionToTurnStateMu.Unlock()
|
||||
|
||||
s.sessionToConnMu.Lock()
|
||||
cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap)
|
||||
s.sessionToConnMu.Unlock()
|
||||
}
|
||||
|
||||
func cleanupExpiredAccountBindings(bindings map[string]openAIWSAccountBinding, now time.Time, maxScan int) {
|
||||
if len(bindings) == 0 || maxScan <= 0 {
|
||||
return
|
||||
}
|
||||
scanned := 0
|
||||
for key, binding := range bindings {
|
||||
if now.After(binding.expiresAt) {
|
||||
delete(bindings, key)
|
||||
}
|
||||
scanned++
|
||||
if scanned >= maxScan {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now time.Time, maxScan int) {
|
||||
if len(bindings) == 0 || maxScan <= 0 {
|
||||
return
|
||||
}
|
||||
scanned := 0
|
||||
for key, binding := range bindings {
|
||||
if now.After(binding.expiresAt) {
|
||||
delete(bindings, key)
|
||||
}
|
||||
scanned++
|
||||
if scanned >= maxScan {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) {
|
||||
if len(bindings) == 0 || maxScan <= 0 {
|
||||
return
|
||||
}
|
||||
scanned := 0
|
||||
for key, binding := range bindings {
|
||||
if now.After(binding.expiresAt) {
|
||||
delete(bindings, key)
|
||||
}
|
||||
scanned++
|
||||
if scanned >= maxScan {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) {
|
||||
if len(bindings) == 0 || maxScan <= 0 {
|
||||
return
|
||||
}
|
||||
scanned := 0
|
||||
for key, binding := range bindings {
|
||||
if now.After(binding.expiresAt) {
|
||||
delete(bindings, key)
|
||||
}
|
||||
scanned++
|
||||
if scanned >= maxScan {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) {
|
||||
if len(bindings) < maxEntries || maxEntries <= 0 {
|
||||
return
|
||||
}
|
||||
if _, exists := bindings[incomingKey]; exists {
|
||||
return
|
||||
}
|
||||
// 固定上限保护:淘汰任意一项,优先保证内存有界。
|
||||
for key := range bindings {
|
||||
delete(bindings, key)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpenAIWSResponseID(responseID string) string {
|
||||
return strings.TrimSpace(responseID)
|
||||
}
|
||||
|
||||
func openAIWSResponseAccountCacheKey(responseID string) string {
|
||||
sum := sha256.Sum256([]byte(responseID))
|
||||
return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration {
|
||||
if ttl <= 0 {
|
||||
return time.Hour
|
||||
}
|
||||
return ttl
|
||||
}
|
||||
|
||||
func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string {
|
||||
hash := strings.TrimSpace(sessionHash)
|
||||
if hash == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%d:%s", groupID, hash)
|
||||
}
|
||||
|
||||
func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithTimeout(ctx, openAIWSStateStoreRedisTimeout)
|
||||
}
|
||||
235
backend/internal/service/openai_ws_state_store_test.go
Normal file
235
backend/internal/service/openai_ws_state_store_test.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpenAIWSStateStore_BindGetDeleteResponseAccount(t *testing.T) {
|
||||
cache := &stubGatewayCache{}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
ctx := context.Background()
|
||||
groupID := int64(7)
|
||||
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_abc", 101, time.Minute))
|
||||
|
||||
accountID, err := store.GetResponseAccount(ctx, groupID, "resp_abc")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(101), accountID)
|
||||
|
||||
require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_abc"))
|
||||
accountID, err = store.GetResponseAccount(ctx, groupID, "resp_abc")
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, accountID)
|
||||
}
|
||||
|
||||
func TestOpenAIWSStateStore_ResponseConnTTL(t *testing.T) {
|
||||
store := NewOpenAIWSStateStore(nil)
|
||||
store.BindResponseConn("resp_conn", "conn_1", 30*time.Millisecond)
|
||||
|
||||
connID, ok := store.GetResponseConn("resp_conn")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "conn_1", connID)
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
_, ok = store.GetResponseConn("resp_conn")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestOpenAIWSStateStore_SessionTurnStateTTL(t *testing.T) {
|
||||
store := NewOpenAIWSStateStore(nil)
|
||||
store.BindSessionTurnState(9, "session_hash_1", "turn_state_1", 30*time.Millisecond)
|
||||
|
||||
state, ok := store.GetSessionTurnState(9, "session_hash_1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "turn_state_1", state)
|
||||
|
||||
// group 隔离
|
||||
_, ok = store.GetSessionTurnState(10, "session_hash_1")
|
||||
require.False(t, ok)
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
_, ok = store.GetSessionTurnState(9, "session_hash_1")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestOpenAIWSStateStore_SessionConnTTL(t *testing.T) {
|
||||
store := NewOpenAIWSStateStore(nil)
|
||||
store.BindSessionConn(9, "session_hash_conn_1", "conn_1", 30*time.Millisecond)
|
||||
|
||||
connID, ok := store.GetSessionConn(9, "session_hash_conn_1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "conn_1", connID)
|
||||
|
||||
// group 隔离
|
||||
_, ok = store.GetSessionConn(10, "session_hash_conn_1")
|
||||
require.False(t, ok)
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
_, ok = store.GetSessionConn(9, "session_hash_conn_1")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing.T) {
|
||||
cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
|
||||
store := NewOpenAIWSStateStore(cache)
|
||||
ctx := context.Background()
|
||||
groupID := int64(17)
|
||||
responseID := "resp_cache_stale"
|
||||
cacheKey := openAIWSResponseAccountCacheKey(responseID)
|
||||
|
||||
cache.sessionBindings[cacheKey] = 501
|
||||
accountID, err := store.GetResponseAccount(ctx, groupID, responseID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(501), accountID)
|
||||
|
||||
delete(cache.sessionBindings, cacheKey)
|
||||
accountID, err = store.GetResponseAccount(ctx, groupID, responseID)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, accountID, "上游缓存失效后不应继续命中本地陈旧映射")
|
||||
}
|
||||
|
||||
func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T) {
|
||||
raw := NewOpenAIWSStateStore(nil)
|
||||
store, ok := raw.(*defaultOpenAIWSStateStore)
|
||||
require.True(t, ok)
|
||||
|
||||
expiredAt := time.Now().Add(-time.Minute)
|
||||
total := 2048
|
||||
store.responseToConnMu.Lock()
|
||||
for i := 0; i < total; i++ {
|
||||
store.responseToConn[fmt.Sprintf("resp_%d", i)] = openAIWSConnBinding{
|
||||
connID: "conn_incremental",
|
||||
expiresAt: expiredAt,
|
||||
}
|
||||
}
|
||||
store.responseToConnMu.Unlock()
|
||||
|
||||
store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano())
|
||||
store.maybeCleanup()
|
||||
|
||||
store.responseToConnMu.RLock()
|
||||
remainingAfterFirst := len(store.responseToConn)
|
||||
store.responseToConnMu.RUnlock()
|
||||
require.Less(t, remainingAfterFirst, total, "单轮 cleanup 应至少有进展")
|
||||
require.Greater(t, remainingAfterFirst, 0, "增量清理不要求单轮清空全部键")
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano())
|
||||
store.maybeCleanup()
|
||||
}
|
||||
|
||||
store.responseToConnMu.RLock()
|
||||
remaining := len(store.responseToConn)
|
||||
store.responseToConnMu.RUnlock()
|
||||
require.Zero(t, remaining, "多轮 cleanup 后应逐步清空全部过期键")
|
||||
}
|
||||
|
||||
func TestEnsureBindingCapacity_EvictsOneWhenMapIsFull(t *testing.T) {
|
||||
bindings := map[string]int{
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
}
|
||||
|
||||
ensureBindingCapacity(bindings, "c", 2)
|
||||
bindings["c"] = 3
|
||||
|
||||
require.Len(t, bindings, 2)
|
||||
require.Equal(t, 3, bindings["c"])
|
||||
}
|
||||
|
||||
func TestEnsureBindingCapacity_DoesNotEvictWhenUpdatingExistingKey(t *testing.T) {
|
||||
bindings := map[string]int{
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
}
|
||||
|
||||
ensureBindingCapacity(bindings, "a", 2)
|
||||
bindings["a"] = 9
|
||||
|
||||
require.Len(t, bindings, 2)
|
||||
require.Equal(t, 9, bindings["a"])
|
||||
}
|
||||
|
||||
type openAIWSStateStoreTimeoutProbeCache struct {
|
||||
setHasDeadline bool
|
||||
getHasDeadline bool
|
||||
deleteHasDeadline bool
|
||||
setDeadlineDelta time.Duration
|
||||
getDeadlineDelta time.Duration
|
||||
delDeadlineDelta time.Duration
|
||||
}
|
||||
|
||||
func (c *openAIWSStateStoreTimeoutProbeCache) GetSessionAccountID(ctx context.Context, _ int64, _ string) (int64, error) {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
c.getHasDeadline = true
|
||||
c.getDeadlineDelta = time.Until(deadline)
|
||||
}
|
||||
return 123, nil
|
||||
}
|
||||
|
||||
func (c *openAIWSStateStoreTimeoutProbeCache) SetSessionAccountID(ctx context.Context, _ int64, _ string, _ int64, _ time.Duration) error {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
c.setHasDeadline = true
|
||||
c.setDeadlineDelta = time.Until(deadline)
|
||||
}
|
||||
return errors.New("set failed")
|
||||
}
|
||||
|
||||
func (c *openAIWSStateStoreTimeoutProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *openAIWSStateStoreTimeoutProbeCache) DeleteSessionAccountID(ctx context.Context, _ int64, _ string) error {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
c.deleteHasDeadline = true
|
||||
c.delDeadlineDelta = time.Until(deadline)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) {
|
||||
probe := &openAIWSStateStoreTimeoutProbeCache{}
|
||||
store := NewOpenAIWSStateStore(probe)
|
||||
ctx := context.Background()
|
||||
groupID := int64(5)
|
||||
|
||||
err := store.BindResponseAccount(ctx, groupID, "resp_timeout_probe", 11, time.Minute)
|
||||
require.Error(t, err)
|
||||
|
||||
accountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_timeout_probe")
|
||||
require.NoError(t, getErr)
|
||||
require.Equal(t, int64(11), accountID, "本地缓存命中应优先返回已绑定账号")
|
||||
|
||||
require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_timeout_probe"))
|
||||
|
||||
require.True(t, probe.setHasDeadline, "SetSessionAccountID 应携带独立超时上下文")
|
||||
require.True(t, probe.deleteHasDeadline, "DeleteSessionAccountID 应携带独立超时上下文")
|
||||
require.False(t, probe.getHasDeadline, "GetSessionAccountID 本用例应由本地缓存命中,不触发 Redis 读取")
|
||||
require.Greater(t, probe.setDeadlineDelta, 2*time.Second)
|
||||
require.LessOrEqual(t, probe.setDeadlineDelta, 3*time.Second)
|
||||
require.Greater(t, probe.delDeadlineDelta, 2*time.Second)
|
||||
require.LessOrEqual(t, probe.delDeadlineDelta, 3*time.Second)
|
||||
|
||||
probe2 := &openAIWSStateStoreTimeoutProbeCache{}
|
||||
store2 := NewOpenAIWSStateStore(probe2)
|
||||
accountID2, err2 := store2.GetResponseAccount(ctx, groupID, "resp_cache_only")
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, int64(123), accountID2)
|
||||
require.True(t, probe2.getHasDeadline, "GetSessionAccountID 在缓存未命中时应携带独立超时上下文")
|
||||
require.Greater(t, probe2.getDeadlineDelta, 2*time.Second)
|
||||
require.LessOrEqual(t, probe2.getDeadlineDelta, 3*time.Second)
|
||||
}
|
||||
|
||||
func TestWithOpenAIWSStateStoreRedisTimeout_WithParentContext(t *testing.T) {
|
||||
ctx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
|
||||
defer cancel()
|
||||
require.NotNil(t, ctx)
|
||||
_, ok := ctx.Deadline()
|
||||
require.True(t, ok, "应附加短超时")
|
||||
}
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lib/pq"
|
||||
@@ -480,7 +479,7 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq
|
||||
|
||||
attemptCtx := ctx
|
||||
if switches > 0 {
|
||||
attemptCtx = context.WithValue(attemptCtx, ctxkey.AccountSwitchCount, switches)
|
||||
attemptCtx = WithAccountSwitchCount(attemptCtx, switches, false)
|
||||
}
|
||||
exec := func() *opsRetryExecution {
|
||||
defer selection.ReleaseFunc()
|
||||
@@ -675,6 +674,7 @@ func newOpsRetryContext(ctx context.Context, errorLog *OpsErrorLogDetail) (*gin.
|
||||
}
|
||||
|
||||
c.Request = req
|
||||
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
|
||||
return c, w
|
||||
}
|
||||
|
||||
|
||||
47
backend/internal/service/ops_retry_context_test.go
Normal file
47
backend/internal/service/ops_retry_context_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewOpsRetryContext_SetsHTTPTransportAndRequestHeaders(t *testing.T) {
|
||||
errorLog := &OpsErrorLogDetail{
|
||||
OpsErrorLog: OpsErrorLog{
|
||||
RequestPath: "/openai/v1/responses",
|
||||
},
|
||||
UserAgent: "ops-retry-agent/1.0",
|
||||
RequestHeaders: `{
|
||||
"anthropic-beta":"beta-v1",
|
||||
"ANTHROPIC-VERSION":"2023-06-01",
|
||||
"authorization":"Bearer should-not-forward"
|
||||
}`,
|
||||
}
|
||||
|
||||
c, w := newOpsRetryContext(context.Background(), errorLog)
|
||||
require.NotNil(t, c)
|
||||
require.NotNil(t, w)
|
||||
require.NotNil(t, c.Request)
|
||||
|
||||
require.Equal(t, "/openai/v1/responses", c.Request.URL.Path)
|
||||
require.Equal(t, "application/json", c.Request.Header.Get("Content-Type"))
|
||||
require.Equal(t, "ops-retry-agent/1.0", c.Request.Header.Get("User-Agent"))
|
||||
require.Equal(t, "beta-v1", c.Request.Header.Get("anthropic-beta"))
|
||||
require.Equal(t, "2023-06-01", c.Request.Header.Get("anthropic-version"))
|
||||
require.Empty(t, c.Request.Header.Get("authorization"), "未在白名单内的敏感头不应被重放")
|
||||
require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
func TestNewOpsRetryContext_InvalidHeadersJSONStillSetsHTTPTransport(t *testing.T) {
|
||||
errorLog := &OpsErrorLogDetail{
|
||||
RequestHeaders: "{invalid-json",
|
||||
}
|
||||
|
||||
c, _ := newOpsRetryContext(context.Background(), errorLog)
|
||||
require.NotNil(t, c)
|
||||
require.NotNil(t, c.Request)
|
||||
require.Equal(t, "/", c.Request.URL.Path)
|
||||
require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c))
|
||||
}
|
||||
@@ -27,6 +27,11 @@ const (
|
||||
OpsUpstreamLatencyMsKey = "ops_upstream_latency_ms"
|
||||
OpsResponseLatencyMsKey = "ops_response_latency_ms"
|
||||
OpsTimeToFirstTokenMsKey = "ops_time_to_first_token_ms"
|
||||
// OpenAI WS 关键观测字段
|
||||
OpsOpenAIWSQueueWaitMsKey = "ops_openai_ws_queue_wait_ms"
|
||||
OpsOpenAIWSConnPickMsKey = "ops_openai_ws_conn_pick_ms"
|
||||
OpsOpenAIWSConnReusedKey = "ops_openai_ws_conn_reused"
|
||||
OpsOpenAIWSConnIDKey = "ops_openai_ws_conn_id"
|
||||
|
||||
// OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。
|
||||
// ops_error_logger 中间件检查此 key,为 true 时跳过错误记录。
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// RateLimitService 处理限流和过载状态管理
|
||||
@@ -33,6 +34,10 @@ type geminiUsageCacheEntry struct {
|
||||
totals GeminiUsageTotals
|
||||
}
|
||||
|
||||
type geminiUsageTotalsBatchProvider interface {
|
||||
GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]GeminiUsageTotals, error)
|
||||
}
|
||||
|
||||
const geminiPrecheckCacheTTL = time.Minute
|
||||
|
||||
// NewRateLimitService 创建RateLimitService实例
|
||||
@@ -162,6 +167,17 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
if upstreamMsg != "" {
|
||||
msg = "Access forbidden (403): " + upstreamMsg
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"service.ratelimit",
|
||||
"[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s",
|
||||
account.ID,
|
||||
account.Platform,
|
||||
account.Type,
|
||||
strings.TrimSpace(headers.Get("x-request-id")),
|
||||
strings.TrimSpace(headers.Get("cf-ray")),
|
||||
upstreamMsg,
|
||||
truncateForLog(responseBody, 1024),
|
||||
)
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
shouldDisable = true
|
||||
case 429:
|
||||
@@ -225,7 +241,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
start := geminiDailyWindowStart(now)
|
||||
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
|
||||
if !ok {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
@@ -272,7 +288,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
|
||||
if limit > 0 {
|
||||
start := now.Truncate(time.Minute)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
@@ -302,6 +318,218 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// PreCheckUsageBatch performs quota precheck for multiple accounts in one request.
|
||||
// Returned map value=false means the account should be skipped.
|
||||
func (s *RateLimitService) PreCheckUsageBatch(ctx context.Context, accounts []*Account, requestedModel string) (map[int64]bool, error) {
|
||||
result := make(map[int64]bool, len(accounts))
|
||||
for _, account := range accounts {
|
||||
if account == nil {
|
||||
continue
|
||||
}
|
||||
result[account.ID] = true
|
||||
}
|
||||
|
||||
if len(accounts) == 0 || requestedModel == "" {
|
||||
return result, nil
|
||||
}
|
||||
if s.usageRepo == nil || s.geminiQuotaService == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
modelClass := geminiModelClassFromName(requestedModel)
|
||||
now := time.Now()
|
||||
dailyStart := geminiDailyWindowStart(now)
|
||||
minuteStart := now.Truncate(time.Minute)
|
||||
|
||||
type quotaAccount struct {
|
||||
account *Account
|
||||
quota GeminiQuota
|
||||
}
|
||||
quotaAccounts := make([]quotaAccount, 0, len(accounts))
|
||||
for _, account := range accounts {
|
||||
if account == nil || account.Platform != PlatformGemini {
|
||||
continue
|
||||
}
|
||||
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
quotaAccounts = append(quotaAccounts, quotaAccount{
|
||||
account: account,
|
||||
quota: quota,
|
||||
})
|
||||
}
|
||||
if len(quotaAccounts) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 1) Daily precheck (cached + batch DB fallback)
|
||||
dailyTotalsByID := make(map[int64]GeminiUsageTotals, len(quotaAccounts))
|
||||
dailyMissIDs := make([]int64, 0, len(quotaAccounts))
|
||||
for _, item := range quotaAccounts {
|
||||
limit := geminiDailyLimit(item.quota, modelClass)
|
||||
if limit <= 0 {
|
||||
continue
|
||||
}
|
||||
accountID := item.account.ID
|
||||
if totals, ok := s.getGeminiUsageTotals(accountID, dailyStart, now); ok {
|
||||
dailyTotalsByID[accountID] = totals
|
||||
continue
|
||||
}
|
||||
dailyMissIDs = append(dailyMissIDs, accountID)
|
||||
}
|
||||
if len(dailyMissIDs) > 0 {
|
||||
totalsBatch, err := s.getGeminiUsageTotalsBatch(ctx, dailyMissIDs, dailyStart, now)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
for _, accountID := range dailyMissIDs {
|
||||
totals := totalsBatch[accountID]
|
||||
dailyTotalsByID[accountID] = totals
|
||||
s.setGeminiUsageTotals(accountID, dailyStart, now, totals)
|
||||
}
|
||||
}
|
||||
for _, item := range quotaAccounts {
|
||||
limit := geminiDailyLimit(item.quota, modelClass)
|
||||
if limit <= 0 {
|
||||
continue
|
||||
}
|
||||
accountID := item.account.ID
|
||||
used := geminiUsedRequests(item.quota, modelClass, dailyTotalsByID[accountID], true)
|
||||
if used >= limit {
|
||||
resetAt := geminiDailyResetTime(now)
|
||||
slog.Info("gemini_precheck_daily_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt)
|
||||
result[accountID] = false
|
||||
}
|
||||
}
|
||||
|
||||
// 2) Minute precheck (batch DB)
|
||||
minuteIDs := make([]int64, 0, len(quotaAccounts))
|
||||
for _, item := range quotaAccounts {
|
||||
accountID := item.account.ID
|
||||
if !result[accountID] {
|
||||
continue
|
||||
}
|
||||
if geminiMinuteLimit(item.quota, modelClass) <= 0 {
|
||||
continue
|
||||
}
|
||||
minuteIDs = append(minuteIDs, accountID)
|
||||
}
|
||||
if len(minuteIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
minuteTotalsByID, err := s.getGeminiUsageTotalsBatch(ctx, minuteIDs, minuteStart, now)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
for _, item := range quotaAccounts {
|
||||
accountID := item.account.ID
|
||||
if !result[accountID] {
|
||||
continue
|
||||
}
|
||||
|
||||
limit := geminiMinuteLimit(item.quota, modelClass)
|
||||
if limit <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
used := geminiUsedRequests(item.quota, modelClass, minuteTotalsByID[accountID], false)
|
||||
if used >= limit {
|
||||
resetAt := minuteStart.Add(time.Minute)
|
||||
slog.Info("gemini_precheck_minute_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt)
|
||||
result[accountID] = false
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *RateLimitService) getGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, start, end time.Time) (map[int64]GeminiUsageTotals, error) {
|
||||
result := make(map[int64]GeminiUsageTotals, len(accountIDs))
|
||||
if len(accountIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
ids := make([]int64, 0, len(accountIDs))
|
||||
seen := make(map[int64]struct{}, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
if accountID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[accountID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[accountID] = struct{}{}
|
||||
ids = append(ids, accountID)
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if batchReader, ok := s.usageRepo.(geminiUsageTotalsBatchProvider); ok {
|
||||
stats, err := batchReader.GetGeminiUsageTotalsBatch(ctx, ids, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, accountID := range ids {
|
||||
result[accountID] = stats[accountID]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for _, accountID := range ids {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, end, 0, 0, accountID, 0, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[accountID] = geminiAggregateUsage(stats)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func geminiDailyLimit(quota GeminiQuota, modelClass geminiModelClass) int64 {
|
||||
if quota.SharedRPD > 0 {
|
||||
return quota.SharedRPD
|
||||
}
|
||||
switch modelClass {
|
||||
case geminiModelFlash:
|
||||
return quota.FlashRPD
|
||||
default:
|
||||
return quota.ProRPD
|
||||
}
|
||||
}
|
||||
|
||||
func geminiMinuteLimit(quota GeminiQuota, modelClass geminiModelClass) int64 {
|
||||
if quota.SharedRPM > 0 {
|
||||
return quota.SharedRPM
|
||||
}
|
||||
switch modelClass {
|
||||
case geminiModelFlash:
|
||||
return quota.FlashRPM
|
||||
default:
|
||||
return quota.ProRPM
|
||||
}
|
||||
}
|
||||
|
||||
func geminiUsedRequests(quota GeminiQuota, modelClass geminiModelClass, totals GeminiUsageTotals, daily bool) int64 {
|
||||
if daily {
|
||||
if quota.SharedRPD > 0 {
|
||||
return totals.ProRequests + totals.FlashRequests
|
||||
}
|
||||
} else {
|
||||
if quota.SharedRPM > 0 {
|
||||
return totals.ProRequests + totals.FlashRequests
|
||||
}
|
||||
}
|
||||
switch modelClass {
|
||||
case geminiModelFlash:
|
||||
return totals.FlashRequests
|
||||
default:
|
||||
return totals.ProRequests
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RateLimitService) getGeminiUsageTotals(accountID int64, windowStart, now time.Time) (GeminiUsageTotals, bool) {
|
||||
s.usageCacheMu.RLock()
|
||||
defer s.usageCacheMu.RUnlock()
|
||||
|
||||
216
backend/internal/service/request_metadata.go
Normal file
216
backend/internal/service/request_metadata.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
)
|
||||
|
||||
type requestMetadataContextKey struct{}
|
||||
|
||||
var requestMetadataKey = requestMetadataContextKey{}
|
||||
|
||||
type RequestMetadata struct {
|
||||
IsMaxTokensOneHaikuRequest *bool
|
||||
ThinkingEnabled *bool
|
||||
PrefetchedStickyAccountID *int64
|
||||
PrefetchedStickyGroupID *int64
|
||||
SingleAccountRetry *bool
|
||||
AccountSwitchCount *int
|
||||
}
|
||||
|
||||
var (
|
||||
requestMetadataFallbackIsMaxTokensOneHaikuTotal atomic.Int64
|
||||
requestMetadataFallbackThinkingEnabledTotal atomic.Int64
|
||||
requestMetadataFallbackPrefetchedStickyAccount atomic.Int64
|
||||
requestMetadataFallbackPrefetchedStickyGroup atomic.Int64
|
||||
requestMetadataFallbackSingleAccountRetryTotal atomic.Int64
|
||||
requestMetadataFallbackAccountSwitchCountTotal atomic.Int64
|
||||
)
|
||||
|
||||
func RequestMetadataFallbackStats() (isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount int64) {
|
||||
return requestMetadataFallbackIsMaxTokensOneHaikuTotal.Load(),
|
||||
requestMetadataFallbackThinkingEnabledTotal.Load(),
|
||||
requestMetadataFallbackPrefetchedStickyAccount.Load(),
|
||||
requestMetadataFallbackPrefetchedStickyGroup.Load(),
|
||||
requestMetadataFallbackSingleAccountRetryTotal.Load(),
|
||||
requestMetadataFallbackAccountSwitchCountTotal.Load()
|
||||
}
|
||||
|
||||
func metadataFromContext(ctx context.Context) *RequestMetadata {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
md, _ := ctx.Value(requestMetadataKey).(*RequestMetadata)
|
||||
return md
|
||||
}
|
||||
|
||||
func updateRequestMetadata(
|
||||
ctx context.Context,
|
||||
bridgeOldKeys bool,
|
||||
update func(md *RequestMetadata),
|
||||
legacyBridge func(ctx context.Context) context.Context,
|
||||
) context.Context {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
current := metadataFromContext(ctx)
|
||||
next := &RequestMetadata{}
|
||||
if current != nil {
|
||||
*next = *current
|
||||
}
|
||||
update(next)
|
||||
ctx = context.WithValue(ctx, requestMetadataKey, next)
|
||||
if bridgeOldKeys && legacyBridge != nil {
|
||||
ctx = legacyBridge(ctx)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func WithIsMaxTokensOneHaikuRequest(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
|
||||
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
|
||||
v := value
|
||||
md.IsMaxTokensOneHaikuRequest = &v
|
||||
}, func(base context.Context) context.Context {
|
||||
return context.WithValue(base, ctxkey.IsMaxTokensOneHaikuRequest, value)
|
||||
})
|
||||
}
|
||||
|
||||
func WithThinkingEnabled(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
|
||||
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
|
||||
v := value
|
||||
md.ThinkingEnabled = &v
|
||||
}, func(base context.Context) context.Context {
|
||||
return context.WithValue(base, ctxkey.ThinkingEnabled, value)
|
||||
})
|
||||
}
|
||||
|
||||
func WithPrefetchedStickySession(ctx context.Context, accountID, groupID int64, bridgeOldKeys bool) context.Context {
|
||||
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
|
||||
account := accountID
|
||||
group := groupID
|
||||
md.PrefetchedStickyAccountID = &account
|
||||
md.PrefetchedStickyGroupID = &group
|
||||
}, func(base context.Context) context.Context {
|
||||
bridged := context.WithValue(base, ctxkey.PrefetchedStickyAccountID, accountID)
|
||||
return context.WithValue(bridged, ctxkey.PrefetchedStickyGroupID, groupID)
|
||||
})
|
||||
}
|
||||
|
||||
func WithSingleAccountRetry(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
|
||||
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
|
||||
v := value
|
||||
md.SingleAccountRetry = &v
|
||||
}, func(base context.Context) context.Context {
|
||||
return context.WithValue(base, ctxkey.SingleAccountRetry, value)
|
||||
})
|
||||
}
|
||||
|
||||
func WithAccountSwitchCount(ctx context.Context, value int, bridgeOldKeys bool) context.Context {
|
||||
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
|
||||
v := value
|
||||
md.AccountSwitchCount = &v
|
||||
}, func(base context.Context) context.Context {
|
||||
return context.WithValue(base, ctxkey.AccountSwitchCount, value)
|
||||
})
|
||||
}
|
||||
|
||||
func IsMaxTokensOneHaikuRequestFromContext(ctx context.Context) (bool, bool) {
|
||||
if md := metadataFromContext(ctx); md != nil && md.IsMaxTokensOneHaikuRequest != nil {
|
||||
return *md.IsMaxTokensOneHaikuRequest, true
|
||||
}
|
||||
if ctx == nil {
|
||||
return false, false
|
||||
}
|
||||
if value, ok := ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok {
|
||||
requestMetadataFallbackIsMaxTokensOneHaikuTotal.Add(1)
|
||||
return value, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func ThinkingEnabledFromContext(ctx context.Context) (bool, bool) {
|
||||
if md := metadataFromContext(ctx); md != nil && md.ThinkingEnabled != nil {
|
||||
return *md.ThinkingEnabled, true
|
||||
}
|
||||
if ctx == nil {
|
||||
return false, false
|
||||
}
|
||||
if value, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
|
||||
requestMetadataFallbackThinkingEnabledTotal.Add(1)
|
||||
return value, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func PrefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) {
|
||||
if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyGroupID != nil {
|
||||
return *md.PrefetchedStickyGroupID, true
|
||||
}
|
||||
if ctx == nil {
|
||||
return 0, false
|
||||
}
|
||||
v := ctx.Value(ctxkey.PrefetchedStickyGroupID)
|
||||
switch t := v.(type) {
|
||||
case int64:
|
||||
requestMetadataFallbackPrefetchedStickyGroup.Add(1)
|
||||
return t, true
|
||||
case int:
|
||||
requestMetadataFallbackPrefetchedStickyGroup.Add(1)
|
||||
return int64(t), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func PrefetchedStickyAccountIDFromContext(ctx context.Context) (int64, bool) {
|
||||
if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyAccountID != nil {
|
||||
return *md.PrefetchedStickyAccountID, true
|
||||
}
|
||||
if ctx == nil {
|
||||
return 0, false
|
||||
}
|
||||
v := ctx.Value(ctxkey.PrefetchedStickyAccountID)
|
||||
switch t := v.(type) {
|
||||
case int64:
|
||||
requestMetadataFallbackPrefetchedStickyAccount.Add(1)
|
||||
return t, true
|
||||
case int:
|
||||
requestMetadataFallbackPrefetchedStickyAccount.Add(1)
|
||||
return int64(t), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func SingleAccountRetryFromContext(ctx context.Context) (bool, bool) {
|
||||
if md := metadataFromContext(ctx); md != nil && md.SingleAccountRetry != nil {
|
||||
return *md.SingleAccountRetry, true
|
||||
}
|
||||
if ctx == nil {
|
||||
return false, false
|
||||
}
|
||||
if value, ok := ctx.Value(ctxkey.SingleAccountRetry).(bool); ok {
|
||||
requestMetadataFallbackSingleAccountRetryTotal.Add(1)
|
||||
return value, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func AccountSwitchCountFromContext(ctx context.Context) (int, bool) {
|
||||
if md := metadataFromContext(ctx); md != nil && md.AccountSwitchCount != nil {
|
||||
return *md.AccountSwitchCount, true
|
||||
}
|
||||
if ctx == nil {
|
||||
return 0, false
|
||||
}
|
||||
v := ctx.Value(ctxkey.AccountSwitchCount)
|
||||
switch t := v.(type) {
|
||||
case int:
|
||||
requestMetadataFallbackAccountSwitchCountTotal.Add(1)
|
||||
return t, true
|
||||
case int64:
|
||||
requestMetadataFallbackAccountSwitchCountTotal.Add(1)
|
||||
return int(t), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
119
backend/internal/service/request_metadata_test.go
Normal file
119
backend/internal/service/request_metadata_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRequestMetadataWriteAndRead_NoBridge(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, false)
|
||||
ctx = WithThinkingEnabled(ctx, true, false)
|
||||
ctx = WithPrefetchedStickySession(ctx, 123, 456, false)
|
||||
ctx = WithSingleAccountRetry(ctx, true, false)
|
||||
ctx = WithAccountSwitchCount(ctx, 2, false)
|
||||
|
||||
isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.True(t, isHaiku)
|
||||
|
||||
thinking, ok := ThinkingEnabledFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.True(t, thinking)
|
||||
|
||||
accountID, ok := PrefetchedStickyAccountIDFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(123), accountID)
|
||||
|
||||
groupID, ok := PrefetchedStickyGroupIDFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(456), groupID)
|
||||
|
||||
singleRetry, ok := SingleAccountRetryFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.True(t, singleRetry)
|
||||
|
||||
switchCount, ok := AccountSwitchCountFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 2, switchCount)
|
||||
|
||||
require.Nil(t, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest))
|
||||
require.Nil(t, ctx.Value(ctxkey.ThinkingEnabled))
|
||||
require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyAccountID))
|
||||
require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyGroupID))
|
||||
require.Nil(t, ctx.Value(ctxkey.SingleAccountRetry))
|
||||
require.Nil(t, ctx.Value(ctxkey.AccountSwitchCount))
|
||||
}
|
||||
|
||||
func TestRequestMetadataWrite_BridgeLegacyKeys(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, true)
|
||||
ctx = WithThinkingEnabled(ctx, true, true)
|
||||
ctx = WithPrefetchedStickySession(ctx, 123, 456, true)
|
||||
ctx = WithSingleAccountRetry(ctx, true, true)
|
||||
ctx = WithAccountSwitchCount(ctx, 2, true)
|
||||
|
||||
require.Equal(t, true, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest))
|
||||
require.Equal(t, true, ctx.Value(ctxkey.ThinkingEnabled))
|
||||
require.Equal(t, int64(123), ctx.Value(ctxkey.PrefetchedStickyAccountID))
|
||||
require.Equal(t, int64(456), ctx.Value(ctxkey.PrefetchedStickyGroupID))
|
||||
require.Equal(t, true, ctx.Value(ctxkey.SingleAccountRetry))
|
||||
require.Equal(t, 2, ctx.Value(ctxkey.AccountSwitchCount))
|
||||
}
|
||||
|
||||
func TestRequestMetadataRead_LegacyFallbackAndStats(t *testing.T) {
|
||||
beforeHaiku, beforeThinking, beforeAccount, beforeGroup, beforeSingleRetry, beforeSwitchCount := RequestMetadataFallbackStats()
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, ctxkey.IsMaxTokensOneHaikuRequest, true)
|
||||
ctx = context.WithValue(ctx, ctxkey.ThinkingEnabled, true)
|
||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyAccountID, int64(321))
|
||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(654))
|
||||
ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true)
|
||||
ctx = context.WithValue(ctx, ctxkey.AccountSwitchCount, int64(3))
|
||||
|
||||
isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.True(t, isHaiku)
|
||||
|
||||
thinking, ok := ThinkingEnabledFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.True(t, thinking)
|
||||
|
||||
accountID, ok := PrefetchedStickyAccountIDFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(321), accountID)
|
||||
|
||||
groupID, ok := PrefetchedStickyGroupIDFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(654), groupID)
|
||||
|
||||
singleRetry, ok := SingleAccountRetryFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.True(t, singleRetry)
|
||||
|
||||
switchCount, ok := AccountSwitchCountFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 3, switchCount)
|
||||
|
||||
afterHaiku, afterThinking, afterAccount, afterGroup, afterSingleRetry, afterSwitchCount := RequestMetadataFallbackStats()
|
||||
require.Equal(t, beforeHaiku+1, afterHaiku)
|
||||
require.Equal(t, beforeThinking+1, afterThinking)
|
||||
require.Equal(t, beforeAccount+1, afterAccount)
|
||||
require.Equal(t, beforeGroup+1, afterGroup)
|
||||
require.Equal(t, beforeSingleRetry+1, afterSingleRetry)
|
||||
require.Equal(t, beforeSwitchCount+1, afterSwitchCount)
|
||||
}
|
||||
|
||||
func TestRequestMetadataRead_PreferMetadataOverLegacy(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, false)
|
||||
ctx = WithThinkingEnabled(ctx, true, false)
|
||||
|
||||
thinking, ok := ThinkingEnabledFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.True(t, thinking)
|
||||
require.Equal(t, false, ctx.Value(ctxkey.ThinkingEnabled))
|
||||
}
|
||||
13
backend/internal/service/response_header_filter.go
Normal file
13
backend/internal/service/response_header_filter.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
)
|
||||
|
||||
func compileResponseHeaderFilter(cfg *config.Config) *responseheaders.CompiledHeaderFilter {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return responseheaders.CompileHeaderFilter(cfg.Security.ResponseHeaders)
|
||||
}
|
||||
@@ -305,13 +305,78 @@ func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, p
|
||||
if payload == nil {
|
||||
return nil
|
||||
}
|
||||
ids := parseInt64Slice(payload["account_ids"])
|
||||
for _, id := range ids {
|
||||
if err := s.handleAccountEvent(ctx, &id, payload); err != nil {
|
||||
return err
|
||||
if s.accountRepo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rawIDs := parseInt64Slice(payload["account_ids"])
|
||||
if len(rawIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ids := make([]int64, 0, len(rawIDs))
|
||||
seen := make(map[int64]struct{}, len(rawIDs))
|
||||
for _, id := range rawIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[id]; exists {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
preloadGroupIDs := parseInt64Slice(payload["group_ids"])
|
||||
accounts, err := s.accountRepo.GetByIDs(ctx, ids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
found := make(map[int64]struct{}, len(accounts))
|
||||
rebuildGroupSet := make(map[int64]struct{}, len(preloadGroupIDs))
|
||||
for _, gid := range preloadGroupIDs {
|
||||
if gid > 0 {
|
||||
rebuildGroupSet[gid] = struct{}{}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
for _, account := range accounts {
|
||||
if account == nil || account.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
found[account.ID] = struct{}{}
|
||||
if s.cache != nil {
|
||||
if err := s.cache.SetAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, gid := range account.GroupIDs {
|
||||
if gid > 0 {
|
||||
rebuildGroupSet[gid] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.cache != nil {
|
||||
for _, id := range ids {
|
||||
if _, ok := found[id]; ok {
|
||||
continue
|
||||
}
|
||||
if err := s.cache.DeleteAccount(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rebuildGroupIDs := make([]int64, 0, len(rebuildGroupSet))
|
||||
for gid := range rebuildGroupSet {
|
||||
rebuildGroupIDs = append(rebuildGroupIDs, gid)
|
||||
}
|
||||
return s.rebuildByGroupIDs(ctx, rebuildGroupIDs, "account_bulk_change")
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any) error {
|
||||
|
||||
@@ -9,14 +9,17 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
|
||||
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
|
||||
ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
|
||||
ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
|
||||
)
|
||||
|
||||
type SettingRepository interface {
|
||||
@@ -34,6 +37,7 @@ type SettingService struct {
|
||||
settingRepo SettingRepository
|
||||
cfg *config.Config
|
||||
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
||||
onS3Update func() // Callback when Sora S3 settings are updated
|
||||
version string // Application version
|
||||
}
|
||||
|
||||
@@ -76,6 +80,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyHideCcsImportButton,
|
||||
SettingKeyPurchaseSubscriptionEnabled,
|
||||
SettingKeyPurchaseSubscriptionURL,
|
||||
SettingKeySoraClientEnabled,
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
}
|
||||
|
||||
@@ -114,6 +119,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
|
||||
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
}, nil
|
||||
}
|
||||
@@ -124,6 +130,11 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) {
|
||||
s.onUpdate = callback
|
||||
}
|
||||
|
||||
// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。
|
||||
func (s *SettingService) SetOnS3UpdateCallback(callback func()) {
|
||||
s.onS3Update = callback
|
||||
}
|
||||
|
||||
// SetVersion sets the application version for injection into public settings
|
||||
func (s *SettingService) SetVersion(version string) {
|
||||
s.version = version
|
||||
@@ -157,6 +168,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version,omitempty"`
|
||||
}{
|
||||
@@ -178,6 +190,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: s.version,
|
||||
}, nil
|
||||
@@ -232,6 +245,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
|
||||
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
|
||||
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
|
||||
updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled)
|
||||
|
||||
// 默认配置
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
@@ -383,6 +397,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeySiteLogo: "",
|
||||
SettingKeyPurchaseSubscriptionEnabled: "false",
|
||||
SettingKeyPurchaseSubscriptionURL: "",
|
||||
SettingKeySoraClientEnabled: "false",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeySMTPPort: "587",
|
||||
@@ -436,6 +451,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
|
||||
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
|
||||
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
|
||||
}
|
||||
|
||||
// 解析整数类型
|
||||
@@ -854,3 +870,607 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
|
||||
|
||||
return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data))
|
||||
}
|
||||
|
||||
type soraS3ProfilesStore struct {
|
||||
ActiveProfileID string `json:"active_profile_id"`
|
||||
Items []soraS3ProfileStoreItem `json:"items"`
|
||||
}
|
||||
|
||||
type soraS3ProfileStoreItem struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置)
|
||||
func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
|
||||
profiles, err := s.ListSoraS3Profiles(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
|
||||
if activeProfile == nil {
|
||||
return &SoraS3Settings{}, nil
|
||||
}
|
||||
|
||||
return &SoraS3Settings{
|
||||
Enabled: activeProfile.Enabled,
|
||||
Endpoint: activeProfile.Endpoint,
|
||||
Region: activeProfile.Region,
|
||||
Bucket: activeProfile.Bucket,
|
||||
AccessKeyID: activeProfile.AccessKeyID,
|
||||
SecretAccessKey: activeProfile.SecretAccessKey,
|
||||
SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured,
|
||||
Prefix: activeProfile.Prefix,
|
||||
ForcePathStyle: activeProfile.ForcePathStyle,
|
||||
CDNURL: activeProfile.CDNURL,
|
||||
DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置)
|
||||
func (s *SettingService) SetSoraS3Settings(ctx context.Context, settings *SoraS3Settings) error {
|
||||
if settings == nil {
|
||||
return fmt.Errorf("settings cannot be nil")
|
||||
}
|
||||
|
||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
activeIndex := findSoraS3ProfileIndex(store.Items, store.ActiveProfileID)
|
||||
if activeIndex < 0 {
|
||||
activeID := "default"
|
||||
if hasSoraS3ProfileID(store.Items, activeID) {
|
||||
activeID = fmt.Sprintf("default-%d", time.Now().Unix())
|
||||
}
|
||||
store.Items = append(store.Items, soraS3ProfileStoreItem{
|
||||
ProfileID: activeID,
|
||||
Name: "Default",
|
||||
UpdatedAt: now,
|
||||
})
|
||||
store.ActiveProfileID = activeID
|
||||
activeIndex = len(store.Items) - 1
|
||||
}
|
||||
|
||||
active := store.Items[activeIndex]
|
||||
active.Enabled = settings.Enabled
|
||||
active.Endpoint = strings.TrimSpace(settings.Endpoint)
|
||||
active.Region = strings.TrimSpace(settings.Region)
|
||||
active.Bucket = strings.TrimSpace(settings.Bucket)
|
||||
active.AccessKeyID = strings.TrimSpace(settings.AccessKeyID)
|
||||
active.Prefix = strings.TrimSpace(settings.Prefix)
|
||||
active.ForcePathStyle = settings.ForcePathStyle
|
||||
active.CDNURL = strings.TrimSpace(settings.CDNURL)
|
||||
active.DefaultStorageQuotaBytes = maxInt64(settings.DefaultStorageQuotaBytes, 0)
|
||||
if settings.SecretAccessKey != "" {
|
||||
active.SecretAccessKey = settings.SecretAccessKey
|
||||
}
|
||||
active.UpdatedAt = now
|
||||
store.Items[activeIndex] = active
|
||||
|
||||
return s.persistSoraS3ProfilesStore(ctx, store)
|
||||
}
|
||||
|
||||
// ListSoraS3Profiles 获取 Sora S3 多配置列表
|
||||
func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) {
|
||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return convertSoraS3ProfilesStore(store), nil
|
||||
}
|
||||
|
||||
// CreateSoraS3Profile 创建 Sora S3 配置
|
||||
func (s *SettingService) CreateSoraS3Profile(ctx context.Context, profile *SoraS3Profile, setActive bool) (*SoraS3Profile, error) {
|
||||
if profile == nil {
|
||||
return nil, fmt.Errorf("profile cannot be nil")
|
||||
}
|
||||
|
||||
profileID := strings.TrimSpace(profile.ProfileID)
|
||||
if profileID == "" {
|
||||
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
|
||||
}
|
||||
name := strings.TrimSpace(profile.Name)
|
||||
if name == "" {
|
||||
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
|
||||
}
|
||||
|
||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hasSoraS3ProfileID(store.Items, profileID) {
|
||||
return nil, ErrSoraS3ProfileExists
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
store.Items = append(store.Items, soraS3ProfileStoreItem{
|
||||
ProfileID: profileID,
|
||||
Name: name,
|
||||
Enabled: profile.Enabled,
|
||||
Endpoint: strings.TrimSpace(profile.Endpoint),
|
||||
Region: strings.TrimSpace(profile.Region),
|
||||
Bucket: strings.TrimSpace(profile.Bucket),
|
||||
AccessKeyID: strings.TrimSpace(profile.AccessKeyID),
|
||||
SecretAccessKey: profile.SecretAccessKey,
|
||||
Prefix: strings.TrimSpace(profile.Prefix),
|
||||
ForcePathStyle: profile.ForcePathStyle,
|
||||
CDNURL: strings.TrimSpace(profile.CDNURL),
|
||||
DefaultStorageQuotaBytes: maxInt64(profile.DefaultStorageQuotaBytes, 0),
|
||||
UpdatedAt: now,
|
||||
})
|
||||
|
||||
if setActive || store.ActiveProfileID == "" {
|
||||
store.ActiveProfileID = profileID
|
||||
}
|
||||
|
||||
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
profiles := convertSoraS3ProfilesStore(store)
|
||||
created := findSoraS3ProfileByID(profiles.Items, profileID)
|
||||
if created == nil {
|
||||
return nil, ErrSoraS3ProfileNotFound
|
||||
}
|
||||
return created, nil
|
||||
}
|
||||
|
||||
// UpdateSoraS3Profile 更新 Sora S3 配置
|
||||
func (s *SettingService) UpdateSoraS3Profile(ctx context.Context, profileID string, profile *SoraS3Profile) (*SoraS3Profile, error) {
|
||||
if profile == nil {
|
||||
return nil, fmt.Errorf("profile cannot be nil")
|
||||
}
|
||||
|
||||
targetID := strings.TrimSpace(profileID)
|
||||
if targetID == "" {
|
||||
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
|
||||
}
|
||||
|
||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
|
||||
if targetIndex < 0 {
|
||||
return nil, ErrSoraS3ProfileNotFound
|
||||
}
|
||||
|
||||
target := store.Items[targetIndex]
|
||||
name := strings.TrimSpace(profile.Name)
|
||||
if name == "" {
|
||||
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
|
||||
}
|
||||
target.Name = name
|
||||
target.Enabled = profile.Enabled
|
||||
target.Endpoint = strings.TrimSpace(profile.Endpoint)
|
||||
target.Region = strings.TrimSpace(profile.Region)
|
||||
target.Bucket = strings.TrimSpace(profile.Bucket)
|
||||
target.AccessKeyID = strings.TrimSpace(profile.AccessKeyID)
|
||||
target.Prefix = strings.TrimSpace(profile.Prefix)
|
||||
target.ForcePathStyle = profile.ForcePathStyle
|
||||
target.CDNURL = strings.TrimSpace(profile.CDNURL)
|
||||
target.DefaultStorageQuotaBytes = maxInt64(profile.DefaultStorageQuotaBytes, 0)
|
||||
if profile.SecretAccessKey != "" {
|
||||
target.SecretAccessKey = profile.SecretAccessKey
|
||||
}
|
||||
target.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
store.Items[targetIndex] = target
|
||||
|
||||
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
profiles := convertSoraS3ProfilesStore(store)
|
||||
updated := findSoraS3ProfileByID(profiles.Items, targetID)
|
||||
if updated == nil {
|
||||
return nil, ErrSoraS3ProfileNotFound
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// DeleteSoraS3Profile 删除 Sora S3 配置
|
||||
func (s *SettingService) DeleteSoraS3Profile(ctx context.Context, profileID string) error {
|
||||
targetID := strings.TrimSpace(profileID)
|
||||
if targetID == "" {
|
||||
return infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
|
||||
}
|
||||
|
||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
|
||||
if targetIndex < 0 {
|
||||
return ErrSoraS3ProfileNotFound
|
||||
}
|
||||
|
||||
store.Items = append(store.Items[:targetIndex], store.Items[targetIndex+1:]...)
|
||||
if store.ActiveProfileID == targetID {
|
||||
store.ActiveProfileID = ""
|
||||
if len(store.Items) > 0 {
|
||||
store.ActiveProfileID = store.Items[0].ProfileID
|
||||
}
|
||||
}
|
||||
|
||||
return s.persistSoraS3ProfilesStore(ctx, store)
|
||||
}
|
||||
|
||||
// SetActiveSoraS3Profile 设置激活的 Sora S3 配置
|
||||
func (s *SettingService) SetActiveSoraS3Profile(ctx context.Context, profileID string) (*SoraS3Profile, error) {
|
||||
targetID := strings.TrimSpace(profileID)
|
||||
if targetID == "" {
|
||||
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
|
||||
}
|
||||
|
||||
store, err := s.loadSoraS3ProfilesStore(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
|
||||
if targetIndex < 0 {
|
||||
return nil, ErrSoraS3ProfileNotFound
|
||||
}
|
||||
|
||||
store.ActiveProfileID = targetID
|
||||
store.Items[targetIndex].UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
profiles := convertSoraS3ProfilesStore(store)
|
||||
active := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
|
||||
if active == nil {
|
||||
return nil, ErrSoraS3ProfileNotFound
|
||||
}
|
||||
return active, nil
|
||||
}
|
||||
|
||||
func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles)
|
||||
if err == nil {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return &soraS3ProfilesStore{}, nil
|
||||
}
|
||||
var store soraS3ProfilesStore
|
||||
if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil {
|
||||
legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
|
||||
if legacyErr != nil {
|
||||
return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr)
|
||||
}
|
||||
if isEmptyLegacySoraS3Settings(legacy) {
|
||||
return &soraS3ProfilesStore{}, nil
|
||||
}
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
return &soraS3ProfilesStore{
|
||||
ActiveProfileID: "default",
|
||||
Items: []soraS3ProfileStoreItem{
|
||||
{
|
||||
ProfileID: "default",
|
||||
Name: "Default",
|
||||
Enabled: legacy.Enabled,
|
||||
Endpoint: strings.TrimSpace(legacy.Endpoint),
|
||||
Region: strings.TrimSpace(legacy.Region),
|
||||
Bucket: strings.TrimSpace(legacy.Bucket),
|
||||
AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
|
||||
SecretAccessKey: legacy.SecretAccessKey,
|
||||
Prefix: strings.TrimSpace(legacy.Prefix),
|
||||
ForcePathStyle: legacy.ForcePathStyle,
|
||||
CDNURL: strings.TrimSpace(legacy.CDNURL),
|
||||
DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
|
||||
UpdatedAt: now,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
normalized := normalizeSoraS3ProfilesStore(store)
|
||||
return &normalized, nil
|
||||
}
|
||||
|
||||
if !errors.Is(err, ErrSettingNotFound) {
|
||||
return nil, fmt.Errorf("get sora s3 profiles: %w", err)
|
||||
}
|
||||
|
||||
legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
|
||||
if legacyErr != nil {
|
||||
return nil, legacyErr
|
||||
}
|
||||
if isEmptyLegacySoraS3Settings(legacy) {
|
||||
return &soraS3ProfilesStore{}, nil
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
return &soraS3ProfilesStore{
|
||||
ActiveProfileID: "default",
|
||||
Items: []soraS3ProfileStoreItem{
|
||||
{
|
||||
ProfileID: "default",
|
||||
Name: "Default",
|
||||
Enabled: legacy.Enabled,
|
||||
Endpoint: strings.TrimSpace(legacy.Endpoint),
|
||||
Region: strings.TrimSpace(legacy.Region),
|
||||
Bucket: strings.TrimSpace(legacy.Bucket),
|
||||
AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
|
||||
SecretAccessKey: legacy.SecretAccessKey,
|
||||
Prefix: strings.TrimSpace(legacy.Prefix),
|
||||
ForcePathStyle: legacy.ForcePathStyle,
|
||||
CDNURL: strings.TrimSpace(legacy.CDNURL),
|
||||
DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
|
||||
UpdatedAt: now,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SettingService) persistSoraS3ProfilesStore(ctx context.Context, store *soraS3ProfilesStore) error {
|
||||
if store == nil {
|
||||
return fmt.Errorf("sora s3 profiles store cannot be nil")
|
||||
}
|
||||
|
||||
normalized := normalizeSoraS3ProfilesStore(*store)
|
||||
data, err := json.Marshal(normalized)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal sora s3 profiles: %w", err)
|
||||
}
|
||||
|
||||
updates := map[string]string{
|
||||
SettingKeySoraS3Profiles: string(data),
|
||||
}
|
||||
|
||||
active := pickActiveSoraS3ProfileFromStore(normalized.Items, normalized.ActiveProfileID)
|
||||
if active == nil {
|
||||
updates[SettingKeySoraS3Enabled] = "false"
|
||||
updates[SettingKeySoraS3Endpoint] = ""
|
||||
updates[SettingKeySoraS3Region] = ""
|
||||
updates[SettingKeySoraS3Bucket] = ""
|
||||
updates[SettingKeySoraS3AccessKeyID] = ""
|
||||
updates[SettingKeySoraS3Prefix] = ""
|
||||
updates[SettingKeySoraS3ForcePathStyle] = "false"
|
||||
updates[SettingKeySoraS3CDNURL] = ""
|
||||
updates[SettingKeySoraDefaultStorageQuotaBytes] = "0"
|
||||
updates[SettingKeySoraS3SecretAccessKey] = ""
|
||||
} else {
|
||||
updates[SettingKeySoraS3Enabled] = strconv.FormatBool(active.Enabled)
|
||||
updates[SettingKeySoraS3Endpoint] = strings.TrimSpace(active.Endpoint)
|
||||
updates[SettingKeySoraS3Region] = strings.TrimSpace(active.Region)
|
||||
updates[SettingKeySoraS3Bucket] = strings.TrimSpace(active.Bucket)
|
||||
updates[SettingKeySoraS3AccessKeyID] = strings.TrimSpace(active.AccessKeyID)
|
||||
updates[SettingKeySoraS3Prefix] = strings.TrimSpace(active.Prefix)
|
||||
updates[SettingKeySoraS3ForcePathStyle] = strconv.FormatBool(active.ForcePathStyle)
|
||||
updates[SettingKeySoraS3CDNURL] = strings.TrimSpace(active.CDNURL)
|
||||
updates[SettingKeySoraDefaultStorageQuotaBytes] = strconv.FormatInt(maxInt64(active.DefaultStorageQuotaBytes, 0), 10)
|
||||
updates[SettingKeySoraS3SecretAccessKey] = active.SecretAccessKey
|
||||
}
|
||||
|
||||
if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.onUpdate != nil {
|
||||
s.onUpdate()
|
||||
}
|
||||
if s.onS3Update != nil {
|
||||
s.onS3Update()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
|
||||
keys := []string{
|
||||
SettingKeySoraS3Enabled,
|
||||
SettingKeySoraS3Endpoint,
|
||||
SettingKeySoraS3Region,
|
||||
SettingKeySoraS3Bucket,
|
||||
SettingKeySoraS3AccessKeyID,
|
||||
SettingKeySoraS3SecretAccessKey,
|
||||
SettingKeySoraS3Prefix,
|
||||
SettingKeySoraS3ForcePathStyle,
|
||||
SettingKeySoraS3CDNURL,
|
||||
SettingKeySoraDefaultStorageQuotaBytes,
|
||||
}
|
||||
|
||||
values, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get legacy sora s3 settings: %w", err)
|
||||
}
|
||||
|
||||
result := &SoraS3Settings{
|
||||
Enabled: values[SettingKeySoraS3Enabled] == "true",
|
||||
Endpoint: values[SettingKeySoraS3Endpoint],
|
||||
Region: values[SettingKeySoraS3Region],
|
||||
Bucket: values[SettingKeySoraS3Bucket],
|
||||
AccessKeyID: values[SettingKeySoraS3AccessKeyID],
|
||||
SecretAccessKey: values[SettingKeySoraS3SecretAccessKey],
|
||||
SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "",
|
||||
Prefix: values[SettingKeySoraS3Prefix],
|
||||
ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true",
|
||||
CDNURL: values[SettingKeySoraS3CDNURL],
|
||||
}
|
||||
if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil {
|
||||
result.DefaultStorageQuotaBytes = v
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore {
|
||||
seen := make(map[string]struct{}, len(store.Items))
|
||||
normalized := soraS3ProfilesStore{
|
||||
ActiveProfileID: strings.TrimSpace(store.ActiveProfileID),
|
||||
Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)),
|
||||
}
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
for idx := range store.Items {
|
||||
item := store.Items[idx]
|
||||
item.ProfileID = strings.TrimSpace(item.ProfileID)
|
||||
if item.ProfileID == "" {
|
||||
item.ProfileID = fmt.Sprintf("profile-%d", idx+1)
|
||||
}
|
||||
if _, exists := seen[item.ProfileID]; exists {
|
||||
continue
|
||||
}
|
||||
seen[item.ProfileID] = struct{}{}
|
||||
|
||||
item.Name = strings.TrimSpace(item.Name)
|
||||
if item.Name == "" {
|
||||
item.Name = item.ProfileID
|
||||
}
|
||||
item.Endpoint = strings.TrimSpace(item.Endpoint)
|
||||
item.Region = strings.TrimSpace(item.Region)
|
||||
item.Bucket = strings.TrimSpace(item.Bucket)
|
||||
item.AccessKeyID = strings.TrimSpace(item.AccessKeyID)
|
||||
item.Prefix = strings.TrimSpace(item.Prefix)
|
||||
item.CDNURL = strings.TrimSpace(item.CDNURL)
|
||||
item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0)
|
||||
item.UpdatedAt = strings.TrimSpace(item.UpdatedAt)
|
||||
if item.UpdatedAt == "" {
|
||||
item.UpdatedAt = now
|
||||
}
|
||||
normalized.Items = append(normalized.Items, item)
|
||||
}
|
||||
|
||||
if len(normalized.Items) == 0 {
|
||||
normalized.ActiveProfileID = ""
|
||||
return normalized
|
||||
}
|
||||
|
||||
if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 {
|
||||
return normalized
|
||||
}
|
||||
|
||||
normalized.ActiveProfileID = normalized.Items[0].ProfileID
|
||||
return normalized
|
||||
}
|
||||
|
||||
func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList {
|
||||
if store == nil {
|
||||
return &SoraS3ProfileList{}
|
||||
}
|
||||
items := make([]SoraS3Profile, 0, len(store.Items))
|
||||
for idx := range store.Items {
|
||||
item := store.Items[idx]
|
||||
items = append(items, SoraS3Profile{
|
||||
ProfileID: item.ProfileID,
|
||||
Name: item.Name,
|
||||
IsActive: item.ProfileID == store.ActiveProfileID,
|
||||
Enabled: item.Enabled,
|
||||
Endpoint: item.Endpoint,
|
||||
Region: item.Region,
|
||||
Bucket: item.Bucket,
|
||||
AccessKeyID: item.AccessKeyID,
|
||||
SecretAccessKey: item.SecretAccessKey,
|
||||
SecretAccessKeyConfigured: item.SecretAccessKey != "",
|
||||
Prefix: item.Prefix,
|
||||
ForcePathStyle: item.ForcePathStyle,
|
||||
CDNURL: item.CDNURL,
|
||||
DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes,
|
||||
UpdatedAt: item.UpdatedAt,
|
||||
})
|
||||
}
|
||||
return &SoraS3ProfileList{
|
||||
ActiveProfileID: store.ActiveProfileID,
|
||||
Items: items,
|
||||
}
|
||||
}
|
||||
|
||||
func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile {
|
||||
for idx := range items {
|
||||
if items[idx].ProfileID == activeProfileID {
|
||||
return &items[idx]
|
||||
}
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &items[0]
|
||||
}
|
||||
|
||||
func findSoraS3ProfileByID(items []SoraS3Profile, profileID string) *SoraS3Profile {
|
||||
for idx := range items {
|
||||
if items[idx].ProfileID == profileID {
|
||||
return &items[idx]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func pickActiveSoraS3ProfileFromStore(items []soraS3ProfileStoreItem, activeProfileID string) *soraS3ProfileStoreItem {
|
||||
for idx := range items {
|
||||
if items[idx].ProfileID == activeProfileID {
|
||||
return &items[idx]
|
||||
}
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &items[0]
|
||||
}
|
||||
|
||||
func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int {
|
||||
for idx := range items {
|
||||
if items[idx].ProfileID == profileID {
|
||||
return idx
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func hasSoraS3ProfileID(items []soraS3ProfileStoreItem, profileID string) bool {
|
||||
return findSoraS3ProfileIndex(items, profileID) >= 0
|
||||
}
|
||||
|
||||
func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool {
|
||||
if settings == nil {
|
||||
return true
|
||||
}
|
||||
if settings.Enabled {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(settings.Endpoint) != "" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(settings.Region) != "" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(settings.Bucket) != "" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(settings.AccessKeyID) != "" {
|
||||
return false
|
||||
}
|
||||
if settings.SecretAccessKey != "" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(settings.Prefix) != "" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(settings.CDNURL) != "" {
|
||||
return false
|
||||
}
|
||||
return settings.DefaultStorageQuotaBytes == 0
|
||||
}
|
||||
|
||||
func maxInt64(value int64, min int64) int64 {
|
||||
if value < min {
|
||||
return min
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ type SystemSettings struct {
|
||||
HideCcsImportButton bool
|
||||
PurchaseSubscriptionEnabled bool
|
||||
PurchaseSubscriptionURL string
|
||||
SoraClientEnabled bool
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
@@ -81,11 +82,52 @@ type PublicSettings struct {
|
||||
|
||||
PurchaseSubscriptionEnabled bool
|
||||
PurchaseSubscriptionURL string
|
||||
SoraClientEnabled bool
|
||||
|
||||
LinuxDoOAuthEnabled bool
|
||||
Version string
|
||||
}
|
||||
|
||||
// SoraS3Settings Sora S3 存储配置
|
||||
type SoraS3Settings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// SoraS3Profile Sora S3 多配置项(服务内部模型)
|
||||
type SoraS3Profile struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// SoraS3ProfileList Sora S3 多配置列表
|
||||
type SoraS3ProfileList struct {
|
||||
ActiveProfileID string `json:"active_profile_id"`
|
||||
Items []SoraS3Profile `json:"items"`
|
||||
}
|
||||
|
||||
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
|
||||
type StreamTimeoutSettings struct {
|
||||
// Enabled 是否启用流超时处理
|
||||
|
||||
@@ -43,6 +43,7 @@ type SoraVideoRequest struct {
|
||||
Frames int
|
||||
Model string
|
||||
Size string
|
||||
VideoCount int
|
||||
MediaID string
|
||||
RemixTargetID string
|
||||
CameoIDs []string
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -63,8 +64,8 @@ var soraBlockedCIDRs = mustParseCIDRs([]string{
|
||||
// SoraGatewayService handles forwarding requests to Sora upstream.
|
||||
type SoraGatewayService struct {
|
||||
soraClient SoraClient
|
||||
mediaStorage *SoraMediaStorage
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
@@ -100,14 +101,14 @@ type soraPreflightChecker interface {
|
||||
|
||||
func NewSoraGatewayService(
|
||||
soraClient SoraClient,
|
||||
mediaStorage *SoraMediaStorage,
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
cfg *config.Config,
|
||||
) *SoraGatewayService {
|
||||
return &SoraGatewayService{
|
||||
soraClient: soraClient,
|
||||
mediaStorage: mediaStorage,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
@@ -115,6 +116,15 @@ func NewSoraGatewayService(
|
||||
func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient
|
||||
if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" {
|
||||
if s.httpUpstream == nil {
|
||||
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream)
|
||||
return nil, errors.New("httpUpstream not configured for sora apikey forwarding")
|
||||
}
|
||||
return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime)
|
||||
}
|
||||
|
||||
if s.soraClient == nil || !s.soraClient.Enabled() {
|
||||
if c != nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
@@ -296,6 +306,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
|
||||
taskID := ""
|
||||
var err error
|
||||
videoCount := parseSoraVideoCount(reqBody)
|
||||
switch modelCfg.Type {
|
||||
case "image":
|
||||
taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{
|
||||
@@ -321,6 +332,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
Frames: modelCfg.Frames,
|
||||
Model: modelCfg.Model,
|
||||
Size: modelCfg.Size,
|
||||
VideoCount: videoCount,
|
||||
MediaID: mediaID,
|
||||
RemixTargetID: remixTargetID,
|
||||
CameoIDs: extractSoraCameoIDs(reqBody),
|
||||
@@ -378,16 +390,9 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
}
|
||||
}
|
||||
|
||||
// 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。
|
||||
// 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。
|
||||
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
|
||||
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
|
||||
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
|
||||
if storeErr != nil {
|
||||
// 存储失败时降级使用原始 URL,不中断用户请求
|
||||
log.Printf("[Sora] StoreFromURLs failed, falling back to original URLs: %v", storeErr)
|
||||
} else {
|
||||
finalURLs = s.normalizeSoraMediaURLs(stored)
|
||||
}
|
||||
}
|
||||
if watermarkPostID != "" && watermarkOpts.DeletePost {
|
||||
if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
|
||||
log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
|
||||
@@ -463,6 +468,20 @@ func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
|
||||
}
|
||||
}
|
||||
|
||||
func parseSoraVideoCount(body map[string]any) int {
|
||||
if body == nil {
|
||||
return 1
|
||||
}
|
||||
keys := []string{"video_count", "videos", "n_variants"}
|
||||
for _, key := range keys {
|
||||
count := parseIntWithDefault(body, key, 0)
|
||||
if count > 0 {
|
||||
return clampInt(count, 1, 3)
|
||||
}
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
|
||||
if body == nil {
|
||||
return def
|
||||
@@ -508,6 +527,42 @@ func parseStringWithDefault(body map[string]any, key, def string) string {
|
||||
return def
|
||||
}
|
||||
|
||||
func parseIntWithDefault(body map[string]any, key string, def int) int {
|
||||
if body == nil {
|
||||
return def
|
||||
}
|
||||
val, ok := body[key]
|
||||
if !ok {
|
||||
return def
|
||||
}
|
||||
switch typed := val.(type) {
|
||||
case int:
|
||||
return typed
|
||||
case int32:
|
||||
return int(typed)
|
||||
case int64:
|
||||
return int(typed)
|
||||
case float64:
|
||||
return int(typed)
|
||||
case string:
|
||||
parsed, err := strconv.Atoi(strings.TrimSpace(typed))
|
||||
if err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func clampInt(v, minVal, maxVal int) int {
|
||||
if v < minVal {
|
||||
return minVal
|
||||
}
|
||||
if v > maxVal {
|
||||
return maxVal
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func extractSoraCameoIDs(body map[string]any) []string {
|
||||
if body == nil {
|
||||
return nil
|
||||
@@ -904,6 +959,21 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
|
||||
}
|
||||
var upstreamErr *SoraUpstreamError
|
||||
if errors.As(err, &upstreamErr) {
|
||||
accountID := int64(0)
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"service.sora",
|
||||
"[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s",
|
||||
accountID,
|
||||
model,
|
||||
upstreamErr.StatusCode,
|
||||
strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")),
|
||||
strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")),
|
||||
strings.TrimSpace(upstreamErr.Message),
|
||||
truncateForLog(upstreamErr.Body, 1024),
|
||||
)
|
||||
if s.rateLimitService != nil && account != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
|
||||
}
|
||||
|
||||
@@ -179,6 +179,31 @@ func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
|
||||
require.True(t, client.storyboard)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardVideoCount(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
videoStatus: &SoraVideoTaskStatus{
|
||||
Status: "completed",
|
||||
URLs: []string{"https://example.com/v.mp4"},
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 3, client.videoReq.VideoCount)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{}
|
||||
cfg := &config.Config{
|
||||
@@ -524,3 +549,10 @@ func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
|
||||
require.True(t, opts.Enabled)
|
||||
require.False(t, opts.FallbackOnFailure)
|
||||
}
|
||||
|
||||
func TestParseSoraVideoCount(t *testing.T) {
|
||||
require.Equal(t, 1, parseSoraVideoCount(nil))
|
||||
require.Equal(t, 2, parseSoraVideoCount(map[string]any{"video_count": float64(2)}))
|
||||
require.Equal(t, 3, parseSoraVideoCount(map[string]any{"videos": "5"}))
|
||||
require.Equal(t, 1, parseSoraVideoCount(map[string]any{"n_variants": 0}))
|
||||
}
|
||||
|
||||
63
backend/internal/service/sora_generation.go
Normal file
63
backend/internal/service/sora_generation.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SoraGeneration 代表一条 Sora 客户端生成记录。
|
||||
type SoraGeneration struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID *int64 `json:"api_key_id,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
MediaType string `json:"media_type"` // video / image
|
||||
Status string `json:"status"` // pending / generating / completed / failed / cancelled
|
||||
MediaURL string `json:"media_url"` // 主媒体 URL(预签名或 CDN)
|
||||
MediaURLs []string `json:"media_urls"` // 多图时的 URL 数组
|
||||
FileSizeBytes int64 `json:"file_size_bytes"`
|
||||
StorageType string `json:"storage_type"` // s3 / local / upstream / none
|
||||
S3ObjectKeys []string `json:"s3_object_keys"` // S3 object key 数组
|
||||
UpstreamTaskID string `json:"upstream_task_id"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty"`
|
||||
}
|
||||
|
||||
// Sora 生成记录状态常量
|
||||
const (
|
||||
SoraGenStatusPending = "pending"
|
||||
SoraGenStatusGenerating = "generating"
|
||||
SoraGenStatusCompleted = "completed"
|
||||
SoraGenStatusFailed = "failed"
|
||||
SoraGenStatusCancelled = "cancelled"
|
||||
)
|
||||
|
||||
// Sora 存储类型常量
|
||||
const (
|
||||
SoraStorageTypeS3 = "s3"
|
||||
SoraStorageTypeLocal = "local"
|
||||
SoraStorageTypeUpstream = "upstream"
|
||||
SoraStorageTypeNone = "none"
|
||||
)
|
||||
|
||||
// SoraGenerationListParams 查询生成记录的参数。
|
||||
type SoraGenerationListParams struct {
|
||||
UserID int64
|
||||
Status string // 可选筛选
|
||||
StorageType string // 可选筛选
|
||||
MediaType string // 可选筛选
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
// SoraGenerationRepository 生成记录持久化接口。
|
||||
type SoraGenerationRepository interface {
|
||||
Create(ctx context.Context, gen *SoraGeneration) error
|
||||
GetByID(ctx context.Context, id int64) (*SoraGeneration, error)
|
||||
Update(ctx context.Context, gen *SoraGeneration) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error)
|
||||
CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error)
|
||||
}
|
||||
332
backend/internal/service/sora_generation_service.go
Normal file
332
backend/internal/service/sora_generation_service.go
Normal file
@@ -0,0 +1,332 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrSoraGenerationConcurrencyLimit 表示用户进行中的任务数超限。
|
||||
ErrSoraGenerationConcurrencyLimit = errors.New("sora generation concurrent limit exceeded")
|
||||
// ErrSoraGenerationStateConflict 表示状态已发生变化(例如任务已取消)。
|
||||
ErrSoraGenerationStateConflict = errors.New("sora generation state conflict")
|
||||
// ErrSoraGenerationNotActive 表示任务不在可取消状态。
|
||||
ErrSoraGenerationNotActive = errors.New("sora generation is not active")
|
||||
)
|
||||
|
||||
const soraGenerationActiveLimit = 3
|
||||
|
||||
type soraGenerationRepoAtomicCreator interface {
|
||||
CreatePendingWithLimit(ctx context.Context, gen *SoraGeneration, activeStatuses []string, maxActive int64) error
|
||||
}
|
||||
|
||||
type soraGenerationRepoConditionalUpdater interface {
|
||||
UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error)
|
||||
UpdateCompletedIfActive(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64, completedAt time.Time) (bool, error)
|
||||
UpdateFailedIfActive(ctx context.Context, id int64, errMsg string, completedAt time.Time) (bool, error)
|
||||
UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error)
|
||||
UpdateStorageIfCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) (bool, error)
|
||||
}
|
||||
|
||||
// SoraGenerationService 管理 Sora 客户端的生成记录 CRUD。
|
||||
type SoraGenerationService struct {
|
||||
genRepo SoraGenerationRepository
|
||||
s3Storage *SoraS3Storage
|
||||
quotaService *SoraQuotaService
|
||||
}
|
||||
|
||||
// NewSoraGenerationService 创建生成记录服务。
|
||||
func NewSoraGenerationService(
|
||||
genRepo SoraGenerationRepository,
|
||||
s3Storage *SoraS3Storage,
|
||||
quotaService *SoraQuotaService,
|
||||
) *SoraGenerationService {
|
||||
return &SoraGenerationService{
|
||||
genRepo: genRepo,
|
||||
s3Storage: s3Storage,
|
||||
quotaService: quotaService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreatePending 创建一条 pending 状态的生成记录。
|
||||
func (s *SoraGenerationService) CreatePending(ctx context.Context, userID int64, apiKeyID *int64, model, prompt, mediaType string) (*SoraGeneration, error) {
|
||||
gen := &SoraGeneration{
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
MediaType: mediaType,
|
||||
Status: SoraGenStatusPending,
|
||||
StorageType: SoraStorageTypeNone,
|
||||
}
|
||||
if atomicCreator, ok := s.genRepo.(soraGenerationRepoAtomicCreator); ok {
|
||||
if err := atomicCreator.CreatePendingWithLimit(
|
||||
ctx,
|
||||
gen,
|
||||
[]string{SoraGenStatusPending, SoraGenStatusGenerating},
|
||||
soraGenerationActiveLimit,
|
||||
); err != nil {
|
||||
if errors.Is(err, ErrSoraGenerationConcurrencyLimit) {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("create generation: %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
|
||||
return gen, nil
|
||||
}
|
||||
|
||||
if err := s.genRepo.Create(ctx, gen); err != nil {
|
||||
return nil, fmt.Errorf("create generation: %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
|
||||
return gen, nil
|
||||
}
|
||||
|
||||
// MarkGenerating 标记为生成中。
|
||||
func (s *SoraGenerationService) MarkGenerating(ctx context.Context, id int64, upstreamTaskID string) error {
|
||||
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
|
||||
updated, err := updater.UpdateGeneratingIfPending(ctx, id, upstreamTaskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
gen, err := s.genRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if gen.Status != SoraGenStatusPending {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
gen.Status = SoraGenStatusGenerating
|
||||
gen.UpstreamTaskID = upstreamTaskID
|
||||
return s.genRepo.Update(ctx, gen)
|
||||
}
|
||||
|
||||
// MarkCompleted 标记为已完成。
|
||||
func (s *SoraGenerationService) MarkCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) error {
|
||||
now := time.Now()
|
||||
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
|
||||
updated, err := updater.UpdateCompletedIfActive(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
gen, err := s.genRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
gen.Status = SoraGenStatusCompleted
|
||||
gen.MediaURL = mediaURL
|
||||
gen.MediaURLs = mediaURLs
|
||||
gen.StorageType = storageType
|
||||
gen.S3ObjectKeys = s3Keys
|
||||
gen.FileSizeBytes = fileSizeBytes
|
||||
gen.CompletedAt = &now
|
||||
return s.genRepo.Update(ctx, gen)
|
||||
}
|
||||
|
||||
// MarkFailed 标记为失败。
|
||||
func (s *SoraGenerationService) MarkFailed(ctx context.Context, id int64, errMsg string) error {
|
||||
now := time.Now()
|
||||
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
|
||||
updated, err := updater.UpdateFailedIfActive(ctx, id, errMsg, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
gen, err := s.genRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
gen.Status = SoraGenStatusFailed
|
||||
gen.ErrorMessage = errMsg
|
||||
gen.CompletedAt = &now
|
||||
return s.genRepo.Update(ctx, gen)
|
||||
}
|
||||
|
||||
// MarkCancelled 标记为已取消。
|
||||
func (s *SoraGenerationService) MarkCancelled(ctx context.Context, id int64) error {
|
||||
now := time.Now()
|
||||
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
|
||||
updated, err := updater.UpdateCancelledIfActive(ctx, id, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
return ErrSoraGenerationNotActive
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
gen, err := s.genRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
|
||||
return ErrSoraGenerationNotActive
|
||||
}
|
||||
gen.Status = SoraGenStatusCancelled
|
||||
gen.CompletedAt = &now
|
||||
return s.genRepo.Update(ctx, gen)
|
||||
}
|
||||
|
||||
// UpdateStorageForCompleted 更新已完成记录的存储信息(不重置 completed_at)。
|
||||
func (s *SoraGenerationService) UpdateStorageForCompleted(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
mediaURL string,
|
||||
mediaURLs []string,
|
||||
storageType string,
|
||||
s3Keys []string,
|
||||
fileSizeBytes int64,
|
||||
) error {
|
||||
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
|
||||
updated, err := updater.UpdateStorageIfCompleted(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
gen, err := s.genRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if gen.Status != SoraGenStatusCompleted {
|
||||
return ErrSoraGenerationStateConflict
|
||||
}
|
||||
gen.MediaURL = mediaURL
|
||||
gen.MediaURLs = mediaURLs
|
||||
gen.StorageType = storageType
|
||||
gen.S3ObjectKeys = s3Keys
|
||||
gen.FileSizeBytes = fileSizeBytes
|
||||
return s.genRepo.Update(ctx, gen)
|
||||
}
|
||||
|
||||
// GetByID 获取记录详情(含权限校验)。
|
||||
func (s *SoraGenerationService) GetByID(ctx context.Context, id, userID int64) (*SoraGeneration, error) {
|
||||
gen, err := s.genRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if gen.UserID != userID {
|
||||
return nil, fmt.Errorf("无权访问此生成记录")
|
||||
}
|
||||
return gen, nil
|
||||
}
|
||||
|
||||
// List 查询生成记录列表(分页 + 筛选)。
|
||||
func (s *SoraGenerationService) List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
|
||||
if params.Page <= 0 {
|
||||
params.Page = 1
|
||||
}
|
||||
if params.PageSize <= 0 {
|
||||
params.PageSize = 20
|
||||
}
|
||||
if params.PageSize > 100 {
|
||||
params.PageSize = 100
|
||||
}
|
||||
return s.genRepo.List(ctx, params)
|
||||
}
|
||||
|
||||
// Delete 删除记录(联动 S3/本地文件清理 + 配额释放)。
|
||||
func (s *SoraGenerationService) Delete(ctx context.Context, id, userID int64) error {
|
||||
gen, err := s.genRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if gen.UserID != userID {
|
||||
return fmt.Errorf("无权删除此生成记录")
|
||||
}
|
||||
|
||||
// 清理 S3 文件
|
||||
if gen.StorageType == SoraStorageTypeS3 && len(gen.S3ObjectKeys) > 0 && s.s3Storage != nil {
|
||||
if err := s.s3Storage.DeleteObjects(ctx, gen.S3ObjectKeys); err != nil {
|
||||
logger.LegacyPrintf("service.sora_gen", "[SoraGen] S3 清理失败 id=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 释放配额(S3/本地均释放)
|
||||
if gen.FileSizeBytes > 0 && (gen.StorageType == SoraStorageTypeS3 || gen.StorageType == SoraStorageTypeLocal) && s.quotaService != nil {
|
||||
if err := s.quotaService.ReleaseUsage(ctx, userID, gen.FileSizeBytes); err != nil {
|
||||
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 配额释放失败 id=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
return s.genRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// CountActiveByUser 统计用户进行中的任务数(用于并发限制)。
|
||||
func (s *SoraGenerationService) CountActiveByUser(ctx context.Context, userID int64) (int64, error) {
|
||||
return s.genRepo.CountByUserAndStatus(ctx, userID, []string{SoraGenStatusPending, SoraGenStatusGenerating})
|
||||
}
|
||||
|
||||
// ResolveMediaURLs 为 S3 记录动态生成预签名 URL。
|
||||
func (s *SoraGenerationService) ResolveMediaURLs(ctx context.Context, gen *SoraGeneration) error {
|
||||
if gen == nil || gen.StorageType != SoraStorageTypeS3 || s.s3Storage == nil {
|
||||
return nil
|
||||
}
|
||||
if len(gen.S3ObjectKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
urls := make([]string, len(gen.S3ObjectKeys))
|
||||
var wg sync.WaitGroup
|
||||
var firstErr error
|
||||
var errMu sync.Mutex
|
||||
|
||||
for idx, key := range gen.S3ObjectKeys {
|
||||
wg.Add(1)
|
||||
go func(i int, objectKey string) {
|
||||
defer wg.Done()
|
||||
url, err := s.s3Storage.GetAccessURL(ctx, objectKey)
|
||||
if err != nil {
|
||||
errMu.Lock()
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
errMu.Unlock()
|
||||
return
|
||||
}
|
||||
urls[i] = url
|
||||
}(idx, key)
|
||||
}
|
||||
wg.Wait()
|
||||
if firstErr != nil {
|
||||
return firstErr
|
||||
}
|
||||
|
||||
gen.MediaURL = urls[0]
|
||||
gen.MediaURLs = urls
|
||||
|
||||
return nil
|
||||
}
|
||||
875
backend/internal/service/sora_generation_service_test.go
Normal file
875
backend/internal/service/sora_generation_service_test.go
Normal file
@@ -0,0 +1,875 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ==================== Stub: SoraGenerationRepository ====================
|
||||
|
||||
var _ SoraGenerationRepository = (*stubGenRepo)(nil)
|
||||
|
||||
type stubGenRepo struct {
|
||||
gens map[int64]*SoraGeneration
|
||||
nextID int64
|
||||
createErr error
|
||||
getErr error
|
||||
updateErr error
|
||||
deleteErr error
|
||||
listErr error
|
||||
countErr error
|
||||
countValue int64
|
||||
}
|
||||
|
||||
func newStubGenRepo() *stubGenRepo {
|
||||
return &stubGenRepo{gens: make(map[int64]*SoraGeneration), nextID: 1}
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) Create(_ context.Context, gen *SoraGeneration) error {
|
||||
if r.createErr != nil {
|
||||
return r.createErr
|
||||
}
|
||||
gen.ID = r.nextID
|
||||
gen.CreatedAt = time.Now()
|
||||
r.nextID++
|
||||
r.gens[gen.ID] = gen
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) GetByID(_ context.Context, id int64) (*SoraGeneration, error) {
|
||||
if r.getErr != nil {
|
||||
return nil, r.getErr
|
||||
}
|
||||
if gen, ok := r.gens[id]; ok {
|
||||
return gen, nil
|
||||
}
|
||||
return nil, fmt.Errorf("not found")
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) Update(_ context.Context, gen *SoraGeneration) error {
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
r.gens[gen.ID] = gen
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) Delete(_ context.Context, id int64) error {
|
||||
if r.deleteErr != nil {
|
||||
return r.deleteErr
|
||||
}
|
||||
delete(r.gens, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) List(_ context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
|
||||
if r.listErr != nil {
|
||||
return nil, 0, r.listErr
|
||||
}
|
||||
var result []*SoraGeneration
|
||||
for _, gen := range r.gens {
|
||||
if gen.UserID != params.UserID {
|
||||
continue
|
||||
}
|
||||
if params.Status != "" && gen.Status != params.Status {
|
||||
continue
|
||||
}
|
||||
if params.StorageType != "" && gen.StorageType != params.StorageType {
|
||||
continue
|
||||
}
|
||||
if params.MediaType != "" && gen.MediaType != params.MediaType {
|
||||
continue
|
||||
}
|
||||
result = append(result, gen)
|
||||
}
|
||||
return result, int64(len(result)), nil
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) CountByUserAndStatus(_ context.Context, userID int64, statuses []string) (int64, error) {
|
||||
if r.countErr != nil {
|
||||
return 0, r.countErr
|
||||
}
|
||||
if r.countValue > 0 {
|
||||
return r.countValue, nil
|
||||
}
|
||||
var count int64
|
||||
statusSet := make(map[string]struct{})
|
||||
for _, s := range statuses {
|
||||
statusSet[s] = struct{}{}
|
||||
}
|
||||
for _, gen := range r.gens {
|
||||
if gen.UserID == userID {
|
||||
if _, ok := statusSet[gen.Status]; ok {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
|
||||
|
||||
var _ UserRepository = (*stubUserRepoForQuota)(nil)
|
||||
|
||||
type stubUserRepoForQuota struct {
|
||||
users map[int64]*User
|
||||
updateErr error
|
||||
}
|
||||
|
||||
func newStubUserRepoForQuota() *stubUserRepoForQuota {
|
||||
return &stubUserRepoForQuota{users: make(map[int64]*User)}
|
||||
}
|
||||
|
||||
func (r *stubUserRepoForQuota) GetByID(_ context.Context, id int64) (*User, error) {
|
||||
if u, ok := r.users[id]; ok {
|
||||
return u, nil
|
||||
}
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
func (r *stubUserRepoForQuota) Update(_ context.Context, user *User) error {
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
r.users[user.ID] = user
|
||||
return nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) Create(context.Context, *User) error { return nil }
|
||||
func (r *stubUserRepoForQuota) GetByEmail(context.Context, string) (*User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) GetFirstAdmin(context.Context) (*User, error) { return nil, nil }
|
||||
func (r *stubUserRepoForQuota) Delete(context.Context, int64) error { return nil }
|
||||
func (r *stubUserRepoForQuota) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) UpdateBalance(context.Context, int64, float64) error { return nil }
|
||||
func (r *stubUserRepoForQuota) DeductBalance(context.Context, int64, float64) error { return nil }
|
||||
func (r *stubUserRepoForQuota) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||||
func (r *stubUserRepoForQuota) ExistsByEmail(context.Context, string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
|
||||
func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
|
||||
func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
|
||||
|
||||
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
|
||||
|
||||
// newS3StorageWithCDN 创建一个预缓存了 CDN 配置的 SoraS3Storage,
|
||||
// 避免实际初始化 AWS 客户端。用于测试 GetAccessURL 的 CDN 路径。
|
||||
func newS3StorageWithCDN(cdnURL string) *SoraS3Storage {
|
||||
storage := &SoraS3Storage{}
|
||||
storage.cfg = &SoraS3Settings{
|
||||
Enabled: true,
|
||||
Bucket: "test-bucket",
|
||||
CDNURL: cdnURL,
|
||||
}
|
||||
// 需要 non-nil client 使 getClient 命中缓存
|
||||
storage.client = s3.New(s3.Options{})
|
||||
return storage
|
||||
}
|
||||
|
||||
// newS3StorageFailingDelete 创建一个 settingService=nil 的 SoraS3Storage,
|
||||
// 使 DeleteObjects 返回错误(无法获取配置)。用于测试 Delete 方法 S3 清理失败但仍继续的场景。
|
||||
func newS3StorageFailingDelete() *SoraS3Storage {
|
||||
return &SoraS3Storage{} // settingService 为 nil → getConfig 返回 error
|
||||
}
|
||||
|
||||
// ==================== CreatePending ====================
|
||||
|
||||
func TestCreatePending_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "一只猫跳舞", "video")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), gen.ID)
|
||||
require.Equal(t, int64(1), gen.UserID)
|
||||
require.Equal(t, "sora2-landscape-10s", gen.Model)
|
||||
require.Equal(t, "一只猫跳舞", gen.Prompt)
|
||||
require.Equal(t, "video", gen.MediaType)
|
||||
require.Equal(t, SoraGenStatusPending, gen.Status)
|
||||
require.Equal(t, SoraStorageTypeNone, gen.StorageType)
|
||||
require.Nil(t, gen.APIKeyID)
|
||||
}
|
||||
|
||||
func TestCreatePending_WithAPIKeyID(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
apiKeyID := int64(42)
|
||||
gen, err := svc.CreatePending(context.Background(), 1, &apiKeyID, "gpt-image", "画一朵花", "image")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gen.APIKeyID)
|
||||
require.Equal(t, int64(42), *gen.APIKeyID)
|
||||
}
|
||||
|
||||
func TestCreatePending_RepoError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.createErr = fmt.Errorf("db write error")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, gen)
|
||||
require.Contains(t, err.Error(), "create generation")
|
||||
}
|
||||
|
||||
// ==================== MarkGenerating ====================
|
||||
|
||||
func TestMarkGenerating_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkGenerating(context.Background(), 1, "upstream-task-123")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusGenerating, repo.gens[1].Status)
|
||||
require.Equal(t, "upstream-task-123", repo.gens[1].UpstreamTaskID)
|
||||
}
|
||||
|
||||
func TestMarkGenerating_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkGenerating(context.Background(), 999, "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkGenerating_UpdateError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
||||
repo.updateErr = fmt.Errorf("update failed")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkGenerating(context.Background(), 1, "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== MarkCompleted ====================
|
||||
|
||||
func TestMarkCompleted_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCompleted(context.Background(), 1,
|
||||
"https://cdn.example.com/video.mp4",
|
||||
[]string{"https://cdn.example.com/video.mp4"},
|
||||
SoraStorageTypeS3,
|
||||
[]string{"sora/1/2024/01/01/uuid.mp4"},
|
||||
1048576,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
gen := repo.gens[1]
|
||||
require.Equal(t, SoraGenStatusCompleted, gen.Status)
|
||||
require.Equal(t, "https://cdn.example.com/video.mp4", gen.MediaURL)
|
||||
require.Equal(t, []string{"https://cdn.example.com/video.mp4"}, gen.MediaURLs)
|
||||
require.Equal(t, SoraStorageTypeS3, gen.StorageType)
|
||||
require.Equal(t, []string{"sora/1/2024/01/01/uuid.mp4"}, gen.S3ObjectKeys)
|
||||
require.Equal(t, int64(1048576), gen.FileSizeBytes)
|
||||
require.NotNil(t, gen.CompletedAt)
|
||||
}
|
||||
|
||||
func TestMarkCompleted_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCompleted(context.Background(), 999, "", nil, "", nil, 0)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkCompleted_UpdateError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
repo.updateErr = fmt.Errorf("update failed")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCompleted(context.Background(), 1, "url", nil, SoraStorageTypeUpstream, nil, 0)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== MarkFailed ====================
|
||||
|
||||
func TestMarkFailed_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkFailed(context.Background(), 1, "上游返回 500 错误")
|
||||
require.NoError(t, err)
|
||||
gen := repo.gens[1]
|
||||
require.Equal(t, SoraGenStatusFailed, gen.Status)
|
||||
require.Equal(t, "上游返回 500 错误", gen.ErrorMessage)
|
||||
require.NotNil(t, gen.CompletedAt)
|
||||
}
|
||||
|
||||
func TestMarkFailed_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkFailed(context.Background(), 999, "error")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkFailed_UpdateError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
repo.updateErr = fmt.Errorf("update failed")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkFailed(context.Background(), 1, "err")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== MarkCancelled ====================
|
||||
|
||||
func TestMarkCancelled_Pending(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
|
||||
require.NotNil(t, repo.gens[1].CompletedAt)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_Generating(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_Completed(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrSoraGenerationNotActive)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_Failed(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusFailed}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_AlreadyCancelled(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCancelled}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 999)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_UpdateError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
||||
repo.updateErr = fmt.Errorf("update failed")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== GetByID ====================
|
||||
|
||||
func TestGetByID_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, Model: "sora2-landscape-10s"}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, err := svc.GetByID(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), gen.ID)
|
||||
require.Equal(t, "sora2-landscape-10s", gen.Model)
|
||||
}
|
||||
|
||||
func TestGetByID_WrongUser(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, err := svc.GetByID(context.Background(), 1, 1)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, gen)
|
||||
require.Contains(t, err.Error(), "无权访问")
|
||||
}
|
||||
|
||||
func TestGetByID_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, err := svc.GetByID(context.Background(), 999, 1)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, gen)
|
||||
}
|
||||
|
||||
// ==================== List ====================
|
||||
|
||||
func TestList_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, MediaType: "video"}
|
||||
repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusPending, MediaType: "image"}
|
||||
repo.gens[3] = &SoraGeneration{ID: 3, UserID: 2, Status: SoraGenStatusCompleted, MediaType: "video"}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gens, total, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 20})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, gens, 2) // 只有 userID=1 的
|
||||
require.Equal(t, int64(2), total)
|
||||
}
|
||||
|
||||
func TestList_DefaultPagination(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
// page=0, pageSize=0 → 应修正为 page=1, pageSize=20
|
||||
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestList_MaxPageSize(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
// pageSize > 100 → 应限制为 100
|
||||
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 200})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestList_Error(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.listErr = fmt.Errorf("db error")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== Delete ====================
|
||||
|
||||
func TestDelete_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, StorageType: SoraStorageTypeUpstream}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
_, exists := repo.gens[1]
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func TestDelete_WrongUser(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "无权删除")
|
||||
}
|
||||
|
||||
func TestDelete_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 999, 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDelete_S3Cleanup_NilS3(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err) // s3Storage 为 nil,跳过清理
|
||||
}
|
||||
|
||||
func TestDelete_QuotaRelease_NilQuota(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, FileSizeBytes: 1024}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err) // quotaService 为 nil,跳过释放
|
||||
}
|
||||
|
||||
func TestDelete_NonS3NoCleanup(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeLocal, FileSizeBytes: 1024}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDelete_DeleteRepoError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeUpstream}
|
||||
repo.deleteErr = fmt.Errorf("delete failed")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== CountActiveByUser ====================
|
||||
|
||||
func TestCountActiveByUser_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
||||
repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
repo.gens[3] = &SoraGeneration{ID: 3, UserID: 1, Status: SoraGenStatusCompleted} // 不算
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
count, err := svc.CountActiveByUser(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), count)
|
||||
}
|
||||
|
||||
func TestCountActiveByUser_NoActive(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
count, err := svc.CountActiveByUser(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestCountActiveByUser_Error(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.countErr = fmt.Errorf("db error")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
_, err := svc.CountActiveByUser(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== ResolveMediaURLs ====================
|
||||
|
||||
func TestResolveMediaURLs_NilGen(t *testing.T) {
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), nil))
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_NonS3(t *testing.T) {
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
||||
gen := &SoraGeneration{StorageType: SoraStorageTypeUpstream, MediaURL: "https://original.com/v.mp4"}
|
||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
|
||||
require.Equal(t, "https://original.com/v.mp4", gen.MediaURL) // 不变
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_S3NilStorage(t *testing.T) {
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
||||
gen := &SoraGeneration{StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
|
||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_Local(t *testing.T) {
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
||||
gen := &SoraGeneration{StorageType: SoraStorageTypeLocal, MediaURL: "/video/2024/01/01/file.mp4"}
|
||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
|
||||
require.Equal(t, "/video/2024/01/01/file.mp4", gen.MediaURL) // 不变
|
||||
}
|
||||
|
||||
// ==================== 状态流转完整测试 ====================
|
||||
|
||||
func TestStatusTransition_PendingToCompletedFlow(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
// 1. 创建 pending
|
||||
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusPending, gen.Status)
|
||||
|
||||
// 2. 标记 generating
|
||||
err = svc.MarkGenerating(context.Background(), gen.ID, "task-123")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusGenerating, repo.gens[gen.ID].Status)
|
||||
|
||||
// 3. 标记 completed
|
||||
err = svc.MarkCompleted(context.Background(), gen.ID, "https://s3.com/video.mp4", nil, SoraStorageTypeS3, []string{"key"}, 1024)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusCompleted, repo.gens[gen.ID].Status)
|
||||
}
|
||||
|
||||
func TestStatusTransition_PendingToFailedFlow(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
_ = svc.MarkGenerating(context.Background(), gen.ID, "")
|
||||
|
||||
err := svc.MarkFailed(context.Background(), gen.ID, "上游超时")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusFailed, repo.gens[gen.ID].Status)
|
||||
require.Equal(t, "上游超时", repo.gens[gen.ID].ErrorMessage)
|
||||
}
|
||||
|
||||
func TestStatusTransition_PendingToCancelledFlow(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
err := svc.MarkCancelled(context.Background(), gen.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
|
||||
}
|
||||
|
||||
func TestStatusTransition_GeneratingToCancelledFlow(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
_ = svc.MarkGenerating(context.Background(), gen.ID, "")
|
||||
err := svc.MarkCancelled(context.Background(), gen.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
|
||||
}
|
||||
|
||||
// ==================== 权限隔离测试 ====================
|
||||
|
||||
func TestUserIsolation_CannotAccessOthersRecord(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
|
||||
// 用户 2 尝试访问用户 1 的记录
|
||||
_, err := svc.GetByID(context.Background(), gen.ID, 2)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "无权访问")
|
||||
}
|
||||
|
||||
func TestUserIsolation_CannotDeleteOthersRecord(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
|
||||
err := svc.Delete(context.Background(), gen.ID, 2)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "无权删除")
|
||||
}
|
||||
|
||||
// ==================== Delete: S3 清理 + 配额释放路径 ====================
|
||||
|
||||
func TestDelete_S3Cleanup_WithS3Storage(t *testing.T) {
|
||||
// S3 存储存在但 deleteObjects 会失败(settingService=nil),
|
||||
// 验证 Delete 仍然成功(S3 错误只是记录日志)
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{
|
||||
ID: 1, UserID: 1,
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{"sora/1/2024/01/01/abc.mp4"},
|
||||
}
|
||||
s3Storage := newS3StorageFailingDelete()
|
||||
svc := NewSoraGenerationService(repo, s3Storage, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err) // S3 清理失败不影响删除
|
||||
_, exists := repo.gens[1]
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func TestDelete_QuotaRelease_WithQuotaService(t *testing.T) {
|
||||
// 有配额服务时,删除 S3 类型记录会释放配额
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{
|
||||
ID: 1, UserID: 1,
|
||||
StorageType: SoraStorageTypeS3,
|
||||
FileSizeBytes: 1048576, // 1MB
|
||||
}
|
||||
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2097152} // 2MB
|
||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
svc := NewSoraGenerationService(repo, nil, quotaService)
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
// 配额应被释放: 2MB - 1MB = 1MB
|
||||
require.Equal(t, int64(1048576), userRepo.users[1].SoraStorageUsedBytes)
|
||||
}
|
||||
|
||||
func TestDelete_S3Cleanup_And_QuotaRelease(t *testing.T) {
|
||||
// S3 清理 + 配额释放同时触发
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{
|
||||
ID: 1, UserID: 1,
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{"key1"},
|
||||
FileSizeBytes: 512,
|
||||
}
|
||||
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
|
||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
||||
s3Storage := newS3StorageFailingDelete()
|
||||
|
||||
svc := NewSoraGenerationService(repo, s3Storage, quotaService)
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
_, exists := repo.gens[1]
|
||||
require.False(t, exists)
|
||||
require.Equal(t, int64(512), userRepo.users[1].SoraStorageUsedBytes)
|
||||
}
|
||||
|
||||
func TestDelete_QuotaRelease_LocalStorage(t *testing.T) {
|
||||
// 本地存储同样需要释放配额
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{
|
||||
ID: 1, UserID: 1,
|
||||
StorageType: SoraStorageTypeLocal,
|
||||
FileSizeBytes: 1024,
|
||||
}
|
||||
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2048}
|
||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
svc := NewSoraGenerationService(repo, nil, quotaService)
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
|
||||
}
|
||||
|
||||
func TestDelete_QuotaRelease_ZeroFileSize(t *testing.T) {
|
||||
// FileSizeBytes=0 跳过配额释放
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{
|
||||
ID: 1, UserID: 1,
|
||||
StorageType: SoraStorageTypeS3,
|
||||
FileSizeBytes: 0,
|
||||
}
|
||||
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
|
||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
svc := NewSoraGenerationService(repo, nil, quotaService)
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
|
||||
}
|
||||
|
||||
// ==================== ResolveMediaURLs: S3 + CDN 路径 ====================
|
||||
|
||||
func TestResolveMediaURLs_S3_CDN_SingleKey(t *testing.T) {
|
||||
s3Storage := newS3StorageWithCDN("https://cdn.example.com")
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
||||
|
||||
gen := &SoraGeneration{
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
|
||||
MediaURL: "original",
|
||||
}
|
||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", gen.MediaURL)
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_S3_CDN_MultipleKeys(t *testing.T) {
|
||||
s3Storage := newS3StorageWithCDN("https://cdn.example.com/")
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
||||
|
||||
gen := &SoraGeneration{
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{
|
||||
"sora/1/2024/01/01/img1.png",
|
||||
"sora/1/2024/01/01/img2.png",
|
||||
"sora/1/2024/01/01/img3.png",
|
||||
},
|
||||
MediaURL: "original",
|
||||
}
|
||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
||||
require.NoError(t, err)
|
||||
// 主 URL 更新为第一个 key 的 CDN URL
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURL)
|
||||
// 多图 URLs 全部更新
|
||||
require.Len(t, gen.MediaURLs, 3)
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURLs[0])
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img2.png", gen.MediaURLs[1])
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img3.png", gen.MediaURLs[2])
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_S3_EmptyKeys(t *testing.T) {
|
||||
s3Storage := newS3StorageWithCDN("https://cdn.example.com")
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
||||
|
||||
gen := &SoraGeneration{
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{},
|
||||
MediaURL: "original",
|
||||
}
|
||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "original", gen.MediaURL) // 不变
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_S3_GetAccessURL_Error(t *testing.T) {
|
||||
// 使用无 settingService 的 S3 Storage,getClient 会失败
|
||||
s3Storage := newS3StorageFailingDelete() // 同样 GetAccessURL 也会失败
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
||||
|
||||
gen := &SoraGeneration{
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
|
||||
MediaURL: "original",
|
||||
}
|
||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
||||
require.Error(t, err) // GetAccessURL 失败应传播错误
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_S3_MultiKey_ErrorOnSecond(t *testing.T) {
|
||||
// 只有一个 key 时走主 URL 路径成功,但多 key 路径的错误也需覆盖
|
||||
s3Storage := newS3StorageFailingDelete()
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
||||
|
||||
gen := &SoraGeneration{
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{
|
||||
"sora/1/2024/01/01/img1.png",
|
||||
"sora/1/2024/01/01/img2.png",
|
||||
},
|
||||
MediaURL: "original",
|
||||
}
|
||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
||||
require.Error(t, err) // 第一个 key 的 GetAccessURL 就会失败
|
||||
}
|
||||
@@ -157,6 +157,64 @@ func (s *SoraMediaStorage) StoreFromURLs(ctx context.Context, mediaType string,
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// TotalSizeByRelativePaths 统计本地存储路径总大小(仅统计 /image 和 /video 路径)。
|
||||
func (s *SoraMediaStorage) TotalSizeByRelativePaths(paths []string) (int64, error) {
|
||||
if s == nil || len(paths) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
var total int64
|
||||
for _, p := range paths {
|
||||
localPath, err := s.resolveLocalPath(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
info, err := os.Stat(localPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
if info.Mode().IsRegular() {
|
||||
total += info.Size()
|
||||
}
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// DeleteByRelativePaths 删除本地媒体路径(仅删除 /image 和 /video 路径)。
|
||||
func (s *SoraMediaStorage) DeleteByRelativePaths(paths []string) error {
|
||||
if s == nil || len(paths) == 0 {
|
||||
return nil
|
||||
}
|
||||
var lastErr error
|
||||
for _, p := range paths {
|
||||
localPath, err := s.resolveLocalPath(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if err := os.Remove(localPath); err != nil && !os.IsNotExist(err) {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (s *SoraMediaStorage) resolveLocalPath(relativePath string) (string, error) {
|
||||
if s == nil || strings.TrimSpace(relativePath) == "" {
|
||||
return "", errors.New("empty path")
|
||||
}
|
||||
cleaned := path.Clean(relativePath)
|
||||
if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
|
||||
return "", errors.New("not a local media path")
|
||||
}
|
||||
if strings.TrimSpace(s.root) == "" {
|
||||
return "", errors.New("storage root not configured")
|
||||
}
|
||||
relative := strings.TrimPrefix(cleaned, "/")
|
||||
return filepath.Join(s.root, filepath.FromSlash(relative)), nil
|
||||
}
|
||||
|
||||
func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawURL string) (string, error) {
|
||||
if strings.TrimSpace(rawURL) == "" {
|
||||
return "", errors.New("empty url")
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -247,6 +250,218 @@ func GetSoraModelConfig(model string) (SoraModelConfig, bool) {
|
||||
return cfg, ok
|
||||
}
|
||||
|
||||
// SoraModelFamily 模型家族(前端 Sora 客户端使用)
|
||||
type SoraModelFamily struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Orientations []string `json:"orientations"`
|
||||
Durations []int `json:"durations,omitempty"`
|
||||
}
|
||||
|
||||
var (
|
||||
videoSuffixRe = regexp.MustCompile(`-(landscape|portrait)-(\d+)s$`)
|
||||
imageSuffixRe = regexp.MustCompile(`-(landscape|portrait)$`)
|
||||
|
||||
soraFamilyNames = map[string]string{
|
||||
"sora2": "Sora 2",
|
||||
"sora2pro": "Sora 2 Pro",
|
||||
"sora2pro-hd": "Sora 2 Pro HD",
|
||||
"gpt-image": "GPT Image",
|
||||
}
|
||||
)
|
||||
|
||||
// BuildSoraModelFamilies 从 soraModelConfigs 自动聚合模型家族及其支持的方向和时长
|
||||
func BuildSoraModelFamilies() []SoraModelFamily {
|
||||
type familyData struct {
|
||||
modelType string
|
||||
orientations map[string]bool
|
||||
durations map[int]bool
|
||||
}
|
||||
families := make(map[string]*familyData)
|
||||
|
||||
for id, cfg := range soraModelConfigs {
|
||||
if cfg.Type == "prompt_enhance" {
|
||||
continue
|
||||
}
|
||||
var famID, orientation string
|
||||
var duration int
|
||||
|
||||
switch cfg.Type {
|
||||
case "video":
|
||||
if m := videoSuffixRe.FindStringSubmatch(id); m != nil {
|
||||
famID = id[:len(id)-len(m[0])]
|
||||
orientation = m[1]
|
||||
duration, _ = strconv.Atoi(m[2])
|
||||
}
|
||||
case "image":
|
||||
if m := imageSuffixRe.FindStringSubmatch(id); m != nil {
|
||||
famID = id[:len(id)-len(m[0])]
|
||||
orientation = m[1]
|
||||
} else {
|
||||
famID = id
|
||||
orientation = "square"
|
||||
}
|
||||
}
|
||||
if famID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fd, ok := families[famID]
|
||||
if !ok {
|
||||
fd = &familyData{
|
||||
modelType: cfg.Type,
|
||||
orientations: make(map[string]bool),
|
||||
durations: make(map[int]bool),
|
||||
}
|
||||
families[famID] = fd
|
||||
}
|
||||
if orientation != "" {
|
||||
fd.orientations[orientation] = true
|
||||
}
|
||||
if duration > 0 {
|
||||
fd.durations[duration] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 排序:视频在前、图像在后,同类按名称排序
|
||||
famIDs := make([]string, 0, len(families))
|
||||
for id := range families {
|
||||
famIDs = append(famIDs, id)
|
||||
}
|
||||
sort.Slice(famIDs, func(i, j int) bool {
|
||||
fi, fj := families[famIDs[i]], families[famIDs[j]]
|
||||
if fi.modelType != fj.modelType {
|
||||
return fi.modelType == "video"
|
||||
}
|
||||
return famIDs[i] < famIDs[j]
|
||||
})
|
||||
|
||||
result := make([]SoraModelFamily, 0, len(famIDs))
|
||||
for _, famID := range famIDs {
|
||||
fd := families[famID]
|
||||
fam := SoraModelFamily{
|
||||
ID: famID,
|
||||
Name: soraFamilyNames[famID],
|
||||
Type: fd.modelType,
|
||||
}
|
||||
if fam.Name == "" {
|
||||
fam.Name = famID
|
||||
}
|
||||
for o := range fd.orientations {
|
||||
fam.Orientations = append(fam.Orientations, o)
|
||||
}
|
||||
sort.Strings(fam.Orientations)
|
||||
for d := range fd.durations {
|
||||
fam.Durations = append(fam.Durations, d)
|
||||
}
|
||||
sort.Ints(fam.Durations)
|
||||
result = append(result, fam)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// BuildSoraModelFamiliesFromIDs 从任意模型 ID 列表聚合模型家族(用于解析上游返回的模型列表)。
|
||||
// 通过命名约定自动识别视频/图像模型并分组。
|
||||
func BuildSoraModelFamiliesFromIDs(modelIDs []string) []SoraModelFamily {
|
||||
type familyData struct {
|
||||
modelType string
|
||||
orientations map[string]bool
|
||||
durations map[int]bool
|
||||
}
|
||||
families := make(map[string]*familyData)
|
||||
|
||||
for _, id := range modelIDs {
|
||||
id = strings.ToLower(strings.TrimSpace(id))
|
||||
if id == "" || strings.HasPrefix(id, "prompt-enhance") {
|
||||
continue
|
||||
}
|
||||
|
||||
var famID, orientation, modelType string
|
||||
var duration int
|
||||
|
||||
if m := videoSuffixRe.FindStringSubmatch(id); m != nil {
|
||||
// 视频模型: {family}-{orientation}-{duration}s
|
||||
famID = id[:len(id)-len(m[0])]
|
||||
orientation = m[1]
|
||||
duration, _ = strconv.Atoi(m[2])
|
||||
modelType = "video"
|
||||
} else if m := imageSuffixRe.FindStringSubmatch(id); m != nil {
|
||||
// 图像模型(带方向): {family}-{orientation}
|
||||
famID = id[:len(id)-len(m[0])]
|
||||
orientation = m[1]
|
||||
modelType = "image"
|
||||
} else if cfg, ok := soraModelConfigs[id]; ok && cfg.Type == "image" {
|
||||
// 已知的无后缀图像模型(如 gpt-image)
|
||||
famID = id
|
||||
orientation = "square"
|
||||
modelType = "image"
|
||||
} else if strings.Contains(id, "image") {
|
||||
// 未知但名称包含 image 的模型,推断为图像模型
|
||||
famID = id
|
||||
orientation = "square"
|
||||
modelType = "image"
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
|
||||
if famID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fd, ok := families[famID]
|
||||
if !ok {
|
||||
fd = &familyData{
|
||||
modelType: modelType,
|
||||
orientations: make(map[string]bool),
|
||||
durations: make(map[int]bool),
|
||||
}
|
||||
families[famID] = fd
|
||||
}
|
||||
if orientation != "" {
|
||||
fd.orientations[orientation] = true
|
||||
}
|
||||
if duration > 0 {
|
||||
fd.durations[duration] = true
|
||||
}
|
||||
}
|
||||
|
||||
famIDs := make([]string, 0, len(families))
|
||||
for id := range families {
|
||||
famIDs = append(famIDs, id)
|
||||
}
|
||||
sort.Slice(famIDs, func(i, j int) bool {
|
||||
fi, fj := families[famIDs[i]], families[famIDs[j]]
|
||||
if fi.modelType != fj.modelType {
|
||||
return fi.modelType == "video"
|
||||
}
|
||||
return famIDs[i] < famIDs[j]
|
||||
})
|
||||
|
||||
result := make([]SoraModelFamily, 0, len(famIDs))
|
||||
for _, famID := range famIDs {
|
||||
fd := families[famID]
|
||||
fam := SoraModelFamily{
|
||||
ID: famID,
|
||||
Name: soraFamilyNames[famID],
|
||||
Type: fd.modelType,
|
||||
}
|
||||
if fam.Name == "" {
|
||||
fam.Name = famID
|
||||
}
|
||||
for o := range fd.orientations {
|
||||
fam.Orientations = append(fam.Orientations, o)
|
||||
}
|
||||
sort.Strings(fam.Orientations)
|
||||
for d := range fd.durations {
|
||||
fam.Durations = append(fam.Durations, d)
|
||||
}
|
||||
sort.Ints(fam.Durations)
|
||||
result = append(result, fam)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// DefaultSoraModels returns the default Sora model list.
|
||||
func DefaultSoraModels(cfg *config.Config) []openai.Model {
|
||||
models := make([]openai.Model, 0, len(soraModelIDs))
|
||||
|
||||
257
backend/internal/service/sora_quota_service.go
Normal file
257
backend/internal/service/sora_quota_service.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// SoraQuotaService 管理 Sora 用户存储配额。
|
||||
// 配额优先级:用户级 → 分组级 → 系统默认值。
|
||||
type SoraQuotaService struct {
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
settingService *SettingService
|
||||
}
|
||||
|
||||
// NewSoraQuotaService 创建配额服务实例。
|
||||
func NewSoraQuotaService(
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
settingService *SettingService,
|
||||
) *SoraQuotaService {
|
||||
return &SoraQuotaService{
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
settingService: settingService,
|
||||
}
|
||||
}
|
||||
|
||||
// QuotaInfo 返回给客户端的配额信息。
|
||||
type QuotaInfo struct {
|
||||
QuotaBytes int64 `json:"quota_bytes"` // 总配额(0 表示无限制)
|
||||
UsedBytes int64 `json:"used_bytes"` // 已使用
|
||||
AvailableBytes int64 `json:"available_bytes"` // 剩余可用(无限制时为 0)
|
||||
QuotaSource string `json:"quota_source"` // 配额来源:user / group / system / unlimited
|
||||
Source string `json:"source,omitempty"` // 兼容旧字段
|
||||
}
|
||||
|
||||
// ErrSoraStorageQuotaExceeded 表示配额不足。
|
||||
var ErrSoraStorageQuotaExceeded = errors.New("sora storage quota exceeded")
|
||||
|
||||
// QuotaExceededError 包含配额不足的上下文信息。
|
||||
type QuotaExceededError struct {
|
||||
QuotaBytes int64
|
||||
UsedBytes int64
|
||||
}
|
||||
|
||||
func (e *QuotaExceededError) Error() string {
|
||||
if e == nil {
|
||||
return "存储配额不足"
|
||||
}
|
||||
return fmt.Sprintf("存储配额不足(已用 %d / 配额 %d 字节)", e.UsedBytes, e.QuotaBytes)
|
||||
}
|
||||
|
||||
type soraQuotaAtomicUserRepository interface {
|
||||
AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error)
|
||||
ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error)
|
||||
}
|
||||
|
||||
// GetQuota 获取用户的存储配额信息。
|
||||
// 优先级:用户级 > 用户所属分组级 > 系统默认值。
|
||||
func (s *SoraQuotaService) GetQuota(ctx context.Context, userID int64) (*QuotaInfo, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
info := &QuotaInfo{
|
||||
UsedBytes: user.SoraStorageUsedBytes,
|
||||
}
|
||||
|
||||
// 1. 用户级配额
|
||||
if user.SoraStorageQuotaBytes > 0 {
|
||||
info.QuotaBytes = user.SoraStorageQuotaBytes
|
||||
info.QuotaSource = "user"
|
||||
info.Source = info.QuotaSource
|
||||
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// 2. 分组级配额(取用户可用分组中最大的配额)
|
||||
if len(user.AllowedGroups) > 0 {
|
||||
var maxGroupQuota int64
|
||||
for _, gid := range user.AllowedGroups {
|
||||
group, err := s.groupRepo.GetByID(ctx, gid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if group.SoraStorageQuotaBytes > maxGroupQuota {
|
||||
maxGroupQuota = group.SoraStorageQuotaBytes
|
||||
}
|
||||
}
|
||||
if maxGroupQuota > 0 {
|
||||
info.QuotaBytes = maxGroupQuota
|
||||
info.QuotaSource = "group"
|
||||
info.Source = info.QuotaSource
|
||||
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
|
||||
return info, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 系统默认值
|
||||
defaultQuota := s.getSystemDefaultQuota(ctx)
|
||||
if defaultQuota > 0 {
|
||||
info.QuotaBytes = defaultQuota
|
||||
info.QuotaSource = "system"
|
||||
info.Source = info.QuotaSource
|
||||
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// 无配额限制
|
||||
info.QuotaSource = "unlimited"
|
||||
info.Source = info.QuotaSource
|
||||
info.AvailableBytes = 0
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// CheckQuota 检查用户是否有足够的存储配额。
|
||||
// 返回 nil 表示配额充足或无限制。
|
||||
func (s *SoraQuotaService) CheckQuota(ctx context.Context, userID int64, additionalBytes int64) error {
|
||||
quota, err := s.GetQuota(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 0 表示无限制
|
||||
if quota.QuotaBytes == 0 {
|
||||
return nil
|
||||
}
|
||||
if quota.UsedBytes+additionalBytes > quota.QuotaBytes {
|
||||
return &QuotaExceededError{
|
||||
QuotaBytes: quota.QuotaBytes,
|
||||
UsedBytes: quota.UsedBytes,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddUsage 原子累加用量(上传成功后调用)。
|
||||
func (s *SoraQuotaService) AddUsage(ctx context.Context, userID int64, bytes int64) error {
|
||||
if bytes <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
quota, err := s.GetQuota(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if quota.QuotaBytes > 0 && quota.UsedBytes+bytes > quota.QuotaBytes {
|
||||
return &QuotaExceededError{
|
||||
QuotaBytes: quota.QuotaBytes,
|
||||
UsedBytes: quota.UsedBytes,
|
||||
}
|
||||
}
|
||||
|
||||
if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
|
||||
newUsed, err := repo.AddSoraStorageUsageWithQuota(ctx, userID, bytes, quota.QuotaBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSoraStorageQuotaExceeded) {
|
||||
return &QuotaExceededError{
|
||||
QuotaBytes: quota.QuotaBytes,
|
||||
UsedBytes: quota.UsedBytes,
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("update user quota usage (atomic): %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, newUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get user for quota update: %w", err)
|
||||
}
|
||||
user.SoraStorageUsedBytes += bytes
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return fmt.Errorf("update user quota usage: %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReleaseUsage 释放用量(删除文件后调用)。
|
||||
func (s *SoraQuotaService) ReleaseUsage(ctx context.Context, userID int64, bytes int64) error {
|
||||
if bytes <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
|
||||
newUsed, err := repo.ReleaseSoraStorageUsageAtomic(ctx, userID, bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update user quota release (atomic): %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, newUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get user for quota release: %w", err)
|
||||
}
|
||||
user.SoraStorageUsedBytes -= bytes
|
||||
if user.SoraStorageUsedBytes < 0 {
|
||||
user.SoraStorageUsedBytes = 0
|
||||
}
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return fmt.Errorf("update user quota release: %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
func calcAvailableBytes(quotaBytes, usedBytes int64) int64 {
|
||||
if quotaBytes <= 0 {
|
||||
return 0
|
||||
}
|
||||
if usedBytes >= quotaBytes {
|
||||
return 0
|
||||
}
|
||||
return quotaBytes - usedBytes
|
||||
}
|
||||
|
||||
func (s *SoraQuotaService) getSystemDefaultQuota(ctx context.Context) int64 {
|
||||
if s.settingService == nil {
|
||||
return 0
|
||||
}
|
||||
settings, err := s.settingService.GetSoraS3Settings(ctx)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return settings.DefaultStorageQuotaBytes
|
||||
}
|
||||
|
||||
// GetQuotaFromSettings 从系统设置获取默认配额(供外部使用)。
|
||||
func (s *SoraQuotaService) GetQuotaFromSettings(ctx context.Context) int64 {
|
||||
return s.getSystemDefaultQuota(ctx)
|
||||
}
|
||||
|
||||
// SetUserQuota 设置用户级配额(管理员操作)。
|
||||
func SetUserSoraQuota(ctx context.Context, userRepo UserRepository, userID int64, quotaBytes int64) error {
|
||||
user, err := userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user.SoraStorageQuotaBytes = quotaBytes
|
||||
return userRepo.Update(ctx, user)
|
||||
}
|
||||
|
||||
// ParseQuotaBytes 解析配额字符串为字节数。
|
||||
func ParseQuotaBytes(s string) int64 {
|
||||
v, _ := strconv.ParseInt(s, 10, 64)
|
||||
return v
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user