feat(sync): full code sync from release

This commit is contained in:
yangjianbo
2026-02-28 15:01:20 +08:00
parent bfc7b339f7
commit bb664d9bbf
338 changed files with 54513 additions and 2011 deletions

View File

@@ -3,6 +3,8 @@ package service
import (
"encoding/json"
"hash/fnv"
"reflect"
"sort"
"strconv"
"strings"
@@ -50,6 +52,14 @@ type Account struct {
AccountGroups []AccountGroup
GroupIDs []int64
Groups []*Group
// model_mapping 热路径缓存(非持久化字段)
modelMappingCache map[string]string
modelMappingCacheReady bool
modelMappingCacheCredentialsPtr uintptr
modelMappingCacheRawPtr uintptr
modelMappingCacheRawLen int
modelMappingCacheRawSig uint64
}
type TempUnschedulableRule struct {
@@ -349,6 +359,39 @@ func parseTempUnschedInt(value any) int {
}
func (a *Account) GetModelMapping() map[string]string {
credentialsPtr := mapPtr(a.Credentials)
rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
rawPtr := mapPtr(rawMapping)
rawLen := len(rawMapping)
rawSig := uint64(0)
rawSigReady := false
if a.modelMappingCacheReady &&
a.modelMappingCacheCredentialsPtr == credentialsPtr &&
a.modelMappingCacheRawPtr == rawPtr &&
a.modelMappingCacheRawLen == rawLen {
rawSig = modelMappingSignature(rawMapping)
rawSigReady = true
if a.modelMappingCacheRawSig == rawSig {
return a.modelMappingCache
}
}
mapping := a.resolveModelMapping(rawMapping)
if !rawSigReady {
rawSig = modelMappingSignature(rawMapping)
}
a.modelMappingCache = mapping
a.modelMappingCacheReady = true
a.modelMappingCacheCredentialsPtr = credentialsPtr
a.modelMappingCacheRawPtr = rawPtr
a.modelMappingCacheRawLen = rawLen
a.modelMappingCacheRawSig = rawSig
return mapping
}
func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
if a.Credentials == nil {
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
@@ -356,32 +399,31 @@ func (a *Account) GetModelMapping() map[string]string {
}
return nil
}
raw, ok := a.Credentials["model_mapping"]
if !ok || raw == nil {
if len(rawMapping) == 0 {
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
}
return nil
}
if m, ok := raw.(map[string]any); ok {
result := make(map[string]string)
for k, v := range m {
if s, ok := v.(string); ok {
result[k] = s
}
}
if len(result) > 0 {
if a.Platform == domain.PlatformAntigravity {
ensureAntigravityDefaultPassthroughs(result, []string{
"gemini-3-flash",
"gemini-3.1-pro-high",
"gemini-3.1-pro-low",
})
}
return result
result := make(map[string]string)
for k, v := range rawMapping {
if s, ok := v.(string); ok {
result[k] = s
}
}
if len(result) > 0 {
if a.Platform == domain.PlatformAntigravity {
ensureAntigravityDefaultPassthroughs(result, []string{
"gemini-3-flash",
"gemini-3.1-pro-high",
"gemini-3.1-pro-low",
})
}
return result
}
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
@@ -389,6 +431,37 @@ func (a *Account) GetModelMapping() map[string]string {
return nil
}
func mapPtr(m map[string]any) uintptr {
if m == nil {
return 0
}
return reflect.ValueOf(m).Pointer()
}
func modelMappingSignature(rawMapping map[string]any) uint64 {
if len(rawMapping) == 0 {
return 0
}
keys := make([]string, 0, len(rawMapping))
for k := range rawMapping {
keys = append(keys, k)
}
sort.Strings(keys)
h := fnv.New64a()
for _, k := range keys {
_, _ = h.Write([]byte(k))
_, _ = h.Write([]byte{0})
if v, ok := rawMapping[k].(string); ok {
_, _ = h.Write([]byte(v))
} else {
_, _ = h.Write([]byte{1})
}
_, _ = h.Write([]byte{0xff})
}
return h.Sum64()
}
func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) {
if mapping == nil || model == "" {
return
@@ -742,6 +815,159 @@ func (a *Account) IsOpenAIPassthroughEnabled() bool {
return false
}
// IsOpenAIResponsesWebSocketV2Enabled 返回 OpenAI 账号是否开启 Responses WebSocket v2。
//
// 分类型新字段:
// - OAuth 账号accounts.extra.openai_oauth_responses_websockets_v2_enabled
// - API Key 账号accounts.extra.openai_apikey_responses_websockets_v2_enabled
//
// 兼容字段:
// - accounts.extra.responses_websockets_v2_enabled
// - accounts.extra.openai_ws_enabled历史开关
//
// 优先级:
// 1. 按账号类型读取分类型字段
// 2. 分类型字段缺失时,回退兼容字段
func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
if a.IsOpenAIOAuth() {
if enabled, ok := a.Extra["openai_oauth_responses_websockets_v2_enabled"].(bool); ok {
return enabled
}
}
if a.IsOpenAIApiKey() {
if enabled, ok := a.Extra["openai_apikey_responses_websockets_v2_enabled"].(bool); ok {
return enabled
}
}
if enabled, ok := a.Extra["responses_websockets_v2_enabled"].(bool); ok {
return enabled
}
if enabled, ok := a.Extra["openai_ws_enabled"].(bool); ok {
return enabled
}
return false
}
const (
OpenAIWSIngressModeOff = "off"
OpenAIWSIngressModeShared = "shared"
OpenAIWSIngressModeDedicated = "dedicated"
)
func normalizeOpenAIWSIngressMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case OpenAIWSIngressModeOff:
return OpenAIWSIngressModeOff
case OpenAIWSIngressModeShared:
return OpenAIWSIngressModeShared
case OpenAIWSIngressModeDedicated:
return OpenAIWSIngressModeDedicated
default:
return ""
}
}
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
return normalized
}
return OpenAIWSIngressModeShared
}
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式off/shared/dedicated
//
// 优先级:
// 1. 分类型 mode 新字段string
// 2. 分类型 enabled 旧字段bool
// 3. 兼容 enabled 旧字段bool
// 4. defaultMode非法时回退 shared
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
if a == nil || !a.IsOpenAI() {
return OpenAIWSIngressModeOff
}
if a.Extra == nil {
return resolvedDefault
}
resolveModeString := func(key string) (string, bool) {
raw, ok := a.Extra[key]
if !ok {
return "", false
}
mode, ok := raw.(string)
if !ok {
return "", false
}
normalized := normalizeOpenAIWSIngressMode(mode)
if normalized == "" {
return "", false
}
return normalized, true
}
resolveBoolMode := func(key string) (string, bool) {
raw, ok := a.Extra[key]
if !ok {
return "", false
}
enabled, ok := raw.(bool)
if !ok {
return "", false
}
if enabled {
return OpenAIWSIngressModeShared, true
}
return OpenAIWSIngressModeOff, true
}
if a.IsOpenAIOAuth() {
if mode, ok := resolveModeString("openai_oauth_responses_websockets_v2_mode"); ok {
return mode
}
if mode, ok := resolveBoolMode("openai_oauth_responses_websockets_v2_enabled"); ok {
return mode
}
}
if a.IsOpenAIApiKey() {
if mode, ok := resolveModeString("openai_apikey_responses_websockets_v2_mode"); ok {
return mode
}
if mode, ok := resolveBoolMode("openai_apikey_responses_websockets_v2_enabled"); ok {
return mode
}
}
if mode, ok := resolveBoolMode("responses_websockets_v2_enabled"); ok {
return mode
}
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
return mode
}
return resolvedDefault
}
// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
// 字段accounts.extra.openai_ws_force_http。
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
enabled, ok := a.Extra["openai_ws_force_http"].(bool)
return ok && enabled
}
// IsOpenAIWSAllowStoreRecoveryEnabled 返回账号级 store 恢复开关。
// 字段accounts.extra.openai_ws_allow_store_recovery。
func (a *Account) IsOpenAIWSAllowStoreRecoveryEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
return false
}
enabled, ok := a.Extra["openai_ws_allow_store_recovery"].(bool)
return ok && enabled
}
// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。
func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()

View File

@@ -134,3 +134,161 @@ func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) {
require.False(t, otherPlatform.IsCodexCLIOnlyEnabled())
})
}
func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
t.Run("OAuth使用OAuth专用开关", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
},
}
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("API Key使用API Key专用开关", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("分类型新键优先于兼容键", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": true,
"openai_ws_enabled": true,
},
}
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("分类型键缺失时回退兼容键", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("非OpenAI账号默认关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
}
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
t.Run("default fallback to shared", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{},
}
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
})
t.Run("oauth mode field has highest priority", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
"openai_oauth_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": false,
},
}
require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
})
t.Run("legacy enabled maps to shared", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
})
t.Run("legacy disabled maps to off", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
})
t.Run("non openai always off", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated))
})
}
func TestAccount_OpenAIWSExtraFlags(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_ws_force_http": true,
"openai_ws_allow_store_recovery": true,
},
}
require.True(t, account.IsOpenAIWSForceHTTPEnabled())
require.True(t, account.IsOpenAIWSAllowStoreRecoveryEnabled())
off := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}}
require.False(t, off.IsOpenAIWSForceHTTPEnabled())
require.False(t, off.IsOpenAIWSAllowStoreRecoveryEnabled())
var nilAccount *Account
require.False(t, nilAccount.IsOpenAIWSAllowStoreRecoveryEnabled())
nonOpenAI := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_ws_allow_store_recovery": true,
},
}
require.False(t, nonOpenAI.IsOpenAIWSAllowStoreRecoveryEnabled())
}

View File

@@ -119,6 +119,10 @@ type AccountService struct {
groupRepo GroupRepository
}
type groupExistenceBatchChecker interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAccountService 创建账号服务实例
func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService {
return &AccountService{
@@ -131,11 +135,8 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository)
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
// 验证分组是否存在(如果指定了分组)
if len(req.GroupIDs) > 0 {
for _, groupID := range req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
if err := s.validateGroupIDsExist(ctx, req.GroupIDs); err != nil {
return nil, err
}
}
@@ -256,11 +257,8 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
// 先验证分组是否存在(在任何写操作之前)
if req.GroupIDs != nil {
for _, groupID := range *req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
if err := s.validateGroupIDsExist(ctx, *req.GroupIDs); err != nil {
return nil, err
}
}
@@ -300,6 +298,39 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
return nil
}
func (s *AccountService) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
if s.groupRepo == nil {
return fmt.Errorf("group repository not configured")
}
if batchChecker, ok := s.groupRepo.(groupExistenceBatchChecker); ok {
existsByID, err := batchChecker.ExistsByIDs(ctx, groupIDs)
if err != nil {
return fmt.Errorf("check groups exists: %w", err)
}
for _, groupID := range groupIDs {
if groupID <= 0 {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
if !existsByID[groupID] {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
}
return nil
}
for _, groupID := range groupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
return fmt.Errorf("get group: %w", err)
}
}
return nil
}
// UpdateStatus 更新账号状态
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
account, err := s.accountRepo.GetByID(ctx, id)

View File

@@ -598,9 +598,102 @@ func ceilSeconds(d time.Duration) int {
return sec
}
// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。
// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。
func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error {
ctx := c.Request.Context()
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证")
}
baseURL := account.GetBaseURL()
if baseURL == "" {
return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url")
}
// 验证 base_url 格式
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error()))
}
upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions"
// 设置 SSE 头
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
return s.sendErrorAndEnd(c, msg)
}
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"})
// 构建轻量级 prompt-enhance 请求作为连通性测试
testPayload := map[string]any{
"model": "prompt-enhance-short-10s",
"messages": []map[string]string{{"role": "user", "content": "test"}},
"stream": false,
}
payloadBytes, _ := json.Marshal(testPayload)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "构建测试请求失败")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
// 获取代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
if resp.StatusCode == http.StatusOK {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode))
}
// 其他错误但能连通(如 400 参数错误)也算连通性测试通过
if resp.StatusCode == http.StatusBadRequest {
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d参数校验错误属正常", resp.StatusCode)})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256)))
}
// testSoraAccountConnection 测试 Sora 账号的连接
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token
// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性
// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
// apikey 类型走独立测试流程
if account.Type == AccountTypeAPIKey {
return s.testSoraAPIKeyAccountConnection(c, account)
}
ctx := c.Request.Context()
recorder := &soraProbeRecorder{}

View File

@@ -9,7 +9,9 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"golang.org/x/sync/errgroup"
)
type UsageLogRepository interface {
@@ -33,8 +35,8 @@ type UsageLogRepository interface {
// Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
@@ -62,6 +64,10 @@ type UsageLogRepository interface {
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
}
type accountWindowStatsBatchReader interface {
GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error)
}
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据utilization, resets_at
type apiUsageCache struct {
response *ClaudeUsageResponse
@@ -297,7 +303,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
}
dayStart := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
@@ -319,7 +325,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
}
@@ -440,6 +446,78 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
}, nil
}
// GetTodayStatsBatch 批量获取账号今日统计,优先走批量 SQL失败时回退单账号查询。
func (s *AccountUsageService) GetTodayStatsBatch(ctx context.Context, accountIDs []int64) (map[int64]*WindowStats, error) {
uniqueIDs := make([]int64, 0, len(accountIDs))
seen := make(map[int64]struct{}, len(accountIDs))
for _, accountID := range accountIDs {
if accountID <= 0 {
continue
}
if _, exists := seen[accountID]; exists {
continue
}
seen[accountID] = struct{}{}
uniqueIDs = append(uniqueIDs, accountID)
}
result := make(map[int64]*WindowStats, len(uniqueIDs))
if len(uniqueIDs) == 0 {
return result, nil
}
startTime := timezone.Today()
if batchReader, ok := s.usageLogRepo.(accountWindowStatsBatchReader); ok {
statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, uniqueIDs, startTime)
if err == nil {
for _, accountID := range uniqueIDs {
result[accountID] = windowStatsFromAccountStats(statsByAccount[accountID])
}
return result, nil
}
}
var mu sync.Mutex
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(8)
for _, accountID := range uniqueIDs {
id := accountID
g.Go(func() error {
stats, err := s.usageLogRepo.GetAccountWindowStats(gctx, id, startTime)
if err != nil {
return nil
}
mu.Lock()
result[id] = windowStatsFromAccountStats(stats)
mu.Unlock()
return nil
})
}
_ = g.Wait()
for _, accountID := range uniqueIDs {
if _, ok := result[accountID]; !ok {
result[accountID] = &WindowStats{}
}
}
return result, nil
}
func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
if stats == nil {
return &WindowStats{}
}
return &WindowStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
StandardCost: stats.StandardCost,
UserCost: stats.UserCost,
}
}
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
if err != nil {

View File

@@ -314,3 +314,72 @@ func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T
t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped)
}
}
func TestAccountGetModelMapping_CacheInvalidatesOnCredentialsReplace(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "upstream-a",
},
},
}
first := account.GetModelMapping()
if first["claude-3-5-sonnet"] != "upstream-a" {
t.Fatalf("unexpected first mapping: %v", first)
}
account.Credentials = map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "upstream-b",
},
}
second := account.GetModelMapping()
if second["claude-3-5-sonnet"] != "upstream-b" {
t.Fatalf("expected cache invalidated after credentials replace, got: %v", second)
}
}
func TestAccountGetModelMapping_CacheInvalidatesOnMappingLenChange(t *testing.T) {
rawMapping := map[string]any{
"claude-sonnet": "sonnet-a",
}
account := &Account{
Credentials: map[string]any{
"model_mapping": rawMapping,
},
}
first := account.GetModelMapping()
if len(first) != 1 {
t.Fatalf("unexpected first mapping length: %d", len(first))
}
rawMapping["claude-opus"] = "opus-b"
second := account.GetModelMapping()
if second["claude-opus"] != "opus-b" {
t.Fatalf("expected cache invalidated after mapping len change, got: %v", second)
}
}
func TestAccountGetModelMapping_CacheInvalidatesOnInPlaceValueChange(t *testing.T) {
rawMapping := map[string]any{
"claude-sonnet": "sonnet-a",
}
account := &Account{
Credentials: map[string]any{
"model_mapping": rawMapping,
},
}
first := account.GetModelMapping()
if first["claude-sonnet"] != "sonnet-a" {
t.Fatalf("unexpected first mapping: %v", first)
}
rawMapping["claude-sonnet"] = "sonnet-b"
second := account.GetModelMapping()
if second["claude-sonnet"] != "sonnet-b" {
t.Fatalf("expected cache invalidated after in-place value change, got: %v", second)
}
}

View File

@@ -83,13 +83,14 @@ type AdminService interface {
// CreateUserInput represents input for creating a new user via admin operations.
type CreateUserInput struct {
Email string
Password string
Username string
Notes string
Balance float64
Concurrency int
AllowedGroups []int64
Email string
Password string
Username string
Notes string
Balance float64
Concurrency int
AllowedGroups []int64
SoraStorageQuotaBytes int64
}
type UpdateUserInput struct {
@@ -103,7 +104,8 @@ type UpdateUserInput struct {
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
// map[groupID]*ratenil 表示删除该分组的专属倍率
GroupRates map[int64]*float64
GroupRates map[int64]*float64
SoraStorageQuotaBytes *int64
}
type CreateGroupInput struct {
@@ -135,6 +137,8 @@ type CreateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
// Sora 存储配额
SoraStorageQuotaBytes int64
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
@@ -169,6 +173,8 @@ type UpdateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
// Sora 存储配额
SoraStorageQuotaBytes *int64
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
@@ -402,6 +408,14 @@ type adminServiceImpl struct {
authCacheInvalidator APIKeyAuthCacheInvalidator
}
type userGroupRateBatchReader interface {
GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
}
type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAdminService creates a new AdminService
func NewAdminService(
userRepo UserRepository,
@@ -442,18 +456,43 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
}
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok {
userIDs := make([]int64, 0, len(users))
for i := range users {
userIDs = append(userIDs, users[i].ID)
}
users[i].GroupRates = rates
ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs)
if err != nil {
logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err)
s.loadUserGroupRatesOneByOne(ctx, users)
} else {
for i := range users {
if rates, ok := ratesByUser[users[i].ID]; ok {
users[i].GroupRates = rates
}
}
}
} else {
s.loadUserGroupRatesOneByOne(ctx, users)
}
}
return users, result.Total, nil
}
func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) {
if s.userGroupRateRepo == nil {
return
}
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
}
users[i].GroupRates = rates
}
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
@@ -473,14 +512,15 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
user := &User{
Email: input.Email,
Username: input.Username,
Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
Email: input.Email,
Username: input.Username,
Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
@@ -534,6 +574,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.AllowedGroups = *input.AllowedGroups
}
if input.SoraStorageQuotaBytes != nil {
user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
@@ -820,6 +864,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject,
SupportedModelScopes: input.SupportedModelScopes,
SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
@@ -982,6 +1027,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.SoraVideoPricePerRequestHD != nil {
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
}
if input.SoraStorageQuotaBytes != nil {
group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
}
// Claude Code 客户端限制
if input.ClaudeCodeOnly != nil {
@@ -1188,6 +1236,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
}
// Sora apikey 账号的 base_url 必填校验
if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
baseURL, _ := input.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
account := &Account{
Name: input.Name,
Notes: normalizeAccountNotes(input.Notes),
@@ -1301,12 +1361,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
}
// Sora apikey 账号的 base_url 必填校验
if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
baseURL, _ := account.Credentials["base_url"].(string)
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, errors.New("sora apikey 账号必须设置 base_url")
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
// 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil {
for _, groupID := range *input.GroupIDs {
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
return nil, err
}
// 检查混合渠道风险(除非用户已确认)
@@ -1348,11 +1418,18 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if len(input.AccountIDs) == 0 {
return result, nil
}
if input.GroupIDs != nil {
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
return nil, err
}
}
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
platformByID := map[int64]string{}
groupAccountsByID := map[int64][]Account{}
groupNameByID := map[int64]string{}
if needMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
@@ -1366,6 +1443,13 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
}
}
loadedAccounts, loadedNames, err := s.preloadMixedChannelRiskData(ctx, *input.GroupIDs)
if err != nil {
return nil, err
}
groupAccountsByID = loadedAccounts
groupNameByID = loadedNames
}
if input.RateMultiplier != nil {
@@ -1409,11 +1493,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
// Handle group bindings per account (requires individual operations).
for _, accountID := range input.AccountIDs {
entry := BulkUpdateAccountResult{AccountID: accountID}
platform := ""
if input.GroupIDs != nil {
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
platform := platformByID[accountID]
platform = platformByID[accountID]
if platform == "" {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
@@ -1426,7 +1511,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
platform = account.Platform
}
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
if err := s.checkMixedChannelRiskWithPreloaded(accountID, platform, *input.GroupIDs, groupAccountsByID, groupNameByID); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
@@ -1444,6 +1529,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
result.Results = append(result.Results, entry)
continue
}
if !input.SkipMixedChannelCheck && platform != "" {
updateMixedChannelPreloadedAccounts(groupAccountsByID, *input.GroupIDs, accountID, platform)
}
}
entry.Success = true
@@ -2115,6 +2203,135 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return nil
}
func (s *adminServiceImpl) preloadMixedChannelRiskData(ctx context.Context, groupIDs []int64) (map[int64][]Account, map[int64]string, error) {
accountsByGroup := make(map[int64][]Account)
groupNameByID := make(map[int64]string)
if len(groupIDs) == 0 {
return accountsByGroup, groupNameByID, nil
}
seen := make(map[int64]struct{}, len(groupIDs))
for _, groupID := range groupIDs {
if groupID <= 0 {
continue
}
if _, ok := seen[groupID]; ok {
continue
}
seen[groupID] = struct{}{}
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return nil, nil, fmt.Errorf("get accounts in group %d: %w", groupID, err)
}
accountsByGroup[groupID] = accounts
group, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
continue
}
if group != nil {
groupNameByID[groupID] = group.Name
}
}
return accountsByGroup, groupNameByID, nil
}
func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
if s.groupRepo == nil {
return errors.New("group repository not configured")
}
if batchReader, ok := s.groupRepo.(groupExistenceBatchReader); ok {
existsByID, err := batchReader.ExistsByIDs(ctx, groupIDs)
if err != nil {
return fmt.Errorf("check groups exists: %w", err)
}
for _, groupID := range groupIDs {
if groupID <= 0 || !existsByID[groupID] {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
}
return nil
}
for _, groupID := range groupIDs {
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
return fmt.Errorf("get group: %w", err)
}
}
return nil
}
func (s *adminServiceImpl) checkMixedChannelRiskWithPreloaded(currentAccountID int64, currentAccountPlatform string, groupIDs []int64, accountsByGroup map[int64][]Account, groupNameByID map[int64]string) error {
currentPlatform := getAccountPlatform(currentAccountPlatform)
if currentPlatform == "" {
return nil
}
for _, groupID := range groupIDs {
accounts := accountsByGroup[groupID]
for _, account := range accounts {
if currentAccountID > 0 && account.ID == currentAccountID {
continue
}
otherPlatform := getAccountPlatform(account.Platform)
if otherPlatform == "" {
continue
}
if currentPlatform != otherPlatform {
groupName := fmt.Sprintf("Group %d", groupID)
if name := strings.TrimSpace(groupNameByID[groupID]); name != "" {
groupName = name
}
return &MixedChannelError{
GroupID: groupID,
GroupName: groupName,
CurrentPlatform: currentPlatform,
OtherPlatform: otherPlatform,
}
}
}
}
return nil
}
func updateMixedChannelPreloadedAccounts(accountsByGroup map[int64][]Account, groupIDs []int64, accountID int64, platform string) {
if len(groupIDs) == 0 || accountID <= 0 || platform == "" {
return
}
for _, groupID := range groupIDs {
if groupID <= 0 {
continue
}
accounts := accountsByGroup[groupID]
found := false
for i := range accounts {
if accounts[i].ID != accountID {
continue
}
accounts[i].Platform = platform
found = true
break
}
if !found {
accounts = append(accounts, Account{
ID: accountID,
Platform: platform,
})
}
accountsByGroup[groupID] = accounts
}
}
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)

View File

@@ -15,6 +15,7 @@ type accountRepoStubForBulkUpdate struct {
bulkUpdateErr error
bulkUpdateIDs []int64
bindGroupErrByID map[int64]error
bindGroupsCalls []int64
getByIDsAccounts []*Account
getByIDsErr error
getByIDsCalled bool
@@ -22,6 +23,8 @@ type accountRepoStubForBulkUpdate struct {
getByIDAccounts map[int64]*Account
getByIDErrByID map[int64]error
getByIDCalled []int64
listByGroupData map[int64][]Account
listByGroupErr map[int64]error
}
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
@@ -33,6 +36,7 @@ func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64
}
func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error {
s.bindGroupsCalls = append(s.bindGroupsCalls, accountID)
if err, ok := s.bindGroupErrByID[accountID]; ok {
return err
}
@@ -59,6 +63,16 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac
return nil, errors.New("account not found")
}
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
if err, ok := s.listByGroupErr[groupID]; ok {
return nil, err
}
if rows, ok := s.listByGroupData[groupID]; ok {
return rows, nil
}
return nil, nil
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}
@@ -86,7 +100,10 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
2: errors.New("bind failed"),
},
}
svc := &adminServiceImpl{accountRepo: repo}
svc := &adminServiceImpl{
accountRepo: repo,
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "g10"}},
}
groupIDs := []int64{10}
schedulable := false
@@ -105,3 +122,51 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
require.Len(t, result.Results, 3)
}
func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}
svc := &adminServiceImpl{accountRepo: repo}
groupIDs := []int64{10}
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1},
GroupIDs: &groupIDs,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "group repository not configured")
}
func TestAdminService_BulkUpdateAccounts_MixedChannelCheckUsesUpdatedSnapshot(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
getByIDsAccounts: []*Account{
{ID: 1, Platform: PlatformAnthropic},
{ID: 2, Platform: PlatformAntigravity},
},
listByGroupData: map[int64][]Account{
10: {},
},
}
svc := &adminServiceImpl{
accountRepo: repo,
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "目标分组"}},
}
groupIDs := []int64{10}
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1, 2},
GroupIDs: &groupIDs,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.Equal(t, 1, result.Success)
require.Equal(t, 1, result.Failed)
require.ElementsMatch(t, []int64{1}, result.SuccessIDs)
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
require.Len(t, result.Results, 2)
require.Contains(t, result.Results[1].Error, "mixed channel")
require.Equal(t, []int64{1}, repo.bindGroupsCalls)
}

View File

@@ -0,0 +1,106 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type userRepoStubForListUsers struct {
userRepoStub
users []User
err error
}
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
if s.err != nil {
return nil, nil, s.err
}
out := make([]User, len(s.users))
copy(out, s.users)
return out, &pagination.PaginationResult{
Total: int64(len(out)),
Page: params.Page,
PageSize: params.PageSize,
}, nil
}
type userGroupRateRepoStubForListUsers struct {
batchCalls int
singleCall []int64
batchErr error
batchData map[int64]map[int64]float64
singleErr map[int64]error
singleData map[int64]map[int64]float64
}
func (s *userGroupRateRepoStubForListUsers) GetByUserIDs(_ context.Context, _ []int64) (map[int64]map[int64]float64, error) {
s.batchCalls++
if s.batchErr != nil {
return nil, s.batchErr
}
return s.batchData, nil
}
func (s *userGroupRateRepoStubForListUsers) GetByUserID(_ context.Context, userID int64) (map[int64]float64, error) {
s.singleCall = append(s.singleCall, userID)
if err, ok := s.singleErr[userID]; ok {
return nil, err
}
if rates, ok := s.singleData[userID]; ok {
return rates, nil
}
return map[int64]float64{}, nil
}
func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, userID, groupID int64) (*float64, error) {
panic("unexpected GetByUserAndGroup call")
}
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
panic("unexpected SyncUserGroupRates call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
panic("unexpected DeleteByGroupID call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByUserID(_ context.Context, userID int64) error {
panic("unexpected DeleteByUserID call")
}
func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
userRepo := &userRepoStubForListUsers{
users: []User{
{ID: 101, Username: "u1"},
{ID: 202, Username: "u2"},
},
}
rateRepo := &userGroupRateRepoStubForListUsers{
batchErr: errors.New("batch unavailable"),
singleData: map[int64]map[int64]float64{
101: {11: 1.1},
202: {22: 2.2},
},
}
svc := &adminServiceImpl{
userRepo: userRepo,
userGroupRateRepo: rateRepo,
}
users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{})
require.NoError(t, err)
require.Equal(t, int64(2), total)
require.Len(t, users, 2)
require.Equal(t, 1, rateRepo.batchCalls)
require.ElementsMatch(t, []int64{101, 202}, rateRepo.singleCall)
require.Equal(t, 1.1, users[0].GroupRates[11])
require.Equal(t, 2.2, users[1].GroupRates[22])
}

View File

@@ -21,7 +21,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -2291,7 +2290,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
// isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记
func isSingleAccountRetry(ctx context.Context) bool {
v, _ := ctx.Value(ctxkey.SingleAccountRetry).(bool)
v, _ := SingleAccountRetryFromContext(ctx)
return v
}

View File

@@ -1,6 +1,10 @@
package service
import "time"
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
)
// API Key status constants
const (
@@ -19,11 +23,14 @@ type APIKey struct {
Status string
IPWhitelist []string
IPBlacklist []string
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
// 预编译的 IP 规则,用于认证热路径避免重复 ParseIP/ParseCIDR。
CompiledIPWhitelist *ip.CompiledIPRules `json:"-"`
CompiledIPBlacklist *ip.CompiledIPRules `json:"-"`
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
// Quota fields
Quota float64 // Quota limit in USD (0 = unlimited)

View File

@@ -298,5 +298,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
}
}
s.compileAPIKeyIPRules(apiKey)
return apiKey
}

View File

@@ -158,6 +158,14 @@ func NewAPIKeyService(
return svc
}
func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) {
if apiKey == nil {
return
}
apiKey.CompiledIPWhitelist = ip.CompileIPRules(apiKey.IPWhitelist)
apiKey.CompiledIPBlacklist = ip.CompileIPRules(apiKey.IPBlacklist)
}
// GenerateKey 生成随机API Key
func (s *APIKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
@@ -332,6 +340,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
@@ -363,6 +372,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
@@ -375,6 +385,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
}
@@ -391,6 +402,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
} else {
@@ -402,6 +414,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
}
@@ -411,6 +424,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
@@ -510,6 +524,7 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}

View File

@@ -308,6 +308,17 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
}, nil
}
// VerifyTurnstileForRegister 在注册场景下验证 Turnstile。
// 当邮箱验证开启且已提交验证码时,说明验证码发送阶段已完成 Turnstile 校验,
// 此处跳过二次校验,避免一次性 token 在注册提交时重复使用导致误报失败。
func (s *AuthService) VerifyTurnstileForRegister(ctx context.Context, token, remoteIP, verifyCode string) error {
if s.IsEmailVerifyEnabled(ctx) && strings.TrimSpace(verifyCode) != "" {
logger.LegacyPrintf("service.auth", "%s", "[Auth] Email verify flow detected, skip duplicate Turnstile check on register")
return nil
}
return s.VerifyTurnstile(ctx, token, remoteIP)
}
// VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required

View File

@@ -0,0 +1,96 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type turnstileVerifierSpy struct {
called int
lastToken string
result *TurnstileVerifyResponse
err error
}
func (s *turnstileVerifierSpy) VerifyToken(_ context.Context, _ string, token, _ string) (*TurnstileVerifyResponse, error) {
s.called++
s.lastToken = token
if s.err != nil {
return nil, s.err
}
if s.result != nil {
return s.result, nil
}
return &TurnstileVerifyResponse{Success: true}, nil
}
func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier TurnstileVerifier) *AuthService {
cfg := &config.Config{
Server: config.ServerConfig{
Mode: "release",
},
Turnstile: config.TurnstileConfig{
Required: true,
},
}
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
turnstileService := NewTurnstileService(settingService, verifier)
return NewAuthService(
&userRepoStub{},
nil, // redeemRepo
nil, // refreshTokenCache
cfg,
settingService,
nil, // emailService
turnstileService,
nil, // emailQueueService
nil, // promoService
)
}
func TestAuthService_VerifyTurnstileForRegister_SkipWhenEmailVerifyCodeProvided(t *testing.T) {
verifier := &turnstileVerifierSpy{}
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
SettingKeyEmailVerifyEnabled: "true",
SettingKeyTurnstileEnabled: "true",
SettingKeyTurnstileSecretKey: "secret",
SettingKeyRegistrationEnabled: "true",
}, verifier)
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "123456")
require.NoError(t, err)
require.Equal(t, 0, verifier.called)
}
func TestAuthService_VerifyTurnstileForRegister_RequireWhenVerifyCodeMissing(t *testing.T) {
verifier := &turnstileVerifierSpy{}
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
SettingKeyEmailVerifyEnabled: "true",
SettingKeyTurnstileEnabled: "true",
SettingKeyTurnstileSecretKey: "secret",
}, verifier)
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "")
require.ErrorIs(t, err, ErrTurnstileVerificationFailed)
}
func TestAuthService_VerifyTurnstileForRegister_NoSkipWhenEmailVerifyDisabled(t *testing.T) {
verifier := &turnstileVerifierSpy{}
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
SettingKeyEmailVerifyEnabled: "false",
SettingKeyTurnstileEnabled: "true",
SettingKeyTurnstileSecretKey: "secret",
}, verifier)
err := service.VerifyTurnstileForRegister(context.Background(), "turnstile-token", "127.0.0.1", "123456")
require.NoError(t, err)
require.Equal(t, 1, verifier.called)
require.Equal(t, "turnstile-token", verifier.lastToken)
}

View File

@@ -3,6 +3,7 @@ package service
import (
"context"
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
@@ -10,6 +11,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"golang.org/x/sync/singleflight"
)
// 错误定义
@@ -58,6 +60,7 @@ const (
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
balanceLoadTimeout = 3 * time.Second
)
// cacheWriteTask 缓存写入任务
@@ -82,6 +85,9 @@ type BillingCacheService struct {
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
cacheWriteStopOnce sync.Once
cacheWriteMu sync.RWMutex
stopped atomic.Bool
balanceLoadSF singleflight.Group
// 丢弃日志节流计数器(减少高负载下日志噪音)
cacheWriteDropFullCount uint64
cacheWriteDropFullLastLog int64
@@ -105,35 +111,52 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
// Stop 关闭缓存写入工作池
func (s *BillingCacheService) Stop() {
s.cacheWriteStopOnce.Do(func() {
if s.cacheWriteChan == nil {
s.stopped.Store(true)
s.cacheWriteMu.Lock()
ch := s.cacheWriteChan
if ch != nil {
close(ch)
}
s.cacheWriteMu.Unlock()
if ch == nil {
return
}
close(s.cacheWriteChan)
s.cacheWriteWg.Wait()
s.cacheWriteChan = nil
s.cacheWriteMu.Lock()
if s.cacheWriteChan == ch {
s.cacheWriteChan = nil
}
s.cacheWriteMu.Unlock()
})
}
func (s *BillingCacheService) startCacheWriteWorkers() {
s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize)
ch := make(chan cacheWriteTask, cacheWriteBufferSize)
s.cacheWriteChan = ch
for i := 0; i < cacheWriteWorkerCount; i++ {
s.cacheWriteWg.Add(1)
go s.cacheWriteWorker()
go s.cacheWriteWorker(ch)
}
}
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false并记录告警
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
if s.cacheWriteChan == nil {
if s.stopped.Load() {
s.logCacheWriteDrop(task, "closed")
return false
}
defer func() {
if recovered := recover(); recovered != nil {
// 队列已关闭时可能触发 panic记录后静默失败。
s.logCacheWriteDrop(task, "closed")
enqueued = false
}
}()
s.cacheWriteMu.RLock()
defer s.cacheWriteMu.RUnlock()
if s.cacheWriteChan == nil {
s.logCacheWriteDrop(task, "closed")
return false
}
select {
case s.cacheWriteChan <- task:
return true
@@ -144,9 +167,9 @@ func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued b
}
}
func (s *BillingCacheService) cacheWriteWorker() {
func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) {
defer s.cacheWriteWg.Done()
for task := range s.cacheWriteChan {
for task := range ch {
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
switch task.kind {
case cacheWriteSetBalance:
@@ -243,19 +266,31 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
return balance, nil
}
// 缓存未命中,从数据库读取
balance, err = s.getUserBalanceFromDB(ctx, userID)
// 缓存未命中singleflight 合并同一 userID 的并发回源请求。
value, err, _ := s.balanceLoadSF.Do(strconv.FormatInt(userID, 10), func() (any, error) {
loadCtx, cancel := context.WithTimeout(context.Background(), balanceLoadTimeout)
defer cancel()
balance, err := s.getUserBalanceFromDB(loadCtx, userID)
if err != nil {
return nil, err
}
// 异步建立缓存
_ = s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteSetBalance,
userID: userID,
balance: balance,
})
return balance, nil
})
if err != nil {
return 0, err
}
// 异步建立缓存
_ = s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteSetBalance,
userID: userID,
balance: balance,
})
balance, ok := value.(float64)
if !ok {
return 0, fmt.Errorf("unexpected balance type: %T", value)
}
return balance, nil
}

View File

@@ -0,0 +1,115 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type billingCacheMissStub struct {
setBalanceCalls atomic.Int64
}
func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
return 0, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
s.setBalanceCalls.Add(1)
return nil
}
func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
return nil
}
func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
return nil, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
return nil
}
func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
return nil
}
type balanceLoadUserRepoStub struct {
mockUserRepo
calls atomic.Int64
delay time.Duration
balance float64
}
func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
s.calls.Add(1)
if s.delay > 0 {
select {
case <-time.After(s.delay):
case <-ctx.Done():
return nil, ctx.Err()
}
}
return &User{ID: id, Balance: s.balance}, nil
}
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
cache := &billingCacheMissStub{}
userRepo := &balanceLoadUserRepoStub{
delay: 80 * time.Millisecond,
balance: 12.34,
}
svc := NewBillingCacheService(cache, userRepo, nil, &config.Config{})
t.Cleanup(svc.Stop)
const goroutines = 16
start := make(chan struct{})
var wg sync.WaitGroup
errCh := make(chan error, goroutines)
balCh := make(chan float64, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
bal, err := svc.GetUserBalance(context.Background(), 99)
errCh <- err
balCh <- bal
}()
}
close(start)
wg.Wait()
close(errCh)
close(balCh)
for err := range errCh {
require.NoError(t, err)
}
for bal := range balCh {
require.Equal(t, 12.34, bal)
}
require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并")
require.Eventually(t, func() bool {
return cache.setBalanceCalls.Load() >= 1
}, time.Second, 10*time.Millisecond)
}

View File

@@ -73,3 +73,16 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
}, 2*time.Second, 10*time.Millisecond)
}
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
svc.Stop()
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteDeductBalance,
userID: 1,
amount: 1,
})
require.False(t, enqueued)
}

View File

@@ -63,7 +63,7 @@ func TestCalculateImageCost_RateMultiplier(t *testing.T) {
// 费率倍数 1.5x
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5)
require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5
require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5
require.InDelta(t, 0.3015, cost.ActualCost, 0.0001) // ActualCost = 0.201 * 1.5
// 费率倍数 2.0x

View File

@@ -78,7 +78,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过
// 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt
if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku {
if isMaxTokensOneHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(r.Context()); ok && isMaxTokensOneHaiku {
return true // 绕过 system prompt 检查UA 已在 Step 1 验证
}

View File

@@ -3,8 +3,10 @@ package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"encoding/binary"
"os"
"strconv"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -18,6 +20,7 @@ type ConcurrencyCache interface {
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
// 账号等待队列(账号级)
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
@@ -42,15 +45,25 @@ type ConcurrencyCache interface {
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
}
// generateRequestID generates a unique request ID for concurrency slot tracking
// Uses 8 random bytes (16 hex chars) for uniqueness
func generateRequestID() string {
var (
requestIDPrefix = initRequestIDPrefix()
requestIDCounter atomic.Uint64
)
func initRequestIDPrefix() string {
b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
// Fallback to nanosecond timestamp (extremely rare case)
return fmt.Sprintf("%x", time.Now().UnixNano())
if _, err := rand.Read(b); err == nil {
return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36)
}
return hex.EncodeToString(b)
fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16)
return "r" + strconv.FormatUint(fallback, 36)
}
// generateRequestID generates a unique request ID for concurrency slot tracking.
// Format: {process_random_prefix}-{base36_counter}
func generateRequestID() string {
seq := requestIDCounter.Add(1)
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
}
const (
@@ -321,16 +334,15 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int)
for _, accountID := range accountIDs {
count, err := s.cache.GetAccountConcurrency(ctx, accountID)
if err != nil {
// If key doesn't exist in Redis, count is 0
count = 0
}
result[accountID] = count
if len(accountIDs) == 0 {
return map[int64]int{}, nil
}
return result, nil
if s.cache == nil {
result := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
result[accountID] = 0
}
return result, nil
}
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
}

View File

@@ -5,6 +5,8 @@ package service
import (
"context"
"errors"
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/require"
@@ -12,20 +14,20 @@ import (
// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩
type stubConcurrencyCacheForTest struct {
acquireResult bool
acquireErr error
releaseErr error
concurrency int
acquireResult bool
acquireErr error
releaseErr error
concurrency int
concurrencyErr error
waitAllowed bool
waitErr error
waitCount int
waitCountErr error
loadBatch map[int64]*AccountLoadInfo
loadBatchErr error
waitAllowed bool
waitErr error
waitCount int
waitCountErr error
loadBatch map[int64]*AccountLoadInfo
loadBatchErr error
usersLoadBatch map[int64]*UserLoadInfo
usersLoadErr error
cleanupErr error
cleanupErr error
// 记录调用
releasedAccountIDs []int64
@@ -45,6 +47,16 @@ func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, acco
func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) {
return c.concurrency, c.concurrencyErr
}
func (c *stubConcurrencyCacheForTest) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
if c.concurrencyErr != nil {
return nil, c.concurrencyErr
}
result[accountID] = c.concurrency
}
return result, nil
}
func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) {
return c.waitAllowed, c.waitErr
}
@@ -155,6 +167,25 @@ func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) {
require.True(t, result.Acquired)
}
func TestGenerateRequestID_UsesStablePrefixAndMonotonicCounter(t *testing.T) {
id1 := generateRequestID()
id2 := generateRequestID()
require.NotEmpty(t, id1)
require.NotEmpty(t, id2)
p1 := strings.Split(id1, "-")
p2 := strings.Split(id2, "-")
require.Len(t, p1, 2)
require.Len(t, p2, 2)
require.Equal(t, p1[0], p2[0], "同一进程前缀应保持一致")
n1, err := strconv.ParseUint(p1[1], 36, 64)
require.NoError(t, err)
n2, err := strconv.ParseUint(p2[1], 36, 64)
require.NoError(t, err)
require.Equal(t, n1+1, n2, "计数器应单调递增")
}
func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) {
expected := map[int64]*AccountLoadInfo{
1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60},

View File

@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return stats, nil
}
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err)
}

View File

@@ -0,0 +1,252 @@
package service
import "context"
type DataManagementPostgresConfig struct {
Host string `json:"host"`
Port int32 `json:"port"`
User string `json:"user"`
Password string `json:"password,omitempty"`
PasswordConfigured bool `json:"password_configured"`
Database string `json:"database"`
SSLMode string `json:"ssl_mode"`
ContainerName string `json:"container_name"`
}
type DataManagementRedisConfig struct {
Addr string `json:"addr"`
Username string `json:"username"`
Password string `json:"password,omitempty"`
PasswordConfigured bool `json:"password_configured"`
DB int32 `json:"db"`
ContainerName string `json:"container_name"`
}
type DataManagementS3Config struct {
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key,omitempty"`
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
type DataManagementConfig struct {
SourceMode string `json:"source_mode"`
BackupRoot string `json:"backup_root"`
SQLitePath string `json:"sqlite_path,omitempty"`
RetentionDays int32 `json:"retention_days"`
KeepLast int32 `json:"keep_last"`
ActivePostgresID string `json:"active_postgres_profile_id"`
ActiveRedisID string `json:"active_redis_profile_id"`
Postgres DataManagementPostgresConfig `json:"postgres"`
Redis DataManagementRedisConfig `json:"redis"`
S3 DataManagementS3Config `json:"s3"`
ActiveS3ProfileID string `json:"active_s3_profile_id"`
}
type DataManagementTestS3Result struct {
OK bool `json:"ok"`
Message string `json:"message"`
}
type DataManagementCreateBackupJobInput struct {
BackupType string
UploadToS3 bool
TriggeredBy string
IdempotencyKey string
S3ProfileID string
PostgresID string
RedisID string
}
type DataManagementListBackupJobsInput struct {
PageSize int32
PageToken string
Status string
BackupType string
}
type DataManagementArtifactInfo struct {
LocalPath string `json:"local_path"`
SizeBytes int64 `json:"size_bytes"`
SHA256 string `json:"sha256"`
}
type DataManagementS3ObjectInfo struct {
Bucket string `json:"bucket"`
Key string `json:"key"`
ETag string `json:"etag"`
}
type DataManagementBackupJob struct {
JobID string `json:"job_id"`
BackupType string `json:"backup_type"`
Status string `json:"status"`
TriggeredBy string `json:"triggered_by"`
IdempotencyKey string `json:"idempotency_key,omitempty"`
UploadToS3 bool `json:"upload_to_s3"`
S3ProfileID string `json:"s3_profile_id,omitempty"`
PostgresID string `json:"postgres_profile_id,omitempty"`
RedisID string `json:"redis_profile_id,omitempty"`
StartedAt string `json:"started_at,omitempty"`
FinishedAt string `json:"finished_at,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
Artifact DataManagementArtifactInfo `json:"artifact"`
S3Object DataManagementS3ObjectInfo `json:"s3"`
}
type DataManagementSourceProfile struct {
SourceType string `json:"source_type"`
ProfileID string `json:"profile_id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
Config DataManagementSourceConfig `json:"config"`
PasswordConfigured bool `json:"password_configured"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
type DataManagementSourceConfig struct {
Host string `json:"host"`
Port int32 `json:"port"`
User string `json:"user"`
Password string `json:"password,omitempty"`
Database string `json:"database"`
SSLMode string `json:"ssl_mode"`
Addr string `json:"addr"`
Username string `json:"username"`
DB int32 `json:"db"`
ContainerName string `json:"container_name"`
}
type DataManagementCreateSourceProfileInput struct {
SourceType string
ProfileID string
Name string
Config DataManagementSourceConfig
SetActive bool
}
type DataManagementUpdateSourceProfileInput struct {
SourceType string
ProfileID string
Name string
Config DataManagementSourceConfig
}
type DataManagementS3Profile struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
S3 DataManagementS3Config `json:"s3"`
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
type DataManagementCreateS3ProfileInput struct {
ProfileID string
Name string
S3 DataManagementS3Config
SetActive bool
}
type DataManagementUpdateS3ProfileInput struct {
ProfileID string
Name string
S3 DataManagementS3Config
}
type DataManagementListBackupJobsResult struct {
Items []DataManagementBackupJob `json:"items"`
NextPageToken string `json:"next_page_token,omitempty"`
}
func (s *DataManagementService) GetConfig(ctx context.Context) (DataManagementConfig, error) {
_ = ctx
return DataManagementConfig{}, s.deprecatedError()
}
func (s *DataManagementService) UpdateConfig(ctx context.Context, cfg DataManagementConfig) (DataManagementConfig, error) {
_, _ = ctx, cfg
return DataManagementConfig{}, s.deprecatedError()
}
func (s *DataManagementService) ListSourceProfiles(ctx context.Context, sourceType string) ([]DataManagementSourceProfile, error) {
_, _ = ctx, sourceType
return nil, s.deprecatedError()
}
func (s *DataManagementService) CreateSourceProfile(ctx context.Context, input DataManagementCreateSourceProfileInput) (DataManagementSourceProfile, error) {
_, _ = ctx, input
return DataManagementSourceProfile{}, s.deprecatedError()
}
func (s *DataManagementService) UpdateSourceProfile(ctx context.Context, input DataManagementUpdateSourceProfileInput) (DataManagementSourceProfile, error) {
_, _ = ctx, input
return DataManagementSourceProfile{}, s.deprecatedError()
}
func (s *DataManagementService) DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error {
_, _, _ = ctx, sourceType, profileID
return s.deprecatedError()
}
func (s *DataManagementService) SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (DataManagementSourceProfile, error) {
_, _, _ = ctx, sourceType, profileID
return DataManagementSourceProfile{}, s.deprecatedError()
}
func (s *DataManagementService) ValidateS3(ctx context.Context, cfg DataManagementS3Config) (DataManagementTestS3Result, error) {
_, _ = ctx, cfg
return DataManagementTestS3Result{}, s.deprecatedError()
}
func (s *DataManagementService) ListS3Profiles(ctx context.Context) ([]DataManagementS3Profile, error) {
_ = ctx
return nil, s.deprecatedError()
}
func (s *DataManagementService) CreateS3Profile(ctx context.Context, input DataManagementCreateS3ProfileInput) (DataManagementS3Profile, error) {
_, _ = ctx, input
return DataManagementS3Profile{}, s.deprecatedError()
}
func (s *DataManagementService) UpdateS3Profile(ctx context.Context, input DataManagementUpdateS3ProfileInput) (DataManagementS3Profile, error) {
_, _ = ctx, input
return DataManagementS3Profile{}, s.deprecatedError()
}
func (s *DataManagementService) DeleteS3Profile(ctx context.Context, profileID string) error {
_, _ = ctx, profileID
return s.deprecatedError()
}
func (s *DataManagementService) SetActiveS3Profile(ctx context.Context, profileID string) (DataManagementS3Profile, error) {
_, _ = ctx, profileID
return DataManagementS3Profile{}, s.deprecatedError()
}
func (s *DataManagementService) CreateBackupJob(ctx context.Context, input DataManagementCreateBackupJobInput) (DataManagementBackupJob, error) {
_, _ = ctx, input
return DataManagementBackupJob{}, s.deprecatedError()
}
func (s *DataManagementService) ListBackupJobs(ctx context.Context, input DataManagementListBackupJobsInput) (DataManagementListBackupJobsResult, error) {
_, _ = ctx, input
return DataManagementListBackupJobsResult{}, s.deprecatedError()
}
func (s *DataManagementService) GetBackupJob(ctx context.Context, jobID string) (DataManagementBackupJob, error) {
_, _ = ctx, jobID
return DataManagementBackupJob{}, s.deprecatedError()
}
func (s *DataManagementService) deprecatedError() error {
return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
}

View File

@@ -0,0 +1,36 @@
package service
import (
"context"
"path/filepath"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
func TestDataManagementService_DeprecatedRPCMethods(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "datamanagement.sock")
svc := NewDataManagementServiceWithOptions(socketPath, 0)
_, err := svc.GetConfig(context.Background())
assertDeprecatedDataManagementError(t, err, socketPath)
_, err = svc.CreateBackupJob(context.Background(), DataManagementCreateBackupJobInput{BackupType: "full"})
assertDeprecatedDataManagementError(t, err, socketPath)
err = svc.DeleteS3Profile(context.Background(), "s3-default")
assertDeprecatedDataManagementError(t, err, socketPath)
}
func assertDeprecatedDataManagementError(t *testing.T, err error, socketPath string) {
t.Helper()
require.Error(t, err)
statusCode, status := infraerrors.ToHTTP(err)
require.Equal(t, 503, statusCode)
require.Equal(t, DataManagementDeprecatedReason, status.Reason)
require.Equal(t, socketPath, status.Metadata["socket_path"])
}

View File

@@ -0,0 +1,99 @@
package service
import (
"context"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const (
DefaultDataManagementAgentSocketPath = "/tmp/sub2api-datamanagement.sock"
LegacyBackupAgentSocketPath = "/tmp/sub2api-backup.sock"
DataManagementDeprecatedReason = "DATA_MANAGEMENT_DEPRECATED"
DataManagementAgentSocketMissingReason = "DATA_MANAGEMENT_AGENT_SOCKET_MISSING"
DataManagementAgentUnavailableReason = "DATA_MANAGEMENT_AGENT_UNAVAILABLE"
// Deprecated: keep old names for compatibility.
DefaultBackupAgentSocketPath = DefaultDataManagementAgentSocketPath
BackupAgentSocketMissingReason = DataManagementAgentSocketMissingReason
BackupAgentUnavailableReason = DataManagementAgentUnavailableReason
)
var (
ErrDataManagementDeprecated = infraerrors.ServiceUnavailable(
DataManagementDeprecatedReason,
"data management feature is deprecated",
)
ErrDataManagementAgentSocketMissing = infraerrors.ServiceUnavailable(
DataManagementAgentSocketMissingReason,
"data management agent socket is missing",
)
ErrDataManagementAgentUnavailable = infraerrors.ServiceUnavailable(
DataManagementAgentUnavailableReason,
"data management agent is unavailable",
)
// Deprecated: keep old names for compatibility.
ErrBackupAgentSocketMissing = ErrDataManagementAgentSocketMissing
ErrBackupAgentUnavailable = ErrDataManagementAgentUnavailable
)
type DataManagementAgentHealth struct {
Enabled bool
Reason string
SocketPath string
Agent *DataManagementAgentInfo
}
type DataManagementAgentInfo struct {
Status string
Version string
UptimeSeconds int64
}
type DataManagementService struct {
socketPath string
dialTimeout time.Duration
}
func NewDataManagementService() *DataManagementService {
return NewDataManagementServiceWithOptions(DefaultDataManagementAgentSocketPath, 500*time.Millisecond)
}
func NewDataManagementServiceWithOptions(socketPath string, dialTimeout time.Duration) *DataManagementService {
path := strings.TrimSpace(socketPath)
if path == "" {
path = DefaultDataManagementAgentSocketPath
}
if dialTimeout <= 0 {
dialTimeout = 500 * time.Millisecond
}
return &DataManagementService{
socketPath: path,
dialTimeout: dialTimeout,
}
}
func (s *DataManagementService) SocketPath() string {
if s == nil || strings.TrimSpace(s.socketPath) == "" {
return DefaultDataManagementAgentSocketPath
}
return s.socketPath
}
func (s *DataManagementService) GetAgentHealth(ctx context.Context) DataManagementAgentHealth {
_ = ctx
return DataManagementAgentHealth{
Enabled: false,
Reason: DataManagementDeprecatedReason,
SocketPath: s.SocketPath(),
}
}
func (s *DataManagementService) EnsureAgentEnabled(ctx context.Context) error {
_ = ctx
return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
}

View File

@@ -0,0 +1,37 @@
package service
import (
"context"
"path/filepath"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
func TestDataManagementService_GetAgentHealth_Deprecated(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "unused.sock")
svc := NewDataManagementServiceWithOptions(socketPath, 0)
health := svc.GetAgentHealth(context.Background())
require.False(t, health.Enabled)
require.Equal(t, DataManagementDeprecatedReason, health.Reason)
require.Equal(t, socketPath, health.SocketPath)
require.Nil(t, health.Agent)
}
func TestDataManagementService_EnsureAgentEnabled_Deprecated(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "unused.sock")
svc := NewDataManagementServiceWithOptions(socketPath, 100)
err := svc.EnsureAgentEnabled(context.Background())
require.Error(t, err)
statusCode, status := infraerrors.ToHTTP(err)
require.Equal(t, 503, statusCode)
require.Equal(t, DataManagementDeprecatedReason, status.Reason)
require.Equal(t, socketPath, status.Metadata["socket_path"])
}

View File

@@ -104,6 +104,7 @@ const (
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// OEM设置
SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制)
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
@@ -170,6 +171,27 @@ const (
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
// =========================
// Sora S3 存储配置
// =========================
SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储
SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址
SettingKeySoraS3Region = "sora_s3_region" // S3 区域
SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称
SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID
SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key加密存储
SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀
SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style兼容 MinIO 等)
SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL可选
SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置JSON
// =========================
// Sora 用户存储配额
// =========================
SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节)
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).

View File

@@ -279,10 +279,10 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotE
wantPassthrough: true,
},
{
name: "404 generic not found passes through as 404",
name: "404 generic not found does not passthrough",
statusCode: http.StatusNotFound,
respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
wantPassthrough: true,
wantPassthrough: false,
},
{
name: "400 Invalid URL does not passthrough",

View File

@@ -136,3 +136,67 @@ func TestDroppedBetaSet(t *testing.T) {
require.Contains(t, extended, claude.BetaClaudeCode)
require.Len(t, extended, len(claude.DroppedBetas)+1)
}
func TestBuildBetaTokenSet(t *testing.T) {
got := buildBetaTokenSet([]string{"foo", "", "bar", "foo"})
require.Len(t, got, 2)
require.Contains(t, got, "foo")
require.Contains(t, got, "bar")
require.NotContains(t, got, "")
empty := buildBetaTokenSet(nil)
require.Empty(t, empty)
}
func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) {
header := "oauth-2025-04-20,interleaved-thinking-2025-05-14"
got := stripBetaTokensWithSet(header, map[string]struct{}{})
require.Equal(t, header, got)
}
func TestIsCountTokensUnsupported404(t *testing.T) {
tests := []struct {
name string
statusCode int
body string
want bool
}{
{
name: "exact endpoint not found",
statusCode: 404,
body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`,
want: true,
},
{
name: "contains count_tokens and not found",
statusCode: 404,
body: `{"error":{"message":"count_tokens route not found","type":"not_found_error"}}`,
want: true,
},
{
name: "generic 404",
statusCode: 404,
body: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
want: false,
},
{
name: "404 with empty error message",
statusCode: 404,
body: `{"error":{"message":"","type":"not_found_error"}}`,
want: false,
},
{
name: "non-404 status",
statusCode: 400,
body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"invalid_request_error"}}`,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isCountTokensUnsupported404(tt.statusCode, []byte(tt.body))
require.Equal(t, tt.want, got)
})
}
}

View File

@@ -1892,6 +1892,14 @@ func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accoun
return 0, nil
}
func (m *mockConcurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
result[accountID] = 0
}
return result, nil
}
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,141 @@
package service
import (
"context"
"strings"
"testing"
"time"
)
func TestCollectSelectionFailureStats(t *testing.T) {
svc := &GatewayService{}
model := "sora2-landscape-10s"
resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339)
accounts := []Account{
// excluded
{
ID: 1,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
},
// unschedulable
{
ID: 2,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: false,
},
// platform filtered
{
ID: 3,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
},
// model unsupported
{
ID: 4,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-image": "gpt-image",
},
},
},
// model rate limited
{
ID: 5,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"model_rate_limits": map[string]any{
model: map[string]any{
"rate_limit_reset_at": resetAt,
},
},
},
},
// eligible
{
ID: 6,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
},
}
excluded := map[int64]struct{}{1: {}}
stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false)
if stats.Total != 6 {
t.Fatalf("total=%d want=6", stats.Total)
}
if stats.Excluded != 1 {
t.Fatalf("excluded=%d want=1", stats.Excluded)
}
if stats.Unschedulable != 1 {
t.Fatalf("unschedulable=%d want=1", stats.Unschedulable)
}
if stats.PlatformFiltered != 1 {
t.Fatalf("platform_filtered=%d want=1", stats.PlatformFiltered)
}
if stats.ModelUnsupported != 1 {
t.Fatalf("model_unsupported=%d want=1", stats.ModelUnsupported)
}
if stats.ModelRateLimited != 1 {
t.Fatalf("model_rate_limited=%d want=1", stats.ModelRateLimited)
}
if stats.Eligible != 1 {
t.Fatalf("eligible=%d want=1", stats.Eligible)
}
}
func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) {
svc := &GatewayService{}
acc := &Account{
ID: 7,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: false,
}
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
if diagnosis.Category != "unschedulable" {
t.Fatalf("category=%s want=unschedulable", diagnosis.Category)
}
if diagnosis.Detail != "schedulable=false" {
t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail)
}
}
func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
svc := &GatewayService{}
model := "sora2-landscape-10s"
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
acc := &Account{
ID: 8,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"model_rate_limits": map[string]any{
model: map[string]any{
"rate_limit_reset_at": resetAt,
},
},
},
}
diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false)
if diagnosis.Category != "model_rate_limited" {
t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category)
}
if !strings.Contains(diagnosis.Detail, "remaining=") {
t.Fatalf("detail=%s want contains remaining=", diagnosis.Detail)
}
}

View File

@@ -0,0 +1,79 @@
package service
import "testing"
func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected sora model to be supported when model_mapping is empty")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-4o": "gpt-4o",
},
},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected sora model to be supported when mapping has no sora selectors")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"sora2": "sora2",
},
},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") {
t.Fatalf("expected family selector sora2 to support sora2-landscape-15s")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"sy_8": "sy_8",
},
},
}
if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s")
}
}
func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) {
svc := &GatewayService{}
account := &Account{
Platform: PlatformSora,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-image": "gpt-image",
},
},
}
if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image")
}
}

View File

@@ -0,0 +1,89 @@
package service
import (
"context"
"testing"
"time"
)
func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) {
svc := &GatewayService{}
now := time.Now()
past := now.Add(-1 * time.Minute)
future := now.Add(5 * time.Minute)
acc := &Account{
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
AutoPauseOnExpired: true,
ExpiresAt: &past,
OverloadUntil: &future,
RateLimitResetAt: &future,
}
if !svc.isAccountSchedulableForSelection(acc) {
t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows")
}
}
func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) {
svc := &GatewayService{}
future := time.Now().Add(5 * time.Minute)
acc := &Account{
Platform: PlatformAnthropic,
Status: StatusActive,
Schedulable: true,
RateLimitResetAt: &future,
}
if svc.isAccountSchedulableForSelection(acc) {
t.Fatalf("expected non-sora account to keep generic schedulable checks")
}
}
func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) {
svc := &GatewayService{}
model := "sora2-landscape-10s"
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
globalResetAt := time.Now().Add(2 * time.Minute)
acc := &Account{
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
RateLimitResetAt: &globalResetAt,
Extra: map[string]any{
"model_rate_limits": map[string]any{
model: map[string]any{
"rate_limit_reset_at": resetAt,
},
},
},
}
if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) {
t.Fatalf("expected sora account to be blocked by model scope rate limit")
}
}
func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) {
svc := &GatewayService{}
future := time.Now().Add(3 * time.Minute)
accounts := []Account{
{
ID: 1,
Platform: PlatformSora,
Status: StatusActive,
Schedulable: true,
RateLimitResetAt: &future,
},
}
stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
if stats.Unschedulable != 0 || stats.Eligible != 1 {
t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible)
}
}

View File

@@ -105,12 +105,12 @@ func TestCalculateMaxWait_Scenarios(t *testing.T) {
concurrency int
expected int
}{
{5, 25}, // 5 + 20
{10, 30}, // 10 + 20
{1, 21}, // 1 + 20
{0, 21}, // min(1) + 20
{-1, 21}, // min(1) + 20
{-10, 21}, // min(1) + 20
{5, 25}, // 5 + 20
{10, 30}, // 10 + 20
{1, 21}, // 1 + 20
{0, 21}, // min(1) + 20
{-1, 21}, // min(1) + 20
{-10, 21}, // min(1) + 20
{100, 120}, // 100 + 20
}
for _, tt := range tests {

View File

@@ -53,6 +53,7 @@ type GeminiMessagesCompatService struct {
httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService
cfg *config.Config
responseHeaderFilter *responseheaders.CompiledHeaderFilter
}
func NewGeminiMessagesCompatService(
@@ -76,6 +77,7 @@ func NewGeminiMessagesCompatService(
httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService,
cfg: cfg,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
}
}
@@ -229,6 +231,16 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
account *Account,
requestedModel, platform string,
useMixedScheduling bool,
) bool {
return s.isAccountUsableForRequestWithPrecheck(ctx, account, requestedModel, platform, useMixedScheduling, nil)
}
func (s *GeminiMessagesCompatService) isAccountUsableForRequestWithPrecheck(
ctx context.Context,
account *Account,
requestedModel, platform string,
useMixedScheduling bool,
precheckResult map[int64]bool,
) bool {
// 检查模型调度能力
// Check model scheduling capability
@@ -250,7 +262,7 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
// 速率限制预检
// Rate limit precheck
if !s.passesRateLimitPreCheck(ctx, account, requestedModel) {
if !s.passesRateLimitPreCheckWithCache(ctx, account, requestedModel, precheckResult) {
return false
}
@@ -272,15 +284,17 @@ func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account
return false
}
// passesRateLimitPreCheck 执行速率限制预检。
// 返回 true 表示通过预检或无需预检。
//
// passesRateLimitPreCheck performs rate limit precheck.
// Returns true if passed or precheck not required.
func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool {
func (s *GeminiMessagesCompatService) passesRateLimitPreCheckWithCache(ctx context.Context, account *Account, requestedModel string, precheckResult map[int64]bool) bool {
if s.rateLimitService == nil || requestedModel == "" {
return true
}
if precheckResult != nil {
if ok, exists := precheckResult[account.ID]; exists {
return ok
}
}
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
@@ -302,6 +316,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
useMixedScheduling bool,
) *Account {
var selected *Account
precheckResult := s.buildPreCheckUsageResultMap(ctx, accounts, requestedModel)
for i := range accounts {
acc := &accounts[i]
@@ -312,7 +327,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
}
// 检查账号是否可用于当前请求
if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) {
if !s.isAccountUsableForRequestWithPrecheck(ctx, acc, requestedModel, platform, useMixedScheduling, precheckResult) {
continue
}
@@ -330,6 +345,23 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
return selected
}
func (s *GeminiMessagesCompatService) buildPreCheckUsageResultMap(ctx context.Context, accounts []Account, requestedModel string) map[int64]bool {
if s.rateLimitService == nil || requestedModel == "" || len(accounts) == 0 {
return nil
}
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
candidates = append(candidates, &accounts[i])
}
result, err := s.rateLimitService.PreCheckUsageBatch(ctx, candidates, requestedModel)
if err != nil {
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheckBatch] failed: %v", err)
}
return result
}
// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
// 规则优先级更高数值更小优先同优先级时未使用过的优先OAuth > 非 OAuth其次是最久未使用的。
//
@@ -2390,7 +2422,7 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
}
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
@@ -2415,8 +2447,8 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
}
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Status(resp.StatusCode)
@@ -2557,7 +2589,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.responseHeaderFilter)
if wwwAuthenticate != "" {
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
}

View File

@@ -32,6 +32,9 @@ type Group struct {
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
// Sora 存储配额
SoraStorageQuotaBytes int64
// Claude Code 客户端限制
ClaudeCodeOnly bool
FallbackGroupID *int64

View File

@@ -4,8 +4,6 @@ import (
"context"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
)
const modelRateLimitsKey = "model_rate_limits"
@@ -73,7 +71,7 @@ func resolveFinalAntigravityModelKey(ctx context.Context, account *Account, requ
return ""
}
// thinking 会影响 Antigravity 最终模型名(例如 claude-sonnet-4-5 -> claude-sonnet-4-5-thinking
if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
if enabled, ok := ThinkingEnabledFromContext(ctx); ok {
modelKey = applyThinkingModelSuffix(modelKey, enabled)
}
return modelKey

View File

@@ -12,7 +12,7 @@ import (
// OpenAIOAuthClient interface for OpenAI OAuth operations
type OpenAIOAuthClient interface {
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error)
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error)
}

View File

@@ -14,10 +14,10 @@ import (
// --- mock: ClaudeOAuthClient ---
type mockClaudeOAuthClient struct {
getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error)
getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error)
refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error)
getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error)
refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
}
func (m *mockClaudeOAuthClient) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
@@ -437,9 +437,9 @@ func TestOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) {
// 无 refresh_token 的账号
account := &Account{
ID: 1,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
ID: 1,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "some-token",
},
@@ -460,9 +460,9 @@ func TestOAuthService_RefreshAccountToken_EmptyRefreshToken(t *testing.T) {
defer svc.Stop()
account := &Account{
ID: 2,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
ID: 2,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "some-token",
"refresh_token": "",

View File

@@ -0,0 +1,909 @@
package service
import (
"container/heap"
"context"
"errors"
"hash/fnv"
"math"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
const (
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
openAIAccountScheduleLayerSessionSticky = "session_hash"
openAIAccountScheduleLayerLoadBalance = "load_balance"
)
type OpenAIAccountScheduleRequest struct {
GroupID *int64
SessionHash string
StickyAccountID int64
PreviousResponseID string
RequestedModel string
RequiredTransport OpenAIUpstreamTransport
ExcludedIDs map[int64]struct{}
}
type OpenAIAccountScheduleDecision struct {
Layer string
StickyPreviousHit bool
StickySessionHit bool
CandidateCount int
TopK int
LatencyMs int64
LoadSkew float64
SelectedAccountID int64
SelectedAccountType string
}
type OpenAIAccountSchedulerMetricsSnapshot struct {
SelectTotal int64
StickyPreviousHitTotal int64
StickySessionHitTotal int64
LoadBalanceSelectTotal int64
AccountSwitchTotal int64
SchedulerLatencyMsTotal int64
SchedulerLatencyMsAvg float64
StickyHitRatio float64
AccountSwitchRate float64
LoadSkewAvg float64
RuntimeStatsAccountCount int
}
type OpenAIAccountScheduler interface {
Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error)
ReportResult(accountID int64, success bool, firstTokenMs *int)
ReportSwitch()
SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot
}
type openAIAccountSchedulerMetrics struct {
selectTotal atomic.Int64
stickyPreviousHitTotal atomic.Int64
stickySessionHitTotal atomic.Int64
loadBalanceSelectTotal atomic.Int64
accountSwitchTotal atomic.Int64
latencyMsTotal atomic.Int64
loadSkewMilliTotal atomic.Int64
}
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
if m == nil {
return
}
m.selectTotal.Add(1)
m.latencyMsTotal.Add(decision.LatencyMs)
m.loadSkewMilliTotal.Add(int64(math.Round(decision.LoadSkew * 1000)))
if decision.StickyPreviousHit {
m.stickyPreviousHitTotal.Add(1)
}
if decision.StickySessionHit {
m.stickySessionHitTotal.Add(1)
}
if decision.Layer == openAIAccountScheduleLayerLoadBalance {
m.loadBalanceSelectTotal.Add(1)
}
}
func (m *openAIAccountSchedulerMetrics) recordSwitch() {
if m == nil {
return
}
m.accountSwitchTotal.Add(1)
}
type openAIAccountRuntimeStats struct {
accounts sync.Map
accountCount atomic.Int64
}
type openAIAccountRuntimeStat struct {
errorRateEWMABits atomic.Uint64
ttftEWMABits atomic.Uint64
}
func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats {
return &openAIAccountRuntimeStats{}
}
func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat {
if value, ok := s.accounts.Load(accountID); ok {
stat, _ := value.(*openAIAccountRuntimeStat)
if stat != nil {
return stat
}
}
stat := &openAIAccountRuntimeStat{}
stat.ttftEWMABits.Store(math.Float64bits(math.NaN()))
actual, loaded := s.accounts.LoadOrStore(accountID, stat)
if !loaded {
s.accountCount.Add(1)
return stat
}
existing, _ := actual.(*openAIAccountRuntimeStat)
if existing != nil {
return existing
}
return stat
}
func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) {
for {
oldBits := target.Load()
oldValue := math.Float64frombits(oldBits)
newValue := alpha*sample + (1-alpha)*oldValue
if target.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
return
}
}
}
func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) {
if s == nil || accountID <= 0 {
return
}
const alpha = 0.2
stat := s.loadOrCreate(accountID)
errorSample := 1.0
if success {
errorSample = 0.0
}
updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha)
if firstTokenMs != nil && *firstTokenMs > 0 {
ttft := float64(*firstTokenMs)
ttftBits := math.Float64bits(ttft)
for {
oldBits := stat.ttftEWMABits.Load()
oldValue := math.Float64frombits(oldBits)
if math.IsNaN(oldValue) {
if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) {
break
}
continue
}
newValue := alpha*ttft + (1-alpha)*oldValue
if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
break
}
}
}
}
func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) {
if s == nil || accountID <= 0 {
return 0, 0, false
}
value, ok := s.accounts.Load(accountID)
if !ok {
return 0, 0, false
}
stat, _ := value.(*openAIAccountRuntimeStat)
if stat == nil {
return 0, 0, false
}
errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load()))
ttftValue := math.Float64frombits(stat.ttftEWMABits.Load())
if math.IsNaN(ttftValue) {
return errorRate, 0, false
}
return errorRate, ttftValue, true
}
func (s *openAIAccountRuntimeStats) size() int {
if s == nil {
return 0
}
return int(s.accountCount.Load())
}
type defaultOpenAIAccountScheduler struct {
service *OpenAIGatewayService
metrics openAIAccountSchedulerMetrics
stats *openAIAccountRuntimeStats
}
func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler {
if stats == nil {
stats = newOpenAIAccountRuntimeStats()
}
return &defaultOpenAIAccountScheduler{
service: service,
stats: stats,
}
}
func (s *defaultOpenAIAccountScheduler) Select(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
start := time.Now()
defer func() {
decision.LatencyMs = time.Since(start).Milliseconds()
s.metrics.recordSelect(decision)
}()
previousResponseID := strings.TrimSpace(req.PreviousResponseID)
if previousResponseID != "" {
selection, err := s.service.SelectAccountByPreviousResponseID(
ctx,
req.GroupID,
previousResponseID,
req.RequestedModel,
req.ExcludedIDs,
)
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) {
selection = nil
}
}
if selection != nil && selection.Account != nil {
decision.Layer = openAIAccountScheduleLayerPreviousResponse
decision.StickyPreviousHit = true
decision.SelectedAccountID = selection.Account.ID
decision.SelectedAccountType = selection.Account.Type
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, selection.Account.ID)
}
return selection, decision, nil
}
}
selection, err := s.selectBySessionHash(ctx, req)
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
decision.Layer = openAIAccountScheduleLayerSessionSticky
decision.StickySessionHit = true
decision.SelectedAccountID = selection.Account.ID
decision.SelectedAccountType = selection.Account.Type
return selection, decision, nil
}
selection, candidateCount, topK, loadSkew, err := s.selectByLoadBalance(ctx, req)
decision.Layer = openAIAccountScheduleLayerLoadBalance
decision.CandidateCount = candidateCount
decision.TopK = topK
decision.LoadSkew = loadSkew
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
decision.SelectedAccountID = selection.Account.ID
decision.SelectedAccountType = selection.Account.Type
}
return selection, decision, nil
}
func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, error) {
sessionHash := strings.TrimSpace(req.SessionHash)
if sessionHash == "" || s == nil || s.service == nil || s.service.cache == nil {
return nil, nil
}
accountID := req.StickyAccountID
if accountID <= 0 {
var err error
accountID, err = s.service.getStickySessionAccountID(ctx, req.GroupID, sessionHash)
if err != nil || accountID <= 0 {
return nil, nil
}
}
if accountID <= 0 {
return nil, nil
}
if req.ExcludedIDs != nil {
if _, excluded := req.ExcludedIDs[accountID]; excluded {
return nil, nil
}
}
account, err := s.service.getSchedulableAccount(ctx, accountID)
if err != nil || account == nil {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
return nil, nil
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
cfg := s.service.schedulingConfig()
if s.service.concurrencyService != nil {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
return nil, nil
}
type openAIAccountCandidateScore struct {
account *Account
loadInfo *AccountLoadInfo
score float64
errorRate float64
ttft float64
hasTTFT bool
}
type openAIAccountCandidateHeap []openAIAccountCandidateScore
func (h openAIAccountCandidateHeap) Len() int {
return len(h)
}
func (h openAIAccountCandidateHeap) Less(i, j int) bool {
// 最小堆根节点保存“最差”候选,便于 O(log k) 维护 topK。
return isOpenAIAccountCandidateBetter(h[j], h[i])
}
func (h openAIAccountCandidateHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
}
func (h *openAIAccountCandidateHeap) Push(x any) {
candidate, ok := x.(openAIAccountCandidateScore)
if !ok {
panic("openAIAccountCandidateHeap: invalid element type")
}
*h = append(*h, candidate)
}
func (h *openAIAccountCandidateHeap) Pop() any {
old := *h
n := len(old)
last := old[n-1]
*h = old[:n-1]
return last
}
func isOpenAIAccountCandidateBetter(left openAIAccountCandidateScore, right openAIAccountCandidateScore) bool {
if left.score != right.score {
return left.score > right.score
}
if left.account.Priority != right.account.Priority {
return left.account.Priority < right.account.Priority
}
if left.loadInfo.LoadRate != right.loadInfo.LoadRate {
return left.loadInfo.LoadRate < right.loadInfo.LoadRate
}
if left.loadInfo.WaitingCount != right.loadInfo.WaitingCount {
return left.loadInfo.WaitingCount < right.loadInfo.WaitingCount
}
return left.account.ID < right.account.ID
}
func selectTopKOpenAICandidates(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore {
if len(candidates) == 0 {
return nil
}
if topK <= 0 {
topK = 1
}
if topK >= len(candidates) {
ranked := append([]openAIAccountCandidateScore(nil), candidates...)
sort.Slice(ranked, func(i, j int) bool {
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
})
return ranked
}
best := make(openAIAccountCandidateHeap, 0, topK)
for _, candidate := range candidates {
if len(best) < topK {
heap.Push(&best, candidate)
continue
}
if isOpenAIAccountCandidateBetter(candidate, best[0]) {
best[0] = candidate
heap.Fix(&best, 0)
}
}
ranked := make([]openAIAccountCandidateScore, len(best))
copy(ranked, best)
sort.Slice(ranked, func(i, j int) bool {
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
})
return ranked
}
type openAISelectionRNG struct {
state uint64
}
func newOpenAISelectionRNG(seed uint64) openAISelectionRNG {
if seed == 0 {
seed = 0x9e3779b97f4a7c15
}
return openAISelectionRNG{state: seed}
}
func (r *openAISelectionRNG) nextUint64() uint64 {
// xorshift64*
x := r.state
x ^= x >> 12
x ^= x << 25
x ^= x >> 27
r.state = x
return x * 2685821657736338717
}
func (r *openAISelectionRNG) nextFloat64() float64 {
// [0,1)
return float64(r.nextUint64()>>11) / (1 << 53)
}
func deriveOpenAISelectionSeed(req OpenAIAccountScheduleRequest) uint64 {
hasher := fnv.New64a()
writeValue := func(value string) {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return
}
_, _ = hasher.Write([]byte(trimmed))
_, _ = hasher.Write([]byte{0})
}
writeValue(req.SessionHash)
writeValue(req.PreviousResponseID)
writeValue(req.RequestedModel)
if req.GroupID != nil {
_, _ = hasher.Write([]byte(strconv.FormatInt(*req.GroupID, 10)))
}
seed := hasher.Sum64()
// 对“无会话锚点”的纯负载均衡请求引入时间熵,避免固定命中同一账号。
if strings.TrimSpace(req.SessionHash) == "" && strings.TrimSpace(req.PreviousResponseID) == "" {
seed ^= uint64(time.Now().UnixNano())
}
if seed == 0 {
seed = uint64(time.Now().UnixNano()) ^ 0x9e3779b97f4a7c15
}
return seed
}
func buildOpenAIWeightedSelectionOrder(
candidates []openAIAccountCandidateScore,
req OpenAIAccountScheduleRequest,
) []openAIAccountCandidateScore {
if len(candidates) <= 1 {
return append([]openAIAccountCandidateScore(nil), candidates...)
}
pool := append([]openAIAccountCandidateScore(nil), candidates...)
weights := make([]float64, len(pool))
minScore := pool[0].score
for i := 1; i < len(pool); i++ {
if pool[i].score < minScore {
minScore = pool[i].score
}
}
for i := range pool {
// 将 top-K 分值平移到正区间,避免“单一最高分账号”长期垄断。
weight := (pool[i].score - minScore) + 1.0
if math.IsNaN(weight) || math.IsInf(weight, 0) || weight <= 0 {
weight = 1.0
}
weights[i] = weight
}
order := make([]openAIAccountCandidateScore, 0, len(pool))
rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
for len(pool) > 0 {
total := 0.0
for _, w := range weights {
total += w
}
selectedIdx := 0
if total > 0 {
r := rng.nextFloat64() * total
acc := 0.0
for i, w := range weights {
acc += w
if r <= acc {
selectedIdx = i
break
}
}
} else {
selectedIdx = int(rng.nextUint64() % uint64(len(pool)))
}
order = append(order, pool[selectedIdx])
pool = append(pool[:selectedIdx], pool[selectedIdx+1:]...)
weights = append(weights[:selectedIdx], weights[selectedIdx+1:]...)
}
return order
}
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
ctx context.Context,
req OpenAIAccountScheduleRequest,
) (*AccountSelectionResult, int, int, float64, error) {
accounts, err := s.service.listSchedulableAccounts(ctx, req.GroupID)
if err != nil {
return nil, 0, 0, 0, err
}
if len(accounts) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
}
filtered := make([]*Account, 0, len(accounts))
loadReq := make([]AccountWithConcurrency, 0, len(accounts))
for i := range accounts {
account := &accounts[i]
if req.ExcludedIDs != nil {
if _, excluded := req.ExcludedIDs[account.ID]; excluded {
continue
}
}
if !account.IsSchedulable() || !account.IsOpenAI() {
continue
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
continue
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
continue
}
filtered = append(filtered, account)
loadReq = append(loadReq, AccountWithConcurrency{
ID: account.ID,
MaxConcurrency: account.Concurrency,
})
}
if len(filtered) == 0 {
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
}
loadMap := map[int64]*AccountLoadInfo{}
if s.service.concurrencyService != nil {
if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil {
loadMap = batchLoad
}
}
minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
maxWaiting := 1
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
hasTTFTSample := false
candidates := make([]openAIAccountCandidateScore, 0, len(filtered))
for _, account := range filtered {
loadInfo := loadMap[account.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: account.ID}
}
if account.Priority < minPriority {
minPriority = account.Priority
}
if account.Priority > maxPriority {
maxPriority = account.Priority
}
if loadInfo.WaitingCount > maxWaiting {
maxWaiting = loadInfo.WaitingCount
}
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
if hasTTFT && ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = ttft, ttft
hasTTFTSample = true
} else {
if ttft < minTTFT {
minTTFT = ttft
}
if ttft > maxTTFT {
maxTTFT = ttft
}
}
}
loadRate := float64(loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
candidates = append(candidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
errorRate: errorRate,
ttft: ttft,
hasTTFT: hasTTFT,
})
}
loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
weights := s.service.openAIWSSchedulerWeights()
for i := range candidates {
item := &candidates[i]
priorityFactor := 1.0
if maxPriority > minPriority {
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
}
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
errorFactor := 1 - clamp01(item.errorRate)
ttftFactor := 0.5
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
}
item.score = weights.Priority*priorityFactor +
weights.Load*loadFactor +
weights.Queue*queueFactor +
weights.ErrorRate*errorFactor +
weights.TTFT*ttftFactor
}
topK := s.service.openAIWSLBTopK()
if topK > len(candidates) {
topK = len(candidates)
}
if topK <= 0 {
topK = 1
}
rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency)
if acquireErr != nil {
return nil, len(candidates), topK, loadSkew, acquireErr
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID)
}
return &AccountSelectionResult{
Account: candidate.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, len(candidates), topK, loadSkew, nil
}
}
cfg := s.service.schedulingConfig()
candidate := selectionOrder[0]
return &AccountSelectionResult{
Account: candidate.account,
WaitPlan: &AccountWaitPlan{
AccountID: candidate.account.ID,
MaxConcurrency: candidate.account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, len(candidates), topK, loadSkew, nil
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
// HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
return true
}
if s == nil || s.service == nil || account == nil {
return false
}
return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
}
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
if s == nil || s.stats == nil {
return
}
s.stats.report(accountID, success, firstTokenMs)
}
func (s *defaultOpenAIAccountScheduler) ReportSwitch() {
if s == nil {
return
}
s.metrics.recordSwitch()
}
func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot {
if s == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}
selectTotal := s.metrics.selectTotal.Load()
prevHit := s.metrics.stickyPreviousHitTotal.Load()
sessionHit := s.metrics.stickySessionHitTotal.Load()
switchTotal := s.metrics.accountSwitchTotal.Load()
latencyTotal := s.metrics.latencyMsTotal.Load()
loadSkewTotal := s.metrics.loadSkewMilliTotal.Load()
snapshot := OpenAIAccountSchedulerMetricsSnapshot{
SelectTotal: selectTotal,
StickyPreviousHitTotal: prevHit,
StickySessionHitTotal: sessionHit,
LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(),
AccountSwitchTotal: switchTotal,
SchedulerLatencyMsTotal: latencyTotal,
RuntimeStatsAccountCount: s.stats.size(),
}
if selectTotal > 0 {
snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal)
snapshot.StickyHitRatio = float64(prevHit+sessionHit) / float64(selectTotal)
snapshot.AccountSwitchRate = float64(switchTotal) / float64(selectTotal)
snapshot.LoadSkewAvg = float64(loadSkewTotal) / 1000 / float64(selectTotal)
}
return snapshot
}
func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
if s == nil {
return nil
}
s.openaiSchedulerOnce.Do(func() {
if s.openaiAccountStats == nil {
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
}
if s.openaiScheduler == nil {
s.openaiScheduler = newDefaultOpenAIAccountScheduler(s, s.openaiAccountStats)
}
})
return s.openaiScheduler
}
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
ctx context.Context,
groupID *int64,
previousResponseID string,
sessionHash string,
requestedModel string,
excludedIDs map[int64]struct{},
requiredTransport OpenAIUpstreamTransport,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance
return selection, decision, err
}
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 {
stickyAccountID = accountID
}
}
return scheduler.Select(ctx, OpenAIAccountScheduleRequest{
GroupID: groupID,
SessionHash: sessionHash,
StickyAccountID: stickyAccountID,
PreviousResponseID: previousResponseID,
RequestedModel: requestedModel,
RequiredTransport: requiredTransport,
ExcludedIDs: excludedIDs,
})
}
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
return
}
scheduler.ReportResult(accountID, success, firstTokenMs)
}
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
return
}
scheduler.ReportSwitch()
}
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}
return scheduler.SnapshotMetrics()
}
func (s *OpenAIGatewayService) openAIWSSessionStickyTTL() time.Duration {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 {
return time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second
}
return openaiStickySessionTTL
}
func (s *OpenAIGatewayService) openAIWSLBTopK() int {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 {
return s.cfg.Gateway.OpenAIWS.LBTopK
}
return 7
}
func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView {
if s != nil && s.cfg != nil {
return GatewayOpenAIWSSchedulerScoreWeightsView{
Priority: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority,
Load: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load,
Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue,
ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate,
TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT,
}
}
return GatewayOpenAIWSSchedulerScoreWeightsView{
Priority: 1.0,
Load: 1.0,
Queue: 0.7,
ErrorRate: 0.8,
TTFT: 0.5,
}
}
type GatewayOpenAIWSSchedulerScoreWeightsView struct {
Priority float64
Load float64
Queue float64
ErrorRate float64
TTFT float64
}
func clamp01(value float64) float64 {
switch {
case value < 0:
return 0
case value > 1:
return 1
default:
return value
}
}
func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 {
if count <= 1 {
return 0
}
mean := sum / float64(count)
variance := sumSquares/float64(count) - mean*mean
if variance < 0 {
variance = 0
}
return math.Sqrt(variance)
}

View File

@@ -0,0 +1,83 @@
package service
import (
"sort"
"testing"
)
func buildOpenAISchedulerBenchmarkCandidates(size int) []openAIAccountCandidateScore {
if size <= 0 {
return nil
}
candidates := make([]openAIAccountCandidateScore, 0, size)
for i := 0; i < size; i++ {
accountID := int64(10_000 + i)
candidates = append(candidates, openAIAccountCandidateScore{
account: &Account{
ID: accountID,
Priority: i % 7,
},
loadInfo: &AccountLoadInfo{
AccountID: accountID,
LoadRate: (i * 17) % 100,
WaitingCount: (i * 11) % 13,
},
score: float64((i*29)%1000) / 100,
errorRate: float64((i * 5) % 100 / 100),
ttft: float64(30 + (i*3)%500),
hasTTFT: i%3 != 0,
})
}
return candidates
}
func selectTopKOpenAICandidatesBySortBenchmark(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore {
if len(candidates) == 0 {
return nil
}
if topK <= 0 {
topK = 1
}
ranked := append([]openAIAccountCandidateScore(nil), candidates...)
sort.Slice(ranked, func(i, j int) bool {
return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
})
if topK > len(ranked) {
topK = len(ranked)
}
return ranked[:topK]
}
func BenchmarkOpenAIAccountSchedulerSelectTopK(b *testing.B) {
cases := []struct {
name string
size int
topK int
}{
{name: "n_16_k_3", size: 16, topK: 3},
{name: "n_64_k_3", size: 64, topK: 3},
{name: "n_256_k_5", size: 256, topK: 5},
}
for _, tc := range cases {
candidates := buildOpenAISchedulerBenchmarkCandidates(tc.size)
b.Run(tc.name+"/heap_topk", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
result := selectTopKOpenAICandidates(candidates, tc.topK)
if len(result) == 0 {
b.Fatal("unexpected empty result")
}
}
})
b.Run(tc.name+"/full_sort", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
result := selectTopKOpenAICandidatesBySortBenchmark(candidates, tc.topK)
if len(result) == 0 {
b.Fatal("unexpected empty result")
}
}
})
}
}

View File

@@ -0,0 +1,841 @@
package service
import (
"context"
"fmt"
"math"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
ctx := context.Background()
groupID := int64(9)
account := Account{
ID: 1001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 2,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_001", account.ID, time.Hour))
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"resp_prev_001",
"session_hash_001",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, account.ID, selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer)
require.True(t, decision.StickyPreviousHit)
require.Equal(t, account.ID, cache.sessionBindings["openai:session_hash_001"])
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testing.T) {
ctx := context.Background()
groupID := int64(10)
account := Account{
ID: 2001,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_abc": account.ID,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"session_hash_abc",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, account.ID, selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
require.True(t, decision.StickySessionHit)
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsSticky(t *testing.T) {
ctx := context.Background()
groupID := int64(10100)
accounts := []Account{
{
ID: 21001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
{
ID: 21002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 9,
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_sticky_busy": 21001,
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2
cfg.Gateway.Scheduling.StickySessionWaitTimeout = 45 * time.Second
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
concurrencyCache := stubConcurrencyCache{
acquireResults: map[int64]bool{
21001: false, // sticky 账号已满
21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换)
},
waitCounts: map[int64]int{
21001: 999,
},
loadMap: map[int64]*AccountLoadInfo{
21001: {AccountID: 21001, LoadRate: 90, WaitingCount: 9},
21002: {AccountID: 21002, LoadRate: 1, WaitingCount: 0},
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"session_hash_sticky_busy",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(21001), selection.Account.ID, "busy sticky account should remain selected")
require.False(t, selection.Acquired)
require.NotNil(t, selection.WaitPlan)
require.Equal(t, int64(21001), selection.WaitPlan.AccountID)
require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
require.True(t, decision.StickySessionHit)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP(t *testing.T) {
ctx := context.Background()
groupID := int64(1010)
account := Account{
ID: 2101,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Extra: map[string]any{
"openai_ws_force_http": true,
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_force_http": account.ID,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"session_hash_force_http",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, account.ID, selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
require.True(t, decision.StickySessionHit)
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStickyHTTPAccount(t *testing.T) {
ctx := context.Background()
groupID := int64(1011)
accounts := []Account{
{
ID: 2201,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
{
ID: 2202,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
},
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_ws_only": 2201,
},
}
cfg := newOpenAIWSV2TestConfig()
// 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
concurrencyCache := stubConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0},
2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5},
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"session_hash_ws_only",
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(2202), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
require.False(t, decision.StickySessionHit)
require.Equal(t, 1, decision.CandidateCount)
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailableAccount(t *testing.T) {
ctx := context.Background()
groupID := int64(1012)
accounts := []Account{
{
ID: 2301,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{},
cfg: newOpenAIWSV2TestConfig(),
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
)
require.Error(t, err)
require.Nil(t, selection)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
require.Equal(t, 0, decision.CandidateCount)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback(t *testing.T) {
ctx := context.Background()
groupID := int64(11)
accounts := []Account{
{
ID: 3001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
{
ID: 3002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
{
ID: 3003,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.LBTopK = 2
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.4
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1
concurrencyCache := stubConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8},
3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1},
3003: {AccountID: 3003, LoadRate: 10, WaitingCount: 0},
},
acquireResults: map[int64]bool{
3003: false, // top1 失败,必须回退到 top-K 的下一候选
3002: true,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{},
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(3002), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
require.Equal(t, 3, decision.CandidateCount)
require.Equal(t, 2, decision.TopK)
require.Greater(t, decision.LoadSkew, 0.0)
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
ctx := context.Background()
groupID := int64(12)
account := Account{
ID: 4001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
}
cache := &stubGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_metrics": account.ID,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
require.NotNil(t, selection)
svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120))
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
require.GreaterOrEqual(t, snapshot.SelectTotal, int64(1))
require.GreaterOrEqual(t, snapshot.StickySessionHitTotal, int64(1))
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
require.GreaterOrEqual(t, snapshot.SchedulerLatencyMsAvg, float64(0))
require.GreaterOrEqual(t, snapshot.StickyHitRatio, 0.0)
require.GreaterOrEqual(t, snapshot.RuntimeStatsAccountCount, 1)
}
func intPtrForTest(v int) *int {
return &v
}
func TestOpenAIAccountRuntimeStats_ReportAndSnapshot(t *testing.T) {
stats := newOpenAIAccountRuntimeStats()
stats.report(1001, true, nil)
firstTTFT := 100
stats.report(1001, false, &firstTTFT)
secondTTFT := 200
stats.report(1001, false, &secondTTFT)
errorRate, ttft, hasTTFT := stats.snapshot(1001)
require.True(t, hasTTFT)
require.InDelta(t, 0.36, errorRate, 1e-9)
require.InDelta(t, 120.0, ttft, 1e-9)
require.Equal(t, 1, stats.size())
}
func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) {
stats := newOpenAIAccountRuntimeStats()
const (
accountCount = 4
workers = 16
iterations = 800
)
var wg sync.WaitGroup
wg.Add(workers)
for worker := 0; worker < workers; worker++ {
worker := worker
go func() {
defer wg.Done()
for i := 0; i < iterations; i++ {
accountID := int64(i%accountCount + 1)
success := (i+worker)%3 != 0
ttft := 80 + (i+worker)%40
stats.report(accountID, success, &ttft)
}
}()
}
wg.Wait()
require.Equal(t, accountCount, stats.size())
for accountID := int64(1); accountID <= accountCount; accountID++ {
errorRate, ttft, hasTTFT := stats.snapshot(accountID)
require.GreaterOrEqual(t, errorRate, 0.0)
require.LessOrEqual(t, errorRate, 1.0)
require.True(t, hasTTFT)
require.Greater(t, ttft, 0.0)
}
}
func TestSelectTopKOpenAICandidates(t *testing.T) {
candidates := []openAIAccountCandidateScore{
{
account: &Account{ID: 11, Priority: 2},
loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 1},
score: 10.0,
},
{
account: &Account{ID: 12, Priority: 1},
loadInfo: &AccountLoadInfo{LoadRate: 20, WaitingCount: 1},
score: 9.5,
},
{
account: &Account{ID: 13, Priority: 1},
loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 0},
score: 10.0,
},
{
account: &Account{ID: 14, Priority: 0},
loadInfo: &AccountLoadInfo{LoadRate: 40, WaitingCount: 0},
score: 8.0,
},
}
top2 := selectTopKOpenAICandidates(candidates, 2)
require.Len(t, top2, 2)
require.Equal(t, int64(13), top2[0].account.ID)
require.Equal(t, int64(11), top2[1].account.ID)
topAll := selectTopKOpenAICandidates(candidates, 8)
require.Len(t, topAll, len(candidates))
require.Equal(t, int64(13), topAll[0].account.ID)
require.Equal(t, int64(11), topAll[1].account.ID)
require.Equal(t, int64(12), topAll[2].account.ID)
require.Equal(t, int64(14), topAll[3].account.ID)
}
func TestBuildOpenAIWeightedSelectionOrder_DeterministicBySessionSeed(t *testing.T) {
candidates := []openAIAccountCandidateScore{
{
account: &Account{ID: 101},
loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 0},
score: 4.2,
},
{
account: &Account{ID: 102},
loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 1},
score: 3.5,
},
{
account: &Account{ID: 103},
loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 2},
score: 2.1,
},
}
req := OpenAIAccountScheduleRequest{
GroupID: int64PtrForTest(99),
SessionHash: "session_seed_fixed",
RequestedModel: "gpt-5.1",
}
first := buildOpenAIWeightedSelectionOrder(candidates, req)
second := buildOpenAIWeightedSelectionOrder(candidates, req)
require.Len(t, first, len(candidates))
require.Len(t, second, len(candidates))
for i := range first {
require.Equal(t, first[i].account.ID, second[i].account.ID)
}
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesAcrossSessions(t *testing.T) {
ctx := context.Background()
groupID := int64(15)
accounts := []Account{
{
ID: 5101,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 3,
Priority: 0,
},
{
ID: 5102,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 3,
Priority: 0,
},
{
ID: 5103,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 3,
Priority: 0,
},
}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.LBTopK = 3
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
concurrencyCache := stubConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1},
5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1},
5103: {AccountID: 5103, LoadRate: 20, WaitingCount: 1},
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
selected := make(map[int64]int, len(accounts))
for i := 0; i < 60; i++ {
sessionHash := fmt.Sprintf("session_hash_lb_%d", i)
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"",
sessionHash,
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
selected[selection.Account.ID]++
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
// 多 session 应该能打散到多个账号,避免“恒定单账号命中”。
require.GreaterOrEqual(t, len(selected), 2)
}
func TestDeriveOpenAISelectionSeed_NoAffinityAddsEntropy(t *testing.T) {
req := OpenAIAccountScheduleRequest{
RequestedModel: "gpt-5.1",
}
seed1 := deriveOpenAISelectionSeed(req)
time.Sleep(1 * time.Millisecond)
seed2 := deriveOpenAISelectionSeed(req)
require.NotZero(t, seed1)
require.NotZero(t, seed2)
require.NotEqual(t, seed1, seed2)
}
func TestBuildOpenAIWeightedSelectionOrder_HandlesInvalidScores(t *testing.T) {
candidates := []openAIAccountCandidateScore{
{
account: &Account{ID: 901},
loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
score: math.NaN(),
},
{
account: &Account{ID: 902},
loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
score: math.Inf(1),
},
{
account: &Account{ID: 903},
loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
score: -1,
},
}
req := OpenAIAccountScheduleRequest{
SessionHash: "seed_invalid_scores",
}
order := buildOpenAIWeightedSelectionOrder(candidates, req)
require.Len(t, order, len(candidates))
seen := map[int64]struct{}{}
for _, item := range order {
seen[item.account.ID] = struct{}{}
}
require.Len(t, seen, len(candidates))
}
func TestOpenAISelectionRNG_SeedZeroStillWorks(t *testing.T) {
rng := newOpenAISelectionRNG(0)
v1 := rng.nextUint64()
v2 := rng.nextUint64()
require.NotEqual(t, v1, v2)
require.GreaterOrEqual(t, rng.nextFloat64(), 0.0)
require.Less(t, rng.nextFloat64(), 1.0)
}
func TestOpenAIAccountCandidateHeap_PushPopAndInvalidType(t *testing.T) {
h := openAIAccountCandidateHeap{}
h.Push(openAIAccountCandidateScore{
account: &Account{ID: 7001},
loadInfo: &AccountLoadInfo{LoadRate: 0, WaitingCount: 0},
score: 1.0,
})
require.Equal(t, 1, h.Len())
popped, ok := h.Pop().(openAIAccountCandidateScore)
require.True(t, ok)
require.Equal(t, int64(7001), popped.account.ID)
require.Equal(t, 0, h.Len())
require.Panics(t, func() {
h.Push("bad_element_type")
})
}
func TestClamp01_AllBranches(t *testing.T) {
require.Equal(t, 0.0, clamp01(-0.2))
require.Equal(t, 1.0, clamp01(1.3))
require.Equal(t, 0.5, clamp01(0.5))
}
func TestCalcLoadSkewByMoments_Branches(t *testing.T) {
require.Equal(t, 0.0, calcLoadSkewByMoments(1, 1, 1))
// variance < 0 分支sumSquares/count - mean^2 为负值时应钳制为 0。
require.Equal(t, 0.0, calcLoadSkewByMoments(1, 0, 2))
require.GreaterOrEqual(t, calcLoadSkewByMoments(6, 20, 3), 0.0)
}
func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
schedulerAny := newDefaultOpenAIAccountScheduler(&OpenAIGatewayService{}, nil)
scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
require.True(t, ok)
ttft := 100
scheduler.ReportResult(1001, true, &ttft)
scheduler.ReportSwitch()
scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{
Layer: openAIAccountScheduleLayerLoadBalance,
LatencyMs: 8,
LoadSkew: 0.5,
StickyPreviousHit: true,
})
scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{
Layer: openAIAccountScheduleLayerSessionSticky,
LatencyMs: 6,
LoadSkew: 0.2,
StickySessionHit: true,
})
snapshot := scheduler.SnapshotMetrics()
require.Equal(t, int64(2), snapshot.SelectTotal)
require.Equal(t, int64(1), snapshot.StickyPreviousHitTotal)
require.Equal(t, int64(1), snapshot.StickySessionHitTotal)
require.Equal(t, int64(1), snapshot.LoadBalanceSelectTotal)
require.Equal(t, int64(1), snapshot.AccountSwitchTotal)
require.Greater(t, snapshot.SchedulerLatencyMsAvg, 0.0)
require.Greater(t, snapshot.StickyHitRatio, 0.0)
require.Greater(t, snapshot.LoadSkewAvg, 0.0)
}
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
svc := &OpenAIGatewayService{}
ttft := 120
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
require.Equal(t, 7, svc.openAIWSLBTopK())
require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
defaultWeights := svc.openAIWSSchedulerWeights()
require.Equal(t, 1.0, defaultWeights.Priority)
require.Equal(t, 1.0, defaultWeights.Load)
require.Equal(t, 0.7, defaultWeights.Queue)
require.Equal(t, 0.8, defaultWeights.ErrorRate)
require.Equal(t, 0.5, defaultWeights.TTFT)
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.LBTopK = 9
cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 180
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.2
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0.3
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.4
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.5
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.6
svcWithCfg := &OpenAIGatewayService{cfg: cfg}
require.Equal(t, 9, svcWithCfg.openAIWSLBTopK())
require.Equal(t, 180*time.Second, svcWithCfg.openAIWSSessionStickyTTL())
customWeights := svcWithCfg.openAIWSSchedulerWeights()
require.Equal(t, 0.2, customWeights.Priority)
require.Equal(t, 0.3, customWeights.Load)
require.Equal(t, 0.4, customWeights.Queue)
require.Equal(t, 0.5, customWeights.ErrorRate)
require.Equal(t, 0.6, customWeights.TTFT)
}
func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *testing.T) {
scheduler := &defaultOpenAIAccountScheduler{}
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportAny))
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE))
require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2))
cfg := newOpenAIWSV2TestConfig()
scheduler.service = &OpenAIGatewayService{cfg: cfg}
account := &Account{
ID: 8801,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
require.True(t, scheduler.isAccountTransportCompatible(account, OpenAIUpstreamTransportResponsesWebsocketV2))
}
func int64PtrForTest(v int64) *int64 {
return &v
}

View File

@@ -0,0 +1,71 @@
package service
import (
"strings"
"github.com/gin-gonic/gin"
)
// OpenAIClientTransport 表示客户端入站协议类型。
type OpenAIClientTransport string
const (
OpenAIClientTransportUnknown OpenAIClientTransport = ""
OpenAIClientTransportHTTP OpenAIClientTransport = "http"
OpenAIClientTransportWS OpenAIClientTransport = "ws"
)
const openAIClientTransportContextKey = "openai_client_transport"
// SetOpenAIClientTransport 标记当前请求的客户端入站协议。
func SetOpenAIClientTransport(c *gin.Context, transport OpenAIClientTransport) {
if c == nil {
return
}
normalized := normalizeOpenAIClientTransport(transport)
if normalized == OpenAIClientTransportUnknown {
return
}
c.Set(openAIClientTransportContextKey, string(normalized))
}
// GetOpenAIClientTransport 读取当前请求的客户端入站协议。
func GetOpenAIClientTransport(c *gin.Context) OpenAIClientTransport {
if c == nil {
return OpenAIClientTransportUnknown
}
raw, ok := c.Get(openAIClientTransportContextKey)
if !ok || raw == nil {
return OpenAIClientTransportUnknown
}
switch v := raw.(type) {
case OpenAIClientTransport:
return normalizeOpenAIClientTransport(v)
case string:
return normalizeOpenAIClientTransport(OpenAIClientTransport(v))
default:
return OpenAIClientTransportUnknown
}
}
func normalizeOpenAIClientTransport(transport OpenAIClientTransport) OpenAIClientTransport {
switch strings.ToLower(strings.TrimSpace(string(transport))) {
case string(OpenAIClientTransportHTTP), "http_sse", "sse":
return OpenAIClientTransportHTTP
case string(OpenAIClientTransportWS), "websocket":
return OpenAIClientTransportWS
default:
return OpenAIClientTransportUnknown
}
}
func resolveOpenAIWSDecisionByClientTransport(
decision OpenAIWSProtocolDecision,
clientTransport OpenAIClientTransport,
) OpenAIWSProtocolDecision {
if clientTransport == OpenAIClientTransportHTTP {
return openAIWSHTTPDecision("client_protocol_http")
}
return decision
}

View File

@@ -0,0 +1,107 @@
package service
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestOpenAIClientTransport_SetAndGet(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(c))
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c))
SetOpenAIClientTransport(c, OpenAIClientTransportWS)
require.Equal(t, OpenAIClientTransportWS, GetOpenAIClientTransport(c))
}
func TestOpenAIClientTransport_GetNormalizesRawContextValue(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
rawValue any
want OpenAIClientTransport
}{
{
name: "type_value_ws",
rawValue: OpenAIClientTransportWS,
want: OpenAIClientTransportWS,
},
{
name: "http_sse_alias",
rawValue: "http_sse",
want: OpenAIClientTransportHTTP,
},
{
name: "sse_alias",
rawValue: "sSe",
want: OpenAIClientTransportHTTP,
},
{
name: "websocket_alias",
rawValue: "WebSocket",
want: OpenAIClientTransportWS,
},
{
name: "invalid_string",
rawValue: "tcp",
want: OpenAIClientTransportUnknown,
},
{
name: "invalid_type",
rawValue: 123,
want: OpenAIClientTransportUnknown,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Set(openAIClientTransportContextKey, tt.rawValue)
require.Equal(t, tt.want, GetOpenAIClientTransport(c))
})
}
}
func TestOpenAIClientTransport_NilAndUnknownInput(t *testing.T) {
SetOpenAIClientTransport(nil, OpenAIClientTransportHTTP)
require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(nil))
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
SetOpenAIClientTransport(c, OpenAIClientTransportUnknown)
_, exists := c.Get(openAIClientTransportContextKey)
require.False(t, exists)
SetOpenAIClientTransport(c, OpenAIClientTransport(" "))
_, exists = c.Get(openAIClientTransportContextKey)
require.False(t, exists)
}
func TestResolveOpenAIWSDecisionByClientTransport(t *testing.T) {
base := OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
Reason: "ws_v2_enabled",
}
httpDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportHTTP)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, httpDecision.Transport)
require.Equal(t, "client_protocol_http", httpDecision.Reason)
wsDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportWS)
require.Equal(t, base, wsDecision)
unknownDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportUnknown)
require.Equal(t, base, unknownDecision)
}

File diff suppressed because it is too large Load Diff

View File

@@ -123,3 +123,19 @@ func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "parse request")
}
func TestGetOpenAIRequestBodyMap_WriteBackContextCache(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
got, err := getOpenAIRequestBodyMap(c, []byte(`{"model":"gpt-5","stream":true}`))
require.NoError(t, err)
require.Equal(t, "gpt-5", got["model"])
cached, ok := c.Get(OpenAIParsedRequestBodyKey)
require.True(t, ok)
cachedMap, ok := cached.(map[string]any)
require.True(t, ok)
require.Equal(t, got, cachedMap)
}

View File

@@ -5,6 +5,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
@@ -13,6 +14,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/cespare/xxhash/v2"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -166,6 +168,54 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
}
}
func TestOpenAIGatewayService_GenerateSessionHash_UsesXXHash64(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("session_id", "sess-fixed-value")
svc := &OpenAIGatewayService{}
got := svc.GenerateSessionHash(c, nil)
want := fmt.Sprintf("%016x", xxhash.Sum64String("sess-fixed-value"))
require.Equal(t, want, got)
}
func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("session_id", "sess-legacy-check")
svc := &OpenAIGatewayService{}
sessionHash := svc.GenerateSessionHash(c, nil)
require.NotEmpty(t, sessionHash)
require.NotNil(t, c.Request)
require.NotNil(t, c.Request.Context())
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
}
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
svc := &OpenAIGatewayService{}
seed := "openai_ws_ingress:9:100:200"
got := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), seed)
want := fmt.Sprintf("%016x", xxhash.Sum64String(seed))
require.Equal(t, want, got)
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
empty := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), " ")
require.Equal(t, "", empty)
}
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if c.waitCounts != nil {
if count, ok := c.waitCounts[accountID]; ok {

View File

@@ -0,0 +1,357 @@
package service
import (
"encoding/json"
"strconv"
"strings"
"testing"
"github.com/tidwall/gjson"
)
var (
benchmarkToolContinuationBoolSink bool
benchmarkWSParseStringSink string
benchmarkWSParseMapSink map[string]any
benchmarkUsageSink OpenAIUsage
)
func BenchmarkToolContinuationValidationLegacy(b *testing.B) {
reqBody := benchmarkToolContinuationRequestBody()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkToolContinuationBoolSink = legacyValidateFunctionCallOutputContext(reqBody)
}
}
func BenchmarkToolContinuationValidationOptimized(b *testing.B) {
reqBody := benchmarkToolContinuationRequestBody()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkToolContinuationBoolSink = optimizedValidateFunctionCallOutputContext(reqBody)
}
}
func BenchmarkWSIngressPayloadParseLegacy(b *testing.B) {
raw := benchmarkWSIngressPayloadBytes()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
eventType, model, promptCacheKey, previousResponseID, payload, err := legacyParseWSIngressPayload(raw)
if err == nil {
benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID
benchmarkWSParseMapSink = payload
}
}
}
func BenchmarkWSIngressPayloadParseOptimized(b *testing.B) {
raw := benchmarkWSIngressPayloadBytes()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
eventType, model, promptCacheKey, previousResponseID, payload, err := optimizedParseWSIngressPayload(raw)
if err == nil {
benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID
benchmarkWSParseMapSink = payload
}
}
}
func BenchmarkOpenAIUsageExtractLegacy(b *testing.B) {
body := benchmarkOpenAIUsageJSONBytes()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
usage, ok := legacyExtractOpenAIUsageFromJSONBytes(body)
if ok {
benchmarkUsageSink = usage
}
}
}
func BenchmarkOpenAIUsageExtractOptimized(b *testing.B) {
body := benchmarkOpenAIUsageJSONBytes()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
usage, ok := extractOpenAIUsageFromJSONBytes(body)
if ok {
benchmarkUsageSink = usage
}
}
}
func benchmarkToolContinuationRequestBody() map[string]any {
input := make([]any, 0, 64)
for i := 0; i < 24; i++ {
input = append(input, map[string]any{
"type": "text",
"text": "benchmark text",
})
}
for i := 0; i < 10; i++ {
callID := "call_" + strconv.Itoa(i)
input = append(input, map[string]any{
"type": "tool_call",
"call_id": callID,
})
input = append(input, map[string]any{
"type": "function_call_output",
"call_id": callID,
})
input = append(input, map[string]any{
"type": "item_reference",
"id": callID,
})
}
return map[string]any{
"model": "gpt-5.3-codex",
"input": input,
}
}
func benchmarkWSIngressPayloadBytes() []byte {
return []byte(`{"type":"response.create","model":"gpt-5.3-codex","prompt_cache_key":"cache_bench","previous_response_id":"resp_prev_bench","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`)
}
func benchmarkOpenAIUsageJSONBytes() []byte {
return []byte(`{"id":"resp_bench","object":"response","model":"gpt-5.3-codex","usage":{"input_tokens":3210,"output_tokens":987,"input_tokens_details":{"cached_tokens":456}}}`)
}
func legacyValidateFunctionCallOutputContext(reqBody map[string]any) bool {
if !legacyHasFunctionCallOutput(reqBody) {
return true
}
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) != "" {
return true
}
if legacyHasToolCallContext(reqBody) {
return true
}
if legacyHasFunctionCallOutputMissingCallID(reqBody) {
return false
}
callIDs := legacyFunctionCallOutputCallIDs(reqBody)
return legacyHasItemReferenceForCallIDs(reqBody, callIDs)
}
func optimizedValidateFunctionCallOutputContext(reqBody map[string]any) bool {
validation := ValidateFunctionCallOutputContext(reqBody)
if !validation.HasFunctionCallOutput {
return true
}
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) != "" {
return true
}
if validation.HasToolCallContext {
return true
}
if validation.HasFunctionCallOutputMissingCallID {
return false
}
return validation.HasItemReferenceForAllCallIDs
}
func legacyHasFunctionCallOutput(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType == "function_call_output" {
return true
}
}
return false
}
func legacyHasToolCallContext(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "tool_call" && itemType != "function_call" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
return true
}
}
return false
}
func legacyFunctionCallOutputCallIDs(reqBody map[string]any) []string {
if reqBody == nil {
return nil
}
input, ok := reqBody["input"].([]any)
if !ok {
return nil
}
ids := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
ids[callID] = struct{}{}
}
}
if len(ids) == 0 {
return nil
}
callIDs := make([]string, 0, len(ids))
for id := range ids {
callIDs = append(callIDs, id)
}
return callIDs
}
func legacyHasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) == "" {
return true
}
}
return false
}
func legacyHasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
if reqBody == nil || len(callIDs) == 0 {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
referenceIDs := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "item_reference" {
continue
}
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
referenceIDs[idValue] = struct{}{}
}
if len(referenceIDs) == 0 {
return false
}
for _, callID := range callIDs {
if _, ok := referenceIDs[callID]; !ok {
return false
}
}
return true
}
func legacyParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) {
values := gjson.GetManyBytes(raw, "type", "model", "prompt_cache_key", "previous_response_id")
eventType = strings.TrimSpace(values[0].String())
if eventType == "" {
eventType = "response.create"
}
model = strings.TrimSpace(values[1].String())
promptCacheKey = strings.TrimSpace(values[2].String())
previousResponseID = strings.TrimSpace(values[3].String())
payload = make(map[string]any)
if err = json.Unmarshal(raw, &payload); err != nil {
return "", "", "", "", nil, err
}
if _, exists := payload["type"]; !exists {
payload["type"] = "response.create"
}
return eventType, model, promptCacheKey, previousResponseID, payload, nil
}
func optimizedParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) {
payload = make(map[string]any)
if err = json.Unmarshal(raw, &payload); err != nil {
return "", "", "", "", nil, err
}
eventType = openAIWSPayloadString(payload, "type")
if eventType == "" {
eventType = "response.create"
payload["type"] = eventType
}
model = openAIWSPayloadString(payload, "model")
promptCacheKey = openAIWSPayloadString(payload, "prompt_cache_key")
previousResponseID = openAIWSPayloadString(payload, "previous_response_id")
return eventType, model, promptCacheKey, previousResponseID, payload, nil
}
func legacyExtractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
var response struct {
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
} `json:"input_tokens_details"`
} `json:"usage"`
}
if err := json.Unmarshal(body, &response); err != nil {
return OpenAIUsage{}, false
}
return OpenAIUsage{
InputTokens: response.Usage.InputTokens,
OutputTokens: response.Usage.OutputTokens,
CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens,
}, true
}

View File

@@ -515,7 +515,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te
require.NoError(t, err)
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent"))
require.Equal(t, "codex_cli_rs/0.104.0", upstream.lastReq.Header.Get("User-Agent"))
}
func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) {

View File

@@ -5,8 +5,12 @@ import (
"crypto/subtle"
"encoding/json"
"io"
"log/slog"
"net/http"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"time"
@@ -16,6 +20,13 @@ import (
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`)
type soraSessionChunk struct {
index int
value string
}
// OpenAIOAuthService handles OpenAI OAuth authentication flows
type OpenAIOAuthService struct {
sessionStore *openai.SessionStore
@@ -39,7 +50,7 @@ type OpenAIAuthURLResult struct {
}
// GenerateAuthURL generates an OpenAI OAuth authorization URL
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*OpenAIAuthURLResult, error) {
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, platform string) (*OpenAIAuthURLResult, error) {
// Generate PKCE values
state, err := openai.GenerateState()
if err != nil {
@@ -75,11 +86,14 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
}
normalizedPlatform := normalizeOpenAIOAuthPlatform(platform)
clientID, _ := openai.OAuthClientConfigByPlatform(normalizedPlatform)
// Store session
session := &openai.OAuthSession{
State: state,
CodeVerifier: codeVerifier,
ClientID: clientID,
RedirectURI: redirectURI,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
@@ -87,7 +101,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
s.sessionStore.Set(sessionID, session)
// Build authorization URL
authURL := openai.BuildAuthorizationURL(state, codeChallenge, redirectURI)
authURL := openai.BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, normalizedPlatform)
return &OpenAIAuthURLResult{
AuthURL: authURL,
@@ -111,6 +125,7 @@ type OpenAITokenInfo struct {
IDToken string `json:"id_token,omitempty"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
ClientID string `json:"client_id,omitempty"`
Email string `json:"email,omitempty"`
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
@@ -148,9 +163,13 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
if input.RedirectURI != "" {
redirectURI = input.RedirectURI
}
clientID := strings.TrimSpace(session.ClientID)
if clientID == "" {
clientID = openai.ClientID
}
// Exchange code for token
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL, clientID)
if err != nil {
return nil, err
}
@@ -158,8 +177,10 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// Parse ID token to get user info
var userInfo *openai.UserInfo
if tokenResp.IDToken != "" {
claims, err := openai.ParseIDToken(tokenResp.IDToken)
if err == nil {
claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
if parseErr != nil {
slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
} else {
userInfo = claims.GetUserInfo()
}
}
@@ -173,6 +194,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
IDToken: tokenResp.IDToken,
ExpiresIn: int64(tokenResp.ExpiresIn),
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
ClientID: clientID,
}
if userInfo != nil {
@@ -200,8 +222,10 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
// Parse ID token to get user info
var userInfo *openai.UserInfo
if tokenResp.IDToken != "" {
claims, err := openai.ParseIDToken(tokenResp.IDToken)
if err == nil {
claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
if parseErr != nil {
slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
} else {
userInfo = claims.GetUserInfo()
}
}
@@ -213,6 +237,9 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
ExpiresIn: int64(tokenResp.ExpiresIn),
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
}
if trimmed := strings.TrimSpace(clientID); trimmed != "" {
tokenInfo.ClientID = trimmed
}
if userInfo != nil {
tokenInfo.Email = userInfo.Email
@@ -226,6 +253,7 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
sessionToken = normalizeSoraSessionTokenInput(sessionToken)
if strings.TrimSpace(sessionToken) == "" {
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
}
@@ -287,10 +315,141 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
ExpiresIn: expiresIn,
ExpiresAt: expiresAt,
ClientID: openai.SoraClientID,
Email: strings.TrimSpace(sessionResp.User.Email),
}, nil
}
func normalizeSoraSessionTokenInput(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return ""
}
matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1)
if len(matches) == 0 {
return sanitizeSessionToken(trimmed)
}
chunkMatches := make([]soraSessionChunk, 0, len(matches))
singleValues := make([]string, 0, len(matches))
for _, match := range matches {
if len(match) < 3 {
continue
}
value := sanitizeSessionToken(match[2])
if value == "" {
continue
}
if strings.TrimSpace(match[1]) == "" {
singleValues = append(singleValues, value)
continue
}
idx, err := strconv.Atoi(strings.TrimSpace(match[1]))
if err != nil || idx < 0 {
continue
}
chunkMatches = append(chunkMatches, soraSessionChunk{
index: idx,
value: value,
})
}
if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" {
return merged
}
if len(singleValues) > 0 {
return singleValues[len(singleValues)-1]
}
return ""
}
func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string {
if len(chunks) == 0 {
return ""
}
byIndex := make(map[int]string, len(chunks))
for _, chunk := range chunks {
byIndex[chunk.index] = chunk.value
}
if _, ok := byIndex[0]; !ok {
return ""
}
if requireComplete {
for idx := 0; idx <= requiredMaxIndex; idx++ {
if _, ok := byIndex[idx]; !ok {
return ""
}
}
}
orderedIndexes := make([]int, 0, len(byIndex))
for idx := range byIndex {
orderedIndexes = append(orderedIndexes, idx)
}
sort.Ints(orderedIndexes)
var builder strings.Builder
for _, idx := range orderedIndexes {
if _, err := builder.WriteString(byIndex[idx]); err != nil {
return ""
}
}
return sanitizeSessionToken(builder.String())
}
func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string {
if len(chunks) == 0 {
return ""
}
requiredMaxIndex := 0
for _, chunk := range chunks {
if chunk.index > requiredMaxIndex {
requiredMaxIndex = chunk.index
}
}
groupStarts := make([]int, 0, len(chunks))
for idx, chunk := range chunks {
if chunk.index == 0 {
groupStarts = append(groupStarts, idx)
}
}
if len(groupStarts) == 0 {
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
}
for i := len(groupStarts) - 1; i >= 0; i-- {
start := groupStarts[i]
end := len(chunks)
if i+1 < len(groupStarts) {
end = groupStarts[i+1]
}
if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" {
return merged
}
}
return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
}
func sanitizeSessionToken(raw string) string {
token := strings.TrimSpace(raw)
token = strings.Trim(token, "\"'`")
token = strings.TrimSuffix(token, ";")
return strings.TrimSpace(token)
}
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
@@ -322,9 +481,12 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339)
creds := map[string]any{
"access_token": tokenInfo.AccessToken,
"refresh_token": tokenInfo.RefreshToken,
"expires_at": expiresAt,
"access_token": tokenInfo.AccessToken,
"expires_at": expiresAt,
}
// 仅在刷新响应返回了新的 refresh_token 时才更新,防止用空值覆盖已有令牌
if strings.TrimSpace(tokenInfo.RefreshToken) != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.IDToken != "" {
@@ -342,6 +504,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
if tokenInfo.OrganizationID != "" {
creds["organization_id"] = tokenInfo.OrganizationID
}
if strings.TrimSpace(tokenInfo.ClientID) != "" {
creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID)
}
return creds
}
@@ -377,3 +542,12 @@ func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
Transport: transport,
}
}
func normalizeOpenAIOAuthPlatform(platform string) string {
switch strings.ToLower(strings.TrimSpace(platform)) {
case PlatformSora:
return openai.OAuthPlatformSora
default:
return openai.OAuthPlatformOpenAI
}
}

View File

@@ -0,0 +1,67 @@
package service
import (
"context"
"errors"
"net/url"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
type openaiOAuthClientAuthURLStub struct{}
func (s *openaiOAuthClientAuthURLStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientAuthURLStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func (s *openaiOAuthClientAuthURLStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) {
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{})
defer svc.Stop()
result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformOpenAI)
require.NoError(t, err)
require.NotEmpty(t, result.AuthURL)
require.NotEmpty(t, result.SessionID)
parsed, err := url.Parse(result.AuthURL)
require.NoError(t, err)
q := parsed.Query()
require.Equal(t, openai.ClientID, q.Get("client_id"))
require.Equal(t, "true", q.Get("codex_cli_simplified_flow"))
session, ok := svc.sessionStore.Get(result.SessionID)
require.True(t, ok)
require.Equal(t, openai.ClientID, session.ClientID)
}
// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的
// client_id支持 localhost redirect_uri但不启用 codex_cli_simplified_flow。
func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) {
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{})
defer svc.Stop()
result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora)
require.NoError(t, err)
require.NotEmpty(t, result.AuthURL)
require.NotEmpty(t, result.SessionID)
parsed, err := url.Parse(result.AuthURL)
require.NoError(t, err)
q := parsed.Query()
require.Equal(t, openai.ClientID, q.Get("client_id"))
require.Empty(t, q.Get("codex_cli_simplified_flow"))
session, ok := svc.sessionStore.Get(result.SessionID)
require.True(t, ok)
require.Equal(t, openai.ClientID, session.ClientID)
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
@@ -13,7 +14,7 @@ import (
type openaiOAuthClientNoopStub struct{}
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
@@ -67,3 +68,106 @@ func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testi
require.Error(t, err)
require.Contains(t, err.Error(), "missing access token")
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax"
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := strings.Join([]string{
"Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly",
}, "\n")
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := strings.Join([]string{
"Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly",
"Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly",
}, "\n")
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}
func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
}))
defer server.Close()
origin := openAISoraSessionAuthURL
openAISoraSessionAuthURL = server.URL
defer func() { openAISoraSessionAuthURL = origin }()
svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
defer svc.Stop()
raw := strings.Join([]string{
"set-cookie",
"__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/",
"set-cookie",
"__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/",
"set-cookie",
"__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/",
}, "\n")
info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
require.NoError(t, err)
require.Equal(t, "at-token", info.AccessToken)
}

View File

@@ -13,10 +13,12 @@ import (
type openaiOAuthClientStateStub struct {
exchangeCalled int32
lastClientID string
}
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
atomic.AddInt32(&s.exchangeCalled, 1)
s.lastClientID = clientID
return &openai.TokenResponse{
AccessToken: "at",
RefreshToken: "rt",
@@ -95,6 +97,8 @@ func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, info)
require.Equal(t, "at", info.AccessToken)
require.Equal(t, openai.ClientID, info.ClientID)
require.Equal(t, openai.ClientID, client.lastClientID)
require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled))
_, ok := svc.sessionStore.Get("sid")

View File

@@ -0,0 +1,37 @@
package service
import (
"regexp"
"strings"
)
const (
OpenAIPreviousResponseIDKindEmpty = "empty"
OpenAIPreviousResponseIDKindResponseID = "response_id"
OpenAIPreviousResponseIDKindMessageID = "message_id"
OpenAIPreviousResponseIDKindUnknown = "unknown"
)
var (
openAIResponseIDPattern = regexp.MustCompile(`^resp_[A-Za-z0-9_-]{1,256}$`)
openAIMessageIDPattern = regexp.MustCompile(`^(msg|message|item|chatcmpl)_[A-Za-z0-9_-]{1,256}$`)
)
// ClassifyOpenAIPreviousResponseIDKind classifies previous_response_id to improve diagnostics.
func ClassifyOpenAIPreviousResponseIDKind(id string) string {
trimmed := strings.TrimSpace(id)
if trimmed == "" {
return OpenAIPreviousResponseIDKindEmpty
}
if openAIResponseIDPattern.MatchString(trimmed) {
return OpenAIPreviousResponseIDKindResponseID
}
if openAIMessageIDPattern.MatchString(strings.ToLower(trimmed)) {
return OpenAIPreviousResponseIDKindMessageID
}
return OpenAIPreviousResponseIDKindUnknown
}
func IsOpenAIPreviousResponseIDLikelyMessageID(id string) bool {
return ClassifyOpenAIPreviousResponseIDKind(id) == OpenAIPreviousResponseIDKindMessageID
}

View File

@@ -0,0 +1,34 @@
package service
import "testing"
func TestClassifyOpenAIPreviousResponseIDKind(t *testing.T) {
tests := []struct {
name string
id string
want string
}{
{name: "empty", id: " ", want: OpenAIPreviousResponseIDKindEmpty},
{name: "response_id", id: "resp_0906a621bc423a8d0169a108637ef88197b74b0e2f37ba358f", want: OpenAIPreviousResponseIDKindResponseID},
{name: "message_id", id: "msg_123456", want: OpenAIPreviousResponseIDKindMessageID},
{name: "item_id", id: "item_abcdef", want: OpenAIPreviousResponseIDKindMessageID},
{name: "unknown", id: "foo_123456", want: OpenAIPreviousResponseIDKindUnknown},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := ClassifyOpenAIPreviousResponseIDKind(tc.id); got != tc.want {
t.Fatalf("ClassifyOpenAIPreviousResponseIDKind(%q)=%q want=%q", tc.id, got, tc.want)
}
})
}
}
func TestIsOpenAIPreviousResponseIDLikelyMessageID(t *testing.T) {
if !IsOpenAIPreviousResponseIDLikelyMessageID("msg_123") {
t.Fatal("expected msg_123 to be identified as message id")
}
if IsOpenAIPreviousResponseIDLikelyMessageID("resp_123") {
t.Fatal("expected resp_123 not to be identified as message id")
}
}

View File

@@ -0,0 +1,214 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/cespare/xxhash/v2"
"github.com/gin-gonic/gin"
)
type openAILegacySessionHashContextKey struct{}
var openAILegacySessionHashKey = openAILegacySessionHashContextKey{}
var (
openAIStickyLegacyReadFallbackTotal atomic.Int64
openAIStickyLegacyReadFallbackHit atomic.Int64
openAIStickyLegacyDualWriteTotal atomic.Int64
)
func openAIStickyCompatStats() (legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal int64) {
return openAIStickyLegacyReadFallbackTotal.Load(),
openAIStickyLegacyReadFallbackHit.Load(),
openAIStickyLegacyDualWriteTotal.Load()
}
func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash string) {
normalized := strings.TrimSpace(sessionID)
if normalized == "" {
return "", ""
}
currentHash = fmt.Sprintf("%016x", xxhash.Sum64String(normalized))
sum := sha256.Sum256([]byte(normalized))
legacyHash = hex.EncodeToString(sum[:])
return currentHash, legacyHash
}
func withOpenAILegacySessionHash(ctx context.Context, legacyHash string) context.Context {
if ctx == nil {
return nil
}
trimmed := strings.TrimSpace(legacyHash)
if trimmed == "" {
return ctx
}
return context.WithValue(ctx, openAILegacySessionHashKey, trimmed)
}
func openAILegacySessionHashFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
value, _ := ctx.Value(openAILegacySessionHashKey).(string)
return strings.TrimSpace(value)
}
func attachOpenAILegacySessionHashToGin(c *gin.Context, legacyHash string) {
if c == nil || c.Request == nil {
return
}
c.Request = c.Request.WithContext(withOpenAILegacySessionHash(c.Request.Context(), legacyHash))
}
func (s *OpenAIGatewayService) openAISessionHashReadOldFallbackEnabled() bool {
if s == nil || s.cfg == nil {
return true
}
return s.cfg.Gateway.OpenAIWS.SessionHashReadOldFallback
}
func (s *OpenAIGatewayService) openAISessionHashDualWriteOldEnabled() bool {
if s == nil || s.cfg == nil {
return true
}
return s.cfg.Gateway.OpenAIWS.SessionHashDualWriteOld
}
func (s *OpenAIGatewayService) openAISessionCacheKey(sessionHash string) string {
normalized := strings.TrimSpace(sessionHash)
if normalized == "" {
return ""
}
return "openai:" + normalized
}
func (s *OpenAIGatewayService) openAILegacySessionCacheKey(ctx context.Context, sessionHash string) string {
legacyHash := openAILegacySessionHashFromContext(ctx)
if legacyHash == "" {
return ""
}
legacyKey := "openai:" + legacyHash
if legacyKey == s.openAISessionCacheKey(sessionHash) {
return ""
}
return legacyKey
}
func (s *OpenAIGatewayService) openAIStickyLegacyTTL(ttl time.Duration) time.Duration {
legacyTTL := ttl
if legacyTTL <= 0 {
legacyTTL = openaiStickySessionTTL
}
if legacyTTL > 10*time.Minute {
return 10 * time.Minute
}
return legacyTTL
}
func (s *OpenAIGatewayService) getStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) {
if s == nil || s.cache == nil {
return 0, nil
}
primaryKey := s.openAISessionCacheKey(sessionHash)
if primaryKey == "" {
return 0, nil
}
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), primaryKey)
if err == nil && accountID > 0 {
return accountID, nil
}
if !s.openAISessionHashReadOldFallbackEnabled() {
return accountID, err
}
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
if legacyKey == "" {
return accountID, err
}
openAIStickyLegacyReadFallbackTotal.Add(1)
legacyAccountID, legacyErr := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), legacyKey)
if legacyErr == nil && legacyAccountID > 0 {
openAIStickyLegacyReadFallbackHit.Add(1)
return legacyAccountID, nil
}
return accountID, err
}
func (s *OpenAIGatewayService) setStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string, accountID int64, ttl time.Duration) error {
if s == nil || s.cache == nil || accountID <= 0 {
return nil
}
primaryKey := s.openAISessionCacheKey(sessionHash)
if primaryKey == "" {
return nil
}
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), primaryKey, accountID, ttl); err != nil {
return err
}
if !s.openAISessionHashDualWriteOldEnabled() {
return nil
}
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
if legacyKey == "" {
return nil
}
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), legacyKey, accountID, s.openAIStickyLegacyTTL(ttl)); err != nil {
return err
}
openAIStickyLegacyDualWriteTotal.Add(1)
return nil
}
func (s *OpenAIGatewayService) refreshStickySessionTTL(ctx context.Context, groupID *int64, sessionHash string, ttl time.Duration) error {
if s == nil || s.cache == nil {
return nil
}
primaryKey := s.openAISessionCacheKey(sessionHash)
if primaryKey == "" {
return nil
}
err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), primaryKey, ttl)
if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() {
return err
}
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
if legacyKey != "" {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), legacyKey, s.openAIStickyLegacyTTL(ttl))
}
return err
}
func (s *OpenAIGatewayService) deleteStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) error {
if s == nil || s.cache == nil {
return nil
}
primaryKey := s.openAISessionCacheKey(sessionHash)
if primaryKey == "" {
return nil
}
err := s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), primaryKey)
if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() {
return err
}
legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
if legacyKey != "" {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), legacyKey)
}
return err
}

View File

@@ -0,0 +1,96 @@
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func TestGetStickySessionAccountID_FallbackToLegacyKey(t *testing.T) {
beforeFallbackTotal, beforeFallbackHit, _ := openAIStickyCompatStats()
cache := &stubGatewayCache{
sessionBindings: map[string]int64{
"openai:legacy-hash": 42,
},
}
svc := &OpenAIGatewayService{
cache: cache,
cfg: &config.Config{
Gateway: config.GatewayConfig{
OpenAIWS: config.GatewayOpenAIWSConfig{
SessionHashReadOldFallback: true,
},
},
},
}
ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash")
accountID, err := svc.getStickySessionAccountID(ctx, nil, "new-hash")
require.NoError(t, err)
require.Equal(t, int64(42), accountID)
afterFallbackTotal, afterFallbackHit, _ := openAIStickyCompatStats()
require.Equal(t, beforeFallbackTotal+1, afterFallbackTotal)
require.Equal(t, beforeFallbackHit+1, afterFallbackHit)
}
func TestSetStickySessionAccountID_DualWriteOldEnabled(t *testing.T) {
_, _, beforeDualWriteTotal := openAIStickyCompatStats()
cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
svc := &OpenAIGatewayService{
cache: cache,
cfg: &config.Config{
Gateway: config.GatewayConfig{
OpenAIWS: config.GatewayOpenAIWSConfig{
SessionHashDualWriteOld: true,
},
},
},
}
ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash")
err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL)
require.NoError(t, err)
require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"])
require.Equal(t, int64(9), cache.sessionBindings["openai:legacy-hash"])
_, _, afterDualWriteTotal := openAIStickyCompatStats()
require.Equal(t, beforeDualWriteTotal+1, afterDualWriteTotal)
}
func TestSetStickySessionAccountID_DualWriteOldDisabled(t *testing.T) {
cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
svc := &OpenAIGatewayService{
cache: cache,
cfg: &config.Config{
Gateway: config.GatewayConfig{
OpenAIWS: config.GatewayOpenAIWSConfig{
SessionHashDualWriteOld: false,
},
},
},
}
ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash")
err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL)
require.NoError(t, err)
require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"])
_, exists := cache.sessionBindings["openai:legacy-hash"]
require.False(t, exists)
}
func TestSnapshotOpenAICompatibilityFallbackMetrics(t *testing.T) {
before := SnapshotOpenAICompatibilityFallbackMetrics()
ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true)
_, _ = ThinkingEnabledFromContext(ctx)
after := SnapshotOpenAICompatibilityFallbackMetrics()
require.GreaterOrEqual(t, after.MetadataLegacyFallbackTotal, before.MetadataLegacyFallbackTotal+1)
require.GreaterOrEqual(t, after.MetadataLegacyFallbackThinkingEnabledTotal, before.MetadataLegacyFallbackThinkingEnabledTotal+1)
}

View File

@@ -2,6 +2,24 @@ package service
import "strings"
// ToolContinuationSignals 聚合工具续链相关信号,避免重复遍历 input。
type ToolContinuationSignals struct {
HasFunctionCallOutput bool
HasFunctionCallOutputMissingCallID bool
HasToolCallContext bool
HasItemReference bool
HasItemReferenceForAllCallIDs bool
FunctionCallOutputCallIDs []string
}
// FunctionCallOutputValidation 汇总 function_call_output 关联性校验结果。
type FunctionCallOutputValidation struct {
HasFunctionCallOutput bool
HasToolCallContext bool
HasFunctionCallOutputMissingCallID bool
HasItemReferenceForAllCallIDs bool
}
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
// 满足以下任一信号即视为续链previous_response_id、input 内包含 function_call_output/item_reference、
// 或显式声明 tools/tool_choice。
@@ -18,107 +36,191 @@ func NeedsToolContinuation(reqBody map[string]any) bool {
if hasToolChoiceSignal(reqBody) {
return true
}
if inputHasType(reqBody, "function_call_output") {
return true
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
if inputHasType(reqBody, "item_reference") {
return true
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType == "function_call_output" || itemType == "item_reference" {
return true
}
}
return false
}
// AnalyzeToolContinuationSignals 单次遍历 input提取 function_call_output/tool_call/item_reference 相关信号。
func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSignals {
signals := ToolContinuationSignals{}
if reqBody == nil {
return signals
}
input, ok := reqBody["input"].([]any)
if !ok {
return signals
}
var callIDs map[string]struct{}
var referenceIDs map[string]struct{}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
switch itemType {
case "tool_call", "function_call":
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) != "" {
signals.HasToolCallContext = true
}
case "function_call_output":
signals.HasFunctionCallOutput = true
callID, _ := itemMap["call_id"].(string)
callID = strings.TrimSpace(callID)
if callID == "" {
signals.HasFunctionCallOutputMissingCallID = true
continue
}
if callIDs == nil {
callIDs = make(map[string]struct{})
}
callIDs[callID] = struct{}{}
case "item_reference":
signals.HasItemReference = true
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
if referenceIDs == nil {
referenceIDs = make(map[string]struct{})
}
referenceIDs[idValue] = struct{}{}
}
}
if len(callIDs) == 0 {
return signals
}
signals.FunctionCallOutputCallIDs = make([]string, 0, len(callIDs))
allReferenced := len(referenceIDs) > 0
for callID := range callIDs {
signals.FunctionCallOutputCallIDs = append(signals.FunctionCallOutputCallIDs, callID)
if allReferenced {
if _, ok := referenceIDs[callID]; !ok {
allReferenced = false
}
}
}
signals.HasItemReferenceForAllCallIDs = allReferenced
return signals
}
// ValidateFunctionCallOutputContext 为 handler 提供低开销校验结果:
// 1) 无 function_call_output 直接返回
// 2) 若已存在 tool_call/function_call 上下文则提前返回
// 3) 仅在无工具上下文时才构建 call_id / item_reference 集合
func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutputValidation {
result := FunctionCallOutputValidation{}
if reqBody == nil {
return result
}
input, ok := reqBody["input"].([]any)
if !ok {
return result
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
switch itemType {
case "function_call_output":
result.HasFunctionCallOutput = true
case "tool_call", "function_call":
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) != "" {
result.HasToolCallContext = true
}
}
if result.HasFunctionCallOutput && result.HasToolCallContext {
return result
}
}
if !result.HasFunctionCallOutput || result.HasToolCallContext {
return result
}
callIDs := make(map[string]struct{})
referenceIDs := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
switch itemType {
case "function_call_output":
callID, _ := itemMap["call_id"].(string)
callID = strings.TrimSpace(callID)
if callID == "" {
result.HasFunctionCallOutputMissingCallID = true
continue
}
callIDs[callID] = struct{}{}
case "item_reference":
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
referenceIDs[idValue] = struct{}{}
}
}
if len(callIDs) == 0 || len(referenceIDs) == 0 {
return result
}
allReferenced := true
for callID := range callIDs {
if _, ok := referenceIDs[callID]; !ok {
allReferenced = false
break
}
}
result.HasItemReferenceForAllCallIDs = allReferenced
return result
}
// HasFunctionCallOutput 判断 input 是否包含 function_call_output用于触发续链校验。
func HasFunctionCallOutput(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
return inputHasType(reqBody, "function_call_output")
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutput
}
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call
// 用于判断 function_call_output 是否具备可关联的上下文。
func HasToolCallContext(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "tool_call" && itemType != "function_call" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
return true
}
}
return false
return AnalyzeToolContinuationSignals(reqBody).HasToolCallContext
}
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
// 仅返回非空 call_id用于与 item_reference.id 做匹配校验。
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
if reqBody == nil {
return nil
}
input, ok := reqBody["input"].([]any)
if !ok {
return nil
}
ids := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
ids[callID] = struct{}{}
}
}
if len(ids) == 0 {
return nil
}
result := make([]string, 0, len(ids))
for id := range ids {
result = append(result, id)
}
return result
return AnalyzeToolContinuationSignals(reqBody).FunctionCallOutputCallIDs
}
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) == "" {
return true
}
}
return false
return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutputMissingCallID
}
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
@@ -152,32 +254,13 @@ func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
return false
}
for _, callID := range callIDs {
if _, ok := referenceIDs[callID]; !ok {
if _, ok := referenceIDs[strings.TrimSpace(callID)]; !ok {
return false
}
}
return true
}
// inputHasType 判断 input 中是否存在指定类型的 item。
func inputHasType(reqBody map[string]any, want string) bool {
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType == want {
return true
}
}
return false
}
// hasNonEmptyString 判断字段是否为非空字符串。
func hasNonEmptyString(value any) bool {
stringValue, ok := value.(string)

View File

@@ -1,11 +1,15 @@
package service
import (
"encoding/json"
"bytes"
"fmt"
"strconv"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
@@ -62,169 +66,201 @@ func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, boo
if data == "" || data == "\n" {
return data, false
}
correctedBytes, corrected := c.CorrectToolCallsInSSEBytes([]byte(data))
if !corrected {
return data, false
}
return string(correctedBytes), true
}
// 尝试解析 JSON
var payload map[string]any
if err := json.Unmarshal([]byte(data), &payload); err != nil {
// 不是有效的 JSON直接返回原数据
// CorrectToolCallsInSSEBytes 修正 SSE JSON 数据中的工具调用(字节路径)。
// 返回修正后的数据和是否进行了修正。
func (c *CodexToolCorrector) CorrectToolCallsInSSEBytes(data []byte) ([]byte, bool) {
if len(bytes.TrimSpace(data)) == 0 {
return data, false
}
if !mayContainToolCallPayload(data) {
return data, false
}
if !gjson.ValidBytes(data) {
// 不是有效 JSON直接返回原数据
return data, false
}
updated := data
corrected := false
// 处理 tool_calls 数组
if toolCalls, ok := payload["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
collect := func(changed bool, next []byte) {
if changed {
corrected = true
updated = next
}
}
// 处理 function_call 对象
if functionCall, ok := payload["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
if next, changed := c.correctToolCallsArrayAtPath(updated, "tool_calls"); changed {
collect(changed, next)
}
if next, changed := c.correctFunctionAtPath(updated, "function_call"); changed {
collect(changed, next)
}
if next, changed := c.correctToolCallsArrayAtPath(updated, "delta.tool_calls"); changed {
collect(changed, next)
}
if next, changed := c.correctFunctionAtPath(updated, "delta.function_call"); changed {
collect(changed, next)
}
// 处理 delta.tool_calls
if delta, ok := payload["delta"].(map[string]any); ok {
if toolCalls, ok := delta["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
choicesCount := int(gjson.GetBytes(updated, "choices.#").Int())
for i := 0; i < choicesCount; i++ {
prefix := "choices." + strconv.Itoa(i)
if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".message.tool_calls"); changed {
collect(changed, next)
}
if functionCall, ok := delta["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
if next, changed := c.correctFunctionAtPath(updated, prefix+".message.function_call"); changed {
collect(changed, next)
}
}
// 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
if choices, ok := payload["choices"].([]any); ok {
for _, choice := range choices {
if choiceMap, ok := choice.(map[string]any); ok {
// 处理 message 中的工具调用
if message, ok := choiceMap["message"].(map[string]any); ok {
if toolCalls, ok := message["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
if functionCall, ok := message["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
}
// 处理 delta 中的工具调用
if delta, ok := choiceMap["delta"].(map[string]any); ok {
if toolCalls, ok := delta["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
if functionCall, ok := delta["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
}
}
if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".delta.tool_calls"); changed {
collect(changed, next)
}
if next, changed := c.correctFunctionAtPath(updated, prefix+".delta.function_call"); changed {
collect(changed, next)
}
}
if !corrected {
return data, false
}
return updated, true
}
// 序列化回 JSON
correctedBytes, err := json.Marshal(payload)
if err != nil {
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Failed to marshal corrected data: %v", err)
func mayContainToolCallPayload(data []byte) bool {
// 快速路径:多数 token / 文本事件不包含工具字段,避免进入 JSON 解析热路径。
return bytes.Contains(data, []byte(`"tool_calls"`)) ||
bytes.Contains(data, []byte(`"function_call"`)) ||
bytes.Contains(data, []byte(`"function":{"name"`))
}
// correctToolCallsArrayAtPath 修正指定路径下 tool_calls 数组中的工具名称。
func (c *CodexToolCorrector) correctToolCallsArrayAtPath(data []byte, toolCallsPath string) ([]byte, bool) {
count := int(gjson.GetBytes(data, toolCallsPath+".#").Int())
if count <= 0 {
return data, false
}
return string(correctedBytes), true
}
// correctToolCallsArray 修正工具调用数组中的工具名称
func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool {
updated := data
corrected := false
for _, toolCall := range toolCalls {
if toolCallMap, ok := toolCall.(map[string]any); ok {
if function, ok := toolCallMap["function"].(map[string]any); ok {
if c.correctFunctionCall(function) {
corrected = true
}
}
for i := 0; i < count; i++ {
functionPath := toolCallsPath + "." + strconv.Itoa(i) + ".function"
if next, changed := c.correctFunctionAtPath(updated, functionPath); changed {
updated = next
corrected = true
}
}
return corrected
return updated, corrected
}
// correctFunctionCall 修正单个函数调用的工具名称和参数
func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool {
name, ok := functionCall["name"].(string)
if !ok || name == "" {
return false
// correctFunctionAtPath 修正指定路径下单个函数调用的工具名称和参数
func (c *CodexToolCorrector) correctFunctionAtPath(data []byte, functionPath string) ([]byte, bool) {
namePath := functionPath + ".name"
nameResult := gjson.GetBytes(data, namePath)
if !nameResult.Exists() || nameResult.Type != gjson.String {
return data, false
}
name := strings.TrimSpace(nameResult.Str)
if name == "" {
return data, false
}
updated := data
corrected := false
// 查找并修正工具名称
if correctName, found := codexToolNameMapping[name]; found {
functionCall["name"] = correctName
c.recordCorrection(name, correctName)
corrected = true
name = correctName // 使用修正后的名称进行参数修正
if next, err := sjson.SetBytes(updated, namePath, correctName); err == nil {
updated = next
c.recordCorrection(name, correctName)
corrected = true
name = correctName // 使用修正后的名称进行参数修正
}
}
// 修正工具参数(基于工具名称)
if c.correctToolParameters(name, functionCall) {
if next, changed := c.correctToolParametersAtPath(updated, functionPath+".arguments", name); changed {
updated = next
corrected = true
}
return corrected
return updated, corrected
}
// correctToolParameters 修正工具参数以符合 OpenCode 规范
func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool {
arguments, ok := functionCall["arguments"]
if !ok {
return false
// correctToolParametersAtPath 修正指定路径下 arguments 参数。
func (c *CodexToolCorrector) correctToolParametersAtPath(data []byte, argumentsPath, toolName string) ([]byte, bool) {
if toolName != "bash" && toolName != "edit" {
return data, false
}
// arguments 可能是字符串JSON或已解析的 map
var argsMap map[string]any
switch v := arguments.(type) {
case string:
// 解析 JSON 字符串
if err := json.Unmarshal([]byte(v), &argsMap); err != nil {
return false
args := gjson.GetBytes(data, argumentsPath)
if !args.Exists() {
return data, false
}
switch args.Type {
case gjson.String:
argsJSON := strings.TrimSpace(args.Str)
if !gjson.Valid(argsJSON) {
return data, false
}
case map[string]any:
argsMap = v
if !gjson.Parse(argsJSON).IsObject() {
return data, false
}
nextArgsJSON, corrected := c.correctToolArgumentsJSON(argsJSON, toolName)
if !corrected {
return data, false
}
next, err := sjson.SetBytes(data, argumentsPath, nextArgsJSON)
if err != nil {
return data, false
}
return next, true
case gjson.JSON:
if !args.IsObject() || !gjson.Valid(args.Raw) {
return data, false
}
nextArgsJSON, corrected := c.correctToolArgumentsJSON(args.Raw, toolName)
if !corrected {
return data, false
}
next, err := sjson.SetRawBytes(data, argumentsPath, []byte(nextArgsJSON))
if err != nil {
return data, false
}
return next, true
default:
return false
return data, false
}
}
// correctToolArgumentsJSON 修正工具参数 JSON对象字符串返回修正后的 JSON 与是否变更。
func (c *CodexToolCorrector) correctToolArgumentsJSON(argsJSON, toolName string) (string, bool) {
if !gjson.Valid(argsJSON) {
return argsJSON, false
}
if !gjson.Parse(argsJSON).IsObject() {
return argsJSON, false
}
updated := argsJSON
corrected := false
// 根据工具名称应用特定的参数修正规则
switch toolName {
case "bash":
// OpenCode bash 支持 workdir有些来源会输出 work_dir。
if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir {
if workDir, exists := argsMap["work_dir"]; exists {
argsMap["workdir"] = workDir
delete(argsMap, "work_dir")
if !gjson.Get(updated, "workdir").Exists() {
if next, changed := moveJSONField(updated, "work_dir", "workdir"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool")
}
} else {
if _, exists := argsMap["work_dir"]; exists {
delete(argsMap, "work_dir")
if next, changed := deleteJSONField(updated, "work_dir"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool")
}
@@ -232,67 +268,71 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
case "edit":
// OpenCode edit 参数为 filePath/oldString/newStringcamelCase
if _, exists := argsMap["filePath"]; !exists {
if filePath, exists := argsMap["file_path"]; exists {
argsMap["filePath"] = filePath
delete(argsMap, "file_path")
if !gjson.Get(updated, "filePath").Exists() {
if next, changed := moveJSONField(updated, "file_path", "filePath"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool")
} else if filePath, exists := argsMap["path"]; exists {
argsMap["filePath"] = filePath
delete(argsMap, "path")
} else if next, changed := moveJSONField(updated, "path", "filePath"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool")
} else if filePath, exists := argsMap["file"]; exists {
argsMap["filePath"] = filePath
delete(argsMap, "file")
} else if next, changed := moveJSONField(updated, "file", "filePath"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool")
}
}
if _, exists := argsMap["oldString"]; !exists {
if oldString, exists := argsMap["old_string"]; exists {
argsMap["oldString"] = oldString
delete(argsMap, "old_string")
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
}
if next, changed := moveJSONField(updated, "old_string", "oldString"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
}
if _, exists := argsMap["newString"]; !exists {
if newString, exists := argsMap["new_string"]; exists {
argsMap["newString"] = newString
delete(argsMap, "new_string")
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
}
if next, changed := moveJSONField(updated, "new_string", "newString"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
}
if _, exists := argsMap["replaceAll"]; !exists {
if replaceAll, exists := argsMap["replace_all"]; exists {
argsMap["replaceAll"] = replaceAll
delete(argsMap, "replace_all")
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
}
if next, changed := moveJSONField(updated, "replace_all", "replaceAll"); changed {
updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
}
}
return updated, corrected
}
// 如果修正了参数,需要重新序列化
if corrected {
if _, wasString := arguments.(string); wasString {
// 原本是字符串,序列化回字符串
if newArgsJSON, err := json.Marshal(argsMap); err == nil {
functionCall["arguments"] = string(newArgsJSON)
}
} else {
// 原本是 map直接赋值
functionCall["arguments"] = argsMap
}
func moveJSONField(input, from, to string) (string, bool) {
if gjson.Get(input, to).Exists() {
return input, false
}
src := gjson.Get(input, from)
if !src.Exists() {
return input, false
}
next, err := sjson.SetRaw(input, to, src.Raw)
if err != nil {
return input, false
}
next, err = sjson.Delete(next, from)
if err != nil {
return input, false
}
return next, true
}
return corrected
func deleteJSONField(input, path string) (string, bool) {
if !gjson.Get(input, path).Exists() {
return input, false
}
next, err := sjson.Delete(input, path)
if err != nil {
return input, false
}
return next, true
}
// recordCorrection 记录一次工具名称修正

View File

@@ -5,6 +5,15 @@ import (
"testing"
)
func TestMayContainToolCallPayload(t *testing.T) {
if mayContainToolCallPayload([]byte(`{"type":"response.output_text.delta","delta":"hello"}`)) {
t.Fatalf("plain text event should not trigger tool-call parsing")
}
if !mayContainToolCallPayload([]byte(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)) {
t.Fatalf("tool_calls event should trigger tool-call parsing")
}
}
func TestCorrectToolCallsInSSEData(t *testing.T) {
corrector := NewCodexToolCorrector()

View File

@@ -0,0 +1,190 @@
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T) {
ctx := context.Background()
groupID := int64(23)
account := Account{
ID: 2,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 2,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
cfg := newOpenAIWSV2TestConfig()
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
openaiWSStateStore: store,
}
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_1", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, account.ID, selection.Account.ID)
require.True(t, selection.Acquired)
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) {
ctx := context.Background()
groupID := int64(23)
account := Account{
ID: 8,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
cfg := newOpenAIWSV2TestConfig()
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
openaiWSStateStore: store,
}
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_2", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}})
require.NoError(t, err)
require.Nil(t, selection)
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_ForceHTTPIgnored(t *testing.T) {
ctx := context.Background()
groupID := int64(23)
account := Account{
ID: 11,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Extra: map[string]any{
"openai_ws_force_http": true,
"responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
cfg := newOpenAIWSV2TestConfig()
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
openaiWSStateStore: store,
}
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_force_http", account.ID, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil)
require.NoError(t, err)
require.Nil(t, selection, "force_http 场景应忽略 previous_response_id 粘连")
}
func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(t *testing.T) {
ctx := context.Background()
groupID := int64(23)
accounts := []Account{
{
ID: 21,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
},
{
ID: 22,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 9,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
},
}
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
cfg := newOpenAIWSV2TestConfig()
cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2
cfg.Gateway.Scheduling.StickySessionWaitTimeout = 30 * time.Second
concurrencyCache := stubConcurrencyCache{
acquireResults: map[int64]bool{
21: false, // previous_response 命中的账号繁忙
22: true, // 次优账号可用(若回退会命中)
},
waitCounts: map[int64]int{
21: 999,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
openaiWSStateStore: store,
}
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_busy", 21, time.Hour))
selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(21), selection.Account.ID, "busy previous_response sticky account should remain selected")
require.False(t, selection.Acquired)
require.NotNil(t, selection.WaitPlan)
require.Equal(t, int64(21), selection.WaitPlan.AccountID)
}
func newOpenAIWSV2TestConfig() *config.Config {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
return cfg
}

View File

@@ -0,0 +1,285 @@
package service
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
coderws "github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
const openAIWSMessageReadLimitBytes int64 = 16 * 1024 * 1024
const (
openAIWSProxyTransportMaxIdleConns = 128
openAIWSProxyTransportMaxIdleConnsPerHost = 64
openAIWSProxyTransportIdleConnTimeout = 90 * time.Second
openAIWSProxyClientCacheMaxEntries = 256
openAIWSProxyClientCacheIdleTTL = 15 * time.Minute
)
type OpenAIWSTransportMetricsSnapshot struct {
ProxyClientCacheHits int64 `json:"proxy_client_cache_hits"`
ProxyClientCacheMisses int64 `json:"proxy_client_cache_misses"`
TransportReuseRatio float64 `json:"transport_reuse_ratio"`
}
// openAIWSClientConn 抽象 WS 客户端连接,便于替换底层实现。
type openAIWSClientConn interface {
WriteJSON(ctx context.Context, value any) error
ReadMessage(ctx context.Context) ([]byte, error)
Ping(ctx context.Context) error
Close() error
}
// openAIWSClientDialer 抽象 WS 建连器。
type openAIWSClientDialer interface {
Dial(ctx context.Context, wsURL string, headers http.Header, proxyURL string) (openAIWSClientConn, int, http.Header, error)
}
type openAIWSTransportMetricsDialer interface {
SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot
}
func newDefaultOpenAIWSClientDialer() openAIWSClientDialer {
return &coderOpenAIWSClientDialer{
proxyClients: make(map[string]*openAIWSProxyClientEntry),
}
}
type coderOpenAIWSClientDialer struct {
proxyMu sync.Mutex
proxyClients map[string]*openAIWSProxyClientEntry
proxyHits atomic.Int64
proxyMisses atomic.Int64
}
type openAIWSProxyClientEntry struct {
client *http.Client
lastUsedUnixNano int64
}
func (d *coderOpenAIWSClientDialer) Dial(
ctx context.Context,
wsURL string,
headers http.Header,
proxyURL string,
) (openAIWSClientConn, int, http.Header, error) {
targetURL := strings.TrimSpace(wsURL)
if targetURL == "" {
return nil, 0, nil, errors.New("ws url is empty")
}
opts := &coderws.DialOptions{
HTTPHeader: cloneHeader(headers),
CompressionMode: coderws.CompressionContextTakeover,
}
if proxy := strings.TrimSpace(proxyURL); proxy != "" {
proxyClient, err := d.proxyHTTPClient(proxy)
if err != nil {
return nil, 0, nil, err
}
opts.HTTPClient = proxyClient
}
conn, resp, err := coderws.Dial(ctx, targetURL, opts)
if err != nil {
status := 0
respHeaders := http.Header(nil)
if resp != nil {
status = resp.StatusCode
respHeaders = cloneHeader(resp.Header)
}
return nil, status, respHeaders, err
}
// coder/websocket 默认单消息读取上限为 32KBCodex WS 事件(如 rate_limits/大 delta
// 可能超过该阈值,需显式提高上限,避免本地 read_fail(message too big)。
conn.SetReadLimit(openAIWSMessageReadLimitBytes)
respHeaders := http.Header(nil)
if resp != nil {
respHeaders = cloneHeader(resp.Header)
}
return &coderOpenAIWSClientConn{conn: conn}, 0, respHeaders, nil
}
func (d *coderOpenAIWSClientDialer) proxyHTTPClient(proxy string) (*http.Client, error) {
if d == nil {
return nil, errors.New("openai ws dialer is nil")
}
normalizedProxy := strings.TrimSpace(proxy)
if normalizedProxy == "" {
return nil, errors.New("proxy url is empty")
}
parsedProxyURL, err := url.Parse(normalizedProxy)
if err != nil {
return nil, fmt.Errorf("invalid proxy url: %w", err)
}
now := time.Now().UnixNano()
d.proxyMu.Lock()
defer d.proxyMu.Unlock()
if entry, ok := d.proxyClients[normalizedProxy]; ok && entry != nil && entry.client != nil {
entry.lastUsedUnixNano = now
d.proxyHits.Add(1)
return entry.client, nil
}
d.cleanupProxyClientsLocked(now)
transport := &http.Transport{
Proxy: http.ProxyURL(parsedProxyURL),
MaxIdleConns: openAIWSProxyTransportMaxIdleConns,
MaxIdleConnsPerHost: openAIWSProxyTransportMaxIdleConnsPerHost,
IdleConnTimeout: openAIWSProxyTransportIdleConnTimeout,
TLSHandshakeTimeout: 10 * time.Second,
ForceAttemptHTTP2: true,
}
client := &http.Client{Transport: transport}
d.proxyClients[normalizedProxy] = &openAIWSProxyClientEntry{
client: client,
lastUsedUnixNano: now,
}
d.ensureProxyClientCapacityLocked()
d.proxyMisses.Add(1)
return client, nil
}
func (d *coderOpenAIWSClientDialer) cleanupProxyClientsLocked(nowUnixNano int64) {
if d == nil || len(d.proxyClients) == 0 {
return
}
idleTTL := openAIWSProxyClientCacheIdleTTL
if idleTTL <= 0 {
return
}
now := time.Unix(0, nowUnixNano)
for key, entry := range d.proxyClients {
if entry == nil || entry.client == nil {
delete(d.proxyClients, key)
continue
}
lastUsed := time.Unix(0, entry.lastUsedUnixNano)
if now.Sub(lastUsed) > idleTTL {
closeOpenAIWSProxyClient(entry.client)
delete(d.proxyClients, key)
}
}
}
func (d *coderOpenAIWSClientDialer) ensureProxyClientCapacityLocked() {
if d == nil {
return
}
maxEntries := openAIWSProxyClientCacheMaxEntries
if maxEntries <= 0 {
return
}
for len(d.proxyClients) > maxEntries {
var oldestKey string
var oldestLastUsed int64
hasOldest := false
for key, entry := range d.proxyClients {
lastUsed := int64(0)
if entry != nil {
lastUsed = entry.lastUsedUnixNano
}
if !hasOldest || lastUsed < oldestLastUsed {
hasOldest = true
oldestKey = key
oldestLastUsed = lastUsed
}
}
if !hasOldest {
return
}
if entry := d.proxyClients[oldestKey]; entry != nil {
closeOpenAIWSProxyClient(entry.client)
}
delete(d.proxyClients, oldestKey)
}
}
func closeOpenAIWSProxyClient(client *http.Client) {
if client == nil || client.Transport == nil {
return
}
if transport, ok := client.Transport.(*http.Transport); ok && transport != nil {
transport.CloseIdleConnections()
}
}
func (d *coderOpenAIWSClientDialer) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
if d == nil {
return OpenAIWSTransportMetricsSnapshot{}
}
hits := d.proxyHits.Load()
misses := d.proxyMisses.Load()
total := hits + misses
reuseRatio := 0.0
if total > 0 {
reuseRatio = float64(hits) / float64(total)
}
return OpenAIWSTransportMetricsSnapshot{
ProxyClientCacheHits: hits,
ProxyClientCacheMisses: misses,
TransportReuseRatio: reuseRatio,
}
}
type coderOpenAIWSClientConn struct {
conn *coderws.Conn
}
func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return wsjson.Write(ctx, c.conn, value)
}
func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, error) {
if c == nil || c.conn == nil {
return nil, errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
msgType, payload, err := c.conn.Read(ctx)
if err != nil {
return nil, err
}
switch msgType {
case coderws.MessageText, coderws.MessageBinary:
return payload, nil
default:
return nil, errOpenAIWSConnClosed
}
}
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
}
if ctx == nil {
ctx = context.Background()
}
return c.conn.Ping(ctx)
}
func (c *coderOpenAIWSClientConn) Close() error {
if c == nil || c.conn == nil {
return nil
}
// Close 为幂等,忽略重复关闭错误。
_ = c.conn.Close(coderws.StatusNormalClosure, "")
_ = c.conn.CloseNow()
return nil
}

View File

@@ -0,0 +1,112 @@
package service
import (
"fmt"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestCoderOpenAIWSClientDialer_ProxyHTTPClientReuse(t *testing.T) {
dialer := newDefaultOpenAIWSClientDialer()
impl, ok := dialer.(*coderOpenAIWSClientDialer)
require.True(t, ok)
c1, err := impl.proxyHTTPClient("http://127.0.0.1:8080")
require.NoError(t, err)
c2, err := impl.proxyHTTPClient("http://127.0.0.1:8080")
require.NoError(t, err)
require.Same(t, c1, c2, "同一代理地址应复用同一个 HTTP 客户端")
c3, err := impl.proxyHTTPClient("http://127.0.0.1:8081")
require.NoError(t, err)
require.NotSame(t, c1, c3, "不同代理地址应分离客户端")
}
func TestCoderOpenAIWSClientDialer_ProxyHTTPClientInvalidURL(t *testing.T) {
dialer := newDefaultOpenAIWSClientDialer()
impl, ok := dialer.(*coderOpenAIWSClientDialer)
require.True(t, ok)
_, err := impl.proxyHTTPClient("://bad")
require.Error(t, err)
}
func TestCoderOpenAIWSClientDialer_TransportMetricsSnapshot(t *testing.T) {
dialer := newDefaultOpenAIWSClientDialer()
impl, ok := dialer.(*coderOpenAIWSClientDialer)
require.True(t, ok)
_, err := impl.proxyHTTPClient("http://127.0.0.1:18080")
require.NoError(t, err)
_, err = impl.proxyHTTPClient("http://127.0.0.1:18080")
require.NoError(t, err)
_, err = impl.proxyHTTPClient("http://127.0.0.1:18081")
require.NoError(t, err)
snapshot := impl.SnapshotTransportMetrics()
require.Equal(t, int64(1), snapshot.ProxyClientCacheHits)
require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses)
require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001)
}
func TestCoderOpenAIWSClientDialer_ProxyClientCacheCapacity(t *testing.T) {
dialer := newDefaultOpenAIWSClientDialer()
impl, ok := dialer.(*coderOpenAIWSClientDialer)
require.True(t, ok)
total := openAIWSProxyClientCacheMaxEntries + 32
for i := 0; i < total; i++ {
_, err := impl.proxyHTTPClient(fmt.Sprintf("http://127.0.0.1:%d", 20000+i))
require.NoError(t, err)
}
impl.proxyMu.Lock()
cacheSize := len(impl.proxyClients)
impl.proxyMu.Unlock()
require.LessOrEqual(t, cacheSize, openAIWSProxyClientCacheMaxEntries, "代理客户端缓存应受容量上限约束")
}
func TestCoderOpenAIWSClientDialer_ProxyClientCacheIdleTTL(t *testing.T) {
dialer := newDefaultOpenAIWSClientDialer()
impl, ok := dialer.(*coderOpenAIWSClientDialer)
require.True(t, ok)
oldProxy := "http://127.0.0.1:28080"
_, err := impl.proxyHTTPClient(oldProxy)
require.NoError(t, err)
impl.proxyMu.Lock()
oldEntry := impl.proxyClients[oldProxy]
require.NotNil(t, oldEntry)
oldEntry.lastUsedUnixNano = time.Now().Add(-openAIWSProxyClientCacheIdleTTL - time.Minute).UnixNano()
impl.proxyMu.Unlock()
// 触发一次新的代理获取,驱动 TTL 清理。
_, err = impl.proxyHTTPClient("http://127.0.0.1:28081")
require.NoError(t, err)
impl.proxyMu.Lock()
_, exists := impl.proxyClients[oldProxy]
impl.proxyMu.Unlock()
require.False(t, exists, "超过空闲 TTL 的代理客户端应被回收")
}
func TestCoderOpenAIWSClientDialer_ProxyTransportTLSHandshakeTimeout(t *testing.T) {
dialer := newDefaultOpenAIWSClientDialer()
impl, ok := dialer.(*coderOpenAIWSClientDialer)
require.True(t, ok)
client, err := impl.proxyHTTPClient("http://127.0.0.1:38080")
require.NoError(t, err)
require.NotNil(t, client)
transport, ok := client.Transport.(*http.Transport)
require.True(t, ok)
require.NotNil(t, transport)
require.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout)
}

View File

@@ -0,0 +1,251 @@
package service
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
coderws "github.com/coder/websocket"
"github.com/stretchr/testify/require"
)
func TestClassifyOpenAIWSAcquireError(t *testing.T) {
t.Run("dial_426_upgrade_required", func(t *testing.T) {
err := &openAIWSDialError{StatusCode: 426, Err: errors.New("upgrade required")}
require.Equal(t, "upgrade_required", classifyOpenAIWSAcquireError(err))
})
t.Run("queue_full", func(t *testing.T) {
require.Equal(t, "conn_queue_full", classifyOpenAIWSAcquireError(errOpenAIWSConnQueueFull))
})
t.Run("preferred_conn_unavailable", func(t *testing.T) {
require.Equal(t, "preferred_conn_unavailable", classifyOpenAIWSAcquireError(errOpenAIWSPreferredConnUnavailable))
})
t.Run("acquire_timeout", func(t *testing.T) {
require.Equal(t, "acquire_timeout", classifyOpenAIWSAcquireError(context.DeadlineExceeded))
})
t.Run("auth_failed_401", func(t *testing.T) {
err := &openAIWSDialError{StatusCode: 401, Err: errors.New("unauthorized")}
require.Equal(t, "auth_failed", classifyOpenAIWSAcquireError(err))
})
t.Run("upstream_rate_limited", func(t *testing.T) {
err := &openAIWSDialError{StatusCode: 429, Err: errors.New("rate limited")}
require.Equal(t, "upstream_rate_limited", classifyOpenAIWSAcquireError(err))
})
t.Run("upstream_5xx", func(t *testing.T) {
err := &openAIWSDialError{StatusCode: 502, Err: errors.New("bad gateway")}
require.Equal(t, "upstream_5xx", classifyOpenAIWSAcquireError(err))
})
t.Run("dial_failed_other_status", func(t *testing.T) {
err := &openAIWSDialError{StatusCode: 418, Err: errors.New("teapot")}
require.Equal(t, "dial_failed", classifyOpenAIWSAcquireError(err))
})
t.Run("other", func(t *testing.T) {
require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(errors.New("x")))
})
t.Run("nil", func(t *testing.T) {
require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(nil))
})
}
func TestClassifyOpenAIWSDialError(t *testing.T) {
t.Run("handshake_not_finished", func(t *testing.T) {
err := &openAIWSDialError{
StatusCode: http.StatusBadGateway,
Err: errors.New("WebSocket protocol error: Handshake not finished"),
}
require.Equal(t, "handshake_not_finished", classifyOpenAIWSDialError(err))
})
t.Run("context_deadline", func(t *testing.T) {
err := &openAIWSDialError{
StatusCode: 0,
Err: context.DeadlineExceeded,
}
require.Equal(t, "ctx_deadline_exceeded", classifyOpenAIWSDialError(err))
})
}
func TestSummarizeOpenAIWSDialError(t *testing.T) {
err := &openAIWSDialError{
StatusCode: http.StatusBadGateway,
ResponseHeaders: http.Header{
"Server": []string{"cloudflare"},
"Via": []string{"1.1 example"},
"Cf-Ray": []string{"abcd1234"},
"X-Request-Id": []string{"req_123"},
},
Err: errors.New("WebSocket protocol error: Handshake not finished"),
}
status, class, closeStatus, closeReason, server, via, cfRay, reqID := summarizeOpenAIWSDialError(err)
require.Equal(t, http.StatusBadGateway, status)
require.Equal(t, "handshake_not_finished", class)
require.Equal(t, "-", closeStatus)
require.Equal(t, "-", closeReason)
require.Equal(t, "cloudflare", server)
require.Equal(t, "1.1 example", via)
require.Equal(t, "abcd1234", cfRay)
require.Equal(t, "req_123", reqID)
}
func TestClassifyOpenAIWSErrorEvent(t *testing.T) {
reason, recoverable := classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"upgrade_required","message":"Upgrade required"}}`))
require.Equal(t, "upgrade_required", reason)
require.True(t, recoverable)
reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"previous_response_not_found","message":"not found"}}`))
require.Equal(t, "previous_response_not_found", reason)
require.True(t, recoverable)
}
func TestClassifyOpenAIWSReconnectReason(t *testing.T) {
reason, retryable := classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("policy_violation", errors.New("policy")))
require.Equal(t, "policy_violation", reason)
require.False(t, retryable)
reason, retryable = classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("read_event", errors.New("io")))
require.Equal(t, "read_event", reason)
require.True(t, retryable)
}
func TestOpenAIWSErrorHTTPStatus(t *testing.T) {
require.Equal(t, http.StatusBadRequest, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)))
require.Equal(t, http.StatusUnauthorized, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"authentication_error","code":"invalid_api_key","message":"auth failed"}}`)))
require.Equal(t, http.StatusForbidden, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"permission_error","code":"forbidden","message":"forbidden"}}`)))
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"rate_limit_error","code":"rate_limit_exceeded","message":"rate limited"}}`)))
require.Equal(t, http.StatusBadGateway, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"server_error","code":"server_error","message":"server"}}`)))
}
func TestResolveOpenAIWSFallbackErrorResponse(t *testing.T) {
t.Run("previous_response_not_found", func(t *testing.T) {
statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(
wrapOpenAIWSFallback("previous_response_not_found", errors.New("previous response not found")),
)
require.True(t, ok)
require.Equal(t, http.StatusBadRequest, statusCode)
require.Equal(t, "invalid_request_error", errType)
require.Equal(t, "previous response not found", clientMessage)
require.Equal(t, "previous response not found", upstreamMessage)
})
t.Run("auth_failed_uses_dial_status", func(t *testing.T) {
statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(
wrapOpenAIWSFallback("auth_failed", &openAIWSDialError{
StatusCode: http.StatusForbidden,
Err: errors.New("forbidden"),
}),
)
require.True(t, ok)
require.Equal(t, http.StatusForbidden, statusCode)
require.Equal(t, "upstream_error", errType)
require.Equal(t, "forbidden", clientMessage)
require.Equal(t, "forbidden", upstreamMessage)
})
t.Run("non_fallback_error_not_resolved", func(t *testing.T) {
_, _, _, _, ok := resolveOpenAIWSFallbackErrorResponse(errors.New("plain error"))
require.False(t, ok)
})
}
func TestOpenAIWSFallbackCooling(t *testing.T) {
svc := &OpenAIGatewayService{cfg: &config.Config{}}
svc.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
require.False(t, svc.isOpenAIWSFallbackCooling(1))
svc.markOpenAIWSFallbackCooling(1, "upgrade_required")
require.True(t, svc.isOpenAIWSFallbackCooling(1))
svc.clearOpenAIWSFallbackCooling(1)
require.False(t, svc.isOpenAIWSFallbackCooling(1))
svc.markOpenAIWSFallbackCooling(2, "x")
time.Sleep(1200 * time.Millisecond)
require.False(t, svc.isOpenAIWSFallbackCooling(2))
}
func TestOpenAIWSRetryBackoff(t *testing.T) {
svc := &OpenAIGatewayService{cfg: &config.Config{}}
svc.cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 100
svc.cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 400
svc.cfg.Gateway.OpenAIWS.RetryJitterRatio = 0
require.Equal(t, time.Duration(100)*time.Millisecond, svc.openAIWSRetryBackoff(1))
require.Equal(t, time.Duration(200)*time.Millisecond, svc.openAIWSRetryBackoff(2))
require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(3))
require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(4))
}
func TestOpenAIWSRetryTotalBudget(t *testing.T) {
svc := &OpenAIGatewayService{cfg: &config.Config{}}
svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 1200
require.Equal(t, 1200*time.Millisecond, svc.openAIWSRetryTotalBudget())
svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 0
require.Equal(t, time.Duration(0), svc.openAIWSRetryTotalBudget())
}
func TestClassifyOpenAIWSReadFallbackReason(t *testing.T) {
require.Equal(t, "policy_violation", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusPolicyViolation}))
require.Equal(t, "message_too_big", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusMessageTooBig}))
require.Equal(t, "read_event", classifyOpenAIWSReadFallbackReason(errors.New("io")))
}
func TestOpenAIWSStoreDisabledConnMode(t *testing.T) {
svc := &OpenAIGatewayService{cfg: &config.Config{}}
svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true
require.Equal(t, openAIWSStoreDisabledConnModeStrict, svc.openAIWSStoreDisabledConnMode())
svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = "adaptive"
require.Equal(t, openAIWSStoreDisabledConnModeAdaptive, svc.openAIWSStoreDisabledConnMode())
svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = ""
svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false
require.Equal(t, openAIWSStoreDisabledConnModeOff, svc.openAIWSStoreDisabledConnMode())
}
func TestShouldForceNewConnOnStoreDisabled(t *testing.T) {
require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeStrict, ""))
require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeOff, "policy_violation"))
require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "policy_violation"))
require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "prewarm_message_too_big"))
require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "read_event"))
}
func TestOpenAIWSRetryMetricsSnapshot(t *testing.T) {
svc := &OpenAIGatewayService{}
svc.recordOpenAIWSRetryAttempt(150 * time.Millisecond)
svc.recordOpenAIWSRetryAttempt(0)
svc.recordOpenAIWSRetryExhausted()
svc.recordOpenAIWSNonRetryableFastFallback()
snapshot := svc.SnapshotOpenAIWSRetryMetrics()
require.Equal(t, int64(2), snapshot.RetryAttemptsTotal)
require.Equal(t, int64(150), snapshot.RetryBackoffMsTotal)
require.Equal(t, int64(1), snapshot.RetryExhaustedTotal)
require.Equal(t, int64(1), snapshot.NonRetryableFastFallbackTotal)
}
func TestShouldLogOpenAIWSPayloadSchema(t *testing.T) {
svc := &OpenAIGatewayService{cfg: &config.Config{}}
svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 0
require.True(t, svc.shouldLogOpenAIWSPayloadSchema(1), "首次尝试应始终记录 payload_schema")
require.False(t, svc.shouldLogOpenAIWSPayloadSchema(2))
svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 1
require.True(t, svc.shouldLogOpenAIWSPayloadSchema(2))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,127 @@
package service
import (
"fmt"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
var (
benchmarkOpenAIWSPayloadJSONSink string
benchmarkOpenAIWSStringSink string
benchmarkOpenAIWSBoolSink bool
benchmarkOpenAIWSBytesSink []byte
)
func BenchmarkOpenAIWSForwarderHotPath(b *testing.B) {
cfg := &config.Config{}
svc := &OpenAIGatewayService{cfg: cfg}
account := &Account{ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
reqBody := benchmarkOpenAIWSHotPathRequest()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
payload := svc.buildOpenAIWSCreatePayload(reqBody, account)
_, _ = applyOpenAIWSRetryPayloadStrategy(payload, 2)
setOpenAIWSTurnMetadata(payload, `{"trace":"bench","turn":"1"}`)
benchmarkOpenAIWSStringSink = openAIWSPayloadString(payload, "previous_response_id")
benchmarkOpenAIWSBoolSink = payload["tools"] != nil
benchmarkOpenAIWSStringSink = summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN)
benchmarkOpenAIWSStringSink = summarizeOpenAIWSInput(payload["input"])
benchmarkOpenAIWSPayloadJSONSink = payloadAsJSON(payload)
}
}
func benchmarkOpenAIWSHotPathRequest() map[string]any {
tools := make([]map[string]any, 0, 24)
for i := 0; i < 24; i++ {
tools = append(tools, map[string]any{
"type": "function",
"name": fmt.Sprintf("tool_%02d", i),
"description": "benchmark tool schema",
"parameters": map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string"},
"limit": map[string]any{"type": "number"},
},
"required": []string{"query"},
},
})
}
input := make([]map[string]any, 0, 16)
for i := 0; i < 16; i++ {
input = append(input, map[string]any{
"role": "user",
"type": "message",
"content": fmt.Sprintf("benchmark message %d", i),
})
}
return map[string]any{
"type": "response.create",
"model": "gpt-5.3-codex",
"input": input,
"tools": tools,
"parallel_tool_calls": true,
"previous_response_id": "resp_benchmark_prev",
"prompt_cache_key": "bench-cache-key",
"reasoning": map[string]any{"effort": "medium"},
"instructions": "benchmark instructions",
"store": false,
}
}
func BenchmarkOpenAIWSEventEnvelopeParse(b *testing.B) {
event := []byte(`{"type":"response.completed","response":{"id":"resp_bench_1","model":"gpt-5.1","usage":{"input_tokens":12,"output_tokens":8}}}`)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
eventType, responseID, response := parseOpenAIWSEventEnvelope(event)
benchmarkOpenAIWSStringSink = eventType
benchmarkOpenAIWSStringSink = responseID
benchmarkOpenAIWSBoolSink = response.Exists()
}
}
func BenchmarkOpenAIWSErrorEventFieldReuse(b *testing.B) {
event := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(event)
benchmarkOpenAIWSStringSink, benchmarkOpenAIWSBoolSink = classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw)
code, errType, errMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw)
benchmarkOpenAIWSStringSink = code
benchmarkOpenAIWSStringSink = errType
benchmarkOpenAIWSStringSink = errMsg
benchmarkOpenAIWSBoolSink = openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) > 0
}
}
func BenchmarkReplaceOpenAIWSMessageModel_NoMatchFastPath(b *testing.B) {
event := []byte(`{"type":"response.output_text.delta","delta":"hello world"}`)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model")
}
}
func BenchmarkReplaceOpenAIWSMessageModel_DualReplace(b *testing.B) {
event := []byte(`{"type":"response.completed","model":"gpt-5.1","response":{"id":"resp_1","model":"gpt-5.1"}}`)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model")
}
}

View File

@@ -0,0 +1,73 @@
package service
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestParseOpenAIWSEventEnvelope(t *testing.T) {
eventType, responseID, response := parseOpenAIWSEventEnvelope([]byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.1"}}`))
require.Equal(t, "response.completed", eventType)
require.Equal(t, "resp_1", responseID)
require.True(t, response.Exists())
require.Equal(t, `{"id":"resp_1","model":"gpt-5.1"}`, response.Raw)
eventType, responseID, response = parseOpenAIWSEventEnvelope([]byte(`{"type":"response.delta","id":"evt_1"}`))
require.Equal(t, "response.delta", eventType)
require.Equal(t, "evt_1", responseID)
require.False(t, response.Exists())
}
func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) {
usage := &OpenAIUsage{}
parseOpenAIWSResponseUsageFromCompletedEvent(
[]byte(`{"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`),
usage,
)
require.Equal(t, 11, usage.InputTokens)
require.Equal(t, 7, usage.OutputTokens)
require.Equal(t, 3, usage.CacheReadInputTokens)
}
func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) {
message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
wrappedReason, wrappedRecoverable := classifyOpenAIWSErrorEvent(message)
rawReason, rawRecoverable := classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw)
require.Equal(t, wrappedReason, rawReason)
require.Equal(t, wrappedRecoverable, rawRecoverable)
wrappedStatus := openAIWSErrorHTTPStatus(message)
rawStatus := openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw)
require.Equal(t, wrappedStatus, rawStatus)
require.Equal(t, http.StatusBadRequest, rawStatus)
wrappedCode, wrappedType, wrappedMsg := summarizeOpenAIWSErrorEventFields(message)
rawCode, rawType, rawMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw)
require.Equal(t, wrappedCode, rawCode)
require.Equal(t, wrappedType, rawType)
require.Equal(t, wrappedMsg, rawMsg)
}
func TestOpenAIWSMessageLikelyContainsToolCalls(t *testing.T) {
require.False(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_text.delta","delta":"hello"}`)))
require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"tool_calls":[{"id":"tc1"}]}}`)))
require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"type":"function_call"}}`)))
}
func TestReplaceOpenAIWSMessageModel_OptimizedStillCorrect(t *testing.T) {
noModel := []byte(`{"type":"response.output_text.delta","delta":"hello"}`)
require.Equal(t, string(noModel), string(replaceOpenAIWSMessageModel(noModel, "gpt-5.1", "custom-model")))
rootOnly := []byte(`{"type":"response.created","model":"gpt-5.1"}`)
require.Equal(t, `{"type":"response.created","model":"custom-model"}`, string(replaceOpenAIWSMessageModel(rootOnly, "gpt-5.1", "custom-model")))
responseOnly := []byte(`{"type":"response.completed","response":{"model":"gpt-5.1"}}`)
require.Equal(t, `{"type":"response.completed","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(responseOnly, "gpt-5.1", "custom-model")))
both := []byte(`{"model":"gpt-5.1","response":{"model":"gpt-5.1"}}`)
require.Equal(t, `{"model":"custom-model","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(both, "gpt-5.1", "custom-model")))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,714 @@
package service
import (
"context"
"encoding/json"
"errors"
"io"
"net"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
coderws "github.com/coder/websocket"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestIsOpenAIWSClientDisconnectError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
err error
want bool
}{
{name: "nil", err: nil, want: false},
{name: "io_eof", err: io.EOF, want: true},
{name: "net_closed", err: net.ErrClosed, want: true},
{name: "context_canceled", err: context.Canceled, want: true},
{name: "ws_normal_closure", err: coderws.CloseError{Code: coderws.StatusNormalClosure}, want: true},
{name: "ws_going_away", err: coderws.CloseError{Code: coderws.StatusGoingAway}, want: true},
{name: "ws_no_status", err: coderws.CloseError{Code: coderws.StatusNoStatusRcvd}, want: true},
{name: "ws_abnormal_1006", err: coderws.CloseError{Code: coderws.StatusAbnormalClosure}, want: true},
{name: "ws_policy_violation", err: coderws.CloseError{Code: coderws.StatusPolicyViolation}, want: false},
{name: "wrapped_eof_message", err: errors.New("failed to get reader: failed to read frame header: EOF"), want: true},
{name: "connection_reset_by_peer", err: errors.New("failed to read frame header: read tcp 127.0.0.1:1234->127.0.0.1:5678: read: connection reset by peer"), want: true},
{name: "broken_pipe", err: errors.New("write tcp 127.0.0.1:1234->127.0.0.1:5678: write: broken pipe"), want: true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, isOpenAIWSClientDisconnectError(tt.err))
})
}
}
func TestIsOpenAIWSIngressPreviousResponseNotFound(t *testing.T) {
t.Parallel()
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(nil))
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(errors.New("plain error")))
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(
wrapOpenAIWSIngressTurnError("read_upstream", errors.New("upstream read failed"), false),
))
require.False(t, isOpenAIWSIngressPreviousResponseNotFound(
wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), true),
))
require.True(t, isOpenAIWSIngressPreviousResponseNotFound(
wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), false),
))
}
func TestOpenAIWSIngressPreviousResponseRecoveryEnabled(t *testing.T) {
t.Parallel()
var nilService *OpenAIGatewayService
require.True(t, nilService.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil service should default to enabled")
svcWithNilCfg := &OpenAIGatewayService{}
require.True(t, svcWithNilCfg.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil config should default to enabled")
svc := &OpenAIGatewayService{
cfg: &config.Config{},
}
require.False(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled(), "explicit config default should be false")
svc.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true
require.True(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled())
}
func TestDropPreviousResponseIDFromRawPayload(t *testing.T) {
t.Parallel()
t.Run("empty_payload", func(t *testing.T) {
updated, removed, err := dropPreviousResponseIDFromRawPayload(nil)
require.NoError(t, err)
require.False(t, removed)
require.Empty(t, updated)
})
t.Run("payload_without_previous_response_id", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
require.NoError(t, err)
require.False(t, removed)
require.Equal(t, string(payload), string(updated))
})
t.Run("normal_delete_success", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`)
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
require.NoError(t, err)
require.True(t, removed)
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
})
t.Run("duplicate_keys_are_removed", func(t *testing.T) {
payload := []byte(`{"type":"response.create","previous_response_id":"resp_a","input":[],"previous_response_id":"resp_b"}`)
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
require.NoError(t, err)
require.True(t, removed)
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
})
t.Run("nil_delete_fn_uses_default_delete_logic", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`)
updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, nil)
require.NoError(t, err)
require.True(t, removed)
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
})
t.Run("delete_error", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`)
updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, func(_ []byte, _ string) ([]byte, error) {
return nil, errors.New("delete failed")
})
require.Error(t, err)
require.False(t, removed)
require.Equal(t, string(payload), string(updated))
})
t.Run("malformed_json_is_still_best_effort_deleted", func(t *testing.T) {
payload := []byte(`{"type":"response.create","previous_response_id":"resp_abc"`)
require.True(t, gjson.GetBytes(payload, "previous_response_id").Exists())
updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
require.NoError(t, err)
require.True(t, removed)
require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
})
}
func TestAlignStoreDisabledPreviousResponseID(t *testing.T) {
t.Parallel()
t.Run("empty_payload", func(t *testing.T) {
updated, changed, err := alignStoreDisabledPreviousResponseID(nil, "resp_target")
require.NoError(t, err)
require.False(t, changed)
require.Empty(t, updated)
})
t.Run("empty_expected", func(t *testing.T) {
payload := []byte(`{"type":"response.create","previous_response_id":"resp_old"}`)
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "")
require.NoError(t, err)
require.False(t, changed)
require.Equal(t, string(payload), string(updated))
})
t.Run("missing_previous_response_id", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
require.NoError(t, err)
require.False(t, changed)
require.Equal(t, string(payload), string(updated))
})
t.Run("already_aligned", func(t *testing.T) {
payload := []byte(`{"type":"response.create","previous_response_id":"resp_target"}`)
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
require.NoError(t, err)
require.False(t, changed)
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
})
t.Run("mismatch_rewrites_to_expected", func(t *testing.T) {
payload := []byte(`{"type":"response.create","previous_response_id":"resp_old","input":[]}`)
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
require.NoError(t, err)
require.True(t, changed)
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
})
t.Run("duplicate_keys_rewrites_to_single_expected", func(t *testing.T) {
payload := []byte(`{"type":"response.create","previous_response_id":"resp_old_1","input":[],"previous_response_id":"resp_old_2"}`)
updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
require.NoError(t, err)
require.True(t, changed)
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
})
}
func TestSetPreviousResponseIDToRawPayload(t *testing.T) {
t.Parallel()
t.Run("empty_payload", func(t *testing.T) {
updated, err := setPreviousResponseIDToRawPayload(nil, "resp_target")
require.NoError(t, err)
require.Empty(t, updated)
})
t.Run("empty_previous_response_id", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
updated, err := setPreviousResponseIDToRawPayload(payload, "")
require.NoError(t, err)
require.Equal(t, string(payload), string(updated))
})
t.Run("set_previous_response_id_when_missing", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
updated, err := setPreviousResponseIDToRawPayload(payload, "resp_target")
require.NoError(t, err)
require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
require.Equal(t, "gpt-5.1", gjson.GetBytes(updated, "model").String())
})
t.Run("overwrite_existing_previous_response_id", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_old"}`)
updated, err := setPreviousResponseIDToRawPayload(payload, "resp_new")
require.NoError(t, err)
require.Equal(t, "resp_new", gjson.GetBytes(updated, "previous_response_id").String())
})
}
func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
storeDisabled bool
turn int
hasFunctionCallOutput bool
currentPreviousResponse string
expectedPrevious string
want bool
}{
{
name: "infer_when_all_conditions_match",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: "resp_1",
want: true,
},
{
name: "skip_when_store_enabled",
storeDisabled: false,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_on_first_turn",
storeDisabled: true,
turn: 1,
hasFunctionCallOutput: true,
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_without_function_call_output",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: false,
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_when_request_already_has_previous_response_id",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
currentPreviousResponse: "resp_client",
expectedPrevious: "resp_1",
want: false,
},
{
name: "skip_when_last_turn_response_id_missing",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: "",
want: false,
},
{
name: "trim_whitespace_before_judgement",
storeDisabled: true,
turn: 2,
hasFunctionCallOutput: true,
expectedPrevious: " resp_2 ",
want: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := shouldInferIngressFunctionCallOutputPreviousResponseID(
tt.storeDisabled,
tt.turn,
tt.hasFunctionCallOutput,
tt.currentPreviousResponse,
tt.expectedPrevious,
)
require.Equal(t, tt.want, got)
})
}
}
func TestOpenAIWSInputIsPrefixExtended(t *testing.T) {
t.Parallel()
tests := []struct {
name string
previous []byte
current []byte
want bool
expectErr bool
}{
{
name: "both_missing_input",
previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`),
current: []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_1"}`),
want: true,
},
{
name: "previous_missing_current_empty_array",
previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`),
current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`),
want: true,
},
{
name: "previous_missing_current_non_empty_array",
previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`),
current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[{"type":"input_text","text":"hello"}]}`),
want: false,
},
{
name: "array_prefix_match",
previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`),
current: []byte(`{"input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}]}`),
want: true,
},
{
name: "array_prefix_mismatch",
previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`),
current: []byte(`{"input":[{"type":"input_text","text":"different"}]}`),
want: false,
},
{
name: "current_shorter_than_previous",
previous: []byte(`{"input":[{"type":"input_text","text":"a"},{"type":"input_text","text":"b"}]}`),
current: []byte(`{"input":[{"type":"input_text","text":"a"}]}`),
want: false,
},
{
name: "previous_has_input_current_missing",
previous: []byte(`{"input":[{"type":"input_text","text":"a"}]}`),
current: []byte(`{"model":"gpt-5.1"}`),
want: false,
},
{
name: "input_string_treated_as_single_item",
previous: []byte(`{"input":"hello"}`),
current: []byte(`{"input":"hello"}`),
want: true,
},
{
name: "current_invalid_input_json",
previous: []byte(`{"input":[]}`),
current: []byte(`{"input":[}`),
expectErr: true,
},
{
name: "invalid_input_json",
previous: []byte(`{"input":[}`),
current: []byte(`{"input":[]}`),
expectErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := openAIWSInputIsPrefixExtended(tt.previous, tt.current)
if tt.expectErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.want, got)
})
}
}
func TestNormalizeOpenAIWSJSONForCompare(t *testing.T) {
t.Parallel()
normalized, err := normalizeOpenAIWSJSONForCompare([]byte(`{"b":2,"a":1}`))
require.NoError(t, err)
require.Equal(t, `{"a":1,"b":2}`, string(normalized))
_, err = normalizeOpenAIWSJSONForCompare([]byte(" "))
require.Error(t, err)
_, err = normalizeOpenAIWSJSONForCompare([]byte(`{"a":`))
require.Error(t, err)
}
func TestNormalizeOpenAIWSJSONForCompareOrRaw(t *testing.T) {
t.Parallel()
require.Equal(t, `{"a":1,"b":2}`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"b":2,"a":1}`))))
require.Equal(t, `{"a":`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"a":`))))
}
func TestNormalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(t *testing.T) {
t.Parallel()
normalized, err := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(
[]byte(`{"model":"gpt-5.1","input":[1],"previous_response_id":"resp_x","metadata":{"b":2,"a":1}}`),
)
require.NoError(t, err)
require.False(t, gjson.GetBytes(normalized, "input").Exists())
require.False(t, gjson.GetBytes(normalized, "previous_response_id").Exists())
require.Equal(t, float64(1), gjson.GetBytes(normalized, "metadata.a").Float())
_, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(nil)
require.Error(t, err)
_, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID([]byte(`[]`))
require.Error(t, err)
}
func TestOpenAIWSExtractNormalizedInputSequence(t *testing.T) {
t.Parallel()
t.Run("empty_payload", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence(nil)
require.NoError(t, err)
require.False(t, exists)
require.Nil(t, items)
})
t.Run("input_missing", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"type":"response.create"}`))
require.NoError(t, err)
require.False(t, exists)
require.Nil(t, items)
})
t.Run("input_array", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[{"type":"input_text","text":"hello"}]}`))
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 1)
})
t.Run("input_object", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":{"type":"input_text","text":"hello"}}`))
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 1)
})
t.Run("input_string", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":"hello"}`))
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 1)
require.Equal(t, `"hello"`, string(items[0]))
})
t.Run("input_number", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":42}`))
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 1)
require.Equal(t, "42", string(items[0]))
})
t.Run("input_bool", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":true}`))
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 1)
require.Equal(t, "true", string(items[0]))
})
t.Run("input_null", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":null}`))
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 1)
require.Equal(t, "null", string(items[0]))
})
t.Run("input_invalid_array_json", func(t *testing.T) {
items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[}`))
require.Error(t, err)
require.True(t, exists)
require.Nil(t, items)
})
}
func TestShouldKeepIngressPreviousResponseID(t *testing.T) {
t.Parallel()
previousPayload := []byte(`{
"type":"response.create",
"model":"gpt-5.1",
"store":false,
"tools":[{"type":"function","name":"tool_a"}],
"input":[{"type":"input_text","text":"hello"}]
}`)
currentStrictPayload := []byte(`{
"type":"response.create",
"model":"gpt-5.1",
"store":false,
"tools":[{"name":"tool_a","type":"function"}],
"previous_response_id":"resp_turn_1",
"input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}]
}`)
t.Run("strict_incremental_keep", func(t *testing.T) {
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_1", false)
require.NoError(t, err)
require.True(t, keep)
require.Equal(t, "strict_incremental_ok", reason)
})
t.Run("missing_previous_response_id", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`)
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "missing_previous_response_id", reason)
})
t.Run("missing_last_turn_response_id", func(t *testing.T) {
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "", false)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "missing_last_turn_response_id", reason)
})
t.Run("previous_response_id_mismatch", func(t *testing.T) {
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_other", false)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "previous_response_id_mismatch", reason)
})
t.Run("missing_previous_turn_payload", func(t *testing.T) {
keep, reason, err := shouldKeepIngressPreviousResponseID(nil, currentStrictPayload, "resp_turn_1", false)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "missing_previous_turn_payload", reason)
})
t.Run("non_input_changed", func(t *testing.T) {
payload := []byte(`{
"type":"response.create",
"model":"gpt-5.1-mini",
"store":false,
"tools":[{"type":"function","name":"tool_a"}],
"previous_response_id":"resp_turn_1",
"input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]
}`)
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "non_input_changed", reason)
})
t.Run("delta_input_keeps_previous_response_id", func(t *testing.T) {
payload := []byte(`{
"type":"response.create",
"model":"gpt-5.1",
"store":false,
"tools":[{"type":"function","name":"tool_a"}],
"previous_response_id":"resp_turn_1",
"input":[{"type":"input_text","text":"different"}]
}`)
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
require.NoError(t, err)
require.True(t, keep)
require.Equal(t, "strict_incremental_ok", reason)
})
t.Run("function_call_output_keeps_previous_response_id", func(t *testing.T) {
payload := []byte(`{
"type":"response.create",
"model":"gpt-5.1",
"store":false,
"previous_response_id":"resp_external",
"input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
}`)
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", true)
require.NoError(t, err)
require.True(t, keep)
require.Equal(t, "has_function_call_output", reason)
})
t.Run("non_input_compare_error", func(t *testing.T) {
keep, reason, err := shouldKeepIngressPreviousResponseID([]byte(`[]`), currentStrictPayload, "resp_turn_1", false)
require.Error(t, err)
require.False(t, keep)
require.Equal(t, "non_input_compare_error", reason)
})
t.Run("current_payload_compare_error", func(t *testing.T) {
keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, []byte(`{"previous_response_id":"resp_turn_1","input":[}`), "resp_turn_1", false)
require.Error(t, err)
require.False(t, keep)
require.Equal(t, "non_input_compare_error", reason)
})
}
func TestBuildOpenAIWSReplayInputSequence(t *testing.T) {
t.Parallel()
lastFull := []json.RawMessage{
json.RawMessage(`{"type":"input_text","text":"hello"}`),
}
t.Run("no_previous_response_id_use_current", func(t *testing.T) {
items, exists, err := buildOpenAIWSReplayInputSequence(
lastFull,
true,
[]byte(`{"input":[{"type":"input_text","text":"new"}]}`),
false,
)
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 1)
require.Equal(t, "new", gjson.GetBytes(items[0], "text").String())
})
t.Run("previous_response_id_delta_append", func(t *testing.T) {
items, exists, err := buildOpenAIWSReplayInputSequence(
lastFull,
true,
[]byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"world"}]}`),
true,
)
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 2)
require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String())
require.Equal(t, "world", gjson.GetBytes(items[1], "text").String())
})
t.Run("previous_response_id_full_input_replace", func(t *testing.T) {
items, exists, err := buildOpenAIWSReplayInputSequence(
lastFull,
true,
[]byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]}`),
true,
)
require.NoError(t, err)
require.True(t, exists)
require.Len(t, items, 2)
require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String())
require.Equal(t, "world", gjson.GetBytes(items[1], "text").String())
})
}
func TestSetOpenAIWSPayloadInputSequence(t *testing.T) {
t.Parallel()
t.Run("set_items", func(t *testing.T) {
original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`)
items := []json.RawMessage{
json.RawMessage(`{"type":"input_text","text":"hello"}`),
json.RawMessage(`{"type":"input_text","text":"world"}`),
}
updated, err := setOpenAIWSPayloadInputSequence(original, items, true)
require.NoError(t, err)
require.Equal(t, "hello", gjson.GetBytes(updated, "input.0.text").String())
require.Equal(t, "world", gjson.GetBytes(updated, "input.1.text").String())
})
t.Run("preserve_empty_array_not_null", func(t *testing.T) {
original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`)
updated, err := setOpenAIWSPayloadInputSequence(original, nil, true)
require.NoError(t, err)
require.True(t, gjson.GetBytes(updated, "input").IsArray())
require.Len(t, gjson.GetBytes(updated, "input").Array(), 0)
require.False(t, gjson.GetBytes(updated, "input").Type == gjson.Null)
})
}
func TestCloneOpenAIWSRawMessages(t *testing.T) {
t.Parallel()
t.Run("nil_slice", func(t *testing.T) {
cloned := cloneOpenAIWSRawMessages(nil)
require.Nil(t, cloned)
})
t.Run("empty_slice", func(t *testing.T) {
items := make([]json.RawMessage, 0)
cloned := cloneOpenAIWSRawMessages(items)
require.NotNil(t, cloned)
require.Len(t, cloned, 0)
})
}

View File

@@ -0,0 +1,50 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestApplyOpenAIWSRetryPayloadStrategy_KeepPromptCacheKey(t *testing.T) {
payload := map[string]any{
"model": "gpt-5.3-codex",
"prompt_cache_key": "pcache_123",
"include": []any{"reasoning.encrypted_content"},
"text": map[string]any{
"verbosity": "low",
},
"tools": []any{map[string]any{"type": "function"}},
}
strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 3)
require.Equal(t, "trim_optional_fields", strategy)
require.Contains(t, removed, "include")
require.NotContains(t, removed, "prompt_cache_key")
require.Equal(t, "pcache_123", payload["prompt_cache_key"])
require.NotContains(t, payload, "include")
require.Contains(t, payload, "text")
}
func TestApplyOpenAIWSRetryPayloadStrategy_AttemptSixKeepsSemanticFields(t *testing.T) {
payload := map[string]any{
"prompt_cache_key": "pcache_456",
"instructions": "long instructions",
"tools": []any{map[string]any{"type": "function"}},
"parallel_tool_calls": true,
"tool_choice": "auto",
"include": []any{"reasoning.encrypted_content"},
"text": map[string]any{"verbosity": "high"},
}
strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 6)
require.Equal(t, "trim_optional_fields", strategy)
require.Contains(t, removed, "include")
require.NotContains(t, removed, "prompt_cache_key")
require.Equal(t, "pcache_456", payload["prompt_cache_key"])
require.Contains(t, payload, "instructions")
require.Contains(t, payload, "tools")
require.Contains(t, payload, "tool_choice")
require.Contains(t, payload, "parallel_tool_calls")
require.Contains(t, payload, "text")
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,58 @@
package service
import (
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
func BenchmarkOpenAIWSPoolAcquire(b *testing.B) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 256
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(&openAIWSCountingDialer{})
account := &Account{ID: 1001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
req := openAIWSAcquireRequest{
Account: account,
WSURL: "wss://example.com/v1/responses",
}
ctx := context.Background()
lease, err := pool.Acquire(ctx, req)
if err != nil {
b.Fatalf("warm acquire failed: %v", err)
}
lease.Release()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var (
got *openAIWSConnLease
acquireErr error
)
for retry := 0; retry < 3; retry++ {
got, acquireErr = pool.Acquire(ctx, req)
if acquireErr == nil {
break
}
if !errors.Is(acquireErr, errOpenAIWSConnClosed) {
break
}
}
if acquireErr != nil {
b.Fatalf("acquire failed: %v", acquireErr)
}
got.Release()
}
})
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,117 @@
package service
import "github.com/Wei-Shaw/sub2api/internal/config"
// OpenAIUpstreamTransport 表示 OpenAI 上游传输协议。
type OpenAIUpstreamTransport string
const (
OpenAIUpstreamTransportAny OpenAIUpstreamTransport = ""
OpenAIUpstreamTransportHTTPSSE OpenAIUpstreamTransport = "http_sse"
OpenAIUpstreamTransportResponsesWebsocket OpenAIUpstreamTransport = "responses_websockets"
OpenAIUpstreamTransportResponsesWebsocketV2 OpenAIUpstreamTransport = "responses_websockets_v2"
)
// OpenAIWSProtocolDecision 表示协议决策结果。
type OpenAIWSProtocolDecision struct {
Transport OpenAIUpstreamTransport
Reason string
}
// OpenAIWSProtocolResolver 定义 OpenAI 上游协议决策。
type OpenAIWSProtocolResolver interface {
Resolve(account *Account) OpenAIWSProtocolDecision
}
type defaultOpenAIWSProtocolResolver struct {
cfg *config.Config
}
// NewOpenAIWSProtocolResolver 创建默认协议决策器。
func NewOpenAIWSProtocolResolver(cfg *config.Config) OpenAIWSProtocolResolver {
return &defaultOpenAIWSProtocolResolver{cfg: cfg}
}
func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProtocolDecision {
if account == nil {
return openAIWSHTTPDecision("account_missing")
}
if !account.IsOpenAI() {
return openAIWSHTTPDecision("platform_not_openai")
}
if account.IsOpenAIWSForceHTTPEnabled() {
return openAIWSHTTPDecision("account_force_http")
}
if r == nil || r.cfg == nil {
return openAIWSHTTPDecision("config_missing")
}
wsCfg := r.cfg.Gateway.OpenAIWS
if wsCfg.ForceHTTP {
return openAIWSHTTPDecision("global_force_http")
}
if !wsCfg.Enabled {
return openAIWSHTTPDecision("global_disabled")
}
if account.IsOpenAIOAuth() {
if !wsCfg.OAuthEnabled {
return openAIWSHTTPDecision("oauth_disabled")
}
} else if account.IsOpenAIApiKey() {
if !wsCfg.APIKeyEnabled {
return openAIWSHTTPDecision("apikey_disabled")
}
} else {
return openAIWSHTTPDecision("unknown_auth_type")
}
if wsCfg.ModeRouterV2Enabled {
mode := account.ResolveOpenAIResponsesWebSocketV2Mode(wsCfg.IngressModeDefault)
switch mode {
case OpenAIWSIngressModeOff:
return openAIWSHTTPDecision("account_mode_off")
case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
// continue
default:
return openAIWSHTTPDecision("account_mode_off")
}
if account.Concurrency <= 0 {
return openAIWSHTTPDecision("account_concurrency_invalid")
}
if wsCfg.ResponsesWebsocketsV2 {
return OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
Reason: "ws_v2_mode_" + mode,
}
}
if wsCfg.ResponsesWebsockets {
return OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportResponsesWebsocket,
Reason: "ws_v1_mode_" + mode,
}
}
return openAIWSHTTPDecision("feature_disabled")
}
if !account.IsOpenAIResponsesWebSocketV2Enabled() {
return openAIWSHTTPDecision("account_disabled")
}
if wsCfg.ResponsesWebsocketsV2 {
return OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
Reason: "ws_v2_enabled",
}
}
if wsCfg.ResponsesWebsockets {
return OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportResponsesWebsocket,
Reason: "ws_v1_enabled",
}
}
return openAIWSHTTPDecision("feature_disabled")
}
func openAIWSHTTPDecision(reason string) OpenAIWSProtocolDecision {
return OpenAIWSProtocolDecision{
Transport: OpenAIUpstreamTransportHTTPSSE,
Reason: reason,
}
}

View File

@@ -0,0 +1,203 @@
package service
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestOpenAIWSProtocolResolver_Resolve(t *testing.T) {
baseCfg := &config.Config{}
baseCfg.Gateway.OpenAIWS.Enabled = true
baseCfg.Gateway.OpenAIWS.OAuthEnabled = true
baseCfg.Gateway.OpenAIWS.APIKeyEnabled = true
baseCfg.Gateway.OpenAIWS.ResponsesWebsockets = false
baseCfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
openAIOAuthEnabled := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
},
}
t.Run("v2优先", func(t *testing.T) {
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_enabled", decision.Reason)
})
t.Run("v2关闭时回退v1", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocket, decision.Transport)
require.Equal(t, "ws_v1_enabled", decision.Reason)
})
t.Run("透传开关不影响WS协议判定", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
"openai_passthrough": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_enabled", decision.Reason)
})
t.Run("账号级强制HTTP", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
"openai_ws_force_http": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_force_http", decision.Reason)
})
t.Run("全局关闭保持HTTP", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.Enabled = false
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "global_disabled", decision.Reason)
})
t.Run("账号开关关闭保持HTTP", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_oauth_responses_websockets_v2_enabled": false,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_disabled", decision.Reason)
})
t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_disabled", decision.Reason)
})
t.Run("兼容旧键openai_ws_enabled", func(t *testing.T) {
account := *openAIOAuthEnabled
account.Extra = map[string]any{
"openai_ws_enabled": true,
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_enabled", decision.Reason)
})
t.Run("按账号类型开关控制", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.OAuthEnabled = false
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "oauth_disabled", decision.Reason)
})
t.Run("API Key 账号关闭开关时回退HTTP", func(t *testing.T) {
cfg := *baseCfg
cfg.Gateway.OpenAIWS.APIKeyEnabled = false
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "apikey_disabled", decision.Reason)
})
t.Run("未知认证类型回退HTTP", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: "unknown_type",
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "unknown_auth_type", decision.Reason)
})
}
func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
t.Run("dedicated mode routes to ws v2", func(t *testing.T) {
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_dedicated", decision.Reason)
})
t.Run("off mode routes to http", func(t *testing.T) {
offAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(offAccount)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_mode_off", decision.Reason)
})
t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) {
legacyAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
require.Equal(t, "ws_v2_mode_shared", decision.Reason)
})
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
invalidConcurrency := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)
require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
require.Equal(t, "account_concurrency_invalid", decision.Reason)
})
}

View File

@@ -0,0 +1,440 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
)
const (
openAIWSResponseAccountCachePrefix = "openai:response:"
openAIWSStateStoreCleanupInterval = time.Minute
openAIWSStateStoreCleanupMaxPerMap = 512
openAIWSStateStoreMaxEntriesPerMap = 65536
openAIWSStateStoreRedisTimeout = 3 * time.Second
)
type openAIWSAccountBinding struct {
accountID int64
expiresAt time.Time
}
type openAIWSConnBinding struct {
connID string
expiresAt time.Time
}
type openAIWSTurnStateBinding struct {
turnState string
expiresAt time.Time
}
type openAIWSSessionConnBinding struct {
connID string
expiresAt time.Time
}
// OpenAIWSStateStore 管理 WSv2 的粘连状态。
// - response_id -> account_id 用于续链路由
// - response_id -> conn_id 用于连接内上下文复用
//
// response_id -> account_id 优先走 GatewayCacheRedis同时维护本地热缓存。
// response_id -> conn_id 仅在本进程内有效。
type OpenAIWSStateStore interface {
BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error
GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error)
DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error
BindResponseConn(responseID, connID string, ttl time.Duration)
GetResponseConn(responseID string) (string, bool)
DeleteResponseConn(responseID string)
BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration)
GetSessionTurnState(groupID int64, sessionHash string) (string, bool)
DeleteSessionTurnState(groupID int64, sessionHash string)
BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration)
GetSessionConn(groupID int64, sessionHash string) (string, bool)
DeleteSessionConn(groupID int64, sessionHash string)
}
type defaultOpenAIWSStateStore struct {
cache GatewayCache
responseToAccountMu sync.RWMutex
responseToAccount map[string]openAIWSAccountBinding
responseToConnMu sync.RWMutex
responseToConn map[string]openAIWSConnBinding
sessionToTurnStateMu sync.RWMutex
sessionToTurnState map[string]openAIWSTurnStateBinding
sessionToConnMu sync.RWMutex
sessionToConn map[string]openAIWSSessionConnBinding
lastCleanupUnixNano atomic.Int64
}
// NewOpenAIWSStateStore 创建默认 WS 状态存储。
func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore {
store := &defaultOpenAIWSStateStore{
cache: cache,
responseToAccount: make(map[string]openAIWSAccountBinding, 256),
responseToConn: make(map[string]openAIWSConnBinding, 256),
sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256),
sessionToConn: make(map[string]openAIWSSessionConnBinding, 256),
}
store.lastCleanupUnixNano.Store(time.Now().UnixNano())
return store
}
func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" || accountID <= 0 {
return nil
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
expiresAt := time.Now().Add(ttl)
s.responseToAccountMu.Lock()
ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap)
s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt}
s.responseToAccountMu.Unlock()
if s.cache == nil {
return nil
}
cacheKey := openAIWSResponseAccountCacheKey(id)
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
defer cancel()
return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl)
}
func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return 0, nil
}
s.maybeCleanup()
now := time.Now()
s.responseToAccountMu.RLock()
if binding, ok := s.responseToAccount[id]; ok {
if now.Before(binding.expiresAt) {
accountID := binding.accountID
s.responseToAccountMu.RUnlock()
return accountID, nil
}
}
s.responseToAccountMu.RUnlock()
if s.cache == nil {
return 0, nil
}
cacheKey := openAIWSResponseAccountCacheKey(id)
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
defer cancel()
accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey)
if err != nil || accountID <= 0 {
// 缓存读取失败不阻断主流程,按未命中降级。
return 0, nil
}
return accountID, nil
}
func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return nil
}
s.responseToAccountMu.Lock()
delete(s.responseToAccount, id)
s.responseToAccountMu.Unlock()
if s.cache == nil {
return nil
}
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
defer cancel()
return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id))
}
func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) {
id := normalizeOpenAIWSResponseID(responseID)
conn := strings.TrimSpace(connID)
if id == "" || conn == "" {
return
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
s.responseToConnMu.Lock()
ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap)
s.responseToConn[id] = openAIWSConnBinding{
connID: conn,
expiresAt: time.Now().Add(ttl),
}
s.responseToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return "", false
}
s.maybeCleanup()
now := time.Now()
s.responseToConnMu.RLock()
binding, ok := s.responseToConn[id]
s.responseToConnMu.RUnlock()
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
return "", false
}
return binding.connID, true
}
func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" {
return
}
s.responseToConnMu.Lock()
delete(s.responseToConn, id)
s.responseToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
state := strings.TrimSpace(turnState)
if key == "" || state == "" {
return
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
s.sessionToTurnStateMu.Lock()
ensureBindingCapacity(s.sessionToTurnState, key, openAIWSStateStoreMaxEntriesPerMap)
s.sessionToTurnState[key] = openAIWSTurnStateBinding{
turnState: state,
expiresAt: time.Now().Add(ttl),
}
s.sessionToTurnStateMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return "", false
}
s.maybeCleanup()
now := time.Now()
s.sessionToTurnStateMu.RLock()
binding, ok := s.sessionToTurnState[key]
s.sessionToTurnStateMu.RUnlock()
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.turnState) == "" {
return "", false
}
return binding.turnState, true
}
func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessionHash string) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return
}
s.sessionToTurnStateMu.Lock()
delete(s.sessionToTurnState, key)
s.sessionToTurnStateMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
conn := strings.TrimSpace(connID)
if key == "" || conn == "" {
return
}
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
s.sessionToConnMu.Lock()
ensureBindingCapacity(s.sessionToConn, key, openAIWSStateStoreMaxEntriesPerMap)
s.sessionToConn[key] = openAIWSSessionConnBinding{
connID: conn,
expiresAt: time.Now().Add(ttl),
}
s.sessionToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash string) (string, bool) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return "", false
}
s.maybeCleanup()
now := time.Now()
s.sessionToConnMu.RLock()
binding, ok := s.sessionToConn[key]
s.sessionToConnMu.RUnlock()
if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
return "", false
}
return binding.connID, true
}
func (s *defaultOpenAIWSStateStore) DeleteSessionConn(groupID int64, sessionHash string) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
if key == "" {
return
}
s.sessionToConnMu.Lock()
delete(s.sessionToConn, key)
s.sessionToConnMu.Unlock()
}
func (s *defaultOpenAIWSStateStore) maybeCleanup() {
if s == nil {
return
}
now := time.Now()
last := time.Unix(0, s.lastCleanupUnixNano.Load())
if now.Sub(last) < openAIWSStateStoreCleanupInterval {
return
}
if !s.lastCleanupUnixNano.CompareAndSwap(last.UnixNano(), now.UnixNano()) {
return
}
// 增量限额清理,避免高规模下一次性全量扫描导致长时间阻塞。
s.responseToAccountMu.Lock()
cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap)
s.responseToAccountMu.Unlock()
s.responseToConnMu.Lock()
cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap)
s.responseToConnMu.Unlock()
s.sessionToTurnStateMu.Lock()
cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap)
s.sessionToTurnStateMu.Unlock()
s.sessionToConnMu.Lock()
cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap)
s.sessionToConnMu.Unlock()
}
func cleanupExpiredAccountBindings(bindings map[string]openAIWSAccountBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
}
scanned := 0
for key, binding := range bindings {
if now.After(binding.expiresAt) {
delete(bindings, key)
}
scanned++
if scanned >= maxScan {
break
}
}
}
func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) {
if len(bindings) < maxEntries || maxEntries <= 0 {
return
}
if _, exists := bindings[incomingKey]; exists {
return
}
// 固定上限保护:淘汰任意一项,优先保证内存有界。
for key := range bindings {
delete(bindings, key)
return
}
}
func normalizeOpenAIWSResponseID(responseID string) string {
return strings.TrimSpace(responseID)
}
func openAIWSResponseAccountCacheKey(responseID string) string {
sum := sha256.Sum256([]byte(responseID))
return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:])
}
func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration {
if ttl <= 0 {
return time.Hour
}
return ttl
}
func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string {
hash := strings.TrimSpace(sessionHash)
if hash == "" {
return ""
}
return fmt.Sprintf("%d:%s", groupID, hash)
}
func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = context.Background()
}
return context.WithTimeout(ctx, openAIWSStateStoreRedisTimeout)
}

View File

@@ -0,0 +1,235 @@
package service
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestOpenAIWSStateStore_BindGetDeleteResponseAccount(t *testing.T) {
cache := &stubGatewayCache{}
store := NewOpenAIWSStateStore(cache)
ctx := context.Background()
groupID := int64(7)
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_abc", 101, time.Minute))
accountID, err := store.GetResponseAccount(ctx, groupID, "resp_abc")
require.NoError(t, err)
require.Equal(t, int64(101), accountID)
require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_abc"))
accountID, err = store.GetResponseAccount(ctx, groupID, "resp_abc")
require.NoError(t, err)
require.Zero(t, accountID)
}
func TestOpenAIWSStateStore_ResponseConnTTL(t *testing.T) {
store := NewOpenAIWSStateStore(nil)
store.BindResponseConn("resp_conn", "conn_1", 30*time.Millisecond)
connID, ok := store.GetResponseConn("resp_conn")
require.True(t, ok)
require.Equal(t, "conn_1", connID)
time.Sleep(60 * time.Millisecond)
_, ok = store.GetResponseConn("resp_conn")
require.False(t, ok)
}
func TestOpenAIWSStateStore_SessionTurnStateTTL(t *testing.T) {
store := NewOpenAIWSStateStore(nil)
store.BindSessionTurnState(9, "session_hash_1", "turn_state_1", 30*time.Millisecond)
state, ok := store.GetSessionTurnState(9, "session_hash_1")
require.True(t, ok)
require.Equal(t, "turn_state_1", state)
// group 隔离
_, ok = store.GetSessionTurnState(10, "session_hash_1")
require.False(t, ok)
time.Sleep(60 * time.Millisecond)
_, ok = store.GetSessionTurnState(9, "session_hash_1")
require.False(t, ok)
}
func TestOpenAIWSStateStore_SessionConnTTL(t *testing.T) {
store := NewOpenAIWSStateStore(nil)
store.BindSessionConn(9, "session_hash_conn_1", "conn_1", 30*time.Millisecond)
connID, ok := store.GetSessionConn(9, "session_hash_conn_1")
require.True(t, ok)
require.Equal(t, "conn_1", connID)
// group 隔离
_, ok = store.GetSessionConn(10, "session_hash_conn_1")
require.False(t, ok)
time.Sleep(60 * time.Millisecond)
_, ok = store.GetSessionConn(9, "session_hash_conn_1")
require.False(t, ok)
}
func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing.T) {
cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
store := NewOpenAIWSStateStore(cache)
ctx := context.Background()
groupID := int64(17)
responseID := "resp_cache_stale"
cacheKey := openAIWSResponseAccountCacheKey(responseID)
cache.sessionBindings[cacheKey] = 501
accountID, err := store.GetResponseAccount(ctx, groupID, responseID)
require.NoError(t, err)
require.Equal(t, int64(501), accountID)
delete(cache.sessionBindings, cacheKey)
accountID, err = store.GetResponseAccount(ctx, groupID, responseID)
require.NoError(t, err)
require.Zero(t, accountID, "上游缓存失效后不应继续命中本地陈旧映射")
}
func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T) {
raw := NewOpenAIWSStateStore(nil)
store, ok := raw.(*defaultOpenAIWSStateStore)
require.True(t, ok)
expiredAt := time.Now().Add(-time.Minute)
total := 2048
store.responseToConnMu.Lock()
for i := 0; i < total; i++ {
store.responseToConn[fmt.Sprintf("resp_%d", i)] = openAIWSConnBinding{
connID: "conn_incremental",
expiresAt: expiredAt,
}
}
store.responseToConnMu.Unlock()
store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano())
store.maybeCleanup()
store.responseToConnMu.RLock()
remainingAfterFirst := len(store.responseToConn)
store.responseToConnMu.RUnlock()
require.Less(t, remainingAfterFirst, total, "单轮 cleanup 应至少有进展")
require.Greater(t, remainingAfterFirst, 0, "增量清理不要求单轮清空全部键")
for i := 0; i < 8; i++ {
store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano())
store.maybeCleanup()
}
store.responseToConnMu.RLock()
remaining := len(store.responseToConn)
store.responseToConnMu.RUnlock()
require.Zero(t, remaining, "多轮 cleanup 后应逐步清空全部过期键")
}
func TestEnsureBindingCapacity_EvictsOneWhenMapIsFull(t *testing.T) {
bindings := map[string]int{
"a": 1,
"b": 2,
}
ensureBindingCapacity(bindings, "c", 2)
bindings["c"] = 3
require.Len(t, bindings, 2)
require.Equal(t, 3, bindings["c"])
}
func TestEnsureBindingCapacity_DoesNotEvictWhenUpdatingExistingKey(t *testing.T) {
bindings := map[string]int{
"a": 1,
"b": 2,
}
ensureBindingCapacity(bindings, "a", 2)
bindings["a"] = 9
require.Len(t, bindings, 2)
require.Equal(t, 9, bindings["a"])
}
type openAIWSStateStoreTimeoutProbeCache struct {
setHasDeadline bool
getHasDeadline bool
deleteHasDeadline bool
setDeadlineDelta time.Duration
getDeadlineDelta time.Duration
delDeadlineDelta time.Duration
}
func (c *openAIWSStateStoreTimeoutProbeCache) GetSessionAccountID(ctx context.Context, _ int64, _ string) (int64, error) {
if deadline, ok := ctx.Deadline(); ok {
c.getHasDeadline = true
c.getDeadlineDelta = time.Until(deadline)
}
return 123, nil
}
func (c *openAIWSStateStoreTimeoutProbeCache) SetSessionAccountID(ctx context.Context, _ int64, _ string, _ int64, _ time.Duration) error {
if deadline, ok := ctx.Deadline(); ok {
c.setHasDeadline = true
c.setDeadlineDelta = time.Until(deadline)
}
return errors.New("set failed")
}
func (c *openAIWSStateStoreTimeoutProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error {
return nil
}
func (c *openAIWSStateStoreTimeoutProbeCache) DeleteSessionAccountID(ctx context.Context, _ int64, _ string) error {
if deadline, ok := ctx.Deadline(); ok {
c.deleteHasDeadline = true
c.delDeadlineDelta = time.Until(deadline)
}
return nil
}
func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) {
probe := &openAIWSStateStoreTimeoutProbeCache{}
store := NewOpenAIWSStateStore(probe)
ctx := context.Background()
groupID := int64(5)
err := store.BindResponseAccount(ctx, groupID, "resp_timeout_probe", 11, time.Minute)
require.Error(t, err)
accountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_timeout_probe")
require.NoError(t, getErr)
require.Equal(t, int64(11), accountID, "本地缓存命中应优先返回已绑定账号")
require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_timeout_probe"))
require.True(t, probe.setHasDeadline, "SetSessionAccountID 应携带独立超时上下文")
require.True(t, probe.deleteHasDeadline, "DeleteSessionAccountID 应携带独立超时上下文")
require.False(t, probe.getHasDeadline, "GetSessionAccountID 本用例应由本地缓存命中,不触发 Redis 读取")
require.Greater(t, probe.setDeadlineDelta, 2*time.Second)
require.LessOrEqual(t, probe.setDeadlineDelta, 3*time.Second)
require.Greater(t, probe.delDeadlineDelta, 2*time.Second)
require.LessOrEqual(t, probe.delDeadlineDelta, 3*time.Second)
probe2 := &openAIWSStateStoreTimeoutProbeCache{}
store2 := NewOpenAIWSStateStore(probe2)
accountID2, err2 := store2.GetResponseAccount(ctx, groupID, "resp_cache_only")
require.NoError(t, err2)
require.Equal(t, int64(123), accountID2)
require.True(t, probe2.getHasDeadline, "GetSessionAccountID 在缓存未命中时应携带独立超时上下文")
require.Greater(t, probe2.getDeadlineDelta, 2*time.Second)
require.LessOrEqual(t, probe2.getDeadlineDelta, 3*time.Second)
}
func TestWithOpenAIWSStateStoreRedisTimeout_WithParentContext(t *testing.T) {
ctx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
defer cancel()
require.NotNil(t, ctx)
_, ok := ctx.Deadline()
require.True(t, ok, "应附加短超时")
}

View File

@@ -13,7 +13,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
"github.com/lib/pq"
@@ -480,7 +479,7 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq
attemptCtx := ctx
if switches > 0 {
attemptCtx = context.WithValue(attemptCtx, ctxkey.AccountSwitchCount, switches)
attemptCtx = WithAccountSwitchCount(attemptCtx, switches, false)
}
exec := func() *opsRetryExecution {
defer selection.ReleaseFunc()
@@ -675,6 +674,7 @@ func newOpsRetryContext(ctx context.Context, errorLog *OpsErrorLogDetail) (*gin.
}
c.Request = req
SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
return c, w
}

View File

@@ -0,0 +1,47 @@
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestNewOpsRetryContext_SetsHTTPTransportAndRequestHeaders(t *testing.T) {
errorLog := &OpsErrorLogDetail{
OpsErrorLog: OpsErrorLog{
RequestPath: "/openai/v1/responses",
},
UserAgent: "ops-retry-agent/1.0",
RequestHeaders: `{
"anthropic-beta":"beta-v1",
"ANTHROPIC-VERSION":"2023-06-01",
"authorization":"Bearer should-not-forward"
}`,
}
c, w := newOpsRetryContext(context.Background(), errorLog)
require.NotNil(t, c)
require.NotNil(t, w)
require.NotNil(t, c.Request)
require.Equal(t, "/openai/v1/responses", c.Request.URL.Path)
require.Equal(t, "application/json", c.Request.Header.Get("Content-Type"))
require.Equal(t, "ops-retry-agent/1.0", c.Request.Header.Get("User-Agent"))
require.Equal(t, "beta-v1", c.Request.Header.Get("anthropic-beta"))
require.Equal(t, "2023-06-01", c.Request.Header.Get("anthropic-version"))
require.Empty(t, c.Request.Header.Get("authorization"), "未在白名单内的敏感头不应被重放")
require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c))
}
func TestNewOpsRetryContext_InvalidHeadersJSONStillSetsHTTPTransport(t *testing.T) {
errorLog := &OpsErrorLogDetail{
RequestHeaders: "{invalid-json",
}
c, _ := newOpsRetryContext(context.Background(), errorLog)
require.NotNil(t, c)
require.NotNil(t, c.Request)
require.Equal(t, "/", c.Request.URL.Path)
require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c))
}

View File

@@ -27,6 +27,11 @@ const (
OpsUpstreamLatencyMsKey = "ops_upstream_latency_ms"
OpsResponseLatencyMsKey = "ops_response_latency_ms"
OpsTimeToFirstTokenMsKey = "ops_time_to_first_token_ms"
// OpenAI WS 关键观测字段
OpsOpenAIWSQueueWaitMsKey = "ops_openai_ws_queue_wait_ms"
OpsOpenAIWSConnPickMsKey = "ops_openai_ws_conn_pick_ms"
OpsOpenAIWSConnReusedKey = "ops_openai_ws_conn_reused"
OpsOpenAIWSConnIDKey = "ops_openai_ws_conn_id"
// OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。
// ops_error_logger 中间件检查此 key为 true 时跳过错误记录。

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// RateLimitService 处理限流和过载状态管理
@@ -33,6 +34,10 @@ type geminiUsageCacheEntry struct {
totals GeminiUsageTotals
}
type geminiUsageTotalsBatchProvider interface {
GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]GeminiUsageTotals, error)
}
const geminiPrecheckCacheTTL = time.Minute
// NewRateLimitService 创建RateLimitService实例
@@ -162,6 +167,17 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
}
logger.LegacyPrintf(
"service.ratelimit",
"[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s",
account.ID,
account.Platform,
account.Type,
strings.TrimSpace(headers.Get("x-request-id")),
strings.TrimSpace(headers.Get("cf-ray")),
upstreamMsg,
truncateForLog(responseBody, 1024),
)
s.handleAuthError(ctx, account, msg)
shouldDisable = true
case 429:
@@ -225,7 +241,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil {
return true, err
}
@@ -272,7 +288,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if limit > 0 {
start := now.Truncate(time.Minute)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil {
return true, err
}
@@ -302,6 +318,218 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
return true, nil
}
// PreCheckUsageBatch performs quota precheck for multiple accounts in one request.
// Returned map value=false means the account should be skipped.
func (s *RateLimitService) PreCheckUsageBatch(ctx context.Context, accounts []*Account, requestedModel string) (map[int64]bool, error) {
result := make(map[int64]bool, len(accounts))
for _, account := range accounts {
if account == nil {
continue
}
result[account.ID] = true
}
if len(accounts) == 0 || requestedModel == "" {
return result, nil
}
if s.usageRepo == nil || s.geminiQuotaService == nil {
return result, nil
}
modelClass := geminiModelClassFromName(requestedModel)
now := time.Now()
dailyStart := geminiDailyWindowStart(now)
minuteStart := now.Truncate(time.Minute)
type quotaAccount struct {
account *Account
quota GeminiQuota
}
quotaAccounts := make([]quotaAccount, 0, len(accounts))
for _, account := range accounts {
if account == nil || account.Platform != PlatformGemini {
continue
}
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
if !ok {
continue
}
quotaAccounts = append(quotaAccounts, quotaAccount{
account: account,
quota: quota,
})
}
if len(quotaAccounts) == 0 {
return result, nil
}
// 1) Daily precheck (cached + batch DB fallback)
dailyTotalsByID := make(map[int64]GeminiUsageTotals, len(quotaAccounts))
dailyMissIDs := make([]int64, 0, len(quotaAccounts))
for _, item := range quotaAccounts {
limit := geminiDailyLimit(item.quota, modelClass)
if limit <= 0 {
continue
}
accountID := item.account.ID
if totals, ok := s.getGeminiUsageTotals(accountID, dailyStart, now); ok {
dailyTotalsByID[accountID] = totals
continue
}
dailyMissIDs = append(dailyMissIDs, accountID)
}
if len(dailyMissIDs) > 0 {
totalsBatch, err := s.getGeminiUsageTotalsBatch(ctx, dailyMissIDs, dailyStart, now)
if err != nil {
return result, err
}
for _, accountID := range dailyMissIDs {
totals := totalsBatch[accountID]
dailyTotalsByID[accountID] = totals
s.setGeminiUsageTotals(accountID, dailyStart, now, totals)
}
}
for _, item := range quotaAccounts {
limit := geminiDailyLimit(item.quota, modelClass)
if limit <= 0 {
continue
}
accountID := item.account.ID
used := geminiUsedRequests(item.quota, modelClass, dailyTotalsByID[accountID], true)
if used >= limit {
resetAt := geminiDailyResetTime(now)
slog.Info("gemini_precheck_daily_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt)
result[accountID] = false
}
}
// 2) Minute precheck (batch DB)
minuteIDs := make([]int64, 0, len(quotaAccounts))
for _, item := range quotaAccounts {
accountID := item.account.ID
if !result[accountID] {
continue
}
if geminiMinuteLimit(item.quota, modelClass) <= 0 {
continue
}
minuteIDs = append(minuteIDs, accountID)
}
if len(minuteIDs) == 0 {
return result, nil
}
minuteTotalsByID, err := s.getGeminiUsageTotalsBatch(ctx, minuteIDs, minuteStart, now)
if err != nil {
return result, err
}
for _, item := range quotaAccounts {
accountID := item.account.ID
if !result[accountID] {
continue
}
limit := geminiMinuteLimit(item.quota, modelClass)
if limit <= 0 {
continue
}
used := geminiUsedRequests(item.quota, modelClass, minuteTotalsByID[accountID], false)
if used >= limit {
resetAt := minuteStart.Add(time.Minute)
slog.Info("gemini_precheck_minute_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt)
result[accountID] = false
}
}
return result, nil
}
func (s *RateLimitService) getGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, start, end time.Time) (map[int64]GeminiUsageTotals, error) {
result := make(map[int64]GeminiUsageTotals, len(accountIDs))
if len(accountIDs) == 0 {
return result, nil
}
ids := make([]int64, 0, len(accountIDs))
seen := make(map[int64]struct{}, len(accountIDs))
for _, accountID := range accountIDs {
if accountID <= 0 {
continue
}
if _, ok := seen[accountID]; ok {
continue
}
seen[accountID] = struct{}{}
ids = append(ids, accountID)
}
if len(ids) == 0 {
return result, nil
}
if batchReader, ok := s.usageRepo.(geminiUsageTotalsBatchProvider); ok {
stats, err := batchReader.GetGeminiUsageTotalsBatch(ctx, ids, start, end)
if err != nil {
return nil, err
}
for _, accountID := range ids {
result[accountID] = stats[accountID]
}
return result, nil
}
for _, accountID := range ids {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, end, 0, 0, accountID, 0, nil, nil, nil)
if err != nil {
return nil, err
}
result[accountID] = geminiAggregateUsage(stats)
}
return result, nil
}
func geminiDailyLimit(quota GeminiQuota, modelClass geminiModelClass) int64 {
if quota.SharedRPD > 0 {
return quota.SharedRPD
}
switch modelClass {
case geminiModelFlash:
return quota.FlashRPD
default:
return quota.ProRPD
}
}
func geminiMinuteLimit(quota GeminiQuota, modelClass geminiModelClass) int64 {
if quota.SharedRPM > 0 {
return quota.SharedRPM
}
switch modelClass {
case geminiModelFlash:
return quota.FlashRPM
default:
return quota.ProRPM
}
}
func geminiUsedRequests(quota GeminiQuota, modelClass geminiModelClass, totals GeminiUsageTotals, daily bool) int64 {
if daily {
if quota.SharedRPD > 0 {
return totals.ProRequests + totals.FlashRequests
}
} else {
if quota.SharedRPM > 0 {
return totals.ProRequests + totals.FlashRequests
}
}
switch modelClass {
case geminiModelFlash:
return totals.FlashRequests
default:
return totals.ProRequests
}
}
func (s *RateLimitService) getGeminiUsageTotals(accountID int64, windowStart, now time.Time) (GeminiUsageTotals, bool) {
s.usageCacheMu.RLock()
defer s.usageCacheMu.RUnlock()

View File

@@ -0,0 +1,216 @@
package service
import (
"context"
"sync/atomic"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
)
type requestMetadataContextKey struct{}
var requestMetadataKey = requestMetadataContextKey{}
type RequestMetadata struct {
IsMaxTokensOneHaikuRequest *bool
ThinkingEnabled *bool
PrefetchedStickyAccountID *int64
PrefetchedStickyGroupID *int64
SingleAccountRetry *bool
AccountSwitchCount *int
}
var (
requestMetadataFallbackIsMaxTokensOneHaikuTotal atomic.Int64
requestMetadataFallbackThinkingEnabledTotal atomic.Int64
requestMetadataFallbackPrefetchedStickyAccount atomic.Int64
requestMetadataFallbackPrefetchedStickyGroup atomic.Int64
requestMetadataFallbackSingleAccountRetryTotal atomic.Int64
requestMetadataFallbackAccountSwitchCountTotal atomic.Int64
)
func RequestMetadataFallbackStats() (isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount int64) {
return requestMetadataFallbackIsMaxTokensOneHaikuTotal.Load(),
requestMetadataFallbackThinkingEnabledTotal.Load(),
requestMetadataFallbackPrefetchedStickyAccount.Load(),
requestMetadataFallbackPrefetchedStickyGroup.Load(),
requestMetadataFallbackSingleAccountRetryTotal.Load(),
requestMetadataFallbackAccountSwitchCountTotal.Load()
}
func metadataFromContext(ctx context.Context) *RequestMetadata {
if ctx == nil {
return nil
}
md, _ := ctx.Value(requestMetadataKey).(*RequestMetadata)
return md
}
func updateRequestMetadata(
ctx context.Context,
bridgeOldKeys bool,
update func(md *RequestMetadata),
legacyBridge func(ctx context.Context) context.Context,
) context.Context {
if ctx == nil {
return nil
}
current := metadataFromContext(ctx)
next := &RequestMetadata{}
if current != nil {
*next = *current
}
update(next)
ctx = context.WithValue(ctx, requestMetadataKey, next)
if bridgeOldKeys && legacyBridge != nil {
ctx = legacyBridge(ctx)
}
return ctx
}
func WithIsMaxTokensOneHaikuRequest(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
v := value
md.IsMaxTokensOneHaikuRequest = &v
}, func(base context.Context) context.Context {
return context.WithValue(base, ctxkey.IsMaxTokensOneHaikuRequest, value)
})
}
func WithThinkingEnabled(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
v := value
md.ThinkingEnabled = &v
}, func(base context.Context) context.Context {
return context.WithValue(base, ctxkey.ThinkingEnabled, value)
})
}
func WithPrefetchedStickySession(ctx context.Context, accountID, groupID int64, bridgeOldKeys bool) context.Context {
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
account := accountID
group := groupID
md.PrefetchedStickyAccountID = &account
md.PrefetchedStickyGroupID = &group
}, func(base context.Context) context.Context {
bridged := context.WithValue(base, ctxkey.PrefetchedStickyAccountID, accountID)
return context.WithValue(bridged, ctxkey.PrefetchedStickyGroupID, groupID)
})
}
func WithSingleAccountRetry(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
v := value
md.SingleAccountRetry = &v
}, func(base context.Context) context.Context {
return context.WithValue(base, ctxkey.SingleAccountRetry, value)
})
}
func WithAccountSwitchCount(ctx context.Context, value int, bridgeOldKeys bool) context.Context {
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
v := value
md.AccountSwitchCount = &v
}, func(base context.Context) context.Context {
return context.WithValue(base, ctxkey.AccountSwitchCount, value)
})
}
func IsMaxTokensOneHaikuRequestFromContext(ctx context.Context) (bool, bool) {
if md := metadataFromContext(ctx); md != nil && md.IsMaxTokensOneHaikuRequest != nil {
return *md.IsMaxTokensOneHaikuRequest, true
}
if ctx == nil {
return false, false
}
if value, ok := ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok {
requestMetadataFallbackIsMaxTokensOneHaikuTotal.Add(1)
return value, true
}
return false, false
}
func ThinkingEnabledFromContext(ctx context.Context) (bool, bool) {
if md := metadataFromContext(ctx); md != nil && md.ThinkingEnabled != nil {
return *md.ThinkingEnabled, true
}
if ctx == nil {
return false, false
}
if value, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
requestMetadataFallbackThinkingEnabledTotal.Add(1)
return value, true
}
return false, false
}
func PrefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) {
if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyGroupID != nil {
return *md.PrefetchedStickyGroupID, true
}
if ctx == nil {
return 0, false
}
v := ctx.Value(ctxkey.PrefetchedStickyGroupID)
switch t := v.(type) {
case int64:
requestMetadataFallbackPrefetchedStickyGroup.Add(1)
return t, true
case int:
requestMetadataFallbackPrefetchedStickyGroup.Add(1)
return int64(t), true
}
return 0, false
}
func PrefetchedStickyAccountIDFromContext(ctx context.Context) (int64, bool) {
if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyAccountID != nil {
return *md.PrefetchedStickyAccountID, true
}
if ctx == nil {
return 0, false
}
v := ctx.Value(ctxkey.PrefetchedStickyAccountID)
switch t := v.(type) {
case int64:
requestMetadataFallbackPrefetchedStickyAccount.Add(1)
return t, true
case int:
requestMetadataFallbackPrefetchedStickyAccount.Add(1)
return int64(t), true
}
return 0, false
}
func SingleAccountRetryFromContext(ctx context.Context) (bool, bool) {
if md := metadataFromContext(ctx); md != nil && md.SingleAccountRetry != nil {
return *md.SingleAccountRetry, true
}
if ctx == nil {
return false, false
}
if value, ok := ctx.Value(ctxkey.SingleAccountRetry).(bool); ok {
requestMetadataFallbackSingleAccountRetryTotal.Add(1)
return value, true
}
return false, false
}
func AccountSwitchCountFromContext(ctx context.Context) (int, bool) {
if md := metadataFromContext(ctx); md != nil && md.AccountSwitchCount != nil {
return *md.AccountSwitchCount, true
}
if ctx == nil {
return 0, false
}
v := ctx.Value(ctxkey.AccountSwitchCount)
switch t := v.(type) {
case int:
requestMetadataFallbackAccountSwitchCountTotal.Add(1)
return t, true
case int64:
requestMetadataFallbackAccountSwitchCountTotal.Add(1)
return int(t), true
}
return 0, false
}

View File

@@ -0,0 +1,119 @@
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func TestRequestMetadataWriteAndRead_NoBridge(t *testing.T) {
ctx := context.Background()
ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, false)
ctx = WithThinkingEnabled(ctx, true, false)
ctx = WithPrefetchedStickySession(ctx, 123, 456, false)
ctx = WithSingleAccountRetry(ctx, true, false)
ctx = WithAccountSwitchCount(ctx, 2, false)
isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx)
require.True(t, ok)
require.True(t, isHaiku)
thinking, ok := ThinkingEnabledFromContext(ctx)
require.True(t, ok)
require.True(t, thinking)
accountID, ok := PrefetchedStickyAccountIDFromContext(ctx)
require.True(t, ok)
require.Equal(t, int64(123), accountID)
groupID, ok := PrefetchedStickyGroupIDFromContext(ctx)
require.True(t, ok)
require.Equal(t, int64(456), groupID)
singleRetry, ok := SingleAccountRetryFromContext(ctx)
require.True(t, ok)
require.True(t, singleRetry)
switchCount, ok := AccountSwitchCountFromContext(ctx)
require.True(t, ok)
require.Equal(t, 2, switchCount)
require.Nil(t, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest))
require.Nil(t, ctx.Value(ctxkey.ThinkingEnabled))
require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyAccountID))
require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyGroupID))
require.Nil(t, ctx.Value(ctxkey.SingleAccountRetry))
require.Nil(t, ctx.Value(ctxkey.AccountSwitchCount))
}
func TestRequestMetadataWrite_BridgeLegacyKeys(t *testing.T) {
ctx := context.Background()
ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, true)
ctx = WithThinkingEnabled(ctx, true, true)
ctx = WithPrefetchedStickySession(ctx, 123, 456, true)
ctx = WithSingleAccountRetry(ctx, true, true)
ctx = WithAccountSwitchCount(ctx, 2, true)
require.Equal(t, true, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest))
require.Equal(t, true, ctx.Value(ctxkey.ThinkingEnabled))
require.Equal(t, int64(123), ctx.Value(ctxkey.PrefetchedStickyAccountID))
require.Equal(t, int64(456), ctx.Value(ctxkey.PrefetchedStickyGroupID))
require.Equal(t, true, ctx.Value(ctxkey.SingleAccountRetry))
require.Equal(t, 2, ctx.Value(ctxkey.AccountSwitchCount))
}
func TestRequestMetadataRead_LegacyFallbackAndStats(t *testing.T) {
beforeHaiku, beforeThinking, beforeAccount, beforeGroup, beforeSingleRetry, beforeSwitchCount := RequestMetadataFallbackStats()
ctx := context.Background()
ctx = context.WithValue(ctx, ctxkey.IsMaxTokensOneHaikuRequest, true)
ctx = context.WithValue(ctx, ctxkey.ThinkingEnabled, true)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyAccountID, int64(321))
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(654))
ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true)
ctx = context.WithValue(ctx, ctxkey.AccountSwitchCount, int64(3))
isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx)
require.True(t, ok)
require.True(t, isHaiku)
thinking, ok := ThinkingEnabledFromContext(ctx)
require.True(t, ok)
require.True(t, thinking)
accountID, ok := PrefetchedStickyAccountIDFromContext(ctx)
require.True(t, ok)
require.Equal(t, int64(321), accountID)
groupID, ok := PrefetchedStickyGroupIDFromContext(ctx)
require.True(t, ok)
require.Equal(t, int64(654), groupID)
singleRetry, ok := SingleAccountRetryFromContext(ctx)
require.True(t, ok)
require.True(t, singleRetry)
switchCount, ok := AccountSwitchCountFromContext(ctx)
require.True(t, ok)
require.Equal(t, 3, switchCount)
afterHaiku, afterThinking, afterAccount, afterGroup, afterSingleRetry, afterSwitchCount := RequestMetadataFallbackStats()
require.Equal(t, beforeHaiku+1, afterHaiku)
require.Equal(t, beforeThinking+1, afterThinking)
require.Equal(t, beforeAccount+1, afterAccount)
require.Equal(t, beforeGroup+1, afterGroup)
require.Equal(t, beforeSingleRetry+1, afterSingleRetry)
require.Equal(t, beforeSwitchCount+1, afterSwitchCount)
}
func TestRequestMetadataRead_PreferMetadataOverLegacy(t *testing.T) {
ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, false)
ctx = WithThinkingEnabled(ctx, true, false)
thinking, ok := ThinkingEnabledFromContext(ctx)
require.True(t, ok)
require.True(t, thinking)
require.Equal(t, false, ctx.Value(ctxkey.ThinkingEnabled))
}

View File

@@ -0,0 +1,13 @@
package service
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
)
func compileResponseHeaderFilter(cfg *config.Config) *responseheaders.CompiledHeaderFilter {
if cfg == nil {
return nil
}
return responseheaders.CompileHeaderFilter(cfg.Security.ResponseHeaders)
}

View File

@@ -305,13 +305,78 @@ func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, p
if payload == nil {
return nil
}
ids := parseInt64Slice(payload["account_ids"])
for _, id := range ids {
if err := s.handleAccountEvent(ctx, &id, payload); err != nil {
return err
if s.accountRepo == nil {
return nil
}
rawIDs := parseInt64Slice(payload["account_ids"])
if len(rawIDs) == 0 {
return nil
}
ids := make([]int64, 0, len(rawIDs))
seen := make(map[int64]struct{}, len(rawIDs))
for _, id := range rawIDs {
if id <= 0 {
continue
}
if _, exists := seen[id]; exists {
continue
}
seen[id] = struct{}{}
ids = append(ids, id)
}
if len(ids) == 0 {
return nil
}
preloadGroupIDs := parseInt64Slice(payload["group_ids"])
accounts, err := s.accountRepo.GetByIDs(ctx, ids)
if err != nil {
return err
}
found := make(map[int64]struct{}, len(accounts))
rebuildGroupSet := make(map[int64]struct{}, len(preloadGroupIDs))
for _, gid := range preloadGroupIDs {
if gid > 0 {
rebuildGroupSet[gid] = struct{}{}
}
}
return nil
for _, account := range accounts {
if account == nil || account.ID <= 0 {
continue
}
found[account.ID] = struct{}{}
if s.cache != nil {
if err := s.cache.SetAccount(ctx, account); err != nil {
return err
}
}
for _, gid := range account.GroupIDs {
if gid > 0 {
rebuildGroupSet[gid] = struct{}{}
}
}
}
if s.cache != nil {
for _, id := range ids {
if _, ok := found[id]; ok {
continue
}
if err := s.cache.DeleteAccount(ctx, id); err != nil {
return err
}
}
}
rebuildGroupIDs := make([]int64, 0, len(rebuildGroupSet))
for gid := range rebuildGroupSet {
rebuildGroupIDs = append(rebuildGroupIDs, gid)
}
return s.rebuildByGroupIDs(ctx, rebuildGroupIDs, "account_bulk_change")
}
func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any) error {

View File

@@ -9,14 +9,17 @@ import (
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var (
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
)
type SettingRepository interface {
@@ -34,6 +37,7 @@ type SettingService struct {
settingRepo SettingRepository
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
onS3Update func() // Callback when Sora S3 settings are updated
version string // Application version
}
@@ -76,6 +80,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyHideCcsImportButton,
SettingKeyPurchaseSubscriptionEnabled,
SettingKeyPurchaseSubscriptionURL,
SettingKeySoraClientEnabled,
SettingKeyLinuxDoConnectEnabled,
}
@@ -114,6 +119,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
LinuxDoOAuthEnabled: linuxDoEnabled,
}, nil
}
@@ -124,6 +130,11 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) {
s.onUpdate = callback
}
// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。
func (s *SettingService) SetOnS3UpdateCallback(callback func()) {
s.onS3Update = callback
}
// SetVersion sets the application version for injection into public settings
func (s *SettingService) SetVersion(version string) {
s.version = version
@@ -157,6 +168,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
SoraClientEnabled bool `json:"sora_client_enabled"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version,omitempty"`
}{
@@ -178,6 +190,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: s.version,
}, nil
@@ -232,6 +245,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled)
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
@@ -383,6 +397,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyPurchaseSubscriptionURL: "",
SettingKeySoraClientEnabled: "false",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySMTPPort: "587",
@@ -436,6 +451,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
}
// 解析整数类型
@@ -854,3 +870,607 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data))
}
type soraS3ProfilesStore struct {
ActiveProfileID string `json:"active_profile_id"`
Items []soraS3ProfileStoreItem `json:"items"`
}
type soraS3ProfileStoreItem struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
UpdatedAt string `json:"updated_at"`
}
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置)
func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
profiles, err := s.ListSoraS3Profiles(ctx)
if err != nil {
return nil, err
}
activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
if activeProfile == nil {
return &SoraS3Settings{}, nil
}
return &SoraS3Settings{
Enabled: activeProfile.Enabled,
Endpoint: activeProfile.Endpoint,
Region: activeProfile.Region,
Bucket: activeProfile.Bucket,
AccessKeyID: activeProfile.AccessKeyID,
SecretAccessKey: activeProfile.SecretAccessKey,
SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured,
Prefix: activeProfile.Prefix,
ForcePathStyle: activeProfile.ForcePathStyle,
CDNURL: activeProfile.CDNURL,
DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes,
}, nil
}
// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置)
func (s *SettingService) SetSoraS3Settings(ctx context.Context, settings *SoraS3Settings) error {
if settings == nil {
return fmt.Errorf("settings cannot be nil")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return err
}
now := time.Now().UTC().Format(time.RFC3339)
activeIndex := findSoraS3ProfileIndex(store.Items, store.ActiveProfileID)
if activeIndex < 0 {
activeID := "default"
if hasSoraS3ProfileID(store.Items, activeID) {
activeID = fmt.Sprintf("default-%d", time.Now().Unix())
}
store.Items = append(store.Items, soraS3ProfileStoreItem{
ProfileID: activeID,
Name: "Default",
UpdatedAt: now,
})
store.ActiveProfileID = activeID
activeIndex = len(store.Items) - 1
}
active := store.Items[activeIndex]
active.Enabled = settings.Enabled
active.Endpoint = strings.TrimSpace(settings.Endpoint)
active.Region = strings.TrimSpace(settings.Region)
active.Bucket = strings.TrimSpace(settings.Bucket)
active.AccessKeyID = strings.TrimSpace(settings.AccessKeyID)
active.Prefix = strings.TrimSpace(settings.Prefix)
active.ForcePathStyle = settings.ForcePathStyle
active.CDNURL = strings.TrimSpace(settings.CDNURL)
active.DefaultStorageQuotaBytes = maxInt64(settings.DefaultStorageQuotaBytes, 0)
if settings.SecretAccessKey != "" {
active.SecretAccessKey = settings.SecretAccessKey
}
active.UpdatedAt = now
store.Items[activeIndex] = active
return s.persistSoraS3ProfilesStore(ctx, store)
}
// ListSoraS3Profiles 获取 Sora S3 多配置列表
func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) {
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
return convertSoraS3ProfilesStore(store), nil
}
// CreateSoraS3Profile 创建 Sora S3 配置
func (s *SettingService) CreateSoraS3Profile(ctx context.Context, profile *SoraS3Profile, setActive bool) (*SoraS3Profile, error) {
if profile == nil {
return nil, fmt.Errorf("profile cannot be nil")
}
profileID := strings.TrimSpace(profile.ProfileID)
if profileID == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
name := strings.TrimSpace(profile.Name)
if name == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
if hasSoraS3ProfileID(store.Items, profileID) {
return nil, ErrSoraS3ProfileExists
}
now := time.Now().UTC().Format(time.RFC3339)
store.Items = append(store.Items, soraS3ProfileStoreItem{
ProfileID: profileID,
Name: name,
Enabled: profile.Enabled,
Endpoint: strings.TrimSpace(profile.Endpoint),
Region: strings.TrimSpace(profile.Region),
Bucket: strings.TrimSpace(profile.Bucket),
AccessKeyID: strings.TrimSpace(profile.AccessKeyID),
SecretAccessKey: profile.SecretAccessKey,
Prefix: strings.TrimSpace(profile.Prefix),
ForcePathStyle: profile.ForcePathStyle,
CDNURL: strings.TrimSpace(profile.CDNURL),
DefaultStorageQuotaBytes: maxInt64(profile.DefaultStorageQuotaBytes, 0),
UpdatedAt: now,
})
if setActive || store.ActiveProfileID == "" {
store.ActiveProfileID = profileID
}
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
return nil, err
}
profiles := convertSoraS3ProfilesStore(store)
created := findSoraS3ProfileByID(profiles.Items, profileID)
if created == nil {
return nil, ErrSoraS3ProfileNotFound
}
return created, nil
}
// UpdateSoraS3Profile 更新 Sora S3 配置
func (s *SettingService) UpdateSoraS3Profile(ctx context.Context, profileID string, profile *SoraS3Profile) (*SoraS3Profile, error) {
if profile == nil {
return nil, fmt.Errorf("profile cannot be nil")
}
targetID := strings.TrimSpace(profileID)
if targetID == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
if targetIndex < 0 {
return nil, ErrSoraS3ProfileNotFound
}
target := store.Items[targetIndex]
name := strings.TrimSpace(profile.Name)
if name == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
}
target.Name = name
target.Enabled = profile.Enabled
target.Endpoint = strings.TrimSpace(profile.Endpoint)
target.Region = strings.TrimSpace(profile.Region)
target.Bucket = strings.TrimSpace(profile.Bucket)
target.AccessKeyID = strings.TrimSpace(profile.AccessKeyID)
target.Prefix = strings.TrimSpace(profile.Prefix)
target.ForcePathStyle = profile.ForcePathStyle
target.CDNURL = strings.TrimSpace(profile.CDNURL)
target.DefaultStorageQuotaBytes = maxInt64(profile.DefaultStorageQuotaBytes, 0)
if profile.SecretAccessKey != "" {
target.SecretAccessKey = profile.SecretAccessKey
}
target.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
store.Items[targetIndex] = target
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
return nil, err
}
profiles := convertSoraS3ProfilesStore(store)
updated := findSoraS3ProfileByID(profiles.Items, targetID)
if updated == nil {
return nil, ErrSoraS3ProfileNotFound
}
return updated, nil
}
// DeleteSoraS3Profile 删除 Sora S3 配置
func (s *SettingService) DeleteSoraS3Profile(ctx context.Context, profileID string) error {
targetID := strings.TrimSpace(profileID)
if targetID == "" {
return infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return err
}
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
if targetIndex < 0 {
return ErrSoraS3ProfileNotFound
}
store.Items = append(store.Items[:targetIndex], store.Items[targetIndex+1:]...)
if store.ActiveProfileID == targetID {
store.ActiveProfileID = ""
if len(store.Items) > 0 {
store.ActiveProfileID = store.Items[0].ProfileID
}
}
return s.persistSoraS3ProfilesStore(ctx, store)
}
// SetActiveSoraS3Profile 设置激活的 Sora S3 配置
func (s *SettingService) SetActiveSoraS3Profile(ctx context.Context, profileID string) (*SoraS3Profile, error) {
targetID := strings.TrimSpace(profileID)
if targetID == "" {
return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
}
store, err := s.loadSoraS3ProfilesStore(ctx)
if err != nil {
return nil, err
}
targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
if targetIndex < 0 {
return nil, ErrSoraS3ProfileNotFound
}
store.ActiveProfileID = targetID
store.Items[targetIndex].UpdatedAt = time.Now().UTC().Format(time.RFC3339)
if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
return nil, err
}
profiles := convertSoraS3ProfilesStore(store)
active := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
if active == nil {
return nil, ErrSoraS3ProfileNotFound
}
return active, nil
}
func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) {
raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles)
if err == nil {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return &soraS3ProfilesStore{}, nil
}
var store soraS3ProfilesStore
if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil {
legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
if legacyErr != nil {
return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr)
}
if isEmptyLegacySoraS3Settings(legacy) {
return &soraS3ProfilesStore{}, nil
}
now := time.Now().UTC().Format(time.RFC3339)
return &soraS3ProfilesStore{
ActiveProfileID: "default",
Items: []soraS3ProfileStoreItem{
{
ProfileID: "default",
Name: "Default",
Enabled: legacy.Enabled,
Endpoint: strings.TrimSpace(legacy.Endpoint),
Region: strings.TrimSpace(legacy.Region),
Bucket: strings.TrimSpace(legacy.Bucket),
AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
SecretAccessKey: legacy.SecretAccessKey,
Prefix: strings.TrimSpace(legacy.Prefix),
ForcePathStyle: legacy.ForcePathStyle,
CDNURL: strings.TrimSpace(legacy.CDNURL),
DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
UpdatedAt: now,
},
},
}, nil
}
normalized := normalizeSoraS3ProfilesStore(store)
return &normalized, nil
}
if !errors.Is(err, ErrSettingNotFound) {
return nil, fmt.Errorf("get sora s3 profiles: %w", err)
}
legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
if legacyErr != nil {
return nil, legacyErr
}
if isEmptyLegacySoraS3Settings(legacy) {
return &soraS3ProfilesStore{}, nil
}
now := time.Now().UTC().Format(time.RFC3339)
return &soraS3ProfilesStore{
ActiveProfileID: "default",
Items: []soraS3ProfileStoreItem{
{
ProfileID: "default",
Name: "Default",
Enabled: legacy.Enabled,
Endpoint: strings.TrimSpace(legacy.Endpoint),
Region: strings.TrimSpace(legacy.Region),
Bucket: strings.TrimSpace(legacy.Bucket),
AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
SecretAccessKey: legacy.SecretAccessKey,
Prefix: strings.TrimSpace(legacy.Prefix),
ForcePathStyle: legacy.ForcePathStyle,
CDNURL: strings.TrimSpace(legacy.CDNURL),
DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
UpdatedAt: now,
},
},
}, nil
}
func (s *SettingService) persistSoraS3ProfilesStore(ctx context.Context, store *soraS3ProfilesStore) error {
if store == nil {
return fmt.Errorf("sora s3 profiles store cannot be nil")
}
normalized := normalizeSoraS3ProfilesStore(*store)
data, err := json.Marshal(normalized)
if err != nil {
return fmt.Errorf("marshal sora s3 profiles: %w", err)
}
updates := map[string]string{
SettingKeySoraS3Profiles: string(data),
}
active := pickActiveSoraS3ProfileFromStore(normalized.Items, normalized.ActiveProfileID)
if active == nil {
updates[SettingKeySoraS3Enabled] = "false"
updates[SettingKeySoraS3Endpoint] = ""
updates[SettingKeySoraS3Region] = ""
updates[SettingKeySoraS3Bucket] = ""
updates[SettingKeySoraS3AccessKeyID] = ""
updates[SettingKeySoraS3Prefix] = ""
updates[SettingKeySoraS3ForcePathStyle] = "false"
updates[SettingKeySoraS3CDNURL] = ""
updates[SettingKeySoraDefaultStorageQuotaBytes] = "0"
updates[SettingKeySoraS3SecretAccessKey] = ""
} else {
updates[SettingKeySoraS3Enabled] = strconv.FormatBool(active.Enabled)
updates[SettingKeySoraS3Endpoint] = strings.TrimSpace(active.Endpoint)
updates[SettingKeySoraS3Region] = strings.TrimSpace(active.Region)
updates[SettingKeySoraS3Bucket] = strings.TrimSpace(active.Bucket)
updates[SettingKeySoraS3AccessKeyID] = strings.TrimSpace(active.AccessKeyID)
updates[SettingKeySoraS3Prefix] = strings.TrimSpace(active.Prefix)
updates[SettingKeySoraS3ForcePathStyle] = strconv.FormatBool(active.ForcePathStyle)
updates[SettingKeySoraS3CDNURL] = strings.TrimSpace(active.CDNURL)
updates[SettingKeySoraDefaultStorageQuotaBytes] = strconv.FormatInt(maxInt64(active.DefaultStorageQuotaBytes, 0), 10)
updates[SettingKeySoraS3SecretAccessKey] = active.SecretAccessKey
}
if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
return err
}
if s.onUpdate != nil {
s.onUpdate()
}
if s.onS3Update != nil {
s.onS3Update()
}
return nil
}
func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
keys := []string{
SettingKeySoraS3Enabled,
SettingKeySoraS3Endpoint,
SettingKeySoraS3Region,
SettingKeySoraS3Bucket,
SettingKeySoraS3AccessKeyID,
SettingKeySoraS3SecretAccessKey,
SettingKeySoraS3Prefix,
SettingKeySoraS3ForcePathStyle,
SettingKeySoraS3CDNURL,
SettingKeySoraDefaultStorageQuotaBytes,
}
values, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get legacy sora s3 settings: %w", err)
}
result := &SoraS3Settings{
Enabled: values[SettingKeySoraS3Enabled] == "true",
Endpoint: values[SettingKeySoraS3Endpoint],
Region: values[SettingKeySoraS3Region],
Bucket: values[SettingKeySoraS3Bucket],
AccessKeyID: values[SettingKeySoraS3AccessKeyID],
SecretAccessKey: values[SettingKeySoraS3SecretAccessKey],
SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "",
Prefix: values[SettingKeySoraS3Prefix],
ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true",
CDNURL: values[SettingKeySoraS3CDNURL],
}
if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil {
result.DefaultStorageQuotaBytes = v
}
return result, nil
}
func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore {
seen := make(map[string]struct{}, len(store.Items))
normalized := soraS3ProfilesStore{
ActiveProfileID: strings.TrimSpace(store.ActiveProfileID),
Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)),
}
now := time.Now().UTC().Format(time.RFC3339)
for idx := range store.Items {
item := store.Items[idx]
item.ProfileID = strings.TrimSpace(item.ProfileID)
if item.ProfileID == "" {
item.ProfileID = fmt.Sprintf("profile-%d", idx+1)
}
if _, exists := seen[item.ProfileID]; exists {
continue
}
seen[item.ProfileID] = struct{}{}
item.Name = strings.TrimSpace(item.Name)
if item.Name == "" {
item.Name = item.ProfileID
}
item.Endpoint = strings.TrimSpace(item.Endpoint)
item.Region = strings.TrimSpace(item.Region)
item.Bucket = strings.TrimSpace(item.Bucket)
item.AccessKeyID = strings.TrimSpace(item.AccessKeyID)
item.Prefix = strings.TrimSpace(item.Prefix)
item.CDNURL = strings.TrimSpace(item.CDNURL)
item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0)
item.UpdatedAt = strings.TrimSpace(item.UpdatedAt)
if item.UpdatedAt == "" {
item.UpdatedAt = now
}
normalized.Items = append(normalized.Items, item)
}
if len(normalized.Items) == 0 {
normalized.ActiveProfileID = ""
return normalized
}
if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 {
return normalized
}
normalized.ActiveProfileID = normalized.Items[0].ProfileID
return normalized
}
func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList {
if store == nil {
return &SoraS3ProfileList{}
}
items := make([]SoraS3Profile, 0, len(store.Items))
for idx := range store.Items {
item := store.Items[idx]
items = append(items, SoraS3Profile{
ProfileID: item.ProfileID,
Name: item.Name,
IsActive: item.ProfileID == store.ActiveProfileID,
Enabled: item.Enabled,
Endpoint: item.Endpoint,
Region: item.Region,
Bucket: item.Bucket,
AccessKeyID: item.AccessKeyID,
SecretAccessKey: item.SecretAccessKey,
SecretAccessKeyConfigured: item.SecretAccessKey != "",
Prefix: item.Prefix,
ForcePathStyle: item.ForcePathStyle,
CDNURL: item.CDNURL,
DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes,
UpdatedAt: item.UpdatedAt,
})
}
return &SoraS3ProfileList{
ActiveProfileID: store.ActiveProfileID,
Items: items,
}
}
func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile {
for idx := range items {
if items[idx].ProfileID == activeProfileID {
return &items[idx]
}
}
if len(items) == 0 {
return nil
}
return &items[0]
}
func findSoraS3ProfileByID(items []SoraS3Profile, profileID string) *SoraS3Profile {
for idx := range items {
if items[idx].ProfileID == profileID {
return &items[idx]
}
}
return nil
}
func pickActiveSoraS3ProfileFromStore(items []soraS3ProfileStoreItem, activeProfileID string) *soraS3ProfileStoreItem {
for idx := range items {
if items[idx].ProfileID == activeProfileID {
return &items[idx]
}
}
if len(items) == 0 {
return nil
}
return &items[0]
}
func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int {
for idx := range items {
if items[idx].ProfileID == profileID {
return idx
}
}
return -1
}
func hasSoraS3ProfileID(items []soraS3ProfileStoreItem, profileID string) bool {
return findSoraS3ProfileIndex(items, profileID) >= 0
}
func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool {
if settings == nil {
return true
}
if settings.Enabled {
return false
}
if strings.TrimSpace(settings.Endpoint) != "" {
return false
}
if strings.TrimSpace(settings.Region) != "" {
return false
}
if strings.TrimSpace(settings.Bucket) != "" {
return false
}
if strings.TrimSpace(settings.AccessKeyID) != "" {
return false
}
if settings.SecretAccessKey != "" {
return false
}
if strings.TrimSpace(settings.Prefix) != "" {
return false
}
if strings.TrimSpace(settings.CDNURL) != "" {
return false
}
return settings.DefaultStorageQuotaBytes == 0
}
func maxInt64(value int64, min int64) int64 {
if value < min {
return min
}
return value
}

View File

@@ -39,6 +39,7 @@ type SystemSettings struct {
HideCcsImportButton bool
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
SoraClientEnabled bool
DefaultConcurrency int
DefaultBalance float64
@@ -81,11 +82,52 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
SoraClientEnabled bool
LinuxDoOAuthEnabled bool
Version string
}
// SoraS3Settings Sora S3 存储配置
type SoraS3Settings struct {
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
}
// SoraS3Profile Sora S3 多配置项(服务内部模型)
type SoraS3Profile struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
UpdatedAt string `json:"updated_at"`
}
// SoraS3ProfileList Sora S3 多配置列表
type SoraS3ProfileList struct {
ActiveProfileID string `json:"active_profile_id"`
Items []SoraS3Profile `json:"items"`
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
type StreamTimeoutSettings struct {
// Enabled 是否启用流超时处理

View File

@@ -43,6 +43,7 @@ type SoraVideoRequest struct {
Frames int
Model string
Size string
VideoCount int
MediaID string
RemixTargetID string
CameoIDs []string

View File

@@ -21,6 +21,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
@@ -63,8 +64,8 @@ var soraBlockedCIDRs = mustParseCIDRs([]string{
// SoraGatewayService handles forwarding requests to Sora upstream.
type SoraGatewayService struct {
soraClient SoraClient
mediaStorage *SoraMediaStorage
rateLimitService *RateLimitService
httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传
cfg *config.Config
}
@@ -100,14 +101,14 @@ type soraPreflightChecker interface {
func NewSoraGatewayService(
soraClient SoraClient,
mediaStorage *SoraMediaStorage,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
cfg *config.Config,
) *SoraGatewayService {
return &SoraGatewayService{
soraClient: soraClient,
mediaStorage: mediaStorage,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
cfg: cfg,
}
}
@@ -115,6 +116,15 @@ func NewSoraGatewayService(
func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
startTime := time.Now()
// apikey 类型账号HTTP 透传到上游,不走 SoraSDKClient
if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" {
if s.httpUpstream == nil {
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream)
return nil, errors.New("httpUpstream not configured for sora apikey forwarding")
}
return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime)
}
if s.soraClient == nil || !s.soraClient.Enabled() {
if c != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
@@ -296,6 +306,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
taskID := ""
var err error
videoCount := parseSoraVideoCount(reqBody)
switch modelCfg.Type {
case "image":
taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{
@@ -321,6 +332,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
VideoCount: videoCount,
MediaID: mediaID,
RemixTargetID: remixTargetID,
CameoIDs: extractSoraCameoIDs(reqBody),
@@ -378,16 +390,9 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
}
}
// 直调路径(/sora/v1/chat/completions保持纯透传不执行本地/S3 媒体落盘。
// 媒体存储由客户端 API 路径(/api/v1/sora/generate的异步流程负责。
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
if storeErr != nil {
// 存储失败时降级使用原始 URL不中断用户请求
log.Printf("[Sora] StoreFromURLs failed, falling back to original URLs: %v", storeErr)
} else {
finalURLs = s.normalizeSoraMediaURLs(stored)
}
}
if watermarkPostID != "" && watermarkOpts.DeletePost {
if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
@@ -463,6 +468,20 @@ func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
}
}
func parseSoraVideoCount(body map[string]any) int {
if body == nil {
return 1
}
keys := []string{"video_count", "videos", "n_variants"}
for _, key := range keys {
count := parseIntWithDefault(body, key, 0)
if count > 0 {
return clampInt(count, 1, 3)
}
}
return 1
}
func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
if body == nil {
return def
@@ -508,6 +527,42 @@ func parseStringWithDefault(body map[string]any, key, def string) string {
return def
}
func parseIntWithDefault(body map[string]any, key string, def int) int {
if body == nil {
return def
}
val, ok := body[key]
if !ok {
return def
}
switch typed := val.(type) {
case int:
return typed
case int32:
return int(typed)
case int64:
return int(typed)
case float64:
return int(typed)
case string:
parsed, err := strconv.Atoi(strings.TrimSpace(typed))
if err == nil {
return parsed
}
}
return def
}
func clampInt(v, minVal, maxVal int) int {
if v < minVal {
return minVal
}
if v > maxVal {
return maxVal
}
return v
}
func extractSoraCameoIDs(body map[string]any) []string {
if body == nil {
return nil
@@ -904,6 +959,21 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
}
var upstreamErr *SoraUpstreamError
if errors.As(err, &upstreamErr) {
accountID := int64(0)
if account != nil {
accountID = account.ID
}
logger.LegacyPrintf(
"service.sora",
"[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s",
accountID,
model,
upstreamErr.StatusCode,
strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")),
strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")),
strings.TrimSpace(upstreamErr.Message),
truncateForLog(upstreamErr.Body, 1024),
)
if s.rateLimitService != nil && account != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
}

View File

@@ -179,6 +179,31 @@ func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
require.True(t, client.storyboard)
}
func TestSoraGatewayService_ForwardVideoCount(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/v.mp4"},
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 3, client.videoReq.VideoCount)
}
func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
client := &stubSoraClientForPoll{}
cfg := &config.Config{
@@ -524,3 +549,10 @@ func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
require.True(t, opts.Enabled)
require.False(t, opts.FallbackOnFailure)
}
func TestParseSoraVideoCount(t *testing.T) {
require.Equal(t, 1, parseSoraVideoCount(nil))
require.Equal(t, 2, parseSoraVideoCount(map[string]any{"video_count": float64(2)}))
require.Equal(t, 3, parseSoraVideoCount(map[string]any{"videos": "5"}))
require.Equal(t, 1, parseSoraVideoCount(map[string]any{"n_variants": 0}))
}

View File

@@ -0,0 +1,63 @@
package service
import (
"context"
"time"
)
// SoraGeneration 代表一条 Sora 客户端生成记录。
type SoraGeneration struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
APIKeyID *int64 `json:"api_key_id,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt"`
MediaType string `json:"media_type"` // video / image
Status string `json:"status"` // pending / generating / completed / failed / cancelled
MediaURL string `json:"media_url"` // 主媒体 URL预签名或 CDN
MediaURLs []string `json:"media_urls"` // 多图时的 URL 数组
FileSizeBytes int64 `json:"file_size_bytes"`
StorageType string `json:"storage_type"` // s3 / local / upstream / none
S3ObjectKeys []string `json:"s3_object_keys"` // S3 object key 数组
UpstreamTaskID string `json:"upstream_task_id"`
ErrorMessage string `json:"error_message"`
CreatedAt time.Time `json:"created_at"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
}
// Sora 生成记录状态常量
const (
SoraGenStatusPending = "pending"
SoraGenStatusGenerating = "generating"
SoraGenStatusCompleted = "completed"
SoraGenStatusFailed = "failed"
SoraGenStatusCancelled = "cancelled"
)
// Sora 存储类型常量
const (
SoraStorageTypeS3 = "s3"
SoraStorageTypeLocal = "local"
SoraStorageTypeUpstream = "upstream"
SoraStorageTypeNone = "none"
)
// SoraGenerationListParams 查询生成记录的参数。
type SoraGenerationListParams struct {
UserID int64
Status string // 可选筛选
StorageType string // 可选筛选
MediaType string // 可选筛选
Page int
PageSize int
}
// SoraGenerationRepository 生成记录持久化接口。
type SoraGenerationRepository interface {
Create(ctx context.Context, gen *SoraGeneration) error
GetByID(ctx context.Context, id int64) (*SoraGeneration, error)
Update(ctx context.Context, gen *SoraGeneration) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error)
CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error)
}

View File

@@ -0,0 +1,332 @@
package service
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
var (
// ErrSoraGenerationConcurrencyLimit 表示用户进行中的任务数超限。
ErrSoraGenerationConcurrencyLimit = errors.New("sora generation concurrent limit exceeded")
// ErrSoraGenerationStateConflict 表示状态已发生变化(例如任务已取消)。
ErrSoraGenerationStateConflict = errors.New("sora generation state conflict")
// ErrSoraGenerationNotActive 表示任务不在可取消状态。
ErrSoraGenerationNotActive = errors.New("sora generation is not active")
)
const soraGenerationActiveLimit = 3
type soraGenerationRepoAtomicCreator interface {
CreatePendingWithLimit(ctx context.Context, gen *SoraGeneration, activeStatuses []string, maxActive int64) error
}
type soraGenerationRepoConditionalUpdater interface {
UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error)
UpdateCompletedIfActive(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64, completedAt time.Time) (bool, error)
UpdateFailedIfActive(ctx context.Context, id int64, errMsg string, completedAt time.Time) (bool, error)
UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error)
UpdateStorageIfCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) (bool, error)
}
// SoraGenerationService 管理 Sora 客户端的生成记录 CRUD。
type SoraGenerationService struct {
genRepo SoraGenerationRepository
s3Storage *SoraS3Storage
quotaService *SoraQuotaService
}
// NewSoraGenerationService 创建生成记录服务。
func NewSoraGenerationService(
genRepo SoraGenerationRepository,
s3Storage *SoraS3Storage,
quotaService *SoraQuotaService,
) *SoraGenerationService {
return &SoraGenerationService{
genRepo: genRepo,
s3Storage: s3Storage,
quotaService: quotaService,
}
}
// CreatePending 创建一条 pending 状态的生成记录。
func (s *SoraGenerationService) CreatePending(ctx context.Context, userID int64, apiKeyID *int64, model, prompt, mediaType string) (*SoraGeneration, error) {
gen := &SoraGeneration{
UserID: userID,
APIKeyID: apiKeyID,
Model: model,
Prompt: prompt,
MediaType: mediaType,
Status: SoraGenStatusPending,
StorageType: SoraStorageTypeNone,
}
if atomicCreator, ok := s.genRepo.(soraGenerationRepoAtomicCreator); ok {
if err := atomicCreator.CreatePendingWithLimit(
ctx,
gen,
[]string{SoraGenStatusPending, SoraGenStatusGenerating},
soraGenerationActiveLimit,
); err != nil {
if errors.Is(err, ErrSoraGenerationConcurrencyLimit) {
return nil, err
}
return nil, fmt.Errorf("create generation: %w", err)
}
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
return gen, nil
}
if err := s.genRepo.Create(ctx, gen); err != nil {
return nil, fmt.Errorf("create generation: %w", err)
}
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
return gen, nil
}
// MarkGenerating 标记为生成中。
func (s *SoraGenerationService) MarkGenerating(ctx context.Context, id int64, upstreamTaskID string) error {
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateGeneratingIfPending(ctx, id, upstreamTaskID)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending {
return ErrSoraGenerationStateConflict
}
gen.Status = SoraGenStatusGenerating
gen.UpstreamTaskID = upstreamTaskID
return s.genRepo.Update(ctx, gen)
}
// MarkCompleted 标记为已完成。
func (s *SoraGenerationService) MarkCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) error {
now := time.Now()
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateCompletedIfActive(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes, now)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
return ErrSoraGenerationStateConflict
}
gen.Status = SoraGenStatusCompleted
gen.MediaURL = mediaURL
gen.MediaURLs = mediaURLs
gen.StorageType = storageType
gen.S3ObjectKeys = s3Keys
gen.FileSizeBytes = fileSizeBytes
gen.CompletedAt = &now
return s.genRepo.Update(ctx, gen)
}
// MarkFailed 标记为失败。
func (s *SoraGenerationService) MarkFailed(ctx context.Context, id int64, errMsg string) error {
now := time.Now()
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateFailedIfActive(ctx, id, errMsg, now)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
return ErrSoraGenerationStateConflict
}
gen.Status = SoraGenStatusFailed
gen.ErrorMessage = errMsg
gen.CompletedAt = &now
return s.genRepo.Update(ctx, gen)
}
// MarkCancelled 标记为已取消。
func (s *SoraGenerationService) MarkCancelled(ctx context.Context, id int64) error {
now := time.Now()
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateCancelledIfActive(ctx, id, now)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationNotActive
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
return ErrSoraGenerationNotActive
}
gen.Status = SoraGenStatusCancelled
gen.CompletedAt = &now
return s.genRepo.Update(ctx, gen)
}
// UpdateStorageForCompleted 更新已完成记录的存储信息(不重置 completed_at
func (s *SoraGenerationService) UpdateStorageForCompleted(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
) error {
if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
updated, err := updater.UpdateStorageIfCompleted(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes)
if err != nil {
return err
}
if !updated {
return ErrSoraGenerationStateConflict
}
return nil
}
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.Status != SoraGenStatusCompleted {
return ErrSoraGenerationStateConflict
}
gen.MediaURL = mediaURL
gen.MediaURLs = mediaURLs
gen.StorageType = storageType
gen.S3ObjectKeys = s3Keys
gen.FileSizeBytes = fileSizeBytes
return s.genRepo.Update(ctx, gen)
}
// GetByID 获取记录详情(含权限校验)。
func (s *SoraGenerationService) GetByID(ctx context.Context, id, userID int64) (*SoraGeneration, error) {
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if gen.UserID != userID {
return nil, fmt.Errorf("无权访问此生成记录")
}
return gen, nil
}
// List 查询生成记录列表(分页 + 筛选)。
func (s *SoraGenerationService) List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 20
}
if params.PageSize > 100 {
params.PageSize = 100
}
return s.genRepo.List(ctx, params)
}
// Delete 删除记录(联动 S3/本地文件清理 + 配额释放)。
func (s *SoraGenerationService) Delete(ctx context.Context, id, userID int64) error {
gen, err := s.genRepo.GetByID(ctx, id)
if err != nil {
return err
}
if gen.UserID != userID {
return fmt.Errorf("无权删除此生成记录")
}
// 清理 S3 文件
if gen.StorageType == SoraStorageTypeS3 && len(gen.S3ObjectKeys) > 0 && s.s3Storage != nil {
if err := s.s3Storage.DeleteObjects(ctx, gen.S3ObjectKeys); err != nil {
logger.LegacyPrintf("service.sora_gen", "[SoraGen] S3 清理失败 id=%d err=%v", id, err)
}
}
// 释放配额S3/本地均释放)
if gen.FileSizeBytes > 0 && (gen.StorageType == SoraStorageTypeS3 || gen.StorageType == SoraStorageTypeLocal) && s.quotaService != nil {
if err := s.quotaService.ReleaseUsage(ctx, userID, gen.FileSizeBytes); err != nil {
logger.LegacyPrintf("service.sora_gen", "[SoraGen] 配额释放失败 id=%d err=%v", id, err)
}
}
return s.genRepo.Delete(ctx, id)
}
// CountActiveByUser 统计用户进行中的任务数(用于并发限制)。
func (s *SoraGenerationService) CountActiveByUser(ctx context.Context, userID int64) (int64, error) {
return s.genRepo.CountByUserAndStatus(ctx, userID, []string{SoraGenStatusPending, SoraGenStatusGenerating})
}
// ResolveMediaURLs 为 S3 记录动态生成预签名 URL。
func (s *SoraGenerationService) ResolveMediaURLs(ctx context.Context, gen *SoraGeneration) error {
if gen == nil || gen.StorageType != SoraStorageTypeS3 || s.s3Storage == nil {
return nil
}
if len(gen.S3ObjectKeys) == 0 {
return nil
}
urls := make([]string, len(gen.S3ObjectKeys))
var wg sync.WaitGroup
var firstErr error
var errMu sync.Mutex
for idx, key := range gen.S3ObjectKeys {
wg.Add(1)
go func(i int, objectKey string) {
defer wg.Done()
url, err := s.s3Storage.GetAccessURL(ctx, objectKey)
if err != nil {
errMu.Lock()
if firstErr == nil {
firstErr = err
}
errMu.Unlock()
return
}
urls[i] = url
}(idx, key)
}
wg.Wait()
if firstErr != nil {
return firstErr
}
gen.MediaURL = urls[0]
gen.MediaURLs = urls
return nil
}

View File

@@ -0,0 +1,875 @@
//go:build unit
package service
import (
"context"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/stretchr/testify/require"
)
// ==================== Stub: SoraGenerationRepository ====================
var _ SoraGenerationRepository = (*stubGenRepo)(nil)
type stubGenRepo struct {
gens map[int64]*SoraGeneration
nextID int64
createErr error
getErr error
updateErr error
deleteErr error
listErr error
countErr error
countValue int64
}
func newStubGenRepo() *stubGenRepo {
return &stubGenRepo{gens: make(map[int64]*SoraGeneration), nextID: 1}
}
func (r *stubGenRepo) Create(_ context.Context, gen *SoraGeneration) error {
if r.createErr != nil {
return r.createErr
}
gen.ID = r.nextID
gen.CreatedAt = time.Now()
r.nextID++
r.gens[gen.ID] = gen
return nil
}
func (r *stubGenRepo) GetByID(_ context.Context, id int64) (*SoraGeneration, error) {
if r.getErr != nil {
return nil, r.getErr
}
if gen, ok := r.gens[id]; ok {
return gen, nil
}
return nil, fmt.Errorf("not found")
}
func (r *stubGenRepo) Update(_ context.Context, gen *SoraGeneration) error {
if r.updateErr != nil {
return r.updateErr
}
r.gens[gen.ID] = gen
return nil
}
func (r *stubGenRepo) Delete(_ context.Context, id int64) error {
if r.deleteErr != nil {
return r.deleteErr
}
delete(r.gens, id)
return nil
}
func (r *stubGenRepo) List(_ context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
if r.listErr != nil {
return nil, 0, r.listErr
}
var result []*SoraGeneration
for _, gen := range r.gens {
if gen.UserID != params.UserID {
continue
}
if params.Status != "" && gen.Status != params.Status {
continue
}
if params.StorageType != "" && gen.StorageType != params.StorageType {
continue
}
if params.MediaType != "" && gen.MediaType != params.MediaType {
continue
}
result = append(result, gen)
}
return result, int64(len(result)), nil
}
func (r *stubGenRepo) CountByUserAndStatus(_ context.Context, userID int64, statuses []string) (int64, error) {
if r.countErr != nil {
return 0, r.countErr
}
if r.countValue > 0 {
return r.countValue, nil
}
var count int64
statusSet := make(map[string]struct{})
for _, s := range statuses {
statusSet[s] = struct{}{}
}
for _, gen := range r.gens {
if gen.UserID == userID {
if _, ok := statusSet[gen.Status]; ok {
count++
}
}
}
return count, nil
}
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
var _ UserRepository = (*stubUserRepoForQuota)(nil)
type stubUserRepoForQuota struct {
users map[int64]*User
updateErr error
}
func newStubUserRepoForQuota() *stubUserRepoForQuota {
return &stubUserRepoForQuota{users: make(map[int64]*User)}
}
func (r *stubUserRepoForQuota) GetByID(_ context.Context, id int64) (*User, error) {
if u, ok := r.users[id]; ok {
return u, nil
}
return nil, fmt.Errorf("user not found")
}
func (r *stubUserRepoForQuota) Update(_ context.Context, user *User) error {
if r.updateErr != nil {
return r.updateErr
}
r.users[user.ID] = user
return nil
}
func (r *stubUserRepoForQuota) Create(context.Context, *User) error { return nil }
func (r *stubUserRepoForQuota) GetByEmail(context.Context, string) (*User, error) {
return nil, nil
}
func (r *stubUserRepoForQuota) GetFirstAdmin(context.Context) (*User, error) { return nil, nil }
func (r *stubUserRepoForQuota) Delete(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubUserRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubUserRepoForQuota) UpdateBalance(context.Context, int64, float64) error { return nil }
func (r *stubUserRepoForQuota) DeductBalance(context.Context, int64, float64) error { return nil }
func (r *stubUserRepoForQuota) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (r *stubUserRepoForQuota) ExistsByEmail(context.Context, string) (bool, error) {
return false, nil
}
func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
// newS3StorageWithCDN 创建一个预缓存了 CDN 配置的 SoraS3Storage
// 避免实际初始化 AWS 客户端。用于测试 GetAccessURL 的 CDN 路径。
func newS3StorageWithCDN(cdnURL string) *SoraS3Storage {
storage := &SoraS3Storage{}
storage.cfg = &SoraS3Settings{
Enabled: true,
Bucket: "test-bucket",
CDNURL: cdnURL,
}
// 需要 non-nil client 使 getClient 命中缓存
storage.client = s3.New(s3.Options{})
return storage
}
// newS3StorageFailingDelete 创建一个 settingService=nil 的 SoraS3Storage
// 使 DeleteObjects 返回错误(无法获取配置)。用于测试 Delete 方法 S3 清理失败但仍继续的场景。
func newS3StorageFailingDelete() *SoraS3Storage {
return &SoraS3Storage{} // settingService 为 nil → getConfig 返回 error
}
// ==================== CreatePending ====================
func TestCreatePending_Success(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "一只猫跳舞", "video")
require.NoError(t, err)
require.Equal(t, int64(1), gen.ID)
require.Equal(t, int64(1), gen.UserID)
require.Equal(t, "sora2-landscape-10s", gen.Model)
require.Equal(t, "一只猫跳舞", gen.Prompt)
require.Equal(t, "video", gen.MediaType)
require.Equal(t, SoraGenStatusPending, gen.Status)
require.Equal(t, SoraStorageTypeNone, gen.StorageType)
require.Nil(t, gen.APIKeyID)
}
func TestCreatePending_WithAPIKeyID(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
apiKeyID := int64(42)
gen, err := svc.CreatePending(context.Background(), 1, &apiKeyID, "gpt-image", "画一朵花", "image")
require.NoError(t, err)
require.NotNil(t, gen.APIKeyID)
require.Equal(t, int64(42), *gen.APIKeyID)
}
func TestCreatePending_RepoError(t *testing.T) {
repo := newStubGenRepo()
repo.createErr = fmt.Errorf("db write error")
svc := NewSoraGenerationService(repo, nil, nil)
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
require.Error(t, err)
require.Nil(t, gen)
require.Contains(t, err.Error(), "create generation")
}
// ==================== MarkGenerating ====================
func TestMarkGenerating_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkGenerating(context.Background(), 1, "upstream-task-123")
require.NoError(t, err)
require.Equal(t, SoraGenStatusGenerating, repo.gens[1].Status)
require.Equal(t, "upstream-task-123", repo.gens[1].UpstreamTaskID)
}
func TestMarkGenerating_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkGenerating(context.Background(), 999, "")
require.Error(t, err)
}
func TestMarkGenerating_UpdateError(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
repo.updateErr = fmt.Errorf("update failed")
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkGenerating(context.Background(), 1, "")
require.Error(t, err)
}
// ==================== MarkCompleted ====================
func TestMarkCompleted_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCompleted(context.Background(), 1,
"https://cdn.example.com/video.mp4",
[]string{"https://cdn.example.com/video.mp4"},
SoraStorageTypeS3,
[]string{"sora/1/2024/01/01/uuid.mp4"},
1048576,
)
require.NoError(t, err)
gen := repo.gens[1]
require.Equal(t, SoraGenStatusCompleted, gen.Status)
require.Equal(t, "https://cdn.example.com/video.mp4", gen.MediaURL)
require.Equal(t, []string{"https://cdn.example.com/video.mp4"}, gen.MediaURLs)
require.Equal(t, SoraStorageTypeS3, gen.StorageType)
require.Equal(t, []string{"sora/1/2024/01/01/uuid.mp4"}, gen.S3ObjectKeys)
require.Equal(t, int64(1048576), gen.FileSizeBytes)
require.NotNil(t, gen.CompletedAt)
}
func TestMarkCompleted_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCompleted(context.Background(), 999, "", nil, "", nil, 0)
require.Error(t, err)
}
func TestMarkCompleted_UpdateError(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
repo.updateErr = fmt.Errorf("update failed")
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCompleted(context.Background(), 1, "url", nil, SoraStorageTypeUpstream, nil, 0)
require.Error(t, err)
}
// ==================== MarkFailed ====================
func TestMarkFailed_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkFailed(context.Background(), 1, "上游返回 500 错误")
require.NoError(t, err)
gen := repo.gens[1]
require.Equal(t, SoraGenStatusFailed, gen.Status)
require.Equal(t, "上游返回 500 错误", gen.ErrorMessage)
require.NotNil(t, gen.CompletedAt)
}
func TestMarkFailed_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkFailed(context.Background(), 999, "error")
require.Error(t, err)
}
func TestMarkFailed_UpdateError(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
repo.updateErr = fmt.Errorf("update failed")
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkFailed(context.Background(), 1, "err")
require.Error(t, err)
}
// ==================== MarkCancelled ====================
func TestMarkCancelled_Pending(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
require.NotNil(t, repo.gens[1].CompletedAt)
}
func TestMarkCancelled_Generating(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
}
func TestMarkCancelled_Completed(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.Error(t, err)
require.ErrorIs(t, err, ErrSoraGenerationNotActive)
}
func TestMarkCancelled_Failed(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusFailed}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.Error(t, err)
}
func TestMarkCancelled_AlreadyCancelled(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCancelled}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.Error(t, err)
}
func TestMarkCancelled_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 999)
require.Error(t, err)
}
func TestMarkCancelled_UpdateError(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
repo.updateErr = fmt.Errorf("update failed")
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.MarkCancelled(context.Background(), 1)
require.Error(t, err)
}
// ==================== GetByID ====================
func TestGetByID_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, Model: "sora2-landscape-10s"}
svc := NewSoraGenerationService(repo, nil, nil)
gen, err := svc.GetByID(context.Background(), 1, 1)
require.NoError(t, err)
require.Equal(t, int64(1), gen.ID)
require.Equal(t, "sora2-landscape-10s", gen.Model)
}
func TestGetByID_WrongUser(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
svc := NewSoraGenerationService(repo, nil, nil)
gen, err := svc.GetByID(context.Background(), 1, 1)
require.Error(t, err)
require.Nil(t, gen)
require.Contains(t, err.Error(), "无权访问")
}
func TestGetByID_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, err := svc.GetByID(context.Background(), 999, 1)
require.Error(t, err)
require.Nil(t, gen)
}
// ==================== List ====================
func TestList_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, MediaType: "video"}
repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusPending, MediaType: "image"}
repo.gens[3] = &SoraGeneration{ID: 3, UserID: 2, Status: SoraGenStatusCompleted, MediaType: "video"}
svc := NewSoraGenerationService(repo, nil, nil)
gens, total, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 20})
require.NoError(t, err)
require.Len(t, gens, 2) // 只有 userID=1 的
require.Equal(t, int64(2), total)
}
func TestList_DefaultPagination(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
// page=0, pageSize=0 → 应修正为 page=1, pageSize=20
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
require.NoError(t, err)
}
func TestList_MaxPageSize(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
// pageSize > 100 → 应限制为 100
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 200})
require.NoError(t, err)
}
func TestList_Error(t *testing.T) {
repo := newStubGenRepo()
repo.listErr = fmt.Errorf("db error")
svc := NewSoraGenerationService(repo, nil, nil)
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
require.Error(t, err)
}
// ==================== Delete ====================
func TestDelete_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, StorageType: SoraStorageTypeUpstream}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
_, exists := repo.gens[1]
require.False(t, exists)
}
func TestDelete_WrongUser(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.Error(t, err)
require.Contains(t, err.Error(), "无权删除")
}
func TestDelete_NotFound(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 999, 1)
require.Error(t, err)
}
func TestDelete_S3Cleanup_NilS3(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err) // s3Storage 为 nil跳过清理
}
func TestDelete_QuotaRelease_NilQuota(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, FileSizeBytes: 1024}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err) // quotaService 为 nil跳过释放
}
func TestDelete_NonS3NoCleanup(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeLocal, FileSizeBytes: 1024}
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
}
func TestDelete_DeleteRepoError(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeUpstream}
repo.deleteErr = fmt.Errorf("delete failed")
svc := NewSoraGenerationService(repo, nil, nil)
err := svc.Delete(context.Background(), 1, 1)
require.Error(t, err)
}
// ==================== CountActiveByUser ====================
func TestCountActiveByUser_Success(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusGenerating}
repo.gens[3] = &SoraGeneration{ID: 3, UserID: 1, Status: SoraGenStatusCompleted} // 不算
svc := NewSoraGenerationService(repo, nil, nil)
count, err := svc.CountActiveByUser(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, int64(2), count)
}
func TestCountActiveByUser_NoActive(t *testing.T) {
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
svc := NewSoraGenerationService(repo, nil, nil)
count, err := svc.CountActiveByUser(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, int64(0), count)
}
func TestCountActiveByUser_Error(t *testing.T) {
repo := newStubGenRepo()
repo.countErr = fmt.Errorf("db error")
svc := NewSoraGenerationService(repo, nil, nil)
_, err := svc.CountActiveByUser(context.Background(), 1)
require.Error(t, err)
}
// ==================== ResolveMediaURLs ====================
func TestResolveMediaURLs_NilGen(t *testing.T) {
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
require.NoError(t, svc.ResolveMediaURLs(context.Background(), nil))
}
func TestResolveMediaURLs_NonS3(t *testing.T) {
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
gen := &SoraGeneration{StorageType: SoraStorageTypeUpstream, MediaURL: "https://original.com/v.mp4"}
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
require.Equal(t, "https://original.com/v.mp4", gen.MediaURL) // 不变
}
func TestResolveMediaURLs_S3NilStorage(t *testing.T) {
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
gen := &SoraGeneration{StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
}
func TestResolveMediaURLs_Local(t *testing.T) {
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
gen := &SoraGeneration{StorageType: SoraStorageTypeLocal, MediaURL: "/video/2024/01/01/file.mp4"}
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
require.Equal(t, "/video/2024/01/01/file.mp4", gen.MediaURL) // 不变
}
// ==================== 状态流转完整测试 ====================
func TestStatusTransition_PendingToCompletedFlow(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
// 1. 创建 pending
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
require.NoError(t, err)
require.Equal(t, SoraGenStatusPending, gen.Status)
// 2. 标记 generating
err = svc.MarkGenerating(context.Background(), gen.ID, "task-123")
require.NoError(t, err)
require.Equal(t, SoraGenStatusGenerating, repo.gens[gen.ID].Status)
// 3. 标记 completed
err = svc.MarkCompleted(context.Background(), gen.ID, "https://s3.com/video.mp4", nil, SoraStorageTypeS3, []string{"key"}, 1024)
require.NoError(t, err)
require.Equal(t, SoraGenStatusCompleted, repo.gens[gen.ID].Status)
}
func TestStatusTransition_PendingToFailedFlow(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
_ = svc.MarkGenerating(context.Background(), gen.ID, "")
err := svc.MarkFailed(context.Background(), gen.ID, "上游超时")
require.NoError(t, err)
require.Equal(t, SoraGenStatusFailed, repo.gens[gen.ID].Status)
require.Equal(t, "上游超时", repo.gens[gen.ID].ErrorMessage)
}
func TestStatusTransition_PendingToCancelledFlow(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
err := svc.MarkCancelled(context.Background(), gen.ID)
require.NoError(t, err)
require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
}
func TestStatusTransition_GeneratingToCancelledFlow(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
_ = svc.MarkGenerating(context.Background(), gen.ID, "")
err := svc.MarkCancelled(context.Background(), gen.ID)
require.NoError(t, err)
require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
}
// ==================== 权限隔离测试 ====================
func TestUserIsolation_CannotAccessOthersRecord(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
// 用户 2 尝试访问用户 1 的记录
_, err := svc.GetByID(context.Background(), gen.ID, 2)
require.Error(t, err)
require.Contains(t, err.Error(), "无权访问")
}
func TestUserIsolation_CannotDeleteOthersRecord(t *testing.T) {
repo := newStubGenRepo()
svc := NewSoraGenerationService(repo, nil, nil)
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
err := svc.Delete(context.Background(), gen.ID, 2)
require.Error(t, err)
require.Contains(t, err.Error(), "无权删除")
}
// ==================== Delete: S3 清理 + 配额释放路径 ====================
func TestDelete_S3Cleanup_WithS3Storage(t *testing.T) {
// S3 存储存在但 deleteObjects 会失败settingService=nil
// 验证 Delete 仍然成功S3 错误只是记录日志)
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{
ID: 1, UserID: 1,
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{"sora/1/2024/01/01/abc.mp4"},
}
s3Storage := newS3StorageFailingDelete()
svc := NewSoraGenerationService(repo, s3Storage, nil)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err) // S3 清理失败不影响删除
_, exists := repo.gens[1]
require.False(t, exists)
}
func TestDelete_QuotaRelease_WithQuotaService(t *testing.T) {
// 有配额服务时,删除 S3 类型记录会释放配额
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{
ID: 1, UserID: 1,
StorageType: SoraStorageTypeS3,
FileSizeBytes: 1048576, // 1MB
}
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2097152} // 2MB
quotaService := NewSoraQuotaService(userRepo, nil, nil)
svc := NewSoraGenerationService(repo, nil, quotaService)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
// 配额应被释放: 2MB - 1MB = 1MB
require.Equal(t, int64(1048576), userRepo.users[1].SoraStorageUsedBytes)
}
func TestDelete_S3Cleanup_And_QuotaRelease(t *testing.T) {
// S3 清理 + 配额释放同时触发
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{
ID: 1, UserID: 1,
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{"key1"},
FileSizeBytes: 512,
}
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
quotaService := NewSoraQuotaService(userRepo, nil, nil)
s3Storage := newS3StorageFailingDelete()
svc := NewSoraGenerationService(repo, s3Storage, quotaService)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
_, exists := repo.gens[1]
require.False(t, exists)
require.Equal(t, int64(512), userRepo.users[1].SoraStorageUsedBytes)
}
func TestDelete_QuotaRelease_LocalStorage(t *testing.T) {
// 本地存储同样需要释放配额
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{
ID: 1, UserID: 1,
StorageType: SoraStorageTypeLocal,
FileSizeBytes: 1024,
}
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2048}
quotaService := NewSoraQuotaService(userRepo, nil, nil)
svc := NewSoraGenerationService(repo, nil, quotaService)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
}
func TestDelete_QuotaRelease_ZeroFileSize(t *testing.T) {
// FileSizeBytes=0 跳过配额释放
repo := newStubGenRepo()
repo.gens[1] = &SoraGeneration{
ID: 1, UserID: 1,
StorageType: SoraStorageTypeS3,
FileSizeBytes: 0,
}
userRepo := newStubUserRepoForQuota()
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
quotaService := NewSoraQuotaService(userRepo, nil, nil)
svc := NewSoraGenerationService(repo, nil, quotaService)
err := svc.Delete(context.Background(), 1, 1)
require.NoError(t, err)
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
}
// ==================== ResolveMediaURLs: S3 + CDN 路径 ====================
func TestResolveMediaURLs_S3_CDN_SingleKey(t *testing.T) {
s3Storage := newS3StorageWithCDN("https://cdn.example.com")
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
gen := &SoraGeneration{
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
MediaURL: "original",
}
err := svc.ResolveMediaURLs(context.Background(), gen)
require.NoError(t, err)
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", gen.MediaURL)
}
func TestResolveMediaURLs_S3_CDN_MultipleKeys(t *testing.T) {
s3Storage := newS3StorageWithCDN("https://cdn.example.com/")
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
gen := &SoraGeneration{
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{
"sora/1/2024/01/01/img1.png",
"sora/1/2024/01/01/img2.png",
"sora/1/2024/01/01/img3.png",
},
MediaURL: "original",
}
err := svc.ResolveMediaURLs(context.Background(), gen)
require.NoError(t, err)
// 主 URL 更新为第一个 key 的 CDN URL
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURL)
// 多图 URLs 全部更新
require.Len(t, gen.MediaURLs, 3)
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURLs[0])
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img2.png", gen.MediaURLs[1])
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img3.png", gen.MediaURLs[2])
}
func TestResolveMediaURLs_S3_EmptyKeys(t *testing.T) {
s3Storage := newS3StorageWithCDN("https://cdn.example.com")
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
gen := &SoraGeneration{
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{},
MediaURL: "original",
}
err := svc.ResolveMediaURLs(context.Background(), gen)
require.NoError(t, err)
require.Equal(t, "original", gen.MediaURL) // 不变
}
func TestResolveMediaURLs_S3_GetAccessURL_Error(t *testing.T) {
// 使用无 settingService 的 S3 StoragegetClient 会失败
s3Storage := newS3StorageFailingDelete() // 同样 GetAccessURL 也会失败
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
gen := &SoraGeneration{
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
MediaURL: "original",
}
err := svc.ResolveMediaURLs(context.Background(), gen)
require.Error(t, err) // GetAccessURL 失败应传播错误
}
func TestResolveMediaURLs_S3_MultiKey_ErrorOnSecond(t *testing.T) {
// 只有一个 key 时走主 URL 路径成功,但多 key 路径的错误也需覆盖
s3Storage := newS3StorageFailingDelete()
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
gen := &SoraGeneration{
StorageType: SoraStorageTypeS3,
S3ObjectKeys: []string{
"sora/1/2024/01/01/img1.png",
"sora/1/2024/01/01/img2.png",
},
MediaURL: "original",
}
err := svc.ResolveMediaURLs(context.Background(), gen)
require.Error(t, err) // 第一个 key 的 GetAccessURL 就会失败
}

View File

@@ -157,6 +157,64 @@ func (s *SoraMediaStorage) StoreFromURLs(ctx context.Context, mediaType string,
return results, nil
}
// TotalSizeByRelativePaths 统计本地存储路径总大小(仅统计 /image 和 /video 路径)。
func (s *SoraMediaStorage) TotalSizeByRelativePaths(paths []string) (int64, error) {
if s == nil || len(paths) == 0 {
return 0, nil
}
var total int64
for _, p := range paths {
localPath, err := s.resolveLocalPath(p)
if err != nil {
continue
}
info, err := os.Stat(localPath)
if err != nil {
if os.IsNotExist(err) {
continue
}
return 0, err
}
if info.Mode().IsRegular() {
total += info.Size()
}
}
return total, nil
}
// DeleteByRelativePaths 删除本地媒体路径(仅删除 /image 和 /video 路径)。
func (s *SoraMediaStorage) DeleteByRelativePaths(paths []string) error {
if s == nil || len(paths) == 0 {
return nil
}
var lastErr error
for _, p := range paths {
localPath, err := s.resolveLocalPath(p)
if err != nil {
continue
}
if err := os.Remove(localPath); err != nil && !os.IsNotExist(err) {
lastErr = err
}
}
return lastErr
}
func (s *SoraMediaStorage) resolveLocalPath(relativePath string) (string, error) {
if s == nil || strings.TrimSpace(relativePath) == "" {
return "", errors.New("empty path")
}
cleaned := path.Clean(relativePath)
if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
return "", errors.New("not a local media path")
}
if strings.TrimSpace(s.root) == "" {
return "", errors.New("storage root not configured")
}
relative := strings.TrimPrefix(cleaned, "/")
return filepath.Join(s.root, filepath.FromSlash(relative)), nil
}
func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawURL string) (string, error) {
if strings.TrimSpace(rawURL) == "" {
return "", errors.New("empty url")

View File

@@ -1,6 +1,9 @@
package service
import (
"regexp"
"sort"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -247,6 +250,218 @@ func GetSoraModelConfig(model string) (SoraModelConfig, bool) {
return cfg, ok
}
// SoraModelFamily 模型家族(前端 Sora 客户端使用)
type SoraModelFamily struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Orientations []string `json:"orientations"`
Durations []int `json:"durations,omitempty"`
}
var (
videoSuffixRe = regexp.MustCompile(`-(landscape|portrait)-(\d+)s$`)
imageSuffixRe = regexp.MustCompile(`-(landscape|portrait)$`)
soraFamilyNames = map[string]string{
"sora2": "Sora 2",
"sora2pro": "Sora 2 Pro",
"sora2pro-hd": "Sora 2 Pro HD",
"gpt-image": "GPT Image",
}
)
// BuildSoraModelFamilies 从 soraModelConfigs 自动聚合模型家族及其支持的方向和时长
func BuildSoraModelFamilies() []SoraModelFamily {
type familyData struct {
modelType string
orientations map[string]bool
durations map[int]bool
}
families := make(map[string]*familyData)
for id, cfg := range soraModelConfigs {
if cfg.Type == "prompt_enhance" {
continue
}
var famID, orientation string
var duration int
switch cfg.Type {
case "video":
if m := videoSuffixRe.FindStringSubmatch(id); m != nil {
famID = id[:len(id)-len(m[0])]
orientation = m[1]
duration, _ = strconv.Atoi(m[2])
}
case "image":
if m := imageSuffixRe.FindStringSubmatch(id); m != nil {
famID = id[:len(id)-len(m[0])]
orientation = m[1]
} else {
famID = id
orientation = "square"
}
}
if famID == "" {
continue
}
fd, ok := families[famID]
if !ok {
fd = &familyData{
modelType: cfg.Type,
orientations: make(map[string]bool),
durations: make(map[int]bool),
}
families[famID] = fd
}
if orientation != "" {
fd.orientations[orientation] = true
}
if duration > 0 {
fd.durations[duration] = true
}
}
// 排序:视频在前、图像在后,同类按名称排序
famIDs := make([]string, 0, len(families))
for id := range families {
famIDs = append(famIDs, id)
}
sort.Slice(famIDs, func(i, j int) bool {
fi, fj := families[famIDs[i]], families[famIDs[j]]
if fi.modelType != fj.modelType {
return fi.modelType == "video"
}
return famIDs[i] < famIDs[j]
})
result := make([]SoraModelFamily, 0, len(famIDs))
for _, famID := range famIDs {
fd := families[famID]
fam := SoraModelFamily{
ID: famID,
Name: soraFamilyNames[famID],
Type: fd.modelType,
}
if fam.Name == "" {
fam.Name = famID
}
for o := range fd.orientations {
fam.Orientations = append(fam.Orientations, o)
}
sort.Strings(fam.Orientations)
for d := range fd.durations {
fam.Durations = append(fam.Durations, d)
}
sort.Ints(fam.Durations)
result = append(result, fam)
}
return result
}
// BuildSoraModelFamiliesFromIDs 从任意模型 ID 列表聚合模型家族(用于解析上游返回的模型列表)。
// 通过命名约定自动识别视频/图像模型并分组。
func BuildSoraModelFamiliesFromIDs(modelIDs []string) []SoraModelFamily {
type familyData struct {
modelType string
orientations map[string]bool
durations map[int]bool
}
families := make(map[string]*familyData)
for _, id := range modelIDs {
id = strings.ToLower(strings.TrimSpace(id))
if id == "" || strings.HasPrefix(id, "prompt-enhance") {
continue
}
var famID, orientation, modelType string
var duration int
if m := videoSuffixRe.FindStringSubmatch(id); m != nil {
// 视频模型: {family}-{orientation}-{duration}s
famID = id[:len(id)-len(m[0])]
orientation = m[1]
duration, _ = strconv.Atoi(m[2])
modelType = "video"
} else if m := imageSuffixRe.FindStringSubmatch(id); m != nil {
// 图像模型(带方向): {family}-{orientation}
famID = id[:len(id)-len(m[0])]
orientation = m[1]
modelType = "image"
} else if cfg, ok := soraModelConfigs[id]; ok && cfg.Type == "image" {
// 已知的无后缀图像模型(如 gpt-image
famID = id
orientation = "square"
modelType = "image"
} else if strings.Contains(id, "image") {
// 未知但名称包含 image 的模型,推断为图像模型
famID = id
orientation = "square"
modelType = "image"
} else {
continue
}
if famID == "" {
continue
}
fd, ok := families[famID]
if !ok {
fd = &familyData{
modelType: modelType,
orientations: make(map[string]bool),
durations: make(map[int]bool),
}
families[famID] = fd
}
if orientation != "" {
fd.orientations[orientation] = true
}
if duration > 0 {
fd.durations[duration] = true
}
}
famIDs := make([]string, 0, len(families))
for id := range families {
famIDs = append(famIDs, id)
}
sort.Slice(famIDs, func(i, j int) bool {
fi, fj := families[famIDs[i]], families[famIDs[j]]
if fi.modelType != fj.modelType {
return fi.modelType == "video"
}
return famIDs[i] < famIDs[j]
})
result := make([]SoraModelFamily, 0, len(famIDs))
for _, famID := range famIDs {
fd := families[famID]
fam := SoraModelFamily{
ID: famID,
Name: soraFamilyNames[famID],
Type: fd.modelType,
}
if fam.Name == "" {
fam.Name = famID
}
for o := range fd.orientations {
fam.Orientations = append(fam.Orientations, o)
}
sort.Strings(fam.Orientations)
for d := range fd.durations {
fam.Durations = append(fam.Durations, d)
}
sort.Ints(fam.Durations)
result = append(result, fam)
}
return result
}
// DefaultSoraModels returns the default Sora model list.
func DefaultSoraModels(cfg *config.Config) []openai.Model {
models := make([]openai.Model, 0, len(soraModelIDs))

View File

@@ -0,0 +1,257 @@
package service
import (
"context"
"errors"
"fmt"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// SoraQuotaService 管理 Sora 用户存储配额。
// 配额优先级:用户级 → 分组级 → 系统默认值。
type SoraQuotaService struct {
userRepo UserRepository
groupRepo GroupRepository
settingService *SettingService
}
// NewSoraQuotaService 创建配额服务实例。
func NewSoraQuotaService(
userRepo UserRepository,
groupRepo GroupRepository,
settingService *SettingService,
) *SoraQuotaService {
return &SoraQuotaService{
userRepo: userRepo,
groupRepo: groupRepo,
settingService: settingService,
}
}
// QuotaInfo 返回给客户端的配额信息。
type QuotaInfo struct {
QuotaBytes int64 `json:"quota_bytes"` // 总配额0 表示无限制)
UsedBytes int64 `json:"used_bytes"` // 已使用
AvailableBytes int64 `json:"available_bytes"` // 剩余可用(无限制时为 0
QuotaSource string `json:"quota_source"` // 配额来源user / group / system / unlimited
Source string `json:"source,omitempty"` // 兼容旧字段
}
// ErrSoraStorageQuotaExceeded 表示配额不足。
var ErrSoraStorageQuotaExceeded = errors.New("sora storage quota exceeded")
// QuotaExceededError 包含配额不足的上下文信息。
type QuotaExceededError struct {
QuotaBytes int64
UsedBytes int64
}
func (e *QuotaExceededError) Error() string {
if e == nil {
return "存储配额不足"
}
return fmt.Sprintf("存储配额不足(已用 %d / 配额 %d 字节)", e.UsedBytes, e.QuotaBytes)
}
type soraQuotaAtomicUserRepository interface {
AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error)
ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error)
}
// GetQuota 获取用户的存储配额信息。
// 优先级:用户级 > 用户所属分组级 > 系统默认值。
func (s *SoraQuotaService) GetQuota(ctx context.Context, userID int64) (*QuotaInfo, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
info := &QuotaInfo{
UsedBytes: user.SoraStorageUsedBytes,
}
// 1. 用户级配额
if user.SoraStorageQuotaBytes > 0 {
info.QuotaBytes = user.SoraStorageQuotaBytes
info.QuotaSource = "user"
info.Source = info.QuotaSource
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
return info, nil
}
// 2. 分组级配额(取用户可用分组中最大的配额)
if len(user.AllowedGroups) > 0 {
var maxGroupQuota int64
for _, gid := range user.AllowedGroups {
group, err := s.groupRepo.GetByID(ctx, gid)
if err != nil {
continue
}
if group.SoraStorageQuotaBytes > maxGroupQuota {
maxGroupQuota = group.SoraStorageQuotaBytes
}
}
if maxGroupQuota > 0 {
info.QuotaBytes = maxGroupQuota
info.QuotaSource = "group"
info.Source = info.QuotaSource
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
return info, nil
}
}
// 3. 系统默认值
defaultQuota := s.getSystemDefaultQuota(ctx)
if defaultQuota > 0 {
info.QuotaBytes = defaultQuota
info.QuotaSource = "system"
info.Source = info.QuotaSource
info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
return info, nil
}
// 无配额限制
info.QuotaSource = "unlimited"
info.Source = info.QuotaSource
info.AvailableBytes = 0
return info, nil
}
// CheckQuota 检查用户是否有足够的存储配额。
// 返回 nil 表示配额充足或无限制。
func (s *SoraQuotaService) CheckQuota(ctx context.Context, userID int64, additionalBytes int64) error {
quota, err := s.GetQuota(ctx, userID)
if err != nil {
return err
}
// 0 表示无限制
if quota.QuotaBytes == 0 {
return nil
}
if quota.UsedBytes+additionalBytes > quota.QuotaBytes {
return &QuotaExceededError{
QuotaBytes: quota.QuotaBytes,
UsedBytes: quota.UsedBytes,
}
}
return nil
}
// AddUsage 原子累加用量(上传成功后调用)。
func (s *SoraQuotaService) AddUsage(ctx context.Context, userID int64, bytes int64) error {
if bytes <= 0 {
return nil
}
quota, err := s.GetQuota(ctx, userID)
if err != nil {
return err
}
if quota.QuotaBytes > 0 && quota.UsedBytes+bytes > quota.QuotaBytes {
return &QuotaExceededError{
QuotaBytes: quota.QuotaBytes,
UsedBytes: quota.UsedBytes,
}
}
if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
newUsed, err := repo.AddSoraStorageUsageWithQuota(ctx, userID, bytes, quota.QuotaBytes)
if err != nil {
if errors.Is(err, ErrSoraStorageQuotaExceeded) {
return &QuotaExceededError{
QuotaBytes: quota.QuotaBytes,
UsedBytes: quota.UsedBytes,
}
}
return fmt.Errorf("update user quota usage (atomic): %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, newUsed)
return nil
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user for quota update: %w", err)
}
user.SoraStorageUsedBytes += bytes
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user quota usage: %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
return nil
}
// ReleaseUsage 释放用量(删除文件后调用)。
func (s *SoraQuotaService) ReleaseUsage(ctx context.Context, userID int64, bytes int64) error {
if bytes <= 0 {
return nil
}
if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
newUsed, err := repo.ReleaseSoraStorageUsageAtomic(ctx, userID, bytes)
if err != nil {
return fmt.Errorf("update user quota release (atomic): %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, newUsed)
return nil
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user for quota release: %w", err)
}
user.SoraStorageUsedBytes -= bytes
if user.SoraStorageUsedBytes < 0 {
user.SoraStorageUsedBytes = 0
}
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user quota release: %w", err)
}
logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
return nil
}
func calcAvailableBytes(quotaBytes, usedBytes int64) int64 {
if quotaBytes <= 0 {
return 0
}
if usedBytes >= quotaBytes {
return 0
}
return quotaBytes - usedBytes
}
func (s *SoraQuotaService) getSystemDefaultQuota(ctx context.Context) int64 {
if s.settingService == nil {
return 0
}
settings, err := s.settingService.GetSoraS3Settings(ctx)
if err != nil {
return 0
}
return settings.DefaultStorageQuotaBytes
}
// GetQuotaFromSettings 从系统设置获取默认配额(供外部使用)。
func (s *SoraQuotaService) GetQuotaFromSettings(ctx context.Context) int64 {
return s.getSystemDefaultQuota(ctx)
}
// SetUserQuota 设置用户级配额(管理员操作)。
func SetUserSoraQuota(ctx context.Context, userRepo UserRepository, userID int64, quotaBytes int64) error {
user, err := userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
user.SoraStorageQuotaBytes = quotaBytes
return userRepo.Update(ctx, user)
}
// ParseQuotaBytes 解析配额字符串为字节数。
func ParseQuotaBytes(s string) int64 {
v, _ := strconv.ParseInt(s, 10, 64)
return v
}

Some files were not shown because too many files have changed in this diff Show More