mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-15 04:14:46 +08:00
feat(sync): full code sync from release
This commit is contained in:
@@ -364,6 +364,8 @@ type GatewayConfig struct {
|
||||
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
|
||||
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
|
||||
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
|
||||
// OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP)
|
||||
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
|
||||
|
||||
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
|
||||
// MaxIdleConns: 所有主机的最大空闲连接总数
|
||||
@@ -450,6 +452,101 @@ type GatewayConfig struct {
|
||||
ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"`
|
||||
}
|
||||
|
||||
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
|
||||
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
|
||||
type GatewayOpenAIWSConfig struct {
|
||||
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
|
||||
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
|
||||
// IngressModeDefault: ingress 默认模式(off/shared/dedicated)
|
||||
IngressModeDefault string `mapstructure:"ingress_mode_default"`
|
||||
// Enabled: 全局总开关(默认 true)
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// OAuthEnabled: 是否允许 OpenAI OAuth 账号使用 WS
|
||||
OAuthEnabled bool `mapstructure:"oauth_enabled"`
|
||||
// APIKeyEnabled: 是否允许 OpenAI API Key 账号使用 WS
|
||||
APIKeyEnabled bool `mapstructure:"apikey_enabled"`
|
||||
// ForceHTTP: 全局强制 HTTP(用于紧急回滚)
|
||||
ForceHTTP bool `mapstructure:"force_http"`
|
||||
// AllowStoreRecovery: 允许在 WSv2 下按策略恢复 store=true(默认 false)
|
||||
AllowStoreRecovery bool `mapstructure:"allow_store_recovery"`
|
||||
// IngressPreviousResponseRecoveryEnabled: ingress 模式收到 previous_response_not_found 时,是否允许自动去掉 previous_response_id 重试一次(默认 true)
|
||||
IngressPreviousResponseRecoveryEnabled bool `mapstructure:"ingress_previous_response_recovery_enabled"`
|
||||
// StoreDisabledConnMode: store=false 且无可复用会话连接时的建连策略(strict/adaptive/off)
|
||||
// - strict: 强制新建连接(隔离优先)
|
||||
// - adaptive: 仅在高风险失败后强制新建连接(性能与隔离折中)
|
||||
// - off: 不强制新建连接(复用优先)
|
||||
StoreDisabledConnMode string `mapstructure:"store_disabled_conn_mode"`
|
||||
// StoreDisabledForceNewConn: store=false 且无可复用粘连连接时是否强制新建连接(默认 true,保障会话隔离)
|
||||
// 兼容旧配置;当 StoreDisabledConnMode 为空时才生效。
|
||||
StoreDisabledForceNewConn bool `mapstructure:"store_disabled_force_new_conn"`
|
||||
// PrewarmGenerateEnabled: 是否启用 WSv2 generate=false 预热(默认 false)
|
||||
PrewarmGenerateEnabled bool `mapstructure:"prewarm_generate_enabled"`
|
||||
|
||||
// Feature 开关:v2 优先于 v1
|
||||
ResponsesWebsockets bool `mapstructure:"responses_websockets"`
|
||||
ResponsesWebsocketsV2 bool `mapstructure:"responses_websockets_v2"`
|
||||
|
||||
// 连接池参数
|
||||
MaxConnsPerAccount int `mapstructure:"max_conns_per_account"`
|
||||
MinIdlePerAccount int `mapstructure:"min_idle_per_account"`
|
||||
MaxIdlePerAccount int `mapstructure:"max_idle_per_account"`
|
||||
// DynamicMaxConnsByAccountConcurrencyEnabled: 是否按账号并发动态计算连接池上限
|
||||
DynamicMaxConnsByAccountConcurrencyEnabled bool `mapstructure:"dynamic_max_conns_by_account_concurrency_enabled"`
|
||||
// OAuthMaxConnsFactor: OAuth 账号连接池系数(effective=ceil(concurrency*factor))
|
||||
OAuthMaxConnsFactor float64 `mapstructure:"oauth_max_conns_factor"`
|
||||
// APIKeyMaxConnsFactor: API Key 账号连接池系数(effective=ceil(concurrency*factor))
|
||||
APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"`
|
||||
DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
|
||||
ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
|
||||
WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
|
||||
PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"`
|
||||
QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"`
|
||||
// EventFlushBatchSize: WS 流式写出批量 flush 阈值(事件条数)
|
||||
EventFlushBatchSize int `mapstructure:"event_flush_batch_size"`
|
||||
// EventFlushIntervalMS: WS 流式写出最大等待时间(毫秒);0 表示仅按 batch 触发
|
||||
EventFlushIntervalMS int `mapstructure:"event_flush_interval_ms"`
|
||||
// PrewarmCooldownMS: 连接池预热触发冷却时间(毫秒)
|
||||
PrewarmCooldownMS int `mapstructure:"prewarm_cooldown_ms"`
|
||||
// FallbackCooldownSeconds: WS 回退冷却窗口,避免 WS/HTTP 抖动;0 表示关闭冷却
|
||||
FallbackCooldownSeconds int `mapstructure:"fallback_cooldown_seconds"`
|
||||
// RetryBackoffInitialMS: WS 重试初始退避(毫秒);<=0 表示关闭退避
|
||||
RetryBackoffInitialMS int `mapstructure:"retry_backoff_initial_ms"`
|
||||
// RetryBackoffMaxMS: WS 重试最大退避(毫秒)
|
||||
RetryBackoffMaxMS int `mapstructure:"retry_backoff_max_ms"`
|
||||
// RetryJitterRatio: WS 重试退避抖动比例(0-1)
|
||||
RetryJitterRatio float64 `mapstructure:"retry_jitter_ratio"`
|
||||
// RetryTotalBudgetMS: WS 单次请求重试总预算(毫秒);0 表示关闭预算限制
|
||||
RetryTotalBudgetMS int `mapstructure:"retry_total_budget_ms"`
|
||||
// PayloadLogSampleRate: payload_schema 日志采样率(0-1)
|
||||
PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"`
|
||||
|
||||
// 账号调度与粘连参数
|
||||
LBTopK int `mapstructure:"lb_top_k"`
|
||||
// StickySessionTTLSeconds: session_hash -> account_id 粘连 TTL
|
||||
StickySessionTTLSeconds int `mapstructure:"sticky_session_ttl_seconds"`
|
||||
// SessionHashReadOldFallback: 会话哈希迁移期是否允许“新 key 未命中时回退读旧 SHA-256 key”
|
||||
SessionHashReadOldFallback bool `mapstructure:"session_hash_read_old_fallback"`
|
||||
// SessionHashDualWriteOld: 会话哈希迁移期是否双写旧 SHA-256 key(短 TTL)
|
||||
SessionHashDualWriteOld bool `mapstructure:"session_hash_dual_write_old"`
|
||||
// MetadataBridgeEnabled: RequestMetadata 迁移期是否保留旧 ctxkey.* 兼容桥接
|
||||
MetadataBridgeEnabled bool `mapstructure:"metadata_bridge_enabled"`
|
||||
// StickyResponseIDTTLSeconds: response_id -> account_id 粘连 TTL
|
||||
StickyResponseIDTTLSeconds int `mapstructure:"sticky_response_id_ttl_seconds"`
|
||||
// StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退)
|
||||
StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"`
|
||||
|
||||
SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"`
|
||||
}
|
||||
|
||||
// GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重。
|
||||
type GatewayOpenAIWSSchedulerScoreWeights struct {
|
||||
Priority float64 `mapstructure:"priority"`
|
||||
Load float64 `mapstructure:"load"`
|
||||
Queue float64 `mapstructure:"queue"`
|
||||
ErrorRate float64 `mapstructure:"error_rate"`
|
||||
TTFT float64 `mapstructure:"ttft"`
|
||||
}
|
||||
|
||||
// GatewayUsageRecordConfig 使用量记录异步队列配置
|
||||
type GatewayUsageRecordConfig struct {
|
||||
// WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限)
|
||||
@@ -886,6 +983,12 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
|
||||
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
|
||||
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
|
||||
|
||||
// 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。
|
||||
// 新键未配置(<=0)时回退旧键;新键优先。
|
||||
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
|
||||
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
|
||||
}
|
||||
|
||||
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
|
||||
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
|
||||
if cfg.Totp.EncryptionKey == "" {
|
||||
@@ -945,7 +1048,7 @@ func setDefaults() {
|
||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||
viper.SetDefault("server.trusted_proxies", []string{})
|
||||
viper.SetDefault("server.max_request_body_size", int64(100*1024*1024))
|
||||
viper.SetDefault("server.max_request_body_size", int64(256*1024*1024))
|
||||
// H2C 默认配置
|
||||
viper.SetDefault("server.h2c.enabled", false)
|
||||
viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流
|
||||
@@ -1088,9 +1191,9 @@ func setDefaults() {
|
||||
// RateLimit
|
||||
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
|
||||
|
||||
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据的配置
|
||||
viper.SetDefault("pricing.remote_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.hash_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.sha256")
|
||||
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移)
|
||||
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256")
|
||||
viper.SetDefault("pricing.data_dir", "./data")
|
||||
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
|
||||
viper.SetDefault("pricing.update_interval_hours", 24)
|
||||
@@ -1157,9 +1260,55 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.max_account_switches_gemini", 3)
|
||||
viper.SetDefault("gateway.force_codex_cli", false)
|
||||
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
||||
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
|
||||
viper.SetDefault("gateway.openai_ws.enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
|
||||
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
|
||||
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.force_http", false)
|
||||
viper.SetDefault("gateway.openai_ws.allow_store_recovery", false)
|
||||
viper.SetDefault("gateway.openai_ws.ingress_previous_response_recovery_enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.store_disabled_conn_mode", "strict")
|
||||
viper.SetDefault("gateway.openai_ws.store_disabled_force_new_conn", true)
|
||||
viper.SetDefault("gateway.openai_ws.prewarm_generate_enabled", false)
|
||||
viper.SetDefault("gateway.openai_ws.responses_websockets", false)
|
||||
viper.SetDefault("gateway.openai_ws.responses_websockets_v2", true)
|
||||
viper.SetDefault("gateway.openai_ws.max_conns_per_account", 128)
|
||||
viper.SetDefault("gateway.openai_ws.min_idle_per_account", 4)
|
||||
viper.SetDefault("gateway.openai_ws.max_idle_per_account", 12)
|
||||
viper.SetDefault("gateway.openai_ws.dynamic_max_conns_by_account_concurrency_enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.oauth_max_conns_factor", 1.0)
|
||||
viper.SetDefault("gateway.openai_ws.apikey_max_conns_factor", 1.0)
|
||||
viper.SetDefault("gateway.openai_ws.dial_timeout_seconds", 10)
|
||||
viper.SetDefault("gateway.openai_ws.read_timeout_seconds", 900)
|
||||
viper.SetDefault("gateway.openai_ws.write_timeout_seconds", 120)
|
||||
viper.SetDefault("gateway.openai_ws.pool_target_utilization", 0.7)
|
||||
viper.SetDefault("gateway.openai_ws.queue_limit_per_conn", 64)
|
||||
viper.SetDefault("gateway.openai_ws.event_flush_batch_size", 1)
|
||||
viper.SetDefault("gateway.openai_ws.event_flush_interval_ms", 10)
|
||||
viper.SetDefault("gateway.openai_ws.prewarm_cooldown_ms", 300)
|
||||
viper.SetDefault("gateway.openai_ws.fallback_cooldown_seconds", 30)
|
||||
viper.SetDefault("gateway.openai_ws.retry_backoff_initial_ms", 120)
|
||||
viper.SetDefault("gateway.openai_ws.retry_backoff_max_ms", 2000)
|
||||
viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2)
|
||||
viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000)
|
||||
viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2)
|
||||
viper.SetDefault("gateway.openai_ws.lb_top_k", 7)
|
||||
viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600)
|
||||
viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true)
|
||||
viper.SetDefault("gateway.openai_ws.session_hash_dual_write_old", true)
|
||||
viper.SetDefault("gateway.openai_ws.metadata_bridge_enabled", true)
|
||||
viper.SetDefault("gateway.openai_ws.sticky_response_id_ttl_seconds", 3600)
|
||||
viper.SetDefault("gateway.openai_ws.sticky_previous_response_ttl_seconds", 3600)
|
||||
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.priority", 1.0)
|
||||
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.load", 1.0)
|
||||
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
|
||||
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
|
||||
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
|
||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||
viper.SetDefault("gateway.antigravity_extra_retries", 10)
|
||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
|
||||
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
||||
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
||||
viper.SetDefault("gateway.gemini_debug_response_headers", false)
|
||||
@@ -1747,6 +1896,118 @@ func (c *Config) Validate() error {
|
||||
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
|
||||
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
|
||||
}
|
||||
// 兼容旧键 sticky_previous_response_ttl_seconds
|
||||
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
|
||||
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
|
||||
}
|
||||
if c.Gateway.OpenAIWS.MaxConnsPerAccount <= 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.max_conns_per_account must be positive")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.MinIdlePerAccount < 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be non-negative")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.MaxIdlePerAccount < 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be non-negative")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.MinIdlePerAccount > c.Gateway.OpenAIWS.MaxIdlePerAccount {
|
||||
return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.MaxIdlePerAccount > c.Gateway.OpenAIWS.MaxConnsPerAccount {
|
||||
return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.OAuthMaxConnsFactor <= 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.oauth_max_conns_factor must be positive")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.APIKeyMaxConnsFactor <= 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.apikey_max_conns_factor must be positive")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.DialTimeoutSeconds <= 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.dial_timeout_seconds must be positive")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.ReadTimeoutSeconds <= 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.read_timeout_seconds must be positive")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.WriteTimeoutSeconds <= 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.write_timeout_seconds must be positive")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.PoolTargetUtilization <= 0 || c.Gateway.OpenAIWS.PoolTargetUtilization > 1 {
|
||||
return fmt.Errorf("gateway.openai_ws.pool_target_utilization must be within (0,1]")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.QueueLimitPerConn <= 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.queue_limit_per_conn must be positive")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.EventFlushBatchSize <= 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.event_flush_batch_size must be positive")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.EventFlushIntervalMS < 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.event_flush_interval_ms must be non-negative")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.PrewarmCooldownMS < 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.prewarm_cooldown_ms must be non-negative")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.FallbackCooldownSeconds < 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.fallback_cooldown_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.RetryBackoffInitialMS < 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.retry_backoff_initial_ms must be non-negative")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.RetryBackoffMaxMS < 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be non-negative")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.RetryBackoffInitialMS > 0 && c.Gateway.OpenAIWS.RetryBackoffMaxMS > 0 &&
|
||||
c.Gateway.OpenAIWS.RetryBackoffMaxMS < c.Gateway.OpenAIWS.RetryBackoffInitialMS {
|
||||
return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be >= retry_backoff_initial_ms")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.RetryJitterRatio < 0 || c.Gateway.OpenAIWS.RetryJitterRatio > 1 {
|
||||
return fmt.Errorf("gateway.openai_ws.retry_jitter_ratio must be within [0,1]")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.RetryTotalBudgetMS < 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.retry_total_budget_ms must be non-negative")
|
||||
}
|
||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
|
||||
switch mode {
|
||||
case "off", "shared", "dedicated":
|
||||
default:
|
||||
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
|
||||
}
|
||||
}
|
||||
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
|
||||
switch mode {
|
||||
case "strict", "adaptive", "off":
|
||||
default:
|
||||
return fmt.Errorf("gateway.openai_ws.store_disabled_conn_mode must be one of strict|adaptive|off")
|
||||
}
|
||||
}
|
||||
if c.Gateway.OpenAIWS.PayloadLogSampleRate < 0 || c.Gateway.OpenAIWS.PayloadLogSampleRate > 1 {
|
||||
return fmt.Errorf("gateway.openai_ws.payload_log_sample_rate must be within [0,1]")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.LBTopK <= 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.lb_top_k must be positive")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.StickySessionTTLSeconds <= 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.sticky_session_ttl_seconds must be positive")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.sticky_response_id_ttl_seconds must be positive")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds < 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.sticky_previous_response_ttl_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority < 0 ||
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 ||
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 ||
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 ||
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative")
|
||||
}
|
||||
weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority +
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load +
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue +
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate +
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT
|
||||
if weightSum <= 0 {
|
||||
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights must not all be zero")
|
||||
}
|
||||
if c.Gateway.MaxLineSize < 0 {
|
||||
return fmt.Errorf("gateway.max_line_size must be non-negative")
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func resetViperWithJWTSecret(t *testing.T) {
|
||||
@@ -75,6 +76,103 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if !cfg.Gateway.OpenAIWS.Enabled {
|
||||
t.Fatalf("Gateway.OpenAIWS.Enabled = false, want true")
|
||||
}
|
||||
if !cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 {
|
||||
t.Fatalf("Gateway.OpenAIWS.ResponsesWebsocketsV2 = false, want true")
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.ResponsesWebsockets {
|
||||
t.Fatalf("Gateway.OpenAIWS.ResponsesWebsockets = true, want false")
|
||||
}
|
||||
if !cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled {
|
||||
t.Fatalf("Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false, want true")
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor != 1.0 {
|
||||
t.Fatalf("Gateway.OpenAIWS.OAuthMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor)
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor != 1.0 {
|
||||
t.Fatalf("Gateway.OpenAIWS.APIKeyMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor)
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.StickySessionTTLSeconds != 3600 {
|
||||
t.Fatalf("Gateway.OpenAIWS.StickySessionTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickySessionTTLSeconds)
|
||||
}
|
||||
if !cfg.Gateway.OpenAIWS.SessionHashReadOldFallback {
|
||||
t.Fatalf("Gateway.OpenAIWS.SessionHashReadOldFallback = false, want true")
|
||||
}
|
||||
if !cfg.Gateway.OpenAIWS.SessionHashDualWriteOld {
|
||||
t.Fatalf("Gateway.OpenAIWS.SessionHashDualWriteOld = false, want true")
|
||||
}
|
||||
if !cfg.Gateway.OpenAIWS.MetadataBridgeEnabled {
|
||||
t.Fatalf("Gateway.OpenAIWS.MetadataBridgeEnabled = false, want true")
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 3600 {
|
||||
t.Fatalf("Gateway.OpenAIWS.StickyResponseIDTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.FallbackCooldownSeconds != 30 {
|
||||
t.Fatalf("Gateway.OpenAIWS.FallbackCooldownSeconds = %d, want 30", cfg.Gateway.OpenAIWS.FallbackCooldownSeconds)
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.EventFlushBatchSize != 1 {
|
||||
t.Fatalf("Gateway.OpenAIWS.EventFlushBatchSize = %d, want 1", cfg.Gateway.OpenAIWS.EventFlushBatchSize)
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.EventFlushIntervalMS != 10 {
|
||||
t.Fatalf("Gateway.OpenAIWS.EventFlushIntervalMS = %d, want 10", cfg.Gateway.OpenAIWS.EventFlushIntervalMS)
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.PrewarmCooldownMS != 300 {
|
||||
t.Fatalf("Gateway.OpenAIWS.PrewarmCooldownMS = %d, want 300", cfg.Gateway.OpenAIWS.PrewarmCooldownMS)
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.RetryBackoffInitialMS != 120 {
|
||||
t.Fatalf("Gateway.OpenAIWS.RetryBackoffInitialMS = %d, want 120", cfg.Gateway.OpenAIWS.RetryBackoffInitialMS)
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.RetryBackoffMaxMS != 2000 {
|
||||
t.Fatalf("Gateway.OpenAIWS.RetryBackoffMaxMS = %d, want 2000", cfg.Gateway.OpenAIWS.RetryBackoffMaxMS)
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.RetryJitterRatio != 0.2 {
|
||||
t.Fatalf("Gateway.OpenAIWS.RetryJitterRatio = %v, want 0.2", cfg.Gateway.OpenAIWS.RetryJitterRatio)
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.RetryTotalBudgetMS != 5000 {
|
||||
t.Fatalf("Gateway.OpenAIWS.RetryTotalBudgetMS = %d, want 5000", cfg.Gateway.OpenAIWS.RetryTotalBudgetMS)
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.PayloadLogSampleRate != 0.2 {
|
||||
t.Fatalf("Gateway.OpenAIWS.PayloadLogSampleRate = %v, want 0.2", cfg.Gateway.OpenAIWS.PayloadLogSampleRate)
|
||||
}
|
||||
if !cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn {
|
||||
t.Fatalf("Gateway.OpenAIWS.StoreDisabledForceNewConn = false, want true")
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.StoreDisabledConnMode != "strict" {
|
||||
t.Fatalf("Gateway.OpenAIWS.StoreDisabledConnMode = %q, want %q", cfg.Gateway.OpenAIWS.StoreDisabledConnMode, "strict")
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
|
||||
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
|
||||
}
|
||||
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
|
||||
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadOpenAIWSStickyTTLCompatibility(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
t.Setenv("GATEWAY_OPENAI_WS_STICKY_RESPONSE_ID_TTL_SECONDS", "0")
|
||||
t.Setenv("GATEWAY_OPENAI_WS_STICKY_PREVIOUS_RESPONSE_TTL_SECONDS", "7200")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 7200 {
|
||||
t.Fatalf("StickyResponseIDTTLSeconds = %d, want 7200", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultIdempotencyConfig(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
@@ -993,6 +1091,16 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 },
|
||||
wantErr: "gateway.stream_keepalive_interval",
|
||||
},
|
||||
{
|
||||
name: "gateway openai ws oauth max conns factor",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0 },
|
||||
wantErr: "gateway.openai_ws.oauth_max_conns_factor",
|
||||
},
|
||||
{
|
||||
name: "gateway openai ws apikey max conns factor",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0 },
|
||||
wantErr: "gateway.openai_ws.apikey_max_conns_factor",
|
||||
},
|
||||
{
|
||||
name: "gateway stream data interval range",
|
||||
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },
|
||||
@@ -1174,6 +1282,165 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfig_OpenAIWSRules(t *testing.T) {
|
||||
buildValid := func(t *testing.T) *Config {
|
||||
t.Helper()
|
||||
resetViperWithJWTSecret(t)
|
||||
cfg, err := Load()
|
||||
require.NoError(t, err)
|
||||
return cfg
|
||||
}
|
||||
|
||||
t.Run("sticky response id ttl 兼容旧键回填", func(t *testing.T) {
|
||||
cfg := buildValid(t)
|
||||
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0
|
||||
cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 7200
|
||||
|
||||
require.NoError(t, cfg.Validate())
|
||||
require.Equal(t, 7200, cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
|
||||
})
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
mutate func(*Config)
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "max_conns_per_account 必须为正数",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxConnsPerAccount = 0 },
|
||||
wantErr: "gateway.openai_ws.max_conns_per_account",
|
||||
},
|
||||
{
|
||||
name: "min_idle_per_account 不能为负数",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.MinIdlePerAccount = -1 },
|
||||
wantErr: "gateway.openai_ws.min_idle_per_account",
|
||||
},
|
||||
{
|
||||
name: "max_idle_per_account 不能为负数",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxIdlePerAccount = -1 },
|
||||
wantErr: "gateway.openai_ws.max_idle_per_account",
|
||||
},
|
||||
{
|
||||
name: "min_idle_per_account 不能大于 max_idle_per_account",
|
||||
mutate: func(c *Config) {
|
||||
c.Gateway.OpenAIWS.MinIdlePerAccount = 3
|
||||
c.Gateway.OpenAIWS.MaxIdlePerAccount = 2
|
||||
},
|
||||
wantErr: "gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account",
|
||||
},
|
||||
{
|
||||
name: "max_idle_per_account 不能大于 max_conns_per_account",
|
||||
mutate: func(c *Config) {
|
||||
c.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
||||
c.Gateway.OpenAIWS.MinIdlePerAccount = 1
|
||||
c.Gateway.OpenAIWS.MaxIdlePerAccount = 3
|
||||
},
|
||||
wantErr: "gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account",
|
||||
},
|
||||
{
|
||||
name: "dial_timeout_seconds 必须为正数",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.DialTimeoutSeconds = 0 },
|
||||
wantErr: "gateway.openai_ws.dial_timeout_seconds",
|
||||
},
|
||||
{
|
||||
name: "read_timeout_seconds 必须为正数",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.ReadTimeoutSeconds = 0 },
|
||||
wantErr: "gateway.openai_ws.read_timeout_seconds",
|
||||
},
|
||||
{
|
||||
name: "write_timeout_seconds 必须为正数",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.WriteTimeoutSeconds = 0 },
|
||||
wantErr: "gateway.openai_ws.write_timeout_seconds",
|
||||
},
|
||||
{
|
||||
name: "pool_target_utilization 必须在 (0,1]",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.PoolTargetUtilization = 0 },
|
||||
wantErr: "gateway.openai_ws.pool_target_utilization",
|
||||
},
|
||||
{
|
||||
name: "queue_limit_per_conn 必须为正数",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.QueueLimitPerConn = 0 },
|
||||
wantErr: "gateway.openai_ws.queue_limit_per_conn",
|
||||
},
|
||||
{
|
||||
name: "fallback_cooldown_seconds 不能为负数",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.FallbackCooldownSeconds = -1 },
|
||||
wantErr: "gateway.openai_ws.fallback_cooldown_seconds",
|
||||
},
|
||||
{
|
||||
name: "store_disabled_conn_mode 必须为 strict|adaptive|off",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.StoreDisabledConnMode = "invalid" },
|
||||
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
|
||||
},
|
||||
{
|
||||
name: "ingress_mode_default 必须为 off|shared|dedicated",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
|
||||
wantErr: "gateway.openai_ws.ingress_mode_default",
|
||||
},
|
||||
{
|
||||
name: "payload_log_sample_rate 必须在 [0,1] 范围内",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.PayloadLogSampleRate = 1.2 },
|
||||
wantErr: "gateway.openai_ws.payload_log_sample_rate",
|
||||
},
|
||||
{
|
||||
name: "retry_total_budget_ms 不能为负数",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.RetryTotalBudgetMS = -1 },
|
||||
wantErr: "gateway.openai_ws.retry_total_budget_ms",
|
||||
},
|
||||
{
|
||||
name: "lb_top_k 必须为正数",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.LBTopK = 0 },
|
||||
wantErr: "gateway.openai_ws.lb_top_k",
|
||||
},
|
||||
{
|
||||
name: "sticky_session_ttl_seconds 必须为正数",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.StickySessionTTLSeconds = 0 },
|
||||
wantErr: "gateway.openai_ws.sticky_session_ttl_seconds",
|
||||
},
|
||||
{
|
||||
name: "sticky_response_id_ttl_seconds 必须为正数",
|
||||
mutate: func(c *Config) {
|
||||
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0
|
||||
c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 0
|
||||
},
|
||||
wantErr: "gateway.openai_ws.sticky_response_id_ttl_seconds",
|
||||
},
|
||||
{
|
||||
name: "sticky_previous_response_ttl_seconds 不能为负数",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = -1 },
|
||||
wantErr: "gateway.openai_ws.sticky_previous_response_ttl_seconds",
|
||||
},
|
||||
{
|
||||
name: "scheduler_score_weights 不能为负数",
|
||||
mutate: func(c *Config) { c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = -0.1 },
|
||||
wantErr: "gateway.openai_ws.scheduler_score_weights.* must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "scheduler_score_weights 不能全为 0",
|
||||
mutate: func(c *Config) {
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0
|
||||
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0
|
||||
},
|
||||
wantErr: "gateway.openai_ws.scheduler_score_weights must not all be zero",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cfg := buildValid(t)
|
||||
tc.mutate(cfg)
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tc.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
cfg, err := Load()
|
||||
|
||||
@@ -104,6 +104,9 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||
// Gemini 3.1 image preview 映射
|
||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||
// Gemini 3 image 兼容映射(向 3.1 image 迁移)
|
||||
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||
"gemini-3-pro-image-preview": "gemini-3.1-flash-image",
|
||||
// 其他官方模型
|
||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||
|
||||
24
backend/internal/domain/constants_test.go
Normal file
24
backend/internal/domain/constants_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package domain
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string]string{
|
||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||
"gemini-3-pro-image-preview": "gemini-3.1-flash-image",
|
||||
}
|
||||
|
||||
for from, want := range cases {
|
||||
got, ok := DefaultAntigravityModelMapping[from]
|
||||
if !ok {
|
||||
t.Fatalf("expected mapping for %q to exist", from)
|
||||
}
|
||||
if got != want {
|
||||
t.Fatalf("unexpected mapping for %q: got %q want %q", from, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1337,6 +1337,34 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// BatchTodayStatsRequest 批量今日统计请求体。
|
||||
type BatchTodayStatsRequest struct {
|
||||
AccountIDs []int64 `json:"account_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// GetBatchTodayStats 批量获取多个账号的今日统计。
|
||||
// POST /api/v1/admin/accounts/today-stats/batch
|
||||
func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) {
|
||||
var req BatchTodayStatsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.AccountIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
}
|
||||
|
||||
// SetSchedulableRequest represents the request body for setting schedulable status
|
||||
type SetSchedulableRequest struct {
|
||||
Schedulable bool `json:"schedulable"`
|
||||
|
||||
@@ -3,6 +3,7 @@ package admin
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -186,7 +187,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
|
||||
|
||||
// GetUsageTrend handles getting usage trend data
|
||||
// GET /api/v1/admin/dashboard/trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, request_type, stream, billing_type
|
||||
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
@@ -194,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
var model string
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
var billingType *int8
|
||||
|
||||
@@ -220,9 +222,20 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
if modelStr := c.Query("model"); modelStr != "" {
|
||||
model = modelStr
|
||||
}
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
||||
stream = &streamVal
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
return
|
||||
}
|
||||
}
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
@@ -235,7 +248,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
@@ -251,12 +264,13 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
|
||||
// GetModelStats handles getting model usage statistics
|
||||
// GET /api/v1/admin/dashboard/models
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
|
||||
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
var billingType *int8
|
||||
|
||||
@@ -280,9 +294,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
groupID = id
|
||||
}
|
||||
}
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
||||
stream = &streamVal
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
return
|
||||
}
|
||||
}
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
@@ -295,7 +320,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dashboardUsageRepoCapture struct {
|
||||
service.UsageLogRepository
|
||||
trendRequestType *int16
|
||||
trendStream *bool
|
||||
modelRequestType *int16
|
||||
modelStream *bool
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, error) {
|
||||
s.trendRequestType = requestType
|
||||
s.trendStream = stream
|
||||
return []usagestats.TrendDataPoint{}, nil
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.ModelStat, error) {
|
||||
s.modelRequestType = requestType
|
||||
s.modelStream = stream
|
||||
return []usagestats.ModelStat{}, nil
|
||||
}
|
||||
|
||||
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestDashboardTrendRequestTypePriority(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=ws_v2&stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, repo.trendRequestType)
|
||||
require.Equal(t, int16(service.RequestTypeWSV2), *repo.trendRequestType)
|
||||
require.Nil(t, repo.trendStream)
|
||||
}
|
||||
|
||||
func TestDashboardTrendInvalidRequestType(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardTrendInvalidStream(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardModelStatsRequestTypePriority(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=sync&stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, repo.modelRequestType)
|
||||
require.Equal(t, int16(service.RequestTypeSync), *repo.modelRequestType)
|
||||
require.Nil(t, repo.modelStream)
|
||||
}
|
||||
|
||||
func TestDashboardModelStatsInvalidRequestType(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardModelStatsInvalidStream(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
523
backend/internal/handler/admin/data_management_handler.go
Normal file
523
backend/internal/handler/admin/data_management_handler.go
Normal file
@@ -0,0 +1,523 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DataManagementHandler struct {
|
||||
dataManagementService *service.DataManagementService
|
||||
}
|
||||
|
||||
func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler {
|
||||
return &DataManagementHandler{dataManagementService: dataManagementService}
|
||||
}
|
||||
|
||||
type TestS3ConnectionRequest struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region" binding:"required"`
|
||||
Bucket string `json:"bucket" binding:"required"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
UseSSL bool `json:"use_ssl"`
|
||||
}
|
||||
|
||||
type CreateBackupJobRequest struct {
|
||||
BackupType string `json:"backup_type" binding:"required,oneof=postgres redis full"`
|
||||
UploadToS3 bool `json:"upload_to_s3"`
|
||||
S3ProfileID string `json:"s3_profile_id"`
|
||||
PostgresID string `json:"postgres_profile_id"`
|
||||
RedisID string `json:"redis_profile_id"`
|
||||
IdempotencyKey string `json:"idempotency_key"`
|
||||
}
|
||||
|
||||
type CreateSourceProfileRequest struct {
|
||||
ProfileID string `json:"profile_id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
|
||||
SetActive bool `json:"set_active"`
|
||||
}
|
||||
|
||||
type UpdateSourceProfileRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
|
||||
}
|
||||
|
||||
type CreateS3ProfileRequest struct {
|
||||
ProfileID string `json:"profile_id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
UseSSL bool `json:"use_ssl"`
|
||||
SetActive bool `json:"set_active"`
|
||||
}
|
||||
|
||||
type UpdateS3ProfileRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
UseSSL bool `json:"use_ssl"`
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) GetAgentHealth(c *gin.Context) {
|
||||
health := h.getAgentHealth(c)
|
||||
payload := gin.H{
|
||||
"enabled": health.Enabled,
|
||||
"reason": health.Reason,
|
||||
"socket_path": health.SocketPath,
|
||||
}
|
||||
if health.Agent != nil {
|
||||
payload["agent"] = gin.H{
|
||||
"status": health.Agent.Status,
|
||||
"version": health.Agent.Version,
|
||||
"uptime_seconds": health.Agent.UptimeSeconds,
|
||||
}
|
||||
}
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) GetConfig(c *gin.Context) {
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
cfg, err := h.dataManagementService.GetConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) UpdateConfig(c *gin.Context) {
|
||||
var req service.DataManagementConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) TestS3(c *gin.Context) {
|
||||
var req TestS3ConnectionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{
|
||||
Enabled: true,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
UseSSL: req.UseSSL,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"ok": result.OK, "message": result.Message})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) {
|
||||
var req CreateBackupJobRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
req.IdempotencyKey = normalizeBackupIdempotencyKey(c.GetHeader("X-Idempotency-Key"), req.IdempotencyKey)
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
|
||||
triggeredBy := "admin:unknown"
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
|
||||
triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10)
|
||||
}
|
||||
job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{
|
||||
BackupType: req.BackupType,
|
||||
UploadToS3: req.UploadToS3,
|
||||
S3ProfileID: req.S3ProfileID,
|
||||
PostgresID: req.PostgresID,
|
||||
RedisID: req.RedisID,
|
||||
TriggeredBy: triggeredBy,
|
||||
IdempotencyKey: req.IdempotencyKey,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) {
|
||||
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||
if sourceType == "" {
|
||||
response.BadRequest(c, "Invalid source_type")
|
||||
return
|
||||
}
|
||||
if sourceType != "postgres" && sourceType != "redis" {
|
||||
response.BadRequest(c, "source_type must be postgres or redis")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"items": items})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) {
|
||||
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||
if sourceType != "postgres" && sourceType != "redis" {
|
||||
response.BadRequest(c, "source_type must be postgres or redis")
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateSourceProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{
|
||||
SourceType: sourceType,
|
||||
ProfileID: req.ProfileID,
|
||||
Name: req.Name,
|
||||
Config: req.Config,
|
||||
SetActive: req.SetActive,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) {
|
||||
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||
if sourceType != "postgres" && sourceType != "redis" {
|
||||
response.BadRequest(c, "source_type must be postgres or redis")
|
||||
return
|
||||
}
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateSourceProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{
|
||||
SourceType: sourceType,
|
||||
ProfileID: profileID,
|
||||
Name: req.Name,
|
||||
Config: req.Config,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) {
|
||||
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||
if sourceType != "postgres" && sourceType != "redis" {
|
||||
response.BadRequest(c, "source_type must be postgres or redis")
|
||||
return
|
||||
}
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) {
|
||||
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||
if sourceType != "postgres" && sourceType != "redis" {
|
||||
response.BadRequest(c, "source_type must be postgres or redis")
|
||||
return
|
||||
}
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) {
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
|
||||
items, err := h.dataManagementService.ListS3Profiles(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"items": items})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) {
|
||||
var req CreateS3ProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{
|
||||
ProfileID: req.ProfileID,
|
||||
Name: req.Name,
|
||||
SetActive: req.SetActive,
|
||||
S3: service.DataManagementS3Config{
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
UseSSL: req.UseSSL,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) {
|
||||
var req UpdateS3ProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{
|
||||
ProfileID: profileID,
|
||||
Name: req.Name,
|
||||
S3: service.DataManagementS3Config{
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
UseSSL: req.UseSSL,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) {
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
|
||||
pageSize := int32(20)
|
||||
if raw := strings.TrimSpace(c.Query("page_size")); raw != "" {
|
||||
v, err := strconv.Atoi(raw)
|
||||
if err != nil || v <= 0 {
|
||||
response.BadRequest(c, "Invalid page_size")
|
||||
return
|
||||
}
|
||||
pageSize = int32(v)
|
||||
}
|
||||
|
||||
result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{
|
||||
PageSize: pageSize,
|
||||
PageToken: c.Query("page_token"),
|
||||
Status: c.Query("status"),
|
||||
BackupType: c.Query("backup_type"),
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) GetBackupJob(c *gin.Context) {
|
||||
jobID := strings.TrimSpace(c.Param("job_id"))
|
||||
if jobID == "" {
|
||||
response.BadRequest(c, "Invalid backup job ID")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, job)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool {
|
||||
if h.dataManagementService == nil {
|
||||
err := infraerrors.ServiceUnavailable(
|
||||
service.DataManagementAgentUnavailableReason,
|
||||
"data management agent service is not configured",
|
||||
).WithMetadata(map[string]string{"socket_path": service.DefaultDataManagementAgentSocketPath})
|
||||
response.ErrorFrom(c, err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth {
|
||||
if h.dataManagementService == nil {
|
||||
return service.DataManagementAgentHealth{
|
||||
Enabled: false,
|
||||
Reason: service.DataManagementAgentUnavailableReason,
|
||||
SocketPath: service.DefaultDataManagementAgentSocketPath,
|
||||
}
|
||||
}
|
||||
return h.dataManagementService.GetAgentHealth(c.Request.Context())
|
||||
}
|
||||
|
||||
func normalizeBackupIdempotencyKey(headerValue, bodyValue string) string {
|
||||
headerKey := strings.TrimSpace(headerValue)
|
||||
if headerKey != "" {
|
||||
return headerKey
|
||||
}
|
||||
return strings.TrimSpace(bodyValue)
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type apiEnvelope struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
func TestDataManagementHandler_AgentHealthAlways200(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
|
||||
h := NewDataManagementHandler(svc)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/api/v1/admin/data-management/agent/health", h.GetAgentHealth)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/agent/health", nil)
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var envelope apiEnvelope
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
|
||||
require.Equal(t, 0, envelope.Code)
|
||||
|
||||
var data struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Reason string `json:"reason"`
|
||||
SocketPath string `json:"socket_path"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(envelope.Data, &data))
|
||||
require.False(t, data.Enabled)
|
||||
require.Equal(t, service.DataManagementDeprecatedReason, data.Reason)
|
||||
require.Equal(t, svc.SocketPath(), data.SocketPath)
|
||||
}
|
||||
|
||||
func TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
|
||||
h := NewDataManagementHandler(svc)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/api/v1/admin/data-management/config", h.GetConfig)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/config", nil)
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
|
||||
var envelope apiEnvelope
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
|
||||
require.Equal(t, http.StatusServiceUnavailable, envelope.Code)
|
||||
require.Equal(t, service.DataManagementDeprecatedReason, envelope.Reason)
|
||||
}
|
||||
|
||||
func TestNormalizeBackupIdempotencyKey(t *testing.T) {
|
||||
require.Equal(t, "from-header", normalizeBackupIdempotencyKey("from-header", "from-body"))
|
||||
require.Equal(t, "from-body", normalizeBackupIdempotencyKey(" ", " from-body "))
|
||||
require.Equal(t, "", normalizeBackupIdempotencyKey("", ""))
|
||||
}
|
||||
@@ -51,6 +51,8 @@ type CreateGroupRequest struct {
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
// 从指定分组复制账号(创建后自动绑定)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -84,6 +86,8 @@ type UpdateGroupRequest struct {
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -198,6 +202,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -248,6 +253,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -47,7 +48,12 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
req = OpenAIGenerateAuthURLRequest{}
|
||||
}
|
||||
|
||||
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
|
||||
result, err := h.openaiOAuthService.GenerateAuthURL(
|
||||
c.Request.Context(),
|
||||
req.ProxyID,
|
||||
req.RedirectURI,
|
||||
oauthPlatformFromPath(c),
|
||||
)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -123,7 +129,14 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
|
||||
// 未指定 client_id 时,根据请求路径平台自动设置默认值,避免 repository 层盲猜
|
||||
clientID := strings.TrimSpace(req.ClientID)
|
||||
if clientID == "" {
|
||||
platform := oauthPlatformFromPath(c)
|
||||
clientID, _ = openai.OAuthClientConfigByPlatform(platform)
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, clientID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -62,7 +62,8 @@ const (
|
||||
)
|
||||
|
||||
var wsConnCount atomic.Int32
|
||||
var wsConnCountByIP sync.Map // map[string]*atomic.Int32
|
||||
var wsConnCountByIPMu sync.Mutex
|
||||
var wsConnCountByIP = make(map[string]int32)
|
||||
|
||||
const qpsWSIdleStopDelay = 30 * time.Second
|
||||
|
||||
@@ -389,42 +390,31 @@ func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool {
|
||||
if strings.TrimSpace(clientIP) == "" || limit <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
v, _ := wsConnCountByIP.LoadOrStore(clientIP, &atomic.Int32{})
|
||||
counter, ok := v.(*atomic.Int32)
|
||||
if !ok {
|
||||
wsConnCountByIPMu.Lock()
|
||||
defer wsConnCountByIPMu.Unlock()
|
||||
current := wsConnCountByIP[clientIP]
|
||||
if current >= limit {
|
||||
return false
|
||||
}
|
||||
|
||||
for {
|
||||
current := counter.Load()
|
||||
if current >= limit {
|
||||
return false
|
||||
}
|
||||
if counter.CompareAndSwap(current, current+1) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
wsConnCountByIP[clientIP] = current + 1
|
||||
return true
|
||||
}
|
||||
|
||||
func releaseOpsWSIPSlot(clientIP string) {
|
||||
if strings.TrimSpace(clientIP) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
v, ok := wsConnCountByIP.Load(clientIP)
|
||||
wsConnCountByIPMu.Lock()
|
||||
defer wsConnCountByIPMu.Unlock()
|
||||
current, ok := wsConnCountByIP[clientIP]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
counter, ok := v.(*atomic.Int32)
|
||||
if !ok {
|
||||
if current <= 1 {
|
||||
delete(wsConnCountByIP, clientIP)
|
||||
return
|
||||
}
|
||||
next := counter.Add(-1)
|
||||
if next <= 0 {
|
||||
// Best-effort cleanup; safe even if a new slot was acquired concurrently.
|
||||
wsConnCountByIP.Delete(clientIP)
|
||||
}
|
||||
wsConnCountByIP[clientIP] = current - 1
|
||||
}
|
||||
|
||||
func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -20,15 +21,17 @@ type SettingHandler struct {
|
||||
emailService *service.EmailService
|
||||
turnstileService *service.TurnstileService
|
||||
opsService *service.OpsService
|
||||
soraS3Storage *service.SoraS3Storage
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建系统设置处理器
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler {
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler {
|
||||
return &SettingHandler{
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
opsService: opsService,
|
||||
soraS3Storage: soraS3Storage,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,6 +79,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
@@ -133,6 +137,7 @@ type UpdateSettingsRequest struct {
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
@@ -319,6 +324,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
SoraClientEnabled: req.SoraClientEnabled,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
@@ -400,6 +406,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
HideCcsImportButton: updatedSettings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
|
||||
SoraClientEnabled: updatedSettings.SoraClientEnabled,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
@@ -750,6 +757,384 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings {
|
||||
if settings == nil {
|
||||
return dto.SoraS3Settings{}
|
||||
}
|
||||
return dto.SoraS3Settings{
|
||||
Enabled: settings.Enabled,
|
||||
Endpoint: settings.Endpoint,
|
||||
Region: settings.Region,
|
||||
Bucket: settings.Bucket,
|
||||
AccessKeyID: settings.AccessKeyID,
|
||||
SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured,
|
||||
Prefix: settings.Prefix,
|
||||
ForcePathStyle: settings.ForcePathStyle,
|
||||
CDNURL: settings.CDNURL,
|
||||
DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes,
|
||||
}
|
||||
}
|
||||
|
||||
func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile {
|
||||
return dto.SoraS3Profile{
|
||||
ProfileID: profile.ProfileID,
|
||||
Name: profile.Name,
|
||||
IsActive: profile.IsActive,
|
||||
Enabled: profile.Enabled,
|
||||
Endpoint: profile.Endpoint,
|
||||
Region: profile.Region,
|
||||
Bucket: profile.Bucket,
|
||||
AccessKeyID: profile.AccessKeyID,
|
||||
SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured,
|
||||
Prefix: profile.Prefix,
|
||||
ForcePathStyle: profile.ForcePathStyle,
|
||||
CDNURL: profile.CDNURL,
|
||||
DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes,
|
||||
UpdatedAt: profile.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error {
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(endpoint) == "" {
|
||||
return fmt.Errorf("S3 Endpoint is required when enabled")
|
||||
}
|
||||
if strings.TrimSpace(bucket) == "" {
|
||||
return fmt.Errorf("S3 Bucket is required when enabled")
|
||||
}
|
||||
if strings.TrimSpace(accessKeyID) == "" {
|
||||
return fmt.Errorf("S3 Access Key ID is required when enabled")
|
||||
}
|
||||
if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("S3 Secret Access Key is required when enabled")
|
||||
}
|
||||
|
||||
func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile {
|
||||
for idx := range items {
|
||||
if items[idx].ProfileID == profileID {
|
||||
return &items[idx]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口)
|
||||
// GET /api/v1/admin/settings/sora-s3
|
||||
func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, toSoraS3SettingsDTO(settings))
|
||||
}
|
||||
|
||||
// ListSoraS3Profiles 获取 Sora S3 多配置
|
||||
// GET /api/v1/admin/settings/sora-s3/profiles
|
||||
func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) {
|
||||
result, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
items := make([]dto.SoraS3Profile, 0, len(result.Items))
|
||||
for idx := range result.Items {
|
||||
items = append(items, toSoraS3ProfileDTO(result.Items[idx]))
|
||||
}
|
||||
response.Success(c, dto.ListSoraS3ProfilesResponse{
|
||||
ActiveProfileID: result.ActiveProfileID,
|
||||
Items: items,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口)
|
||||
type UpdateSoraS3SettingsRequest struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
type CreateSoraS3ProfileRequest struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
SetActive bool `json:"set_active"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
type UpdateSoraS3ProfileRequest struct {
|
||||
Name string `json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// CreateSoraS3Profile 创建 Sora S3 配置
|
||||
// POST /api/v1/admin/settings/sora-s3/profiles
|
||||
func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) {
|
||||
var req CreateSoraS3ProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.DefaultStorageQuotaBytes < 0 {
|
||||
req.DefaultStorageQuotaBytes = 0
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
response.BadRequest(c, "Name is required")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.ProfileID) == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{
|
||||
ProfileID: req.ProfileID,
|
||||
Name: req.Name,
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
||||
}, req.SetActive)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, toSoraS3ProfileDTO(*created))
|
||||
}
|
||||
|
||||
// UpdateSoraS3Profile 更新 Sora S3 配置
|
||||
// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id
|
||||
func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateSoraS3ProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.DefaultStorageQuotaBytes < 0 {
|
||||
req.DefaultStorageQuotaBytes = 0
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
response.BadRequest(c, "Name is required")
|
||||
return
|
||||
}
|
||||
|
||||
existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
existing := findSoraS3ProfileByID(existingList.Items, profileID)
|
||||
if existing == nil {
|
||||
response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound)
|
||||
return
|
||||
}
|
||||
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{
|
||||
Name: req.Name,
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
||||
})
|
||||
if updateErr != nil {
|
||||
response.ErrorFrom(c, updateErr)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, toSoraS3ProfileDTO(*updated))
|
||||
}
|
||||
|
||||
// DeleteSoraS3Profile 删除 Sora S3 配置
|
||||
// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id
|
||||
func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
// SetActiveSoraS3Profile 切换激活 Sora S3 配置
|
||||
// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate
|
||||
func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, toSoraS3ProfileDTO(*active))
|
||||
}
|
||||
|
||||
// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口)
|
||||
// PUT /api/v1/admin/settings/sora-s3
|
||||
func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) {
|
||||
var req UpdateSoraS3SettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if req.DefaultStorageQuotaBytes < 0 {
|
||||
req.DefaultStorageQuotaBytes = 0
|
||||
}
|
||||
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
settings := &service.SoraS3Settings{
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
||||
}
|
||||
if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, toSoraS3SettingsDTO(updatedSettings))
|
||||
}
|
||||
|
||||
// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket)
|
||||
// POST /api/v1/admin/settings/sora-s3/test
|
||||
func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
|
||||
if h.soraS3Storage == nil {
|
||||
response.Error(c, 500, "S3 存储服务未初始化")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateSoraS3SettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if !req.Enabled {
|
||||
response.BadRequest(c, "S3 未启用,无法测试连接")
|
||||
return
|
||||
}
|
||||
|
||||
if req.SecretAccessKey == "" {
|
||||
if req.ProfileID != "" {
|
||||
profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
||||
if err == nil {
|
||||
profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID)
|
||||
if profile != nil {
|
||||
req.SecretAccessKey = profile.SecretAccessKey
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.SecretAccessKey == "" {
|
||||
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err == nil {
|
||||
req.SecretAccessKey = existing.SecretAccessKey
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
testCfg := &service.SoraS3Settings{
|
||||
Enabled: true,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
}
|
||||
if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil {
|
||||
response.Error(c, 400, "S3 连接测试失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "S3 连接成功"})
|
||||
}
|
||||
|
||||
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||
type UpdateStreamTimeoutSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@@ -225,6 +225,92 @@ func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskInvalidRequestType(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 88)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-01-02",
|
||||
"timezone": "UTC",
|
||||
"request_type": "invalid",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskRequestTypePriority(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 99)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-01-02",
|
||||
"timezone": "UTC",
|
||||
"request_type": "ws_v2",
|
||||
"stream": false,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.created, 1)
|
||||
created := repo.created[0]
|
||||
require.NotNil(t, created.Filters.RequestType)
|
||||
require.Equal(t, int16(service.RequestTypeWSV2), *created.Filters.RequestType)
|
||||
require.Nil(t, created.Filters.Stream)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskWithLegacyStream(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 99)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-01-02",
|
||||
"timezone": "UTC",
|
||||
"stream": true,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.created, 1)
|
||||
created := repo.created[0]
|
||||
require.Nil(t, created.Filters.RequestType)
|
||||
require.NotNil(t, created.Filters.Stream)
|
||||
require.True(t, *created.Filters.Stream)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
|
||||
@@ -51,6 +51,7 @@ type CreateUsageCleanupTaskRequest struct {
|
||||
AccountID *int64 `json:"account_id"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Model *string `json:"model"`
|
||||
RequestType *string `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
Timezone string `json:"timezone"`
|
||||
@@ -101,8 +102,17 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
|
||||
model := c.Query("model")
|
||||
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
@@ -152,6 +162,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: startTime,
|
||||
@@ -214,8 +225,17 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
|
||||
model := c.Query("model")
|
||||
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
@@ -278,6 +298,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: &startTime,
|
||||
@@ -432,6 +453,19 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
}
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
|
||||
var requestType *int16
|
||||
stream := req.Stream
|
||||
if req.RequestType != nil {
|
||||
parsed, err := service.ParseUsageRequestType(*req.RequestType)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
stream = nil
|
||||
}
|
||||
|
||||
filters := service.UsageCleanupFilters{
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
@@ -440,7 +474,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
AccountID: req.AccountID,
|
||||
GroupID: req.GroupID,
|
||||
Model: req.Model,
|
||||
Stream: req.Stream,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: req.BillingType,
|
||||
}
|
||||
|
||||
@@ -464,9 +499,13 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
if filters.Model != nil {
|
||||
model = *filters.Model
|
||||
}
|
||||
var stream any
|
||||
var streamValue any
|
||||
if filters.Stream != nil {
|
||||
stream = *filters.Stream
|
||||
streamValue = *filters.Stream
|
||||
}
|
||||
var requestTypeName any
|
||||
if filters.RequestType != nil {
|
||||
requestTypeName = service.RequestTypeFromInt16(*filters.RequestType).String()
|
||||
}
|
||||
var billingType any
|
||||
if filters.BillingType != nil {
|
||||
@@ -481,7 +520,7 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
Body: req,
|
||||
}
|
||||
executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v request_type=%v stream=%v billing_type=%v tz=%q",
|
||||
subject.UserID,
|
||||
filters.StartTime.Format(time.RFC3339),
|
||||
filters.EndTime.Format(time.RFC3339),
|
||||
@@ -490,7 +529,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
accountID,
|
||||
groupID,
|
||||
model,
|
||||
stream,
|
||||
requestTypeName,
|
||||
streamValue,
|
||||
billingType,
|
||||
req.Timezone,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type adminUsageRepoCapture struct {
|
||||
service.UsageLogRepository
|
||||
listFilters usagestats.UsageLogFilters
|
||||
statsFilters usagestats.UsageLogFilters
|
||||
}
|
||||
|
||||
func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
s.listFilters = filters
|
||||
return []service.UsageLog{}, &pagination.PaginationResult{
|
||||
Total: 0,
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
Pages: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *adminUsageRepoCapture) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
||||
s.statsFilters = filters
|
||||
return &usagestats.UsageStats{}, nil
|
||||
}
|
||||
|
||||
func newAdminUsageRequestTypeTestRouter(repo *adminUsageRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
usageSvc := service.NewUsageService(repo, nil, nil, nil)
|
||||
handler := NewUsageHandler(usageSvc, nil, nil, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/usage", handler.List)
|
||||
router.GET("/admin/usage/stats", handler.Stats)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAdminUsageListRequestTypePriority(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=ws_v2&stream=false", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, repo.listFilters.RequestType)
|
||||
require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType)
|
||||
require.Nil(t, repo.listFilters.Stream)
|
||||
}
|
||||
|
||||
func TestAdminUsageListInvalidRequestType(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAdminUsageListInvalidStream(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage?stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAdminUsageStatsRequestTypePriority(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=stream&stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, repo.statsFilters.RequestType)
|
||||
require.Equal(t, int16(service.RequestTypeStream), *repo.statsFilters.RequestType)
|
||||
require.Nil(t, repo.statsFilters.Stream)
|
||||
}
|
||||
|
||||
func TestAdminUsageStatsInvalidRequestType(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=oops", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAdminUsageStatsInvalidStream(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?stream=oops", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
@@ -34,13 +34,14 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi
|
||||
|
||||
// CreateUserRequest represents admin create user request
|
||||
type CreateUserRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// UpdateUserRequest represents admin update user request
|
||||
@@ -56,7 +57,8 @@ type UpdateUserRequest struct {
|
||||
AllowedGroups *[]int64 `json:"allowed_groups"`
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||
GroupRates map[int64]*float64 `json:"group_rates"`
|
||||
GroupRates map[int64]*float64 `json:"group_rates"`
|
||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// UpdateBalanceRequest represents balance update request
|
||||
@@ -174,13 +176,14 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -207,15 +210,16 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
|
||||
// 使用指针类型直接传递,nil 表示未提供该字段
|
||||
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
GroupRates: req.GroupRates,
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
GroupRates: req.GroupRates,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
|
||||
@@ -113,9 +113,8 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证 — 始终执行,防止绕过
|
||||
// TODO: 确认前端在提交邮箱验证码注册时也传递了 turnstile_token
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||
// Turnstile 验证(邮箱验证码注册场景避免重复校验一次性 token)
|
||||
if err := h.authService.VerifyTurnstileForRegister(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c), req.VerifyCode); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -59,9 +59,11 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
|
||||
return nil
|
||||
}
|
||||
return &AdminUser{
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
GroupRates: u.GroupRates,
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
GroupRates: u.GroupRates,
|
||||
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
|
||||
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,6 +154,7 @@ func groupFromServiceBase(g *service.Group) Group {
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
@@ -385,6 +388,8 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
|
||||
|
||||
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
// 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。
|
||||
requestType := l.EffectiveRequestType()
|
||||
stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode)
|
||||
return UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
@@ -409,7 +414,9 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
ActualCost: l.ActualCost,
|
||||
RateMultiplier: l.RateMultiplier,
|
||||
BillingType: l.BillingType,
|
||||
Stream: l.Stream,
|
||||
RequestType: requestType.String(),
|
||||
Stream: stream,
|
||||
OpenAIWSMode: openAIWSMode,
|
||||
DurationMs: l.DurationMs,
|
||||
FirstTokenMs: l.FirstTokenMs,
|
||||
ImageCount: l.ImageCount,
|
||||
@@ -464,6 +471,7 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa
|
||||
AccountID: task.Filters.AccountID,
|
||||
GroupID: task.Filters.GroupID,
|
||||
Model: task.Filters.Model,
|
||||
RequestType: requestTypeStringPtr(task.Filters.RequestType),
|
||||
Stream: task.Filters.Stream,
|
||||
BillingType: task.Filters.BillingType,
|
||||
},
|
||||
@@ -479,6 +487,14 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa
|
||||
}
|
||||
}
|
||||
|
||||
func requestTypeStringPtr(requestType *int16) *string {
|
||||
if requestType == nil {
|
||||
return nil
|
||||
}
|
||||
value := service.RequestTypeFromInt16(*requestType).String()
|
||||
return &value
|
||||
}
|
||||
|
||||
func SettingFromService(s *service.Setting) *Setting {
|
||||
if s == nil {
|
||||
return nil
|
||||
|
||||
73
backend/internal/handler/dto/mappers_usage_test.go
Normal file
73
backend/internal/handler/dto/mappers_usage_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUsageLogFromService_IncludesOpenAIWSMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wsLog := &service.UsageLog{
|
||||
RequestID: "req_1",
|
||||
Model: "gpt-5.3-codex",
|
||||
OpenAIWSMode: true,
|
||||
}
|
||||
httpLog := &service.UsageLog{
|
||||
RequestID: "resp_1",
|
||||
Model: "gpt-5.3-codex",
|
||||
OpenAIWSMode: false,
|
||||
}
|
||||
|
||||
require.True(t, UsageLogFromService(wsLog).OpenAIWSMode)
|
||||
require.False(t, UsageLogFromService(httpLog).OpenAIWSMode)
|
||||
require.True(t, UsageLogFromServiceAdmin(wsLog).OpenAIWSMode)
|
||||
require.False(t, UsageLogFromServiceAdmin(httpLog).OpenAIWSMode)
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_PrefersRequestTypeForLegacyFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_2",
|
||||
Model: "gpt-5.3-codex",
|
||||
RequestType: service.RequestTypeWSV2,
|
||||
Stream: false,
|
||||
OpenAIWSMode: false,
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
require.Equal(t, "ws_v2", userDTO.RequestType)
|
||||
require.True(t, userDTO.Stream)
|
||||
require.True(t, userDTO.OpenAIWSMode)
|
||||
require.Equal(t, "ws_v2", adminDTO.RequestType)
|
||||
require.True(t, adminDTO.Stream)
|
||||
require.True(t, adminDTO.OpenAIWSMode)
|
||||
}
|
||||
|
||||
func TestUsageCleanupTaskFromService_RequestTypeMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
requestType := int16(service.RequestTypeStream)
|
||||
task := &service.UsageCleanupTask{
|
||||
ID: 1,
|
||||
Status: service.UsageCleanupStatusPending,
|
||||
Filters: service.UsageCleanupFilters{
|
||||
RequestType: &requestType,
|
||||
},
|
||||
}
|
||||
|
||||
dtoTask := UsageCleanupTaskFromService(task)
|
||||
require.NotNil(t, dtoTask)
|
||||
require.NotNil(t, dtoTask.Filters.RequestType)
|
||||
require.Equal(t, "stream", *dtoTask.Filters.RequestType)
|
||||
}
|
||||
|
||||
func TestRequestTypeStringPtrNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Nil(t, requestTypeStringPtr(nil))
|
||||
}
|
||||
@@ -37,6 +37,7 @@ type SystemSettings struct {
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
@@ -79,9 +80,48 @@ type PublicSettings struct {
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
|
||||
type SoraS3Settings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段)
|
||||
type SoraS3Profile struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ListSoraS3ProfilesResponse Sora S3 配置列表响应
|
||||
type ListSoraS3ProfilesResponse struct {
|
||||
ActiveProfileID string `json:"active_profile_id"`
|
||||
Items []SoraS3Profile `json:"items"`
|
||||
}
|
||||
|
||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||
type StreamTimeoutSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@@ -26,7 +26,9 @@ type AdminUser struct {
|
||||
Notes string `json:"notes"`
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]rateMultiplier
|
||||
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
||||
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
@@ -80,6 +82,9 @@ type Group struct {
|
||||
// 无效请求兜底分组
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
@@ -278,10 +283,12 @@ type UsageLog struct {
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
|
||||
BillingType int8 `json:"billing_type"`
|
||||
Stream bool `json:"stream"`
|
||||
DurationMs *int `json:"duration_ms"`
|
||||
FirstTokenMs *int `json:"first_token_ms"`
|
||||
BillingType int8 `json:"billing_type"`
|
||||
RequestType string `json:"request_type"`
|
||||
Stream bool `json:"stream"`
|
||||
OpenAIWSMode bool `json:"openai_ws_mode"`
|
||||
DurationMs *int `json:"duration_ms"`
|
||||
FirstTokenMs *int `json:"first_token_ms"`
|
||||
|
||||
// 图片生成字段
|
||||
ImageCount int `json:"image_count"`
|
||||
@@ -324,6 +331,7 @@ type UsageCleanupFilters struct {
|
||||
AccountID *int64 `json:"account_id,omitempty"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
RequestType *string `json:"request_type,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
BillingType *int8 `json:"billing_type,omitempty"`
|
||||
}
|
||||
|
||||
@@ -2,11 +2,12 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
|
||||
@@ -78,8 +79,12 @@ func (s *FailoverState) HandleFailoverError(
|
||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||
if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries {
|
||||
s.SameAccountRetryCount[accountID]++
|
||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||
accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries)
|
||||
logger.FromContext(ctx).Warn("gateway.failover_same_account_retry",
|
||||
zap.Int64("account_id", accountID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("same_account_retry_count", s.SameAccountRetryCount[accountID]),
|
||||
zap.Int("same_account_retry_max", maxSameAccountRetries),
|
||||
)
|
||||
if !sleepWithContext(ctx, sameAccountRetryDelay) {
|
||||
return FailoverCanceled
|
||||
}
|
||||
@@ -101,8 +106,12 @@ func (s *FailoverState) HandleFailoverError(
|
||||
|
||||
// 递增切换计数
|
||||
s.SwitchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d",
|
||||
accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches)
|
||||
logger.FromContext(ctx).Warn("gateway.failover_switch_account",
|
||||
zap.Int64("account_id", accountID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", s.SwitchCount),
|
||||
zap.Int("max_switches", s.MaxSwitches),
|
||||
)
|
||||
|
||||
// Antigravity 平台换号线性递增延时
|
||||
if platform == service.PlatformAntigravity {
|
||||
@@ -127,13 +136,18 @@ func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAc
|
||||
s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable &&
|
||||
s.SwitchCount <= s.MaxSwitches {
|
||||
|
||||
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)",
|
||||
singleAccountBackoffDelay, s.SwitchCount)
|
||||
logger.FromContext(ctx).Warn("gateway.failover_single_account_backoff",
|
||||
zap.Duration("backoff_delay", singleAccountBackoffDelay),
|
||||
zap.Int("switch_count", s.SwitchCount),
|
||||
zap.Int("max_switches", s.MaxSwitches),
|
||||
)
|
||||
if !sleepWithContext(ctx, singleAccountBackoffDelay) {
|
||||
return FailoverCanceled
|
||||
}
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d",
|
||||
s.SwitchCount, s.MaxSwitches)
|
||||
logger.FromContext(ctx).Warn("gateway.failover_single_account_retry",
|
||||
zap.Int("switch_count", s.SwitchCount),
|
||||
zap.Int("max_switches", s.MaxSwitches),
|
||||
)
|
||||
s.FailedAccountIDs = make(map[int64]struct{})
|
||||
return FailoverContinue
|
||||
}
|
||||
|
||||
@@ -6,9 +6,10 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
@@ -27,6 +29,10 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const gatewayCompatibilityMetricsLogInterval = 1024
|
||||
|
||||
var gatewayCompatibilityMetricsLogCounter atomic.Uint64
|
||||
|
||||
// GatewayHandler handles API gateway requests
|
||||
type GatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
@@ -109,9 +115,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
defer h.maybeLogCompatibilityFallbackMetrics(reqLog)
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
@@ -140,16 +147,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
|
||||
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
|
||||
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
|
||||
ctx := service.WithIsMaxTokensOneHaikuRequest(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
// 检查是否为 Claude Code 客户端,设置到 context 中(复用已解析请求,避免二次反序列化)。
|
||||
SetClaudeCodeClientContext(c, body, parsedReq)
|
||||
isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context())
|
||||
|
||||
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
|
||||
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
@@ -247,8 +254,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if apiKey.GroupID != nil {
|
||||
prefetchedGroupID = *apiKey.GroupID
|
||||
}
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
|
||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
|
||||
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
@@ -261,7 +267,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
@@ -275,7 +281,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
@@ -364,7 +370,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||
@@ -439,7 +445,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
@@ -458,7 +464,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
@@ -547,7 +553,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||
@@ -956,20 +962,8 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
// Send error event in SSE format with proper JSON marshaling
|
||||
errorData := map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]string{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(errorData)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes))
|
||||
// SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。
|
||||
errorEvent := `data: {"type":"error","error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n"
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
@@ -1024,9 +1018,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
defer h.maybeLogCompatibilityFallbackMetrics(reqLog)
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
@@ -1041,9 +1036,6 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||
@@ -1051,9 +1043,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
// count_tokens 走 messages 严格校验时,复用已解析请求,避免二次反序列化。
|
||||
SetClaudeCodeClientContext(c, body, parsedReq)
|
||||
reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream))
|
||||
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
|
||||
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
|
||||
|
||||
// 验证 model 必填
|
||||
if parsedReq.Model == "" {
|
||||
@@ -1217,24 +1211,8 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
|
||||
textDeltas = []string{"New", " Conversation"}
|
||||
}
|
||||
|
||||
// Build message_start event with proper JSON marshaling
|
||||
messageStart := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []any{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]int{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
messageStartJSON, _ := json.Marshal(messageStart)
|
||||
// Build message_start event with fixed schema.
|
||||
messageStartJSON := `{"type":"message_start","message":{"id":` + strconv.Quote(msgID) + `,"type":"message","role":"assistant","model":` + strconv.Quote(model) + `,"content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0}}}`
|
||||
|
||||
// Build events
|
||||
events := []string{
|
||||
@@ -1244,31 +1222,12 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
|
||||
|
||||
// Add text deltas
|
||||
for _, text := range textDeltas {
|
||||
delta := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]string{
|
||||
"type": "text_delta",
|
||||
"text": text,
|
||||
},
|
||||
}
|
||||
deltaJSON, _ := json.Marshal(delta)
|
||||
deltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + strconv.Quote(text) + `}}`
|
||||
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
|
||||
}
|
||||
|
||||
// Add final events
|
||||
messageDelta := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]int{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
}
|
||||
messageDeltaJSON, _ := json.Marshal(messageDelta)
|
||||
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":10,"output_tokens":` + strconv.Itoa(outputTokens) + `}}`
|
||||
|
||||
events = append(events,
|
||||
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
|
||||
@@ -1366,6 +1325,30 @@ func billingErrorDetails(err error) (status int, code, message string) {
|
||||
return http.StatusForbidden, "billing_error", msg
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) metadataBridgeEnabled() bool {
|
||||
if h == nil || h.cfg == nil {
|
||||
return true
|
||||
}
|
||||
return h.cfg.Gateway.OpenAIWS.MetadataBridgeEnabled
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger) {
|
||||
if reqLog == nil {
|
||||
return
|
||||
}
|
||||
if gatewayCompatibilityMetricsLogCounter.Add(1)%gatewayCompatibilityMetricsLogInterval != 0 {
|
||||
return
|
||||
}
|
||||
metrics := service.SnapshotOpenAICompatibilityFallbackMetrics()
|
||||
reqLog.Info("gateway.compatibility_fallback_metrics",
|
||||
zap.Int64("session_hash_legacy_read_fallback_total", metrics.SessionHashLegacyReadFallbackTotal),
|
||||
zap.Int64("session_hash_legacy_read_fallback_hit", metrics.SessionHashLegacyReadFallbackHit),
|
||||
zap.Int64("session_hash_legacy_dual_write_total", metrics.SessionHashLegacyDualWriteTotal),
|
||||
zap.Float64("session_hash_legacy_read_hit_rate", metrics.SessionHashLegacyReadHitRate),
|
||||
zap.Int64("metadata_legacy_fallback_total", metrics.MetadataLegacyFallbackTotal),
|
||||
)
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
@@ -1377,5 +1360,13 @@ func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
zap.Any("panic", recovered),
|
||||
).Error("gateway.usage_record_task_panic_recovered")
|
||||
}
|
||||
}()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
@@ -119,6 +119,13 @@ func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.A
|
||||
func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||
return map[int64]*service.UserLoadInfo{}, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, id := range accountIDs {
|
||||
result[id] = 0
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||
|
||||
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||
|
||||
@@ -18,12 +18,17 @@ import (
|
||||
// claudeCodeValidator is a singleton validator for Claude Code client detection
|
||||
var claudeCodeValidator = service.NewClaudeCodeValidator()
|
||||
|
||||
const claudeCodeParsedRequestContextKey = "claude_code_parsed_request"
|
||||
|
||||
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
|
||||
// 返回更新后的 context
|
||||
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
||||
func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service.ParsedRequest) {
|
||||
if c == nil || c.Request == nil {
|
||||
return
|
||||
}
|
||||
if parsedReq != nil {
|
||||
c.Set(claudeCodeParsedRequestContextKey, parsedReq)
|
||||
}
|
||||
// Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。
|
||||
if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) {
|
||||
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
|
||||
@@ -37,8 +42,11 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
||||
isClaudeCode = true
|
||||
} else {
|
||||
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
|
||||
var bodyMap map[string]any
|
||||
if len(body) > 0 {
|
||||
bodyMap := claudeCodeBodyMapFromParsedRequest(parsedReq)
|
||||
if bodyMap == nil {
|
||||
bodyMap = claudeCodeBodyMapFromContextCache(c)
|
||||
}
|
||||
if bodyMap == nil && len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &bodyMap)
|
||||
}
|
||||
isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap)
|
||||
@@ -49,6 +57,42 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
func claudeCodeBodyMapFromParsedRequest(parsedReq *service.ParsedRequest) map[string]any {
|
||||
if parsedReq == nil {
|
||||
return nil
|
||||
}
|
||||
bodyMap := map[string]any{
|
||||
"model": parsedReq.Model,
|
||||
}
|
||||
if parsedReq.System != nil || parsedReq.HasSystem {
|
||||
bodyMap["system"] = parsedReq.System
|
||||
}
|
||||
if parsedReq.MetadataUserID != "" {
|
||||
bodyMap["metadata"] = map[string]any{"user_id": parsedReq.MetadataUserID}
|
||||
}
|
||||
return bodyMap
|
||||
}
|
||||
|
||||
func claudeCodeBodyMapFromContextCache(c *gin.Context) map[string]any {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if cached, ok := c.Get(service.OpenAIParsedRequestBodyKey); ok {
|
||||
if bodyMap, ok := cached.(map[string]any); ok {
|
||||
return bodyMap
|
||||
}
|
||||
}
|
||||
if cached, ok := c.Get(claudeCodeParsedRequestContextKey); ok {
|
||||
switch v := cached.(type) {
|
||||
case *service.ParsedRequest:
|
||||
return claudeCodeBodyMapFromParsedRequest(v)
|
||||
case service.ParsedRequest:
|
||||
return claudeCodeBodyMapFromParsedRequest(&v)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 并发槽位等待相关常量
|
||||
//
|
||||
// 性能优化说明:
|
||||
|
||||
@@ -33,6 +33,14 @@ func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accoun
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
result[accountID] = 0
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -49,6 +49,14 @@ func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context,
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
out := make(map[int64]int, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
out[accountID] = 0
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
@@ -133,7 +141,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "curl/8.6.0")
|
||||
|
||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
|
||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil)
|
||||
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
@@ -141,7 +149,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodGet, "/v1/models")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
|
||||
SetClaudeCodeClientContext(c, nil)
|
||||
SetClaudeCodeClientContext(c, nil, nil)
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
@@ -152,7 +160,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
||||
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
||||
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
|
||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil)
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
@@ -160,11 +168,51 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
// 缺少严格校验所需 header + body 字段
|
||||
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`))
|
||||
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`), nil)
|
||||
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing.T) {
|
||||
t.Run("reuse parsed request without body unmarshal", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
c.Request.Header.Set("X-App", "claude-code")
|
||||
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
||||
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
parsedReq := &service.ParsedRequest{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
System: []any{
|
||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||
},
|
||||
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
|
||||
}
|
||||
|
||||
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
|
||||
SetClaudeCodeClientContext(c, []byte(`{invalid`), parsedReq)
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("reuse context cache without body unmarshal", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
c.Request.Header.Set("X-App", "claude-code")
|
||||
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
||||
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
||||
c.Set(service.OpenAIParsedRequestBodyKey, map[string]any{
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"system": []any{
|
||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||
},
|
||||
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
|
||||
})
|
||||
|
||||
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
|
||||
cache := &helperConcurrencyCacheStub{
|
||||
accountSeq: []bool{false, true},
|
||||
|
||||
@@ -7,16 +7,15 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -168,7 +167,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
stream := action == "streamGenerateContent"
|
||||
reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream))
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
|
||||
@@ -268,8 +267,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
if apiKey.GroupID != nil {
|
||||
prefetchedGroupID = *apiKey.GroupID
|
||||
}
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
|
||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
|
||||
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
@@ -349,7 +347,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
@@ -363,7 +361,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
@@ -456,7 +454,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
|
||||
|
||||
@@ -11,6 +11,7 @@ type AdminHandlers struct {
|
||||
Group *admin.GroupHandler
|
||||
Account *admin.AccountHandler
|
||||
Announcement *admin.AnnouncementHandler
|
||||
DataManagement *admin.DataManagementHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
@@ -40,6 +41,7 @@ type Handlers struct {
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
SoraGateway *SoraGatewayHandler
|
||||
SoraClient *SoraClientHandler
|
||||
Setting *SettingHandler
|
||||
Totp *TotpHandler
|
||||
}
|
||||
|
||||
@@ -5,17 +5,20 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
@@ -64,6 +67,11 @@ func NewOpenAIGatewayHandler(
|
||||
// Responses handles OpenAI Responses API endpoint
|
||||
// POST /openai/v1/responses
|
||||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
setOpenAIClientTransportHTTP(c)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||
@@ -85,9 +93,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
@@ -125,43 +136,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
reqStream := streamResult.Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String())
|
||||
if previousResponseID != "" {
|
||||
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
|
||||
reqLog = reqLog.With(
|
||||
zap.Bool("has_previous_response_id", true),
|
||||
zap.String("previous_response_id_kind", previousResponseIDKind),
|
||||
zap.Int("previous_response_id_len", len(previousResponseID)),
|
||||
)
|
||||
if previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "previous_response_id_looks_like_message_id"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||
// 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
|
||||
// 或带 id 且与 call_id 匹配的 item_reference。
|
||||
// 此路径需要遍历 input 数组做 call_id 关联检查,保留 Unmarshal
|
||||
if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err == nil {
|
||||
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
|
||||
if service.HasFunctionCallOutput(reqBody) {
|
||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
||||
if service.HasFunctionCallOutputMissingCallID(reqBody) {
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_call_id"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return
|
||||
}
|
||||
callIDs := service.FunctionCallOutputCallIDs(reqBody)
|
||||
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_item_reference"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
|
||||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
@@ -173,51 +171,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
// 0. 先尝试直接抢占用户槽位(快速路径)
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
|
||||
waitCounted := false
|
||||
if !userAcquired {
|
||||
// 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
if waitErr != nil {
|
||||
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
|
||||
// 按现有降级语义:等待计数异常时放行后续抢槽流程
|
||||
} else if !canWait {
|
||||
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
if waitErr == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 用户槽位已获取:退出等待队列计数。
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
@@ -241,7 +199,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
for {
|
||||
// Select account supporting the requested model
|
||||
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
previousResponseID,
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_select_failed",
|
||||
zap.Error(err),
|
||||
@@ -258,80 +224,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
return
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
if previousResponseID != "" && selection != nil && selection.Account != nil {
|
||||
reqLog.Debug("openai.account_selected_with_previous_response_id", zap.Int64("account_id", selection.Account.ID))
|
||||
}
|
||||
reqLog.Debug("openai.account_schedule_decision",
|
||||
zap.String("layer", scheduleDecision.Layer),
|
||||
zap.Bool("sticky_previous_hit", scheduleDecision.StickyPreviousHit),
|
||||
zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit),
|
||||
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||||
zap.Int("top_k", scheduleDecision.TopK),
|
||||
zap.Int64("latency_ms", scheduleDecision.LatencyMs),
|
||||
zap.Float64("load_skew", scheduleDecision.LoadSkew),
|
||||
)
|
||||
account := selection.Account
|
||||
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 先快速尝试一次账号槽位,命中则跳过等待计数写入。
|
||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||
c.Request.Context(),
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if fastAcquired {
|
||||
accountReleaseFunc = fastReleaseFunc
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
} else {
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
} else if !canWait {
|
||||
reqLog.Info("openai.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
releaseWait := func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
releaseWait()
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
// Forward request
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
@@ -353,6 +269,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
@@ -368,14 +286,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("openai.forward_failed",
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
reqLog.Error("openai.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
@@ -411,6 +340,525 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
|
||||
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||
return true
|
||||
}
|
||||
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
// 保持原有容错语义:解析失败时跳过预校验,沿用后续上游校验结果。
|
||||
return true
|
||||
}
|
||||
|
||||
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
|
||||
validation := service.ValidateFunctionCallOutputContext(reqBody)
|
||||
if !validation.HasFunctionCallOutput {
|
||||
return true
|
||||
}
|
||||
|
||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||
if strings.TrimSpace(previousResponseID) != "" || validation.HasToolCallContext {
|
||||
return true
|
||||
}
|
||||
|
||||
if validation.HasFunctionCallOutputMissingCallID {
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_call_id"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return false
|
||||
}
|
||||
if validation.HasItemReferenceForAllCallIDs {
|
||||
return true
|
||||
}
|
||||
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_item_reference"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) acquireResponsesUserSlot(
|
||||
c *gin.Context,
|
||||
userID int64,
|
||||
userConcurrency int,
|
||||
reqStream bool,
|
||||
streamStarted *bool,
|
||||
reqLog *zap.Logger,
|
||||
) (func(), bool) {
|
||||
ctx := c.Request.Context()
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, userID, userConcurrency)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
if userAcquired {
|
||||
return wrapReleaseOnDone(ctx, userReleaseFunc), true
|
||||
}
|
||||
|
||||
maxWait := service.CalculateMaxWait(userConcurrency)
|
||||
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(ctx, userID, maxWait)
|
||||
if waitErr != nil {
|
||||
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
|
||||
// 按现有降级语义:等待计数异常时放行后续抢槽流程
|
||||
} else if !canWait {
|
||||
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return nil, false
|
||||
}
|
||||
|
||||
waitCounted := waitErr == nil && canWait
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, userID, userConcurrency, reqStream, streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 槽位获取成功后,立刻退出等待计数。
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
|
||||
waitCounted = false
|
||||
}
|
||||
return wrapReleaseOnDone(ctx, userReleaseFunc), true
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
|
||||
c *gin.Context,
|
||||
groupID *int64,
|
||||
sessionHash string,
|
||||
selection *service.AccountSelectionResult,
|
||||
reqStream bool,
|
||||
streamStarted *bool,
|
||||
reqLog *zap.Logger,
|
||||
) (func(), bool) {
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
account := selection.Account
|
||||
if selection.Acquired {
|
||||
return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true
|
||||
}
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||
ctx,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
if fastAcquired {
|
||||
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
return wrapReleaseOnDone(ctx, fastReleaseFunc), true
|
||||
}
|
||||
|
||||
canWait, waitErr := h.concurrencyHelper.IncrementAccountWaitCount(ctx, account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if waitErr != nil {
|
||||
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(waitErr))
|
||||
} else if !canWait {
|
||||
reqLog.Info("openai.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
accountWaitCounted := waitErr == nil && canWait
|
||||
releaseWait := func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(ctx, account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}
|
||||
defer releaseWait()
|
||||
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
return wrapReleaseOnDone(ctx, accountReleaseFunc), true
|
||||
}
|
||||
|
||||
// ResponsesWebSocket handles OpenAI Responses API WebSocket ingress endpoint
|
||||
// GET /openai/v1/responses (Upgrade: websocket)
|
||||
func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
if !isOpenAIWSUpgradeRequest(c.Request) {
|
||||
h.errorResponse(c, http.StatusUpgradeRequired, "invalid_request_error", "WebSocket upgrade required (Upgrade: websocket)")
|
||||
return
|
||||
}
|
||||
setOpenAIClientTransportWS(c)
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.responses_ws",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.Bool("openai_ws_mode", true),
|
||||
)
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
reqLog.Info("openai.websocket_ingress_started")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
userAgent := strings.TrimSpace(c.GetHeader("User-Agent"))
|
||||
|
||||
wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_accept_failed",
|
||||
zap.Error(err),
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("request_user_agent", userAgent),
|
||||
zap.String("upgrade_header", strings.TrimSpace(c.GetHeader("Upgrade"))),
|
||||
zap.String("connection_header", strings.TrimSpace(c.GetHeader("Connection"))),
|
||||
zap.String("sec_websocket_version", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Version"))),
|
||||
zap.Bool("has_sec_websocket_key", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Key")) != ""),
|
||||
)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = wsConn.CloseNow()
|
||||
}()
|
||||
wsConn.SetReadLimit(16 * 1024 * 1024)
|
||||
|
||||
ctx := c.Request.Context()
|
||||
readCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
msgType, firstMessage, err := wsConn.Read(readCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||||
reqLog.Warn("openai.websocket_read_first_message_failed",
|
||||
zap.Error(err),
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("close_status", closeStatus),
|
||||
zap.String("close_reason", closeReason),
|
||||
zap.Duration("read_timeout", 30*time.Second),
|
||||
)
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "missing first response.create message")
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "unsupported websocket message type")
|
||||
return
|
||||
}
|
||||
if !gjson.ValidBytes(firstMessage) {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "invalid JSON payload")
|
||||
return
|
||||
}
|
||||
|
||||
reqModel := strings.TrimSpace(gjson.GetBytes(firstMessage, "model").String())
|
||||
if reqModel == "" {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model is required in first response.create payload")
|
||||
return
|
||||
}
|
||||
previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String())
|
||||
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
|
||||
if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id")
|
||||
return
|
||||
}
|
||||
reqLog = reqLog.With(
|
||||
zap.Bool("ws_ingress", true),
|
||||
zap.String("model", reqModel),
|
||||
zap.Bool("has_previous_response_id", previousResponseID != ""),
|
||||
zap.String("previous_response_id_kind", previousResponseIDKind),
|
||||
)
|
||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||
|
||||
var currentUserRelease func()
|
||||
var currentAccountRelease func()
|
||||
releaseTurnSlots := func() {
|
||||
if currentAccountRelease != nil {
|
||||
currentAccountRelease()
|
||||
currentAccountRelease = nil
|
||||
}
|
||||
if currentUserRelease != nil {
|
||||
currentUserRelease()
|
||||
currentUserRelease = nil
|
||||
}
|
||||
}
|
||||
// 必须尽早注册,确保任何 early return 都能释放已获取的并发槽位。
|
||||
defer releaseTurnSlots()
|
||||
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_user_slot_acquire_failed", zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
|
||||
return
|
||||
}
|
||||
if !userAcquired {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
|
||||
return
|
||||
}
|
||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed")
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHashWithFallback(
|
||||
c,
|
||||
firstMessage,
|
||||
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
|
||||
)
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
apiKey.GroupID,
|
||||
previousResponseID,
|
||||
sessionHash,
|
||||
reqModel,
|
||||
nil,
|
||||
service.OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||
return
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||
return
|
||||
}
|
||||
|
||||
account := selection.Account
|
||||
accountMaxConcurrency := account.Concurrency
|
||||
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
|
||||
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
|
||||
}
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||
return
|
||||
}
|
||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||
ctx,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
|
||||
return
|
||||
}
|
||||
if !fastAcquired {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||
return
|
||||
}
|
||||
accountReleaseFunc = fastReleaseFunc
|
||||
}
|
||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
|
||||
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
|
||||
return
|
||||
}
|
||||
|
||||
reqLog.Debug("openai.websocket_account_selected",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("account_name", account.Name),
|
||||
zap.String("schedule_layer", scheduleDecision.Layer),
|
||||
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||||
)
|
||||
|
||||
hooks := &service.OpenAIWSIngressHooks{
|
||||
BeforeTurn: func(turn int) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
}
|
||||
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
|
||||
releaseTurnSlots()
|
||||
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
|
||||
}
|
||||
if !userAcquired {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
|
||||
}
|
||||
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
|
||||
if err != nil {
|
||||
if userReleaseFunc != nil {
|
||||
userReleaseFunc()
|
||||
}
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
|
||||
}
|
||||
if !accountAcquired {
|
||||
if userReleaseFunc != nil {
|
||||
userReleaseFunc()
|
||||
}
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
|
||||
}
|
||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||
return nil
|
||||
},
|
||||
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
|
||||
releaseTurnSlots()
|
||||
if turnErr != nil || result == nil {
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("request_id", result.RequestID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||||
reqLog.Warn("openai.websocket_proxy_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
zap.String("close_status", closeStatus),
|
||||
zap.String("close_reason", closeReason),
|
||||
)
|
||||
var closeErr *service.OpenAIWSClientCloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
|
||||
return
|
||||
}
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||||
return
|
||||
}
|
||||
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
|
||||
recovered := recover()
|
||||
if recovered == nil {
|
||||
return
|
||||
}
|
||||
|
||||
started := false
|
||||
if streamStarted != nil {
|
||||
started = *streamStarted
|
||||
}
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, started)
|
||||
requestLogger(c, "handler.openai_gateway.responses").Error(
|
||||
"openai.responses_panic_recovered",
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Any("panic", recovered),
|
||||
zap.ByteString("stack", debug.Stack()),
|
||||
)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool {
|
||||
missing := h.missingResponsesDependencies()
|
||||
if len(missing) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if reqLog == nil {
|
||||
reqLog = requestLogger(c, "handler.openai_gateway.responses")
|
||||
}
|
||||
reqLog.Error("openai.handler_dependencies_missing", zap.Strings("missing_dependencies", missing))
|
||||
|
||||
if c != nil && c.Writer != nil && !c.Writer.Written() {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "api_error",
|
||||
"message": "Service temporarily unavailable",
|
||||
},
|
||||
})
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) missingResponsesDependencies() []string {
|
||||
missing := make([]string, 0, 5)
|
||||
if h == nil {
|
||||
return append(missing, "handler")
|
||||
}
|
||||
if h.gatewayService == nil {
|
||||
missing = append(missing, "gatewayService")
|
||||
}
|
||||
if h.billingCacheService == nil {
|
||||
missing = append(missing, "billingCacheService")
|
||||
}
|
||||
if h.apiKeyService == nil {
|
||||
missing = append(missing, "apiKeyService")
|
||||
}
|
||||
if h.concurrencyHelper == nil || h.concurrencyHelper.concurrencyService == nil {
|
||||
missing = append(missing, "concurrencyHelper")
|
||||
}
|
||||
return missing
|
||||
}
|
||||
|
||||
func getContextInt64(c *gin.Context, key string) (int64, bool) {
|
||||
if c == nil || key == "" {
|
||||
return 0, false
|
||||
@@ -444,6 +892,14 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
|
||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
zap.Any("panic", recovered),
|
||||
).Error("openai.usage_record_task_panic_recovered")
|
||||
}
|
||||
}()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
@@ -515,19 +971,8 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
// Send error event in OpenAI SSE format with proper JSON marshaling
|
||||
errorData := map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(errorData)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
||||
// SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。
|
||||
errorEvent := "event: error\ndata: " + `{"error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n"
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
@@ -549,6 +994,16 @@ func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, stream
|
||||
return true
|
||||
}
|
||||
|
||||
func shouldLogOpenAIForwardFailureAsWarn(c *gin.Context, wroteFallback bool) bool {
|
||||
if wroteFallback {
|
||||
return false
|
||||
}
|
||||
if c == nil || c.Writer == nil {
|
||||
return false
|
||||
}
|
||||
return c.Writer.Written()
|
||||
}
|
||||
|
||||
// errorResponse returns OpenAI API format error response
|
||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
@@ -558,3 +1013,61 @@ func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func setOpenAIClientTransportHTTP(c *gin.Context) {
|
||||
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportHTTP)
|
||||
}
|
||||
|
||||
func setOpenAIClientTransportWS(c *gin.Context) {
|
||||
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
|
||||
}
|
||||
|
||||
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
|
||||
gid := int64(0)
|
||||
if groupID != nil {
|
||||
gid = *groupID
|
||||
}
|
||||
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
|
||||
}
|
||||
|
||||
func isOpenAIWSUpgradeRequest(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket") {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade")
|
||||
}
|
||||
|
||||
func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason string) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
reason = strings.TrimSpace(reason)
|
||||
if len(reason) > 120 {
|
||||
reason = reason[:120]
|
||||
}
|
||||
_ = conn.Close(status, reason)
|
||||
_ = conn.CloseNow()
|
||||
}
|
||||
|
||||
func summarizeWSCloseErrorForLog(err error) (string, string) {
|
||||
if err == nil {
|
||||
return "-", "-"
|
||||
}
|
||||
statusCode := coderws.CloseStatus(err)
|
||||
if statusCode == -1 {
|
||||
return "-", "-"
|
||||
}
|
||||
closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
|
||||
closeReason := "-"
|
||||
var closeErr coderws.CloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
reason := strings.TrimSpace(closeErr.Reason)
|
||||
if reason != "" {
|
||||
closeReason = reason
|
||||
}
|
||||
}
|
||||
return closeStatus, closeReason
|
||||
}
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -105,6 +112,27 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
|
||||
assert.Equal(t, "test error", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc(t *testing.T) {
|
||||
payload := `{"model":"gpt-5","input":"hello"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(payload))
|
||||
req.ContentLength = int64(len(payload))
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, payload, string(body))
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_MaxBytesError(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(strings.Repeat("x", 8)))
|
||||
req.Body = http.MaxBytesReader(rec, req.Body, 4)
|
||||
|
||||
_, err := pkghttputil.ReadRequestBodyWithPrealloc(req)
|
||||
require.Error(t, err)
|
||||
var maxErr *http.MaxBytesError
|
||||
require.ErrorAs(t, err, &maxErr)
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -141,6 +169,387 @@ func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *test
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
}
|
||||
|
||||
func TestShouldLogOpenAIForwardFailureAsWarn(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("fallback_written_should_not_downgrade", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, true))
|
||||
})
|
||||
|
||||
t.Run("context_nil_should_not_downgrade", func(t *testing.T) {
|
||||
require.False(t, shouldLogOpenAIForwardFailureAsWarn(nil, false))
|
||||
})
|
||||
|
||||
t.Run("response_not_written_should_not_downgrade", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, false))
|
||||
})
|
||||
|
||||
t.Run("response_already_written_should_downgrade", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.String(http.StatusForbidden, "already written")
|
||||
require.True(t, shouldLogOpenAIForwardFailureAsWarn(c, false))
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIRecoverResponsesPanic_WritesFallbackResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
streamStarted := false
|
||||
require.NotPanics(t, func() {
|
||||
func() {
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
panic("test panic")
|
||||
}()
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIRecoverResponsesPanic_NoPanicNoWrite(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
streamStarted := false
|
||||
require.NotPanics(t, func() {
|
||||
func() {
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
}()
|
||||
})
|
||||
|
||||
require.False(t, c.Writer.Written())
|
||||
assert.Equal(t, "", w.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
streamStarted := false
|
||||
require.NotPanics(t, func() {
|
||||
func() {
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
panic("test panic")
|
||||
}()
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIMissingResponsesDependencies(t *testing.T) {
|
||||
t.Run("nil_handler", func(t *testing.T) {
|
||||
var h *OpenAIGatewayHandler
|
||||
require.Equal(t, []string{"handler"}, h.missingResponsesDependencies())
|
||||
})
|
||||
|
||||
t.Run("all_dependencies_missing", func(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{}
|
||||
require.Equal(t,
|
||||
[]string{"gatewayService", "billingCacheService", "apiKeyService", "concurrencyHelper"},
|
||||
h.missingResponsesDependencies(),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("all_dependencies_present", func(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{
|
||||
concurrencyService: &service.ConcurrencyService{},
|
||||
},
|
||||
}
|
||||
require.Empty(t, h.missingResponsesDependencies())
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureResponsesDependencies(t *testing.T) {
|
||||
t.Run("missing_dependencies_returns_503", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
ok := h.ensureResponsesDependencies(c, nil)
|
||||
|
||||
require.False(t, ok)
|
||||
require.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
errorObj, exists := parsed["error"].(map[string]any)
|
||||
require.True(t, exists)
|
||||
assert.Equal(t, "api_error", errorObj["type"])
|
||||
assert.Equal(t, "Service temporarily unavailable", errorObj["message"])
|
||||
})
|
||||
|
||||
t.Run("already_written_response_not_overridden", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
ok := h.ensureResponsesDependencies(c, nil)
|
||||
|
||||
require.False(t, ok)
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
})
|
||||
|
||||
t.Run("dependencies_ready_returns_true_and_no_write", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{
|
||||
concurrencyService: &service.ConcurrencyService{},
|
||||
},
|
||||
}
|
||||
ok := h.ensureResponsesDependencies(c, nil)
|
||||
|
||||
require.True(t, ok)
|
||||
require.False(t, c.Writer.Written())
|
||||
assert.Equal(t, "", w.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"gpt-5","stream":false}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
groupID := int64(2)
|
||||
c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
|
||||
ID: 10,
|
||||
GroupID: &groupID,
|
||||
})
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||
UserID: 1,
|
||||
Concurrency: 1,
|
||||
})
|
||||
|
||||
// 故意使用未初始化依赖,验证快速失败而不是崩溃。
|
||||
h := &OpenAIGatewayHandler{}
|
||||
require.NotPanics(t, func() {
|
||||
h.Responses(c)
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "api_error", errorObj["type"])
|
||||
assert.Equal(t, "Service temporarily unavailable", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_SetsClientTransportHTTP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(`{"model":"gpt-5"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.Responses(c)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
|
||||
`{"model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456","input":[{"type":"input_text","text":"hello"}]}`,
|
||||
))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
groupID := int64(2)
|
||||
c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
|
||||
ID: 101,
|
||||
GroupID: &groupID,
|
||||
User: &service.User{ID: 1},
|
||||
})
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||
UserID: 1,
|
||||
Concurrency: 1,
|
||||
})
|
||||
|
||||
h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
|
||||
h.Responses(c)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
require.Contains(t, w.Body.String(), "previous_response_id must be a response.id")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("Upgrade", "websocket")
|
||||
c.Request.Header.Set("Connection", "Upgrade")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.ResponsesWebSocket(c)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_InvalidUpgradeDoesNotSetTransport(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.ResponsesWebSocket(c)
|
||||
|
||||
require.Equal(t, http.StatusUpgradeRequired, w.Code)
|
||||
require.Equal(t, service.OpenAIClientTransportUnknown, service.GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
|
||||
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
|
||||
`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`,
|
||||
))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, _, err = clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.Error(t, err)
|
||||
var closeErr coderws.CloseError
|
||||
require.ErrorAs(t, err, &closeErr)
|
||||
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
|
||||
require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailure(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return false, errors.New("user slot unavailable")
|
||||
},
|
||||
}
|
||||
h := newOpenAIHandlerForPreviousResponseIDValidation(t, cache)
|
||||
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
|
||||
`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_123"}`,
|
||||
))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, _, err = clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.Error(t, err)
|
||||
var closeErr coderws.CloseError
|
||||
require.ErrorAs(t, err, &closeErr)
|
||||
require.Equal(t, coderws.StatusInternalError, closeErr.Code)
|
||||
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
||||
}
|
||||
|
||||
func TestSetOpenAIClientTransportHTTP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
setOpenAIClientTransportHTTP(c)
|
||||
require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
func TestSetOpenAIClientTransportWS(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
setOpenAIClientTransportWS(c)
|
||||
require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
|
||||
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -228,3 +637,41 @@ func TestOpenAIHandler_InstructionsInjection(t *testing.T) {
|
||||
require.NoError(t, setErr)
|
||||
require.True(t, gjson.ValidBytes(result))
|
||||
}
|
||||
|
||||
func newOpenAIHandlerForPreviousResponseIDValidation(t *testing.T, cache *concurrencyCacheMock) *OpenAIGatewayHandler {
|
||||
t.Helper()
|
||||
if cache == nil {
|
||||
cache = &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject middleware.AuthSubject) *httptest.Server {
|
||||
t.Helper()
|
||||
groupID := int64(2)
|
||||
apiKey := &service.APIKey{
|
||||
ID: 101,
|
||||
GroupID: &groupID,
|
||||
User: &service.User{ID: subject.UserID},
|
||||
}
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), subject)
|
||||
c.Next()
|
||||
})
|
||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||
return httptest.NewServer(router)
|
||||
}
|
||||
|
||||
@@ -311,6 +311,35 @@ type opsCaptureWriter struct {
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
const opsCaptureWriterLimit = 64 * 1024
|
||||
|
||||
var opsCaptureWriterPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &opsCaptureWriter{limit: opsCaptureWriterLimit}
|
||||
},
|
||||
}
|
||||
|
||||
func acquireOpsCaptureWriter(rw gin.ResponseWriter) *opsCaptureWriter {
|
||||
w, ok := opsCaptureWriterPool.Get().(*opsCaptureWriter)
|
||||
if !ok || w == nil {
|
||||
w = &opsCaptureWriter{}
|
||||
}
|
||||
w.ResponseWriter = rw
|
||||
w.limit = opsCaptureWriterLimit
|
||||
w.buf.Reset()
|
||||
return w
|
||||
}
|
||||
|
||||
func releaseOpsCaptureWriter(w *opsCaptureWriter) {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
w.ResponseWriter = nil
|
||||
w.limit = opsCaptureWriterLimit
|
||||
w.buf.Reset()
|
||||
opsCaptureWriterPool.Put(w)
|
||||
}
|
||||
|
||||
func (w *opsCaptureWriter) Write(b []byte) (int, error) {
|
||||
if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit {
|
||||
remaining := w.limit - w.buf.Len()
|
||||
@@ -342,7 +371,16 @@ func (w *opsCaptureWriter) WriteString(s string) (int, error) {
|
||||
// - Streaming errors after the response has started (SSE) may still need explicit logging.
|
||||
func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
w := &opsCaptureWriter{ResponseWriter: c.Writer, limit: 64 * 1024}
|
||||
originalWriter := c.Writer
|
||||
w := acquireOpsCaptureWriter(originalWriter)
|
||||
defer func() {
|
||||
// Restore the original writer before returning so outer middlewares
|
||||
// don't observe a pooled wrapper that has been released.
|
||||
if c.Writer == w {
|
||||
c.Writer = originalWriter
|
||||
}
|
||||
releaseOpsCaptureWriter(w)
|
||||
}()
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -173,3 +174,43 @@ func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) {
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||
}
|
||||
|
||||
func TestOpsCaptureWriterPool_ResetOnRelease(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
|
||||
writer := acquireOpsCaptureWriter(c.Writer)
|
||||
require.NotNil(t, writer)
|
||||
_, err := writer.buf.WriteString("temp-error-body")
|
||||
require.NoError(t, err)
|
||||
|
||||
releaseOpsCaptureWriter(writer)
|
||||
|
||||
reused := acquireOpsCaptureWriter(c.Writer)
|
||||
defer releaseOpsCaptureWriter(reused)
|
||||
|
||||
require.Zero(t, reused.buf.Len(), "writer should be reset before reuse")
|
||||
}
|
||||
|
||||
func TestOpsErrorLoggerMiddleware_DoesNotBreakOuterMiddlewares(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(middleware2.Recovery())
|
||||
r.Use(middleware2.RequestLogger())
|
||||
r.Use(middleware2.Logger())
|
||||
r.GET("/v1/messages", OpsErrorLoggerMiddleware(nil), func(c *gin.Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/messages", nil)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
r.ServeHTTP(rec, req)
|
||||
})
|
||||
require.Equal(t, http.StatusNoContent, rec.Code)
|
||||
}
|
||||
|
||||
@@ -51,6 +51,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
979
backend/internal/handler/sora_client_handler.go
Normal file
979
backend/internal/handler/sora_client_handler.go
Normal file
@@ -0,0 +1,979 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// 上游模型缓存 TTL
|
||||
modelCacheTTL = 1 * time.Hour // 上游获取成功
|
||||
modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地)
|
||||
)
|
||||
|
||||
// SoraClientHandler 处理 Sora 客户端 API 请求。
|
||||
type SoraClientHandler struct {
|
||||
genService *service.SoraGenerationService
|
||||
quotaService *service.SoraQuotaService
|
||||
s3Storage *service.SoraS3Storage
|
||||
soraGatewayService *service.SoraGatewayService
|
||||
gatewayService *service.GatewayService
|
||||
mediaStorage *service.SoraMediaStorage
|
||||
apiKeyService *service.APIKeyService
|
||||
|
||||
// 上游模型缓存
|
||||
modelCacheMu sync.RWMutex
|
||||
cachedFamilies []service.SoraModelFamily
|
||||
modelCacheTime time.Time
|
||||
modelCacheUpstream bool // 是否来自上游(决定 TTL)
|
||||
}
|
||||
|
||||
// NewSoraClientHandler 创建 Sora 客户端 Handler。
|
||||
func NewSoraClientHandler(
|
||||
genService *service.SoraGenerationService,
|
||||
quotaService *service.SoraQuotaService,
|
||||
s3Storage *service.SoraS3Storage,
|
||||
soraGatewayService *service.SoraGatewayService,
|
||||
gatewayService *service.GatewayService,
|
||||
mediaStorage *service.SoraMediaStorage,
|
||||
apiKeyService *service.APIKeyService,
|
||||
) *SoraClientHandler {
|
||||
return &SoraClientHandler{
|
||||
genService: genService,
|
||||
quotaService: quotaService,
|
||||
s3Storage: s3Storage,
|
||||
soraGatewayService: soraGatewayService,
|
||||
gatewayService: gatewayService,
|
||||
mediaStorage: mediaStorage,
|
||||
apiKeyService: apiKeyService,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRequest 生成请求。
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model" binding:"required"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
MediaType string `json:"media_type"` // video / image,默认 video
|
||||
VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3)
|
||||
ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL)
|
||||
APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID
|
||||
}
|
||||
|
||||
// Generate 异步生成 — 创建 pending 记录后立即返回。
|
||||
// POST /api/v1/sora/generate
|
||||
func (h *SoraClientHandler) Generate(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var req GenerateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.MediaType == "" {
|
||||
req.MediaType = "video"
|
||||
}
|
||||
req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount)
|
||||
|
||||
// 并发数检查(最多 3 个)
|
||||
activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if activeCount >= 3 {
|
||||
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
|
||||
return
|
||||
}
|
||||
|
||||
// 配额检查(粗略检查,实际文件大小在上传后才知道)
|
||||
if h.quotaService != nil {
|
||||
if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil {
|
||||
var quotaErr *service.QuotaExceededError
|
||||
if errors.As(err, "aErr) {
|
||||
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 获取 API Key ID 和 Group ID
|
||||
var apiKeyID *int64
|
||||
var groupID *int64
|
||||
|
||||
if req.APIKeyID != nil && h.apiKeyService != nil {
|
||||
// 前端传递了 api_key_id,需要校验
|
||||
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "API Key 不存在")
|
||||
return
|
||||
}
|
||||
if apiKey.UserID != userID {
|
||||
response.Error(c, http.StatusForbidden, "API Key 不属于当前用户")
|
||||
return
|
||||
}
|
||||
if apiKey.Status != service.StatusAPIKeyActive {
|
||||
response.Error(c, http.StatusForbidden, "API Key 不可用")
|
||||
return
|
||||
}
|
||||
apiKeyID = &apiKey.ID
|
||||
groupID = apiKey.GroupID
|
||||
} else if id, ok := c.Get("api_key_id"); ok {
|
||||
// 兼容 API Key 认证路径(/sora/v1/ 网关路由)
|
||||
if v, ok := id.(int64); ok {
|
||||
apiKeyID = &v
|
||||
}
|
||||
}
|
||||
|
||||
gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) {
|
||||
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 启动后台异步生成 goroutine
|
||||
go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount)
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"generation_id": gen.ID,
|
||||
"status": gen.Status,
|
||||
})
|
||||
}
|
||||
|
||||
// processGeneration 后台异步执行 Sora 生成任务。
|
||||
// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。
|
||||
func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// 标记为生成中
|
||||
if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID)
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d",
|
||||
genID,
|
||||
userID,
|
||||
groupIDForLog(groupID),
|
||||
model,
|
||||
mediaType,
|
||||
videoCount,
|
||||
strings.TrimSpace(imageInput) != "",
|
||||
len(strings.TrimSpace(prompt)),
|
||||
)
|
||||
|
||||
// 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
|
||||
if groupID == nil {
|
||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
|
||||
}
|
||||
|
||||
if h.gatewayService == nil {
|
||||
_ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化")
|
||||
return
|
||||
}
|
||||
|
||||
// 选择 Sora 账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v",
|
||||
genID,
|
||||
userID,
|
||||
groupIDForLog(groupID),
|
||||
model,
|
||||
err,
|
||||
)
|
||||
_ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s",
|
||||
genID,
|
||||
userID,
|
||||
groupIDForLog(groupID),
|
||||
model,
|
||||
account.ID,
|
||||
account.Name,
|
||||
account.Platform,
|
||||
account.Type,
|
||||
)
|
||||
|
||||
// 构建 chat completions 请求体(非流式)
|
||||
body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount))
|
||||
|
||||
if h.soraGatewayService == nil {
|
||||
_ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化")
|
||||
return
|
||||
}
|
||||
|
||||
// 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL)
|
||||
recorder := httptest.NewRecorder()
|
||||
mockGinCtx, _ := gin.CreateTestContext(recorder)
|
||||
mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil)
|
||||
|
||||
// 调用 Forward(非流式)
|
||||
result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v",
|
||||
genID,
|
||||
account.ID,
|
||||
model,
|
||||
recorder.Code,
|
||||
trimForLog(recorder.Body.String(), 400),
|
||||
err,
|
||||
)
|
||||
// 检查是否已取消
|
||||
gen, _ := h.genService.GetByID(ctx, genID, userID)
|
||||
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
||||
return
|
||||
}
|
||||
_ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 提取媒体 URL(优先从 ForwardResult,其次从响应体解析)
|
||||
mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder)
|
||||
if mediaURL == "" {
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s",
|
||||
genID,
|
||||
account.ID,
|
||||
model,
|
||||
recorder.Code,
|
||||
trimForLog(recorder.Body.String(), 400),
|
||||
)
|
||||
_ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查任务是否已被取消
|
||||
gen, _ := h.genService.GetByID(ctx, genID, userID)
|
||||
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID)
|
||||
return
|
||||
}
|
||||
|
||||
// 三层降级存储:S3 → 本地 → 上游临时 URL
|
||||
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs)
|
||||
|
||||
usageAdded := false
|
||||
if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil {
|
||||
if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil {
|
||||
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
||||
var quotaErr *service.QuotaExceededError
|
||||
if errors.As(err, "aErr) {
|
||||
_ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间")
|
||||
return
|
||||
}
|
||||
_ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
usageAdded = true
|
||||
}
|
||||
|
||||
// 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
|
||||
gen, _ = h.genService.GetByID(ctx, genID, userID)
|
||||
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID)
|
||||
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
||||
if usageAdded && h.quotaService != nil {
|
||||
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 标记完成
|
||||
if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
|
||||
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
||||
if usageAdded && h.quotaService != nil {
|
||||
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
|
||||
}
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize)
|
||||
}
|
||||
|
||||
// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。
|
||||
func (h *SoraClientHandler) storeMediaWithDegradation(
|
||||
ctx context.Context, userID int64, mediaType string,
|
||||
mediaURL string, mediaURLs []string,
|
||||
) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) {
|
||||
urls := mediaURLs
|
||||
if len(urls) == 0 {
|
||||
urls = []string{mediaURL}
|
||||
}
|
||||
|
||||
// 第一层:尝试 S3
|
||||
if h.s3Storage != nil && h.s3Storage.Enabled(ctx) {
|
||||
keys := make([]string, 0, len(urls))
|
||||
var totalSize int64
|
||||
allOK := true
|
||||
for _, u := range urls {
|
||||
key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err)
|
||||
allOK = false
|
||||
// 清理已上传的文件
|
||||
if len(keys) > 0 {
|
||||
_ = h.s3Storage.DeleteObjects(ctx, keys)
|
||||
}
|
||||
break
|
||||
}
|
||||
keys = append(keys, key)
|
||||
totalSize += size
|
||||
}
|
||||
if allOK && len(keys) > 0 {
|
||||
accessURLs := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
accessURL, err := h.s3Storage.GetAccessURL(ctx, key)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err)
|
||||
_ = h.s3Storage.DeleteObjects(ctx, keys)
|
||||
allOK = false
|
||||
break
|
||||
}
|
||||
accessURLs = append(accessURLs, accessURL)
|
||||
}
|
||||
if allOK && len(accessURLs) > 0 {
|
||||
return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 第二层:尝试本地存储
|
||||
if h.mediaStorage != nil && h.mediaStorage.Enabled() {
|
||||
storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls)
|
||||
if err == nil && len(storedPaths) > 0 {
|
||||
firstPath := storedPaths[0]
|
||||
totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths)
|
||||
if sizeErr != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr)
|
||||
}
|
||||
return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize
|
||||
}
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err)
|
||||
}
|
||||
|
||||
// 第三层:保留上游临时 URL
|
||||
return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0
|
||||
}
|
||||
|
||||
// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
|
||||
func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte {
|
||||
body := map[string]any{
|
||||
"model": model,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": prompt},
|
||||
},
|
||||
"stream": false,
|
||||
}
|
||||
if imageInput != "" {
|
||||
body["image_input"] = imageInput
|
||||
}
|
||||
if videoCount > 1 {
|
||||
body["video_count"] = videoCount
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
return b
|
||||
}
|
||||
|
||||
func normalizeVideoCount(mediaType string, videoCount int) int {
|
||||
if mediaType != "video" {
|
||||
return 1
|
||||
}
|
||||
if videoCount <= 0 {
|
||||
return 1
|
||||
}
|
||||
if videoCount > 3 {
|
||||
return 3
|
||||
}
|
||||
return videoCount
|
||||
}
|
||||
|
||||
// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
|
||||
// OAuth 路径:ForwardResult.MediaURL 已填充。
|
||||
// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
|
||||
func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) {
|
||||
// 优先从 ForwardResult 获取(OAuth 路径)
|
||||
if result != nil && result.MediaURL != "" {
|
||||
// 尝试从响应体获取完整 URL 列表
|
||||
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
|
||||
return urls[0], urls
|
||||
}
|
||||
return result.MediaURL, []string{result.MediaURL}
|
||||
}
|
||||
|
||||
// 从响应体解析(APIKey 路径)
|
||||
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
|
||||
return urls[0], urls
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
|
||||
func parseMediaURLsFromBody(body []byte) []string {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 优先 media_urls(多图数组)
|
||||
if rawURLs, ok := resp["media_urls"]; ok {
|
||||
if arr, ok := rawURLs.([]any); ok && len(arr) > 0 {
|
||||
urls := make([]string, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
if s, ok := item.(string); ok && s != "" {
|
||||
urls = append(urls, s)
|
||||
}
|
||||
}
|
||||
if len(urls) > 0 {
|
||||
return urls
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到 media_url(单个 URL)
|
||||
if url, ok := resp["media_url"].(string); ok && url != "" {
|
||||
return []string{url}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListGenerations 查询生成记录列表。
|
||||
// GET /api/v1/sora/generations
|
||||
func (h *SoraClientHandler) ListGenerations(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
params := service.SoraGenerationListParams{
|
||||
UserID: userID,
|
||||
Status: c.Query("status"),
|
||||
StorageType: c.Query("storage_type"),
|
||||
MediaType: c.Query("media_type"),
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
gens, total, err := h.genService.List(c.Request.Context(), params)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 为 S3 记录动态生成预签名 URL
|
||||
for _, gen := range gens {
|
||||
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"data": gens,
|
||||
"total": total,
|
||||
"page": page,
|
||||
})
|
||||
}
|
||||
|
||||
// GetGeneration 查询生成记录详情。
|
||||
// GET /api/v1/sora/generations/:id
|
||||
func (h *SoraClientHandler) GetGeneration(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
|
||||
response.Success(c, gen)
|
||||
}
|
||||
|
||||
// DeleteGeneration 删除生成记录。
|
||||
// DELETE /api/v1/sora/generations/:id
|
||||
func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
|
||||
if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil {
|
||||
paths := gen.MediaURLs
|
||||
if len(paths) == 0 && gen.MediaURL != "" {
|
||||
paths = []string{gen.MediaURL}
|
||||
}
|
||||
if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "已删除"})
|
||||
}
|
||||
|
||||
// GetQuota 查询用户存储配额。
|
||||
// GET /api/v1/sora/quota
|
||||
func (h *SoraClientHandler) GetQuota(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
if h.quotaService == nil {
|
||||
response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"})
|
||||
return
|
||||
}
|
||||
|
||||
quota, err := h.quotaService.GetQuota(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, quota)
|
||||
}
|
||||
|
||||
// CancelGeneration 取消生成任务。
|
||||
// POST /api/v1/sora/generations/:id/cancel
|
||||
func (h *SoraClientHandler) CancelGeneration(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
// 权限校验
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
_ = gen
|
||||
|
||||
if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationNotActive) {
|
||||
response.Error(c, http.StatusConflict, "任务已结束,无法取消")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "已取消"})
|
||||
}
|
||||
|
||||
// SaveToStorage 手动保存 upstream 记录到 S3。
|
||||
// POST /api/v1/sora/generations/:id/save
|
||||
func (h *SoraClientHandler) SaveToStorage(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if gen.StorageType != service.SoraStorageTypeUpstream {
|
||||
response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存")
|
||||
return
|
||||
}
|
||||
if gen.MediaURL == "" {
|
||||
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
|
||||
return
|
||||
}
|
||||
|
||||
if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) {
|
||||
response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员")
|
||||
return
|
||||
}
|
||||
|
||||
sourceURLs := gen.MediaURLs
|
||||
if len(sourceURLs) == 0 && gen.MediaURL != "" {
|
||||
sourceURLs = []string{gen.MediaURL}
|
||||
}
|
||||
if len(sourceURLs) == 0 {
|
||||
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
|
||||
return
|
||||
}
|
||||
|
||||
uploadedKeys := make([]string, 0, len(sourceURLs))
|
||||
accessURLs := make([]string, 0, len(sourceURLs))
|
||||
var totalSize int64
|
||||
|
||||
for _, sourceURL := range sourceURLs {
|
||||
objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL)
|
||||
if uploadErr != nil {
|
||||
if len(uploadedKeys) > 0 {
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
}
|
||||
var upstreamErr *service.UpstreamDownloadError
|
||||
if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) {
|
||||
response.Error(c, http.StatusGone, "媒体链接已过期,无法保存")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error())
|
||||
return
|
||||
}
|
||||
accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey)
|
||||
if err != nil {
|
||||
uploadedKeys = append(uploadedKeys, objectKey)
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
uploadedKeys = append(uploadedKeys, objectKey)
|
||||
accessURLs = append(accessURLs, accessURL)
|
||||
totalSize += fileSize
|
||||
}
|
||||
|
||||
usageAdded := false
|
||||
if totalSize > 0 && h.quotaService != nil {
|
||||
if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil {
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
var quotaErr *service.QuotaExceededError
|
||||
if errors.As(err, "aErr) {
|
||||
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
usageAdded = true
|
||||
}
|
||||
|
||||
if err := h.genService.UpdateStorageForCompleted(
|
||||
c.Request.Context(),
|
||||
id,
|
||||
accessURLs[0],
|
||||
accessURLs,
|
||||
service.SoraStorageTypeS3,
|
||||
uploadedKeys,
|
||||
totalSize,
|
||||
); err != nil {
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
if usageAdded && h.quotaService != nil {
|
||||
_ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize)
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"message": "已保存到 S3",
|
||||
"object_key": uploadedKeys[0],
|
||||
"object_keys": uploadedKeys,
|
||||
})
|
||||
}
|
||||
|
||||
// GetStorageStatus 返回存储状态。
|
||||
// GET /api/v1/sora/storage-status
|
||||
func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) {
|
||||
s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context())
|
||||
s3Healthy := false
|
||||
if s3Enabled {
|
||||
s3Healthy = h.s3Storage.IsHealthy(c.Request.Context())
|
||||
}
|
||||
localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled()
|
||||
response.Success(c, gin.H{
|
||||
"s3_enabled": s3Enabled,
|
||||
"s3_healthy": s3Healthy,
|
||||
"local_enabled": localEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) {
|
||||
switch storageType {
|
||||
case service.SoraStorageTypeS3:
|
||||
if h.s3Storage != nil && len(s3Keys) > 0 {
|
||||
if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err)
|
||||
}
|
||||
}
|
||||
case service.SoraStorageTypeLocal:
|
||||
if h.mediaStorage != nil && len(localPaths) > 0 {
|
||||
if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getUserIDFromContext 从 gin 上下文中提取用户 ID。
|
||||
func getUserIDFromContext(c *gin.Context) int64 {
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
|
||||
return subject.UserID
|
||||
}
|
||||
|
||||
if id, ok := c.Get("user_id"); ok {
|
||||
switch v := id.(type) {
|
||||
case int64:
|
||||
return v
|
||||
case float64:
|
||||
return int64(v)
|
||||
case string:
|
||||
n, _ := strconv.ParseInt(v, 10, 64)
|
||||
return n
|
||||
}
|
||||
}
|
||||
// 尝试从 JWT claims 获取
|
||||
if id, ok := c.Get("userID"); ok {
|
||||
if v, ok := id.(int64); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func groupIDForLog(groupID *int64) int64 {
|
||||
if groupID == nil {
|
||||
return 0
|
||||
}
|
||||
return *groupID
|
||||
}
|
||||
|
||||
func trimForLog(raw string, maxLen int) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if maxLen <= 0 || len(trimmed) <= maxLen {
|
||||
return trimmed
|
||||
}
|
||||
return trimmed[:maxLen] + "...(truncated)"
|
||||
}
|
||||
|
||||
// GetModels 获取可用 Sora 模型家族列表。
|
||||
// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
|
||||
// GET /api/v1/sora/models
|
||||
func (h *SoraClientHandler) GetModels(c *gin.Context) {
|
||||
families := h.getModelFamilies(c.Request.Context())
|
||||
response.Success(c, families)
|
||||
}
|
||||
|
||||
// getModelFamilies 获取模型家族列表(带缓存)。
|
||||
func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily {
|
||||
// 读锁检查缓存
|
||||
h.modelCacheMu.RLock()
|
||||
ttl := modelCacheTTL
|
||||
if !h.modelCacheUpstream {
|
||||
ttl = modelCacheFailedTTL
|
||||
}
|
||||
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
|
||||
families := h.cachedFamilies
|
||||
h.modelCacheMu.RUnlock()
|
||||
return families
|
||||
}
|
||||
h.modelCacheMu.RUnlock()
|
||||
|
||||
// 写锁更新缓存
|
||||
h.modelCacheMu.Lock()
|
||||
defer h.modelCacheMu.Unlock()
|
||||
|
||||
// double-check
|
||||
ttl = modelCacheTTL
|
||||
if !h.modelCacheUpstream {
|
||||
ttl = modelCacheFailedTTL
|
||||
}
|
||||
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
|
||||
return h.cachedFamilies
|
||||
}
|
||||
|
||||
// 尝试从上游获取
|
||||
families, err := h.fetchUpstreamModels(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err)
|
||||
families = service.BuildSoraModelFamilies()
|
||||
h.cachedFamilies = families
|
||||
h.modelCacheTime = time.Now()
|
||||
h.modelCacheUpstream = false
|
||||
return families
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families))
|
||||
h.cachedFamilies = families
|
||||
h.modelCacheTime = time.Now()
|
||||
h.modelCacheUpstream = true
|
||||
return families
|
||||
}
|
||||
|
||||
// fetchUpstreamModels 从上游 Sora API 获取模型列表。
|
||||
func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) {
|
||||
if h.gatewayService == nil {
|
||||
return nil, fmt.Errorf("gatewayService 未初始化")
|
||||
}
|
||||
|
||||
// 设置 ForcePlatform 用于 Sora 账号选择
|
||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
|
||||
|
||||
// 选择一个 Sora 账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("选择 Sora 账号失败: %w", err)
|
||||
}
|
||||
|
||||
// 仅支持 API Key 类型账号
|
||||
if account.Type != service.AccountTypeAPIKey {
|
||||
return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type)
|
||||
}
|
||||
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("账号缺少 api_key")
|
||||
}
|
||||
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("账号缺少 base_url")
|
||||
}
|
||||
|
||||
// 构建上游模型列表请求
|
||||
modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models"
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求上游失败: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析 OpenAI 格式的模型列表
|
||||
var modelsResp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &modelsResp); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if len(modelsResp.Data) == 0 {
|
||||
return nil, fmt.Errorf("上游返回空模型列表")
|
||||
}
|
||||
|
||||
// 提取模型 ID
|
||||
modelIDs := make([]string, 0, len(modelsResp.Data))
|
||||
for _, m := range modelsResp.Data {
|
||||
modelIDs = append(modelIDs, m.ID)
|
||||
}
|
||||
|
||||
// 转换为模型家族
|
||||
families := service.BuildSoraModelFamiliesFromIDs(modelIDs)
|
||||
if len(families) == 0 {
|
||||
return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族")
|
||||
}
|
||||
|
||||
return families, nil
|
||||
}
|
||||
3135
backend/internal/handler/sora_client_handler_test.go
Normal file
3135
backend/internal/handler/sora_client_handler_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
@@ -17,6 +16,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -107,7 +107,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
@@ -461,6 +461,14 @@ func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask)
|
||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||
zap.Any("panic", recovered),
|
||||
).Error("sora.usage_record_task_panic_recovered")
|
||||
}
|
||||
}()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
|
||||
@@ -314,10 +314,10 @@ func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID i
|
||||
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@@ -65,8 +66,17 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
// Parse additional filters
|
||||
model := c.Query("model")
|
||||
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
@@ -114,6 +124,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
UserID: subject.UserID, // Always filter by current user for security
|
||||
APIKeyID: apiKeyID,
|
||||
Model: model,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: startTime,
|
||||
|
||||
80
backend/internal/handler/usage_handler_request_type_test.go
Normal file
80
backend/internal/handler/usage_handler_request_type_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type userUsageRepoCapture struct {
|
||||
service.UsageLogRepository
|
||||
listFilters usagestats.UsageLogFilters
|
||||
}
|
||||
|
||||
func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
s.listFilters = filters
|
||||
return []service.UsageLog{}, &pagination.PaginationResult{
|
||||
Total: 0,
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
Pages: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newUserUsageRequestTypeTestRouter(repo *userUsageRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
usageSvc := service.NewUsageService(repo, nil, nil, nil)
|
||||
handler := NewUsageHandler(usageSvc, nil)
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42})
|
||||
c.Next()
|
||||
})
|
||||
router.GET("/usage", handler.List)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestUserUsageListRequestTypePriority(t *testing.T) {
|
||||
repo := &userUsageRepoCapture{}
|
||||
router := newUserUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/usage?request_type=ws_v2&stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, int64(42), repo.listFilters.UserID)
|
||||
require.NotNil(t, repo.listFilters.RequestType)
|
||||
require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType)
|
||||
require.Nil(t, repo.listFilters.Stream)
|
||||
}
|
||||
|
||||
func TestUserUsageListInvalidRequestType(t *testing.T) {
|
||||
repo := &userUsageRepoCapture{}
|
||||
router := newUserUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/usage?request_type=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestUserUsageListInvalidStream(t *testing.T) {
|
||||
repo := &userUsageRepoCapture{}
|
||||
router := newUserUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/usage?stream=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
@@ -61,6 +61,22 @@ func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
|
||||
h := &GatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
panic("usage task panic")
|
||||
})
|
||||
})
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
pool := newUsageRecordTestPool(t)
|
||||
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
||||
@@ -98,6 +114,22 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
panic("usage task panic")
|
||||
})
|
||||
})
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
pool := newUsageRecordTestPool(t)
|
||||
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
|
||||
@@ -134,3 +166,19 @@ func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
h.submitUsageRecordTask(nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
|
||||
h := &SoraGatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
panic("usage task panic")
|
||||
})
|
||||
})
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ func ProvideAdminHandlers(
|
||||
groupHandler *admin.GroupHandler,
|
||||
accountHandler *admin.AccountHandler,
|
||||
announcementHandler *admin.AnnouncementHandler,
|
||||
dataManagementHandler *admin.DataManagementHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
@@ -35,6 +36,7 @@ func ProvideAdminHandlers(
|
||||
Group: groupHandler,
|
||||
Account: accountHandler,
|
||||
Announcement: announcementHandler,
|
||||
DataManagement: dataManagementHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
@@ -75,6 +77,7 @@ func ProvideHandlers(
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
soraGatewayHandler *SoraGatewayHandler,
|
||||
soraClientHandler *SoraClientHandler,
|
||||
settingHandler *SettingHandler,
|
||||
totpHandler *TotpHandler,
|
||||
_ *service.IdempotencyCoordinator,
|
||||
@@ -92,6 +95,7 @@ func ProvideHandlers(
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
SoraGateway: soraGatewayHandler,
|
||||
SoraClient: soraClientHandler,
|
||||
Setting: settingHandler,
|
||||
Totp: totpHandler,
|
||||
}
|
||||
@@ -119,6 +123,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewGroupHandler,
|
||||
admin.NewAccountHandler,
|
||||
admin.NewAnnouncementHandler,
|
||||
admin.NewDataManagementHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
|
||||
@@ -152,6 +152,7 @@ var claudeModels = []modelDef{
|
||||
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
{ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
|
||||
{ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
|
||||
}
|
||||
|
||||
@@ -165,6 +166,8 @@ var geminiModels = []modelDef{
|
||||
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3.1-pro-low", DisplayName: "Gemini 3.1 Pro Low", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3.1-pro-high", DisplayName: "Gemini 3.1 Pro High", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3.1-flash-image", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3.1-flash-image-preview", DisplayName: "Gemini 3.1 Flash Image Preview", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
}
|
||||
|
||||
26
backend/internal/pkg/antigravity/claude_types_test.go
Normal file
26
backend/internal/pkg/antigravity/claude_types_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package antigravity
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
models := DefaultModels()
|
||||
byID := make(map[string]ClaudeModel, len(models))
|
||||
for _, m := range models {
|
||||
byID[m.ID] = m
|
||||
}
|
||||
|
||||
requiredIDs := []string{
|
||||
"claude-opus-4-6-thinking",
|
||||
"gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview",
|
||||
"gemini-3-pro-image", // legacy compatibility
|
||||
}
|
||||
|
||||
for _, id := range requiredIDs {
|
||||
if _, ok := byID[id]; !ok {
|
||||
t.Fatalf("expected model %q to be exposed in DefaultModels", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -70,7 +70,7 @@ type GeminiGenerationConfig struct {
|
||||
ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiImageConfig Gemini 图片生成配置(仅 gemini-3-pro-image 支持)
|
||||
// GeminiImageConfig Gemini 图片生成配置(gemini-3-pro-image / gemini-3.1-flash-image 等图片模型支持)
|
||||
type GeminiImageConfig struct {
|
||||
AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4"
|
||||
ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K"
|
||||
|
||||
@@ -53,7 +53,8 @@ const (
|
||||
var defaultUserAgentVersion = "1.19.6"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
// 默认值使用占位符,生产环境请通过环境变量注入真实值。
|
||||
var defaultClientSecret = "GOCSPX-your-client-secret"
|
||||
|
||||
func init() {
|
||||
// 从环境变量读取版本号,未设置则使用默认值
|
||||
|
||||
@@ -612,14 +612,14 @@ func TestBuildAuthorizationURL_参数验证(t *testing.T) {
|
||||
|
||||
expectedParams := map[string]string{
|
||||
"client_id": ClientID,
|
||||
"redirect_uri": RedirectURI,
|
||||
"response_type": "code",
|
||||
"scope": Scopes,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
"redirect_uri": RedirectURI,
|
||||
"response_type": "code",
|
||||
"scope": Scopes,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
"include_granted_scopes": "true",
|
||||
}
|
||||
|
||||
@@ -684,7 +684,7 @@ func TestConstants_值正确(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
|
||||
}
|
||||
if secret != "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" {
|
||||
if secret != "GOCSPX-your-client-secret" {
|
||||
t.Errorf("默认 client_secret 不匹配: got %s", secret)
|
||||
}
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
|
||||
@@ -166,3 +166,18 @@ func TestToHTTP(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToHTTP_MetadataDeepCopy(t *testing.T) {
|
||||
md := map[string]string{"k": "v"}
|
||||
appErr := BadRequest("BAD_REQUEST", "invalid").WithMetadata(md)
|
||||
|
||||
code, body := ToHTTP(appErr)
|
||||
require.Equal(t, http.StatusBadRequest, code)
|
||||
require.Equal(t, "v", body.Metadata["k"])
|
||||
|
||||
md["k"] = "changed"
|
||||
require.Equal(t, "v", body.Metadata["k"])
|
||||
|
||||
appErr.Metadata["k"] = "changed-again"
|
||||
require.Equal(t, "v", body.Metadata["k"])
|
||||
}
|
||||
|
||||
@@ -16,6 +16,16 @@ func ToHTTP(err error) (statusCode int, body Status) {
|
||||
return http.StatusOK, Status{Code: int32(http.StatusOK)}
|
||||
}
|
||||
|
||||
cloned := Clone(appErr)
|
||||
return int(cloned.Code), cloned.Status
|
||||
body = Status{
|
||||
Code: appErr.Code,
|
||||
Reason: appErr.Reason,
|
||||
Message: appErr.Message,
|
||||
}
|
||||
if appErr.Metadata != nil {
|
||||
body.Metadata = make(map[string]string, len(appErr.Metadata))
|
||||
for k, v := range appErr.Metadata {
|
||||
body.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
return int(appErr.Code), body
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ const (
|
||||
// They enable the "login without creating your own OAuth client" experience, but Google may
|
||||
// restrict which scopes are allowed for this client.
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
GeminiCLIOAuthClientSecret = "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"
|
||||
GeminiCLIOAuthClientSecret = "GOCSPX-your-client-secret"
|
||||
|
||||
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
|
||||
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"
|
||||
|
||||
@@ -32,6 +32,7 @@ const (
|
||||
defaultMaxIdleConns = 100 // 最大空闲连接数
|
||||
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
|
||||
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时)
|
||||
validatedHostTTL = 30 * time.Second // DNS Rebinding 校验缓存 TTL
|
||||
)
|
||||
|
||||
// Options 定义共享 HTTP 客户端的构建参数
|
||||
@@ -53,6 +54,9 @@ type Options struct {
|
||||
// sharedClients 存储按配置参数缓存的 http.Client 实例
|
||||
var sharedClients sync.Map
|
||||
|
||||
// 允许测试替换校验函数,生产默认指向真实实现。
|
||||
var validateResolvedIP = urlvalidator.ValidateResolvedIP
|
||||
|
||||
// GetClient 返回共享的 HTTP 客户端实例
|
||||
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
|
||||
// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险
|
||||
@@ -84,7 +88,7 @@ func buildClient(opts Options) (*http.Client, error) {
|
||||
|
||||
var rt http.RoundTripper = transport
|
||||
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
|
||||
rt = &validatedTransport{base: transport}
|
||||
rt = newValidatedTransport(transport)
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: rt,
|
||||
@@ -149,17 +153,56 @@ func buildClientKey(opts Options) string {
|
||||
}
|
||||
|
||||
type validatedTransport struct {
|
||||
base http.RoundTripper
|
||||
base http.RoundTripper
|
||||
validatedHosts sync.Map // map[string]time.Time, value 为过期时间
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
func newValidatedTransport(base http.RoundTripper) *validatedTransport {
|
||||
return &validatedTransport{
|
||||
base: base,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *validatedTransport) isValidatedHost(host string, now time.Time) bool {
|
||||
if t == nil {
|
||||
return false
|
||||
}
|
||||
raw, ok := t.validatedHosts.Load(host)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
expireAt, ok := raw.(time.Time)
|
||||
if !ok {
|
||||
t.validatedHosts.Delete(host)
|
||||
return false
|
||||
}
|
||||
if now.Before(expireAt) {
|
||||
return true
|
||||
}
|
||||
t.validatedHosts.Delete(host)
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req != nil && req.URL != nil {
|
||||
host := strings.TrimSpace(req.URL.Hostname())
|
||||
host := strings.ToLower(strings.TrimSpace(req.URL.Hostname()))
|
||||
if host != "" {
|
||||
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
|
||||
return nil, err
|
||||
now := time.Now()
|
||||
if t != nil && t.now != nil {
|
||||
now = t.now()
|
||||
}
|
||||
if !t.isValidatedHost(host, now) {
|
||||
if err := validateResolvedIP(host); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.validatedHosts.Store(host, now.Add(validatedHostTTL))
|
||||
}
|
||||
}
|
||||
}
|
||||
if t == nil || t.base == nil {
|
||||
return nil, fmt.Errorf("validated transport base is nil")
|
||||
}
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
||||
|
||||
115
backend/internal/pkg/httpclient/pool_test.go
Normal file
115
backend/internal/pkg/httpclient/pool_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package httpclient
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestValidatedTransport_CacheHostValidation(t *testing.T) {
|
||||
originalValidate := validateResolvedIP
|
||||
defer func() { validateResolvedIP = originalValidate }()
|
||||
|
||||
var validateCalls int32
|
||||
validateResolvedIP = func(host string) error {
|
||||
atomic.AddInt32(&validateCalls, 1)
|
||||
require.Equal(t, "api.openai.com", host)
|
||||
return nil
|
||||
}
|
||||
|
||||
var baseCalls int32
|
||||
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
|
||||
atomic.AddInt32(&baseCalls, 1)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{}`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})
|
||||
|
||||
now := time.Unix(1730000000, 0)
|
||||
transport := newValidatedTransport(base)
|
||||
transport.now = func() time.Time { return now }
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = transport.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
_, err = transport.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&validateCalls))
|
||||
require.Equal(t, int32(2), atomic.LoadInt32(&baseCalls))
|
||||
}
|
||||
|
||||
func TestValidatedTransport_ExpiredCacheTriggersRevalidation(t *testing.T) {
|
||||
originalValidate := validateResolvedIP
|
||||
defer func() { validateResolvedIP = originalValidate }()
|
||||
|
||||
var validateCalls int32
|
||||
validateResolvedIP = func(_ string) error {
|
||||
atomic.AddInt32(&validateCalls, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{}`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})
|
||||
|
||||
now := time.Unix(1730001000, 0)
|
||||
transport := newValidatedTransport(base)
|
||||
transport.now = func() time.Time { return now }
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = transport.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
now = now.Add(validatedHostTTL + time.Second)
|
||||
_, err = transport.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int32(2), atomic.LoadInt32(&validateCalls))
|
||||
}
|
||||
|
||||
func TestValidatedTransport_ValidationErrorStopsRoundTrip(t *testing.T) {
|
||||
originalValidate := validateResolvedIP
|
||||
defer func() { validateResolvedIP = originalValidate }()
|
||||
|
||||
expectedErr := errors.New("dns rebinding rejected")
|
||||
validateResolvedIP = func(_ string) error {
|
||||
return expectedErr
|
||||
}
|
||||
|
||||
var baseCalls int32
|
||||
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
|
||||
atomic.AddInt32(&baseCalls, 1)
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{}`))}, nil
|
||||
})
|
||||
|
||||
transport := newValidatedTransport(base)
|
||||
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = transport.RoundTrip(req)
|
||||
require.ErrorIs(t, err, expectedErr)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&baseCalls))
|
||||
}
|
||||
37
backend/internal/pkg/httputil/body.go
Normal file
37
backend/internal/pkg/httputil/body.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package httputil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
requestBodyReadInitCap = 512
|
||||
requestBodyReadMaxInitCap = 1 << 20
|
||||
)
|
||||
|
||||
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
|
||||
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
|
||||
if req == nil || req.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
capHint := requestBodyReadInitCap
|
||||
if req.ContentLength > 0 {
|
||||
switch {
|
||||
case req.ContentLength < int64(requestBodyReadInitCap):
|
||||
capHint = requestBodyReadInitCap
|
||||
case req.ContentLength > int64(requestBodyReadMaxInitCap):
|
||||
capHint = requestBodyReadMaxInitCap
|
||||
default:
|
||||
capHint = int(req.ContentLength)
|
||||
}
|
||||
}
|
||||
|
||||
buf := bytes.NewBuffer(make([]byte, 0, capHint))
|
||||
if _, err := io.Copy(buf, req.Body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
@@ -67,6 +67,14 @@ func normalizeIP(ip string) string {
|
||||
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
|
||||
var privateNets []*net.IPNet
|
||||
|
||||
// CompiledIPRules 表示预编译的 IP 匹配规则。
|
||||
// PatternCount 记录原始规则数量,用于保留“规则存在但全无效”时的行为语义。
|
||||
type CompiledIPRules struct {
|
||||
CIDRs []*net.IPNet
|
||||
IPs []net.IP
|
||||
PatternCount int
|
||||
}
|
||||
|
||||
func init() {
|
||||
for _, cidr := range []string{
|
||||
"10.0.0.0/8",
|
||||
@@ -84,6 +92,53 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// CompileIPRules 将 IP/CIDR 字符串规则预编译为可复用结构。
|
||||
// 非法规则会被忽略,但 PatternCount 会保留原始规则条数。
|
||||
func CompileIPRules(patterns []string) *CompiledIPRules {
|
||||
compiled := &CompiledIPRules{
|
||||
CIDRs: make([]*net.IPNet, 0, len(patterns)),
|
||||
IPs: make([]net.IP, 0, len(patterns)),
|
||||
PatternCount: len(patterns),
|
||||
}
|
||||
for _, pattern := range patterns {
|
||||
normalized := strings.TrimSpace(pattern)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(normalized, "/") {
|
||||
_, cidr, err := net.ParseCIDR(normalized)
|
||||
if err != nil || cidr == nil {
|
||||
continue
|
||||
}
|
||||
compiled.CIDRs = append(compiled.CIDRs, cidr)
|
||||
continue
|
||||
}
|
||||
parsedIP := net.ParseIP(normalized)
|
||||
if parsedIP == nil {
|
||||
continue
|
||||
}
|
||||
compiled.IPs = append(compiled.IPs, parsedIP)
|
||||
}
|
||||
return compiled
|
||||
}
|
||||
|
||||
func matchesCompiledRules(parsedIP net.IP, rules *CompiledIPRules) bool {
|
||||
if parsedIP == nil || rules == nil {
|
||||
return false
|
||||
}
|
||||
for _, cidr := range rules.CIDRs {
|
||||
if cidr.Contains(parsedIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, ruleIP := range rules.IPs {
|
||||
if parsedIP.Equal(ruleIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isPrivateIP 检查 IP 是否为私有地址。
|
||||
func isPrivateIP(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
@@ -142,19 +197,32 @@ func MatchesAnyPattern(clientIP string, patterns []string) bool {
|
||||
// 2. 如果白名单不为空,IP 必须在白名单中
|
||||
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
|
||||
func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) {
|
||||
return CheckIPRestrictionWithCompiledRules(
|
||||
clientIP,
|
||||
CompileIPRules(whitelist),
|
||||
CompileIPRules(blacklist),
|
||||
)
|
||||
}
|
||||
|
||||
// CheckIPRestrictionWithCompiledRules 使用预编译规则检查 IP 是否允许访问。
|
||||
func CheckIPRestrictionWithCompiledRules(clientIP string, whitelist, blacklist *CompiledIPRules) (bool, string) {
|
||||
// 规范化 IP
|
||||
clientIP = normalizeIP(clientIP)
|
||||
if clientIP == "" {
|
||||
return false, "access denied"
|
||||
}
|
||||
parsedIP := net.ParseIP(clientIP)
|
||||
if parsedIP == nil {
|
||||
return false, "access denied"
|
||||
}
|
||||
|
||||
// 1. 检查黑名单
|
||||
if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) {
|
||||
if blacklist != nil && blacklist.PatternCount > 0 && matchesCompiledRules(parsedIP, blacklist) {
|
||||
return false, "access denied"
|
||||
}
|
||||
|
||||
// 2. 检查白名单(如果设置了白名单,IP 必须在其中)
|
||||
if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) {
|
||||
if whitelist != nil && whitelist.PatternCount > 0 && !matchesCompiledRules(parsedIP, whitelist) {
|
||||
return false, "access denied"
|
||||
}
|
||||
|
||||
|
||||
@@ -73,3 +73,24 @@ func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
|
||||
require.Equal(t, 200, w.Code)
|
||||
require.Equal(t, "9.9.9.9", w.Body.String())
|
||||
}
|
||||
|
||||
func TestCheckIPRestrictionWithCompiledRules(t *testing.T) {
|
||||
whitelist := CompileIPRules([]string{"10.0.0.0/8", "192.168.1.2"})
|
||||
blacklist := CompileIPRules([]string{"10.1.1.1"})
|
||||
|
||||
allowed, reason := CheckIPRestrictionWithCompiledRules("10.2.3.4", whitelist, blacklist)
|
||||
require.True(t, allowed)
|
||||
require.Equal(t, "", reason)
|
||||
|
||||
allowed, reason = CheckIPRestrictionWithCompiledRules("10.1.1.1", whitelist, blacklist)
|
||||
require.False(t, allowed)
|
||||
require.Equal(t, "access denied", reason)
|
||||
}
|
||||
|
||||
func TestCheckIPRestrictionWithCompiledRules_InvalidWhitelistStillDenies(t *testing.T) {
|
||||
// 与旧实现保持一致:白名单有配置但全无效时,最终应拒绝访问。
|
||||
invalidWhitelist := CompileIPRules([]string{"not-a-valid-pattern"})
|
||||
allowed, reason := CheckIPRestrictionWithCompiledRules("8.8.8.8", invalidWhitelist, nil)
|
||||
require.False(t, allowed)
|
||||
require.Equal(t, "access denied", reason)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
@@ -42,15 +43,19 @@ type LogEvent struct {
|
||||
|
||||
var (
|
||||
mu sync.RWMutex
|
||||
global *zap.Logger
|
||||
sugar *zap.SugaredLogger
|
||||
global atomic.Pointer[zap.Logger]
|
||||
sugar atomic.Pointer[zap.SugaredLogger]
|
||||
atomicLevel zap.AtomicLevel
|
||||
initOptions InitOptions
|
||||
currentSink Sink
|
||||
currentSink atomic.Value // sinkState
|
||||
stdLogUndo func()
|
||||
bootstrapOnce sync.Once
|
||||
)
|
||||
|
||||
type sinkState struct {
|
||||
sink Sink
|
||||
}
|
||||
|
||||
func InitBootstrap() {
|
||||
bootstrapOnce.Do(func() {
|
||||
if err := Init(bootstrapOptions()); err != nil {
|
||||
@@ -72,9 +77,9 @@ func initLocked(options InitOptions) error {
|
||||
return err
|
||||
}
|
||||
|
||||
prev := global
|
||||
global = zl
|
||||
sugar = zl.Sugar()
|
||||
prev := global.Load()
|
||||
global.Store(zl)
|
||||
sugar.Store(zl.Sugar())
|
||||
atomicLevel = al
|
||||
initOptions = normalized
|
||||
|
||||
@@ -115,24 +120,32 @@ func SetLevel(level string) error {
|
||||
func CurrentLevel() string {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if global == nil {
|
||||
if global.Load() == nil {
|
||||
return "info"
|
||||
}
|
||||
return atomicLevel.Level().String()
|
||||
}
|
||||
|
||||
func SetSink(sink Sink) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
currentSink = sink
|
||||
currentSink.Store(sinkState{sink: sink})
|
||||
}
|
||||
|
||||
func loadSink() Sink {
|
||||
v := currentSink.Load()
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
state, ok := v.(sinkState)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return state.sink
|
||||
}
|
||||
|
||||
// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。
|
||||
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
|
||||
func WriteSinkEvent(level, component, message string, fields map[string]any) {
|
||||
mu.RLock()
|
||||
sink := currentSink
|
||||
mu.RUnlock()
|
||||
sink := loadSink()
|
||||
if sink == nil {
|
||||
return
|
||||
}
|
||||
@@ -168,19 +181,15 @@ func WriteSinkEvent(level, component, message string, fields map[string]any) {
|
||||
}
|
||||
|
||||
func L() *zap.Logger {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if global != nil {
|
||||
return global
|
||||
if l := global.Load(); l != nil {
|
||||
return l
|
||||
}
|
||||
return zap.NewNop()
|
||||
}
|
||||
|
||||
func S() *zap.SugaredLogger {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if sugar != nil {
|
||||
return sugar
|
||||
if s := sugar.Load(); s != nil {
|
||||
return s
|
||||
}
|
||||
return zap.NewNop().Sugar()
|
||||
}
|
||||
@@ -190,9 +199,7 @@ func With(fields ...zap.Field) *zap.Logger {
|
||||
}
|
||||
|
||||
func Sync() {
|
||||
mu.RLock()
|
||||
l := global
|
||||
mu.RUnlock()
|
||||
l := global.Load()
|
||||
if l != nil {
|
||||
_ = l.Sync()
|
||||
}
|
||||
@@ -210,7 +217,11 @@ func bridgeStdLogLocked() {
|
||||
|
||||
log.SetFlags(0)
|
||||
log.SetPrefix("")
|
||||
log.SetOutput(newStdLogBridge(global.Named("stdlog")))
|
||||
base := global.Load()
|
||||
if base == nil {
|
||||
base = zap.NewNop()
|
||||
}
|
||||
log.SetOutput(newStdLogBridge(base.Named("stdlog")))
|
||||
|
||||
stdLogUndo = func() {
|
||||
log.SetOutput(prevWriter)
|
||||
@@ -220,7 +231,11 @@ func bridgeStdLogLocked() {
|
||||
}
|
||||
|
||||
func bridgeSlogLocked() {
|
||||
slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog"))))
|
||||
base := global.Load()
|
||||
if base == nil {
|
||||
base = zap.NewNop()
|
||||
}
|
||||
slog.SetDefault(slog.New(newSlogZapHandler(base.Named("slog"))))
|
||||
}
|
||||
|
||||
func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) {
|
||||
@@ -363,9 +378,7 @@ func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore
|
||||
func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
|
||||
// Only handle sink forwarding — the inner cores write via their own
|
||||
// Write methods (added to CheckedEntry by s.core.Check above).
|
||||
mu.RLock()
|
||||
sink := currentSink
|
||||
mu.RUnlock()
|
||||
sink := loadSink()
|
||||
if sink == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -454,7 +467,7 @@ func inferStdLogLevel(msg string) Level {
|
||||
if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") {
|
||||
return LevelError
|
||||
}
|
||||
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
|
||||
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
|
||||
return LevelWarn
|
||||
}
|
||||
return LevelInfo
|
||||
@@ -467,9 +480,7 @@ func LegacyPrintf(component, format string, args ...any) {
|
||||
return
|
||||
}
|
||||
|
||||
mu.RLock()
|
||||
initialized := global != nil
|
||||
mu.RUnlock()
|
||||
initialized := global.Load() != nil
|
||||
if !initialized {
|
||||
// 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。
|
||||
log.Print(msg)
|
||||
|
||||
@@ -48,16 +48,15 @@ func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
|
||||
return true
|
||||
})
|
||||
|
||||
entry := h.logger.With(fields...)
|
||||
switch {
|
||||
case record.Level >= slog.LevelError:
|
||||
entry.Error(record.Message)
|
||||
h.logger.Error(record.Message, fields...)
|
||||
case record.Level >= slog.LevelWarn:
|
||||
entry.Warn(record.Message)
|
||||
h.logger.Warn(record.Message, fields...)
|
||||
case record.Level <= slog.LevelDebug:
|
||||
entry.Debug(record.Message)
|
||||
h.logger.Debug(record.Message, fields...)
|
||||
default:
|
||||
entry.Info(record.Message)
|
||||
h.logger.Info(record.Message, fields...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ func TestInferStdLogLevel(t *testing.T) {
|
||||
{msg: "Warning: queue full", want: LevelWarn},
|
||||
{msg: "Forward request failed: timeout", want: LevelError},
|
||||
{msg: "[ERROR] upstream unavailable", want: LevelError},
|
||||
{msg: "[OpenAI WS Mode] reconnect_retry account_id=22 retry=1 max_retries=5", want: LevelInfo},
|
||||
{msg: "service started", want: LevelInfo},
|
||||
{msg: "debug: cache miss", want: LevelDebug},
|
||||
}
|
||||
|
||||
@@ -36,10 +36,18 @@ const (
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
const (
|
||||
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
|
||||
OAuthPlatformOpenAI = "openai"
|
||||
// OAuthPlatformSora uses Sora OAuth client.
|
||||
OAuthPlatformSora = "sora"
|
||||
)
|
||||
|
||||
// OAuthSession stores OAuth flow state for OpenAI
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
@@ -174,13 +182,20 @@ func base64URLEncode(data []byte) string {
|
||||
|
||||
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
|
||||
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
|
||||
return BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, OAuthPlatformOpenAI)
|
||||
}
|
||||
|
||||
// BuildAuthorizationURLForPlatform builds authorization URL by platform.
|
||||
func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platform string) string {
|
||||
if redirectURI == "" {
|
||||
redirectURI = DefaultRedirectURI
|
||||
}
|
||||
|
||||
clientID, codexFlow := OAuthClientConfigByPlatform(platform)
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_id", clientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scope", DefaultScopes)
|
||||
params.Set("state", state)
|
||||
@@ -188,11 +203,25 @@ func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
|
||||
params.Set("code_challenge_method", "S256")
|
||||
// OpenAI specific parameters
|
||||
params.Set("id_token_add_organizations", "true")
|
||||
params.Set("codex_cli_simplified_flow", "true")
|
||||
if codexFlow {
|
||||
params.Set("codex_cli_simplified_flow", "true")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
|
||||
// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri),
|
||||
// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
|
||||
func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(platform)) {
|
||||
case OAuthPlatformSora:
|
||||
return ClientID, false
|
||||
default:
|
||||
return ClientID, true
|
||||
}
|
||||
}
|
||||
|
||||
// TokenRequest represents the token exchange request body
|
||||
type TokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
@@ -296,9 +325,11 @@ func (r *RefreshTokenRequest) ToFormData() string {
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims
|
||||
// Note: This does NOT verify the signature - it only decodes the payload
|
||||
// For production, you should verify the token signature using OpenAI's public keys
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims.
|
||||
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
||||
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
||||
//
|
||||
// https://auth.openai.com/.well-known/jwks.json
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -329,6 +360,13 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
|
||||
const clockSkewTolerance = 120 // 秒
|
||||
now := time.Now().Unix()
|
||||
if claims.Exp > 0 && now > claims.Exp+clockSkewTolerance {
|
||||
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -41,3 +42,41 @@ func TestSessionStore_Stop_Concurrent(t *testing.T) {
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
|
||||
authURL := BuildAuthorizationURLForPlatform("state-1", "challenge-1", DefaultRedirectURI, OAuthPlatformOpenAI)
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse URL failed: %v", err)
|
||||
}
|
||||
q := parsed.Query()
|
||||
if got := q.Get("client_id"); got != ClientID {
|
||||
t.Fatalf("client_id mismatch: got=%q want=%q", got, ClientID)
|
||||
}
|
||||
if got := q.Get("codex_cli_simplified_flow"); got != "true" {
|
||||
t.Fatalf("codex flow mismatch: got=%q want=true", got)
|
||||
}
|
||||
if got := q.Get("id_token_add_organizations"); got != "true" {
|
||||
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
|
||||
// 但不启用 codex_cli_simplified_flow。
|
||||
func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) {
|
||||
authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora)
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse URL failed: %v", err)
|
||||
}
|
||||
q := parsed.Query()
|
||||
if got := q.Get("client_id"); got != ClientID {
|
||||
t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID)
|
||||
}
|
||||
if got := q.Get("codex_cli_simplified_flow"); got != "" {
|
||||
t.Fatalf("codex flow should be empty for sora, got=%q", got)
|
||||
}
|
||||
if got := q.Get("id_token_add_organizations"); got != "true" {
|
||||
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,10 +29,10 @@ func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, P
|
||||
t.Helper()
|
||||
// 先用 raw json 解析,因为 Data 是 any 类型
|
||||
var raw struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))
|
||||
|
||||
|
||||
@@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
|
||||
"cipher_suites", len(spec.CipherSuites),
|
||||
"extensions", len(spec.Extensions),
|
||||
"compression_methods", spec.CompressionMethods,
|
||||
"tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax),
|
||||
"tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin))
|
||||
"tls_vers_max", spec.TLSVersMax,
|
||||
"tls_vers_min", spec.TLSVersMin)
|
||||
|
||||
if d.profile != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
|
||||
@@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
|
||||
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_socks5_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"version", state.Version,
|
||||
"cipher_suite", state.CipherSuite,
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
@@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri
|
||||
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_http_proxy_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"version", state.Version,
|
||||
"cipher_suite", state.CipherSuite,
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
@@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.
|
||||
// Log successful handshake details
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"version", state.Version,
|
||||
"cipher_suite", state.CipherSuite,
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
|
||||
@@ -139,6 +139,7 @@ type UsageLogFilters struct {
|
||||
AccountID int64
|
||||
GroupID int64
|
||||
Model string
|
||||
RequestType *int16
|
||||
Stream *bool
|
||||
BillingType *int8
|
||||
StartTime *time.Time
|
||||
|
||||
@@ -50,11 +50,6 @@ type accountRepository struct {
|
||||
schedulerCache service.SchedulerCache
|
||||
}
|
||||
|
||||
type tempUnschedSnapshot struct {
|
||||
until *time.Time
|
||||
reason string
|
||||
}
|
||||
|
||||
// NewAccountRepository 创建账户仓储实例。
|
||||
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
|
||||
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
|
||||
@@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
|
||||
accountIDs = append(accountIDs, acc.ID)
|
||||
}
|
||||
|
||||
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
|
||||
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
|
||||
out.AccountGroups = ags
|
||||
}
|
||||
if snap, ok := tempUnschedMap[entAcc.ID]; ok {
|
||||
out.TempUnschedulableUntil = snap.until
|
||||
out.TempUnschedulableReason = snap.reason
|
||||
}
|
||||
outByID[entAcc.ID] = out
|
||||
}
|
||||
|
||||
@@ -611,6 +597,43 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
|
||||
}
|
||||
}
|
||||
|
||||
func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) {
|
||||
if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
uniqueIDs := make([]int64, 0, len(accountIDs))
|
||||
seen := make(map[int64]struct{}, len(accountIDs))
|
||||
for _, id := range accountIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[id]; exists {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
uniqueIDs = append(uniqueIDs, id)
|
||||
}
|
||||
if len(uniqueIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
accounts, err := r.GetByIDs(ctx, uniqueIDs)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot read failed: count=%d err=%v", len(uniqueIDs), err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, account := range accounts {
|
||||
if account == nil {
|
||||
continue
|
||||
}
|
||||
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot write failed: id=%d err=%v", account.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
_, err := r.client.Account.Update().
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
@@ -1197,9 +1220,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
shouldSync = true
|
||||
}
|
||||
if shouldSync {
|
||||
for _, id := range ids {
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshots(ctx, ids)
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
@@ -1291,10 +1312,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1320,10 +1337,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
|
||||
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
|
||||
out.AccountGroups = ags
|
||||
}
|
||||
if snap, ok := tempUnschedMap[acc.ID]; ok {
|
||||
out.TempUnschedulableUntil = snap.until
|
||||
out.TempUnschedulableReason = snap.reason
|
||||
}
|
||||
outAccounts = append(outAccounts, *out)
|
||||
}
|
||||
|
||||
@@ -1348,48 +1361,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account {
|
||||
)
|
||||
}
|
||||
|
||||
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
|
||||
out := make(map[int64]tempUnschedSnapshot)
|
||||
if len(accountIDs) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT id, temp_unschedulable_until, temp_unschedulable_reason
|
||||
FROM accounts
|
||||
WHERE id = ANY($1)
|
||||
`, pq.Array(accountIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var until sql.NullTime
|
||||
var reason sql.NullString
|
||||
if err := rows.Scan(&id, &until, &reason); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var untilPtr *time.Time
|
||||
if until.Valid {
|
||||
tmp := until.Time
|
||||
untilPtr = &tmp
|
||||
}
|
||||
if reason.Valid {
|
||||
out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String}
|
||||
} else {
|
||||
out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
|
||||
proxyMap := make(map[int64]*service.Proxy)
|
||||
if len(proxyIDs) == 0 {
|
||||
@@ -1500,31 +1471,33 @@ func accountEntityToService(m *dbent.Account) *service.Account {
|
||||
rateMultiplier := m.RateMultiplier
|
||||
|
||||
return &service.Account{
|
||||
ID: m.ID,
|
||||
Name: m.Name,
|
||||
Notes: m.Notes,
|
||||
Platform: m.Platform,
|
||||
Type: m.Type,
|
||||
Credentials: copyJSONMap(m.Credentials),
|
||||
Extra: copyJSONMap(m.Extra),
|
||||
ProxyID: m.ProxyID,
|
||||
Concurrency: m.Concurrency,
|
||||
Priority: m.Priority,
|
||||
RateMultiplier: &rateMultiplier,
|
||||
Status: m.Status,
|
||||
ErrorMessage: derefString(m.ErrorMessage),
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
AutoPauseOnExpired: m.AutoPauseOnExpired,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
Schedulable: m.Schedulable,
|
||||
RateLimitedAt: m.RateLimitedAt,
|
||||
RateLimitResetAt: m.RateLimitResetAt,
|
||||
OverloadUntil: m.OverloadUntil,
|
||||
SessionWindowStart: m.SessionWindowStart,
|
||||
SessionWindowEnd: m.SessionWindowEnd,
|
||||
SessionWindowStatus: derefString(m.SessionWindowStatus),
|
||||
ID: m.ID,
|
||||
Name: m.Name,
|
||||
Notes: m.Notes,
|
||||
Platform: m.Platform,
|
||||
Type: m.Type,
|
||||
Credentials: copyJSONMap(m.Credentials),
|
||||
Extra: copyJSONMap(m.Extra),
|
||||
ProxyID: m.ProxyID,
|
||||
Concurrency: m.Concurrency,
|
||||
Priority: m.Priority,
|
||||
RateMultiplier: &rateMultiplier,
|
||||
Status: m.Status,
|
||||
ErrorMessage: derefString(m.ErrorMessage),
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
AutoPauseOnExpired: m.AutoPauseOnExpired,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
Schedulable: m.Schedulable,
|
||||
RateLimitedAt: m.RateLimitedAt,
|
||||
RateLimitResetAt: m.RateLimitResetAt,
|
||||
OverloadUntil: m.OverloadUntil,
|
||||
TempUnschedulableUntil: m.TempUnschedulableUntil,
|
||||
TempUnschedulableReason: derefString(m.TempUnschedulableReason),
|
||||
SessionWindowStart: m.SessionWindowStart,
|
||||
SessionWindowEnd: m.SessionWindowEnd,
|
||||
SessionWindowStatus: derefString(m.SessionWindowStatus),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -500,6 +500,38 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
|
||||
s.Require().Nil(got.OverloadUntil)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestTempUnschedulableFieldsLoadedByGetByIDAndGetByIDs() {
|
||||
acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-1"})
|
||||
acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-2"})
|
||||
|
||||
until := time.Now().Add(15 * time.Minute).UTC().Truncate(time.Second)
|
||||
reason := `{"rule":"429","matched_keyword":"too many requests"}`
|
||||
s.Require().NoError(s.repo.SetTempUnschedulable(s.ctx, acc1.ID, until, reason))
|
||||
|
||||
gotByID, err := s.repo.GetByID(s.ctx, acc1.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(gotByID.TempUnschedulableUntil)
|
||||
s.Require().WithinDuration(until, *gotByID.TempUnschedulableUntil, time.Second)
|
||||
s.Require().Equal(reason, gotByID.TempUnschedulableReason)
|
||||
|
||||
gotByIDs, err := s.repo.GetByIDs(s.ctx, []int64{acc2.ID, acc1.ID})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(gotByIDs, 2)
|
||||
s.Require().Equal(acc2.ID, gotByIDs[0].ID)
|
||||
s.Require().Nil(gotByIDs[0].TempUnschedulableUntil)
|
||||
s.Require().Equal("", gotByIDs[0].TempUnschedulableReason)
|
||||
s.Require().Equal(acc1.ID, gotByIDs[1].ID)
|
||||
s.Require().NotNil(gotByIDs[1].TempUnschedulableUntil)
|
||||
s.Require().WithinDuration(until, *gotByIDs[1].TempUnschedulableUntil, time.Second)
|
||||
s.Require().Equal(reason, gotByIDs[1].TempUnschedulableReason)
|
||||
|
||||
s.Require().NoError(s.repo.ClearTempUnschedulable(s.ctx, acc1.ID))
|
||||
cleared, err := s.repo.GetByID(s.ctx, acc1.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(cleared.TempUnschedulableUntil)
|
||||
s.Require().Equal("", cleared.TempUnschedulableReason)
|
||||
}
|
||||
|
||||
// --- UpdateLastUsed ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateLastUsed() {
|
||||
|
||||
@@ -445,20 +445,22 @@ func userEntityToService(u *dbent.User) *service.User {
|
||||
return nil
|
||||
}
|
||||
return &service.User{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||
TotpEnabled: u.TotpEnabled,
|
||||
TotpEnabledAt: u.TotpEnabledAt,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
|
||||
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
|
||||
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||
TotpEnabled: u.TotpEnabled,
|
||||
TotpEnabledAt: u.TotpEnabledAt,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -486,6 +488,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
SoraImagePrice540: g.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
|
||||
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
|
||||
@@ -227,6 +227,43 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return map[int64]int{}, nil
|
||||
}
|
||||
|
||||
now, err := c.rdb.Time(ctx).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("redis TIME: %w", err)
|
||||
}
|
||||
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
type accountCmd struct {
|
||||
accountID int64
|
||||
zcardCmd *redis.IntCmd
|
||||
}
|
||||
cmds := make([]accountCmd, 0, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
slotKey := accountSlotKeyPrefix + strconv.FormatInt(accountID, 10)
|
||||
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
|
||||
cmds = append(cmds, accountCmd{
|
||||
accountID: accountID,
|
||||
zcardCmd: pipe.ZCard(ctx, slotKey),
|
||||
})
|
||||
}
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
|
||||
return nil, fmt.Errorf("pipeline exec: %w", err)
|
||||
}
|
||||
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, cmd := range cmds {
|
||||
result[cmd.accountID] = int(cmd.zcardCmd.Val())
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// User slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
|
||||
@@ -104,7 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
||||
}
|
||||
|
||||
|
||||
func TestGatewayCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayCacheSuite))
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
@@ -56,7 +58,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
||||
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject)
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
|
||||
|
||||
// 设置模型路由配置
|
||||
if groupIn.ModelRouting != nil {
|
||||
@@ -121,7 +124,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject)
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
|
||||
|
||||
// 处理 FallbackGroupID:nil 时清除,否则设置
|
||||
if groupIn.FallbackGroupID != nil {
|
||||
@@ -281,6 +285,54 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool,
|
||||
return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx)
|
||||
}
|
||||
|
||||
// ExistsByIDs 批量检查分组是否存在(仅检查未软删除记录)。
|
||||
// 返回结构:map[groupID]exists。
|
||||
func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) {
|
||||
result := make(map[int64]bool, len(ids))
|
||||
if len(ids) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
uniqueIDs := make([]int64, 0, len(ids))
|
||||
seen := make(map[int64]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
uniqueIDs = append(uniqueIDs, id)
|
||||
result[id] = false
|
||||
}
|
||||
if len(uniqueIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT id
|
||||
FROM groups
|
||||
WHERE id = ANY($1) AND deleted_at IS NULL
|
||||
`, pq.Array(uniqueIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[id] = true
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
|
||||
@@ -512,22 +564,72 @@ func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []servic
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用事务批量更新
|
||||
tx, err := r.client.Tx(ctx)
|
||||
// 去重后保留最后一次排序值,避免重复 ID 造成 CASE 分支冲突。
|
||||
sortOrderByID := make(map[int64]int, len(updates))
|
||||
groupIDs := make([]int64, 0, len(updates))
|
||||
for _, u := range updates {
|
||||
if u.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, exists := sortOrderByID[u.ID]; !exists {
|
||||
groupIDs = append(groupIDs, u.ID)
|
||||
}
|
||||
sortOrderByID[u.ID] = u.SortOrder
|
||||
}
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 与旧实现保持一致:任何不存在/已删除的分组都返回 not found,且不执行更新。
|
||||
var existingCount int
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
`SELECT COUNT(*) FROM groups WHERE deleted_at IS NULL AND id = ANY($1)`,
|
||||
[]any{pq.Array(groupIDs)},
|
||||
&existingCount,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
if existingCount != len(groupIDs) {
|
||||
return service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
args := make([]any, 0, len(groupIDs)*2+1)
|
||||
caseClauses := make([]string, 0, len(groupIDs))
|
||||
placeholder := 1
|
||||
for _, id := range groupIDs {
|
||||
caseClauses = append(caseClauses, fmt.Sprintf("WHEN $%d THEN $%d", placeholder, placeholder+1))
|
||||
args = append(args, id, sortOrderByID[id])
|
||||
placeholder += 2
|
||||
}
|
||||
args = append(args, pq.Array(groupIDs))
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE groups
|
||||
SET sort_order = CASE id
|
||||
%s
|
||||
ELSE sort_order
|
||||
END
|
||||
WHERE deleted_at IS NULL AND id = ANY($%d)
|
||||
`, strings.Join(caseClauses, "\n\t\t\t"), placeholder)
|
||||
|
||||
result, err := r.sql.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
for _, u := range updates {
|
||||
if _, err := tx.Group.UpdateOneID(u.ID).SetSortOrder(u.SortOrder).Save(ctx); err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected != int64(len(groupIDs)) {
|
||||
return service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
for _, id := range groupIDs {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group sort update failed: group=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -352,6 +352,81 @@ func (s *GroupRepoSuite) TestListWithFilters_Search() {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestUpdateSortOrders_BatchCaseWhen() {
|
||||
g1 := &service.Group{
|
||||
Name: "sort-g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
g2 := &service.Group{
|
||||
Name: "sort-g2",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
g3 := &service.Group{
|
||||
Name: "sort-g3",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g1))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g2))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g3))
|
||||
|
||||
err := s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{
|
||||
{ID: g1.ID, SortOrder: 30},
|
||||
{ID: g2.ID, SortOrder: 10},
|
||||
{ID: g3.ID, SortOrder: 20},
|
||||
{ID: g2.ID, SortOrder: 15}, // 重复 ID 应以最后一次为准
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
got1, err := s.repo.GetByID(s.ctx, g1.ID)
|
||||
s.Require().NoError(err)
|
||||
got2, err := s.repo.GetByID(s.ctx, g2.ID)
|
||||
s.Require().NoError(err)
|
||||
got3, err := s.repo.GetByID(s.ctx, g3.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(30, got1.SortOrder)
|
||||
s.Require().Equal(15, got2.SortOrder)
|
||||
s.Require().Equal(20, got3.SortOrder)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestUpdateSortOrders_MissingGroupNoPartialUpdate() {
|
||||
g1 := &service.Group{
|
||||
Name: "sort-no-partial",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g1))
|
||||
|
||||
before, err := s.repo.GetByID(s.ctx, g1.ID)
|
||||
s.Require().NoError(err)
|
||||
beforeSort := before.SortOrder
|
||||
|
||||
err = s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{
|
||||
{ID: g1.ID, SortOrder: 99},
|
||||
{ID: 99999999, SortOrder: 1},
|
||||
})
|
||||
s.Require().Error(err)
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
|
||||
after, err := s.repo.GetByID(s.ctx, g1.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(beforeSort, after.SortOrder)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
g1 := &service.Group{
|
||||
Name: "g1",
|
||||
|
||||
@@ -147,4 +147,3 @@ func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) {
|
||||
require.Equal(t, `{"ok":true}`, *got.ResponseBody)
|
||||
require.Nil(t, got.LockedUntil)
|
||||
}
|
||||
|
||||
|
||||
@@ -50,6 +50,23 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
|
||||
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
|
||||
const migrationsAdvisoryLockID int64 = 694208311321144027
|
||||
const migrationsLockRetryInterval = 500 * time.Millisecond
|
||||
const nonTransactionalMigrationSuffix = "_notx.sql"
|
||||
|
||||
type migrationChecksumCompatibilityRule struct {
|
||||
fileChecksum string
|
||||
acceptedDBChecksum map[string]struct{}
|
||||
}
|
||||
|
||||
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
|
||||
// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。
|
||||
var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{
|
||||
"054_drop_legacy_cache_columns.sql": {
|
||||
fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
|
||||
acceptedDBChecksum: map[string]struct{}{
|
||||
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
|
||||
//
|
||||
@@ -147,6 +164,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
if rowErr == nil {
|
||||
// 迁移已应用,验证校验和是否匹配
|
||||
if existing != checksum {
|
||||
// 兼容特定历史误改场景(仅白名单规则),其余仍保持严格不可变约束。
|
||||
if isMigrationChecksumCompatible(name, existing, checksum) {
|
||||
continue
|
||||
}
|
||||
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
|
||||
// 正确的做法是创建新的迁移文件来进行变更。
|
||||
return fmt.Errorf(
|
||||
@@ -165,8 +186,34 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
return fmt.Errorf("check migration %s: %w", name, rowErr)
|
||||
}
|
||||
|
||||
// 迁移未应用,在事务中执行。
|
||||
// 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。
|
||||
nonTx, err := validateMigrationExecutionMode(name, content)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validate migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
if nonTx {
|
||||
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
|
||||
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
|
||||
statements := splitSQLStatements(content)
|
||||
for i, stmt := range statements {
|
||||
trimmed := strings.TrimSpace(stmt)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if stripSQLLineComment(trimmed) == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := db.ExecContext(ctx, trimmed); err != nil {
|
||||
return fmt.Errorf("apply migration %s (non-tx statement %d): %w", name, i+1, err)
|
||||
}
|
||||
}
|
||||
if _, err := db.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
|
||||
return fmt.Errorf("record migration %s (non-tx): %w", name, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 默认迁移在事务中执行,确保原子性:要么完全成功,要么完全回滚。
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin migration %s: %w", name, err)
|
||||
@@ -268,6 +315,84 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
|
||||
return version, version, hash, nil
|
||||
}
|
||||
|
||||
func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool {
|
||||
rule, ok := migrationChecksumCompatibilityRules[name]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if rule.fileChecksum != fileChecksum {
|
||||
return false
|
||||
}
|
||||
_, ok = rule.acceptedDBChecksum[dbChecksum]
|
||||
return ok
|
||||
}
|
||||
|
||||
func validateMigrationExecutionMode(name, content string) (bool, error) {
|
||||
normalizedName := strings.ToLower(strings.TrimSpace(name))
|
||||
upperContent := strings.ToUpper(content)
|
||||
nonTx := strings.HasSuffix(normalizedName, nonTransactionalMigrationSuffix)
|
||||
|
||||
if !nonTx {
|
||||
if strings.Contains(upperContent, "CONCURRENTLY") {
|
||||
return false, errors.New("CONCURRENTLY statements must be placed in *_notx.sql migrations")
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if strings.Contains(upperContent, "BEGIN") || strings.Contains(upperContent, "COMMIT") || strings.Contains(upperContent, "ROLLBACK") {
|
||||
return false, errors.New("*_notx.sql must not contain transaction control statements (BEGIN/COMMIT/ROLLBACK)")
|
||||
}
|
||||
|
||||
statements := splitSQLStatements(content)
|
||||
for _, stmt := range statements {
|
||||
normalizedStmt := strings.ToUpper(stripSQLLineComment(strings.TrimSpace(stmt)))
|
||||
if normalizedStmt == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(normalizedStmt, "CONCURRENTLY") {
|
||||
isCreateIndex := strings.Contains(normalizedStmt, "CREATE") && strings.Contains(normalizedStmt, "INDEX")
|
||||
isDropIndex := strings.Contains(normalizedStmt, "DROP") && strings.Contains(normalizedStmt, "INDEX")
|
||||
if !isCreateIndex && !isDropIndex {
|
||||
return false, errors.New("*_notx.sql currently only supports CREATE/DROP INDEX CONCURRENTLY statements")
|
||||
}
|
||||
if isCreateIndex && !strings.Contains(normalizedStmt, "IF NOT EXISTS") {
|
||||
return false, errors.New("CREATE INDEX CONCURRENTLY in *_notx.sql must include IF NOT EXISTS for idempotency")
|
||||
}
|
||||
if isDropIndex && !strings.Contains(normalizedStmt, "IF EXISTS") {
|
||||
return false, errors.New("DROP INDEX CONCURRENTLY in *_notx.sql must include IF EXISTS for idempotency")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
return false, errors.New("*_notx.sql must not mix non-CONCURRENTLY SQL statements")
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func splitSQLStatements(content string) []string {
|
||||
parts := strings.Split(content, ";")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if strings.TrimSpace(part) == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, part)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func stripSQLLineComment(s string) string {
|
||||
lines := strings.Split(s, "\n")
|
||||
for i, line := range lines {
|
||||
if idx := strings.Index(line, "--"); idx >= 0 {
|
||||
lines[i] = line[:idx]
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(lines, "\n"))
|
||||
}
|
||||
|
||||
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
|
||||
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
|
||||
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsMigrationChecksumCompatible(t *testing.T) {
|
||||
t.Run("054历史checksum可兼容", func(t *testing.T) {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"054_drop_legacy_cache_columns.sql",
|
||||
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
|
||||
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
|
||||
)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("054在未知文件checksum下不兼容", func(t *testing.T) {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"054_drop_legacy_cache_columns.sql",
|
||||
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
|
||||
"0000000000000000000000000000000000000000000000000000000000000000",
|
||||
)
|
||||
require.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("非白名单迁移不兼容", func(t *testing.T) {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"001_init.sql",
|
||||
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
|
||||
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
|
||||
)
|
||||
require.False(t, ok)
|
||||
})
|
||||
}
|
||||
368
backend/internal/repository/migrations_runner_extra_test.go
Normal file
368
backend/internal/repository/migrations_runner_extra_test.go
Normal file
@@ -0,0 +1,368 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
sqlmock "github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplyMigrations_NilDB(t *testing.T) {
|
||||
err := ApplyMigrations(context.Background(), nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "nil sql db")
|
||||
}
|
||||
|
||||
func TestApplyMigrations_DelegatesToApplyMigrationsFS(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnError(errors.New("lock failed"))
|
||||
|
||||
err = ApplyMigrations(context.Background(), db)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "acquire migrations lock")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestLatestMigrationBaseline(t *testing.T) {
|
||||
t.Run("empty_fs_returns_baseline", func(t *testing.T) {
|
||||
version, description, hash, err := latestMigrationBaseline(fstest.MapFS{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "baseline", version)
|
||||
require.Equal(t, "baseline", description)
|
||||
require.Equal(t, "", hash)
|
||||
})
|
||||
|
||||
t.Run("uses_latest_sorted_sql_file", func(t *testing.T) {
|
||||
fsys := fstest.MapFS{
|
||||
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
|
||||
"010_final.sql": &fstest.MapFile{
|
||||
Data: []byte("CREATE TABLE t2(id int);"),
|
||||
},
|
||||
}
|
||||
version, description, hash, err := latestMigrationBaseline(fsys)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "010_final", version)
|
||||
require.Equal(t, "010_final", description)
|
||||
require.Len(t, hash, 64)
|
||||
})
|
||||
|
||||
t.Run("read_file_error", func(t *testing.T) {
|
||||
fsys := fstest.MapFS{
|
||||
"010_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
|
||||
}
|
||||
_, _, _, err := latestMigrationBaseline(fsys)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
|
||||
require.False(t, isMigrationChecksumCompatible("unknown.sql", "db", "file"))
|
||||
|
||||
var (
|
||||
name string
|
||||
rule migrationChecksumCompatibilityRule
|
||||
)
|
||||
for n, r := range migrationChecksumCompatibilityRules {
|
||||
name = n
|
||||
rule = r
|
||||
break
|
||||
}
|
||||
require.NotEmpty(t, name)
|
||||
|
||||
require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", "file-not-match"))
|
||||
require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", rule.fileChecksum))
|
||||
|
||||
var accepted string
|
||||
for checksum := range rule.acceptedDBChecksum {
|
||||
accepted = checksum
|
||||
break
|
||||
}
|
||||
require.NotEmpty(t, accepted)
|
||||
require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum))
|
||||
}
|
||||
|
||||
func TestEnsureAtlasBaselineAligned(t *testing.T) {
|
||||
t.Run("skip_when_no_legacy_table", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("schema_migrations").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
|
||||
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
|
||||
t.Run("create_atlas_and_insert_baseline_when_empty", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("schema_migrations").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec("INSERT INTO atlas_schema_revisions").
|
||||
WithArgs("002_next", "002_next", 1, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
|
||||
"002_next.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t2(id int);")},
|
||||
}
|
||||
err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
|
||||
t.Run("error_when_checking_legacy_table", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("schema_migrations").
|
||||
WillReturnError(errors.New("exists failed"))
|
||||
|
||||
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "check schema_migrations")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
|
||||
t.Run("error_when_counting_atlas_rows", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("schema_migrations").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
|
||||
WillReturnError(errors.New("count failed"))
|
||||
|
||||
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "count atlas_schema_revisions")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
|
||||
t.Run("error_when_creating_atlas_table", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("schema_migrations").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
|
||||
WillReturnError(errors.New("create failed"))
|
||||
|
||||
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "create atlas_schema_revisions")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
|
||||
t.Run("error_when_inserting_baseline", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("schema_migrations").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec("INSERT INTO atlas_schema_revisions").
|
||||
WithArgs("001_init", "001_init", 1, sqlmock.AnyArg()).
|
||||
WillReturnError(errors.New("insert failed"))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
|
||||
}
|
||||
err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "insert atlas baseline")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_ChecksumMismatchRejected(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||
WithArgs("001_init.sql").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow("mismatched-checksum"))
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
|
||||
}
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "checksum mismatch")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_CheckMigrationQueryError(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||
WithArgs("001_err.sql").
|
||||
WillReturnError(errors.New("query failed"))
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_err.sql": &fstest.MapFile{Data: []byte("SELECT 1;")},
|
||||
}
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "check migration 001_err.sql")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_SkipEmptyAndAlreadyApplied(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
|
||||
alreadySQL := "CREATE TABLE t(id int);"
|
||||
checksum := migrationChecksum(alreadySQL)
|
||||
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||
WithArgs("001_already.sql").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow(checksum))
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"000_empty.sql": &fstest.MapFile{Data: []byte(" \n\t ")},
|
||||
"001_already.sql": &fstest.MapFile{Data: []byte(alreadySQL)},
|
||||
}
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_ReadMigrationError(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
|
||||
}
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "read migration 001_bad.sql")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestPgAdvisoryLockAndUnlock_ErrorBranches(t *testing.T) {
|
||||
t.Run("context_cancelled_while_not_locked", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
defer cancel()
|
||||
err = pgAdvisoryLock(ctx, db)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "acquire migrations lock")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
|
||||
t.Run("unlock_exec_error", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnError(errors.New("unlock failed"))
|
||||
|
||||
err = pgAdvisoryUnlock(context.Background(), db)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "release migrations lock")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
|
||||
t.Run("acquire_lock_after_retry", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
|
||||
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), migrationsLockRetryInterval*3)
|
||||
defer cancel()
|
||||
start := time.Now()
|
||||
err = pgAdvisoryLock(ctx, db)
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, time.Since(start), migrationsLockRetryInterval)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
|
||||
func migrationChecksum(content string) string {
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(content)))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
164
backend/internal/repository/migrations_runner_notx_test.go
Normal file
164
backend/internal/repository/migrations_runner_notx_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
sqlmock "github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateMigrationExecutionMode(t *testing.T) {
|
||||
t.Run("事务迁移包含CONCURRENTLY会被拒绝", func(t *testing.T) {
|
||||
nonTx, err := validateMigrationExecutionMode("001_add_idx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);")
|
||||
require.False(t, nonTx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("notx迁移要求CREATE使用IF NOT EXISTS", func(t *testing.T) {
|
||||
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);")
|
||||
require.False(t, nonTx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("notx迁移要求DROP使用IF EXISTS", func(t *testing.T) {
|
||||
nonTx, err := validateMigrationExecutionMode("001_drop_idx_notx.sql", "DROP INDEX CONCURRENTLY idx_a;")
|
||||
require.False(t, nonTx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("notx迁移禁止事务控制语句", func(t *testing.T) {
|
||||
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "BEGIN; CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); COMMIT;")
|
||||
require.False(t, nonTx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("notx迁移禁止混用非CONCURRENTLY语句", func(t *testing.T) {
|
||||
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); UPDATE t SET a = 1;")
|
||||
require.False(t, nonTx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("notx迁移允许幂等并发索引语句", func(t *testing.T) {
|
||||
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", `
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a);
|
||||
DROP INDEX CONCURRENTLY IF EXISTS idx_b;
|
||||
`)
|
||||
require.True(t, nonTx)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_NonTransactionalMigration(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||
WithArgs("001_add_idx_notx.sql").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
|
||||
WithArgs("001_add_idx_notx.sql", sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_add_idx_notx.sql": &fstest.MapFile{
|
||||
Data: []byte("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);"),
|
||||
},
|
||||
}
|
||||
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_NonTransactionalMigration_MultiStatements(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||
WithArgs("001_add_multi_idx_notx.sql").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t\\(b\\)").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
|
||||
WithArgs("001_add_multi_idx_notx.sql", sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_add_multi_idx_notx.sql": &fstest.MapFile{
|
||||
Data: []byte(`
|
||||
-- first
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);
|
||||
-- second
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
|
||||
`),
|
||||
},
|
||||
}
|
||||
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||
WithArgs("001_add_col.sql").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("ALTER TABLE t ADD COLUMN name TEXT").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
|
||||
WithArgs("001_add_col.sql", sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_add_col.sql": &fstest.MapFile{
|
||||
Data: []byte("ALTER TABLE t ADD COLUMN name TEXT;"),
|
||||
},
|
||||
}
|
||||
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func prepareMigrationsBootstrapExpectations(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true))
|
||||
mock.ExpectExec("CREATE TABLE IF NOT EXISTS schema_migrations").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("schema_migrations").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
}
|
||||
@@ -42,6 +42,8 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
|
||||
// usage_logs: billing_type used by filters/stats
|
||||
requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false)
|
||||
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
|
||||
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
|
||||
|
||||
// settings table should exist
|
||||
var settingsRegclass sql.NullString
|
||||
|
||||
@@ -22,16 +22,20 @@ type openaiOAuthService struct {
|
||||
tokenURL string
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
|
||||
if redirectURI == "" {
|
||||
redirectURI = openai.DefaultRedirectURI
|
||||
}
|
||||
clientID = strings.TrimSpace(clientID)
|
||||
if clientID == "" {
|
||||
clientID = openai.ClientID
|
||||
}
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("client_id", openai.ClientID)
|
||||
formData.Set("client_id", clientID)
|
||||
formData.Set("code", code)
|
||||
formData.Set("redirect_uri", redirectURI)
|
||||
formData.Set("code_verifier", codeVerifier)
|
||||
@@ -61,36 +65,12 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||
if strings.TrimSpace(clientID) != "" {
|
||||
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID))
|
||||
// 调用方应始终传入正确的 client_id;为兼容旧数据,未指定时默认使用 OpenAI ClientID
|
||||
clientID = strings.TrimSpace(clientID)
|
||||
if clientID == "" {
|
||||
clientID = openai.ClientID
|
||||
}
|
||||
|
||||
clientIDs := []string{
|
||||
openai.ClientID,
|
||||
openai.SoraClientID,
|
||||
}
|
||||
seen := make(map[string]struct{}, len(clientIDs))
|
||||
var lastErr error
|
||||
for _, clientID := range clientIDs {
|
||||
clientID = strings.TrimSpace(clientID)
|
||||
if clientID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[clientID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[clientID] = struct{}{}
|
||||
|
||||
tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||
if err == nil {
|
||||
return tokenResp, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed")
|
||||
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||
|
||||
@@ -81,7 +81,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "")
|
||||
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "", "")
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
@@ -136,7 +136,9 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
|
||||
require.Equal(s.T(), "rt2", resp.RefreshToken)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
|
||||
// TestRefreshToken_DefaultsToOpenAIClientID 验证未指定 client_id 时默认使用 OpenAI ClientID,
|
||||
// 且只发送一次请求(不再盲猜多个 client_id)。
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
|
||||
var seenClientIDs []string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
@@ -145,11 +147,27 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
|
||||
}
|
||||
clientID := r.PostForm.Get("client_id")
|
||||
seenClientIDs = append(seenClientIDs, clientID)
|
||||
if clientID == openai.ClientID {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||
require.NoError(s.T(), err, "RefreshToken")
|
||||
require.Equal(s.T(), "at", resp.AccessToken)
|
||||
// 只发送了一次请求,使用默认的 OpenAI ClientID
|
||||
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
|
||||
}
|
||||
|
||||
// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
|
||||
var seenClientIDs []string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = io.WriteString(w, "invalid_grant")
|
||||
return
|
||||
}
|
||||
clientID := r.PostForm.Get("client_id")
|
||||
seenClientIDs = append(seenClientIDs, clientID)
|
||||
if clientID == openai.SoraClientID {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
|
||||
@@ -158,11 +176,10 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||
require.NoError(s.T(), err, "RefreshToken")
|
||||
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
|
||||
require.NoError(s.T(), err, "RefreshTokenWithClientID")
|
||||
require.Equal(s.T(), "at-sora", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt-sora", resp.RefreshToken)
|
||||
require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs)
|
||||
require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
|
||||
@@ -196,7 +213,7 @@ func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
|
||||
_, _ = io.WriteString(w, "bad")
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status 400")
|
||||
require.ErrorContains(s.T(), err, "bad")
|
||||
@@ -206,7 +223,7 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
s.srv.Close()
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "request failed")
|
||||
}
|
||||
@@ -223,7 +240,7 @@ func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
|
||||
done <- err
|
||||
}()
|
||||
|
||||
@@ -249,7 +266,30 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "")
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "", "")
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
require.Fail(s.T(), msg)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
|
||||
wantClientID := openai.SoraClientID
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
if got := r.PostForm.Get("client_id"); got != wantClientID {
|
||||
errCh <- "client_id mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", wantClientID)
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
@@ -267,7 +307,7 @@ func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
|
||||
}))
|
||||
s.svc.tokenURL = s.srv.URL + "?x=1"
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case <-s.received:
|
||||
@@ -283,7 +323,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
|
||||
_, _ = io.WriteString(w, "not-valid-json")
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
|
||||
require.Error(s.T(), err, "expected error for invalid JSON response")
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,11 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const (
|
||||
opsRawLatencyQueryTimeout = 2 * time.Second
|
||||
opsRawPeakQueryTimeout = 1500 * time.Millisecond
|
||||
)
|
||||
|
||||
func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
@@ -45,15 +50,24 @@ func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *servic
|
||||
func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) {
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
degraded := false
|
||||
|
||||
successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end)
|
||||
latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout)
|
||||
duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end)
|
||||
cancelLatency()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if isQueryTimeoutErr(err) {
|
||||
degraded = true
|
||||
duration = service.OpsPercentiles{}
|
||||
ttft = service.OpsPercentiles{}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end)
|
||||
@@ -75,20 +89,40 @@ func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *ser
|
||||
|
||||
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if isQueryTimeoutErr(err) {
|
||||
degraded = true
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end)
|
||||
peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout)
|
||||
qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end)
|
||||
cancelPeak()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if isQueryTimeoutErr(err) {
|
||||
degraded = true
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
|
||||
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
|
||||
if degraded {
|
||||
if qpsCurrent <= 0 {
|
||||
qpsCurrent = qpsAvg
|
||||
}
|
||||
if tpsCurrent <= 0 {
|
||||
tpsCurrent = tpsAvg
|
||||
}
|
||||
if qpsPeak <= 0 {
|
||||
qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg))
|
||||
}
|
||||
if tpsPeak <= 0 {
|
||||
tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg))
|
||||
}
|
||||
}
|
||||
|
||||
return &service.OpsDashboardOverview{
|
||||
StartTime: start,
|
||||
@@ -230,26 +264,45 @@ func (r *opsRepository) getDashboardOverviewPreaggregated(ctx context.Context, f
|
||||
sla := safeDivideFloat64(float64(successCount), float64(requestCountSLA))
|
||||
errorRate := safeDivideFloat64(float64(errorCountSLA), float64(requestCountSLA))
|
||||
upstreamErrorRate := safeDivideFloat64(float64(upstreamExcl), float64(requestCountSLA))
|
||||
degraded := false
|
||||
|
||||
// Keep "current" rates as raw, to preserve realtime semantics.
|
||||
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if isQueryTimeoutErr(err) {
|
||||
degraded = true
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: peak still uses raw logs (minute granularity). This is typically cheaper than percentile_cont
|
||||
// and keeps semantics consistent across modes.
|
||||
qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end)
|
||||
peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout)
|
||||
qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end)
|
||||
cancelPeak()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if isQueryTimeoutErr(err) {
|
||||
degraded = true
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
|
||||
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
|
||||
if degraded {
|
||||
if qpsCurrent <= 0 {
|
||||
qpsCurrent = qpsAvg
|
||||
}
|
||||
if tpsCurrent <= 0 {
|
||||
tpsCurrent = tpsAvg
|
||||
}
|
||||
if qpsPeak <= 0 {
|
||||
qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg))
|
||||
}
|
||||
if tpsPeak <= 0 {
|
||||
tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg))
|
||||
}
|
||||
}
|
||||
|
||||
return &service.OpsDashboardOverview{
|
||||
StartTime: start,
|
||||
@@ -577,9 +630,16 @@ func (r *opsRepository) queryRawPartial(ctx context.Context, filter *service.Ops
|
||||
return nil, err
|
||||
}
|
||||
|
||||
duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end)
|
||||
latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout)
|
||||
duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end)
|
||||
cancelLatency()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if isQueryTimeoutErr(err) {
|
||||
duration = service.OpsPercentiles{}
|
||||
ttft = service.OpsPercentiles{}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end)
|
||||
@@ -735,68 +795,56 @@ FROM usage_logs ul
|
||||
}
|
||||
|
||||
func (r *opsRepository) queryUsageLatency(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (duration service.OpsPercentiles, ttft service.OpsPercentiles, err error) {
|
||||
{
|
||||
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
|
||||
q := `
|
||||
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
|
||||
q := `
|
||||
SELECT
|
||||
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) AS p50,
|
||||
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) AS p90,
|
||||
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) AS p95,
|
||||
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) AS p99,
|
||||
AVG(duration_ms) AS avg_ms,
|
||||
MAX(duration_ms) AS max_ms
|
||||
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50,
|
||||
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90,
|
||||
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95,
|
||||
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99,
|
||||
AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg,
|
||||
MAX(duration_ms) AS duration_max,
|
||||
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50,
|
||||
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90,
|
||||
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95,
|
||||
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99,
|
||||
AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg,
|
||||
MAX(first_token_ms) AS ttft_max
|
||||
FROM usage_logs ul
|
||||
` + join + `
|
||||
` + where + `
|
||||
AND duration_ms IS NOT NULL`
|
||||
` + where
|
||||
|
||||
var p50, p90, p95, p99 sql.NullFloat64
|
||||
var avg sql.NullFloat64
|
||||
var max sql.NullInt64
|
||||
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
|
||||
return service.OpsPercentiles{}, service.OpsPercentiles{}, err
|
||||
}
|
||||
duration.P50 = floatToIntPtr(p50)
|
||||
duration.P90 = floatToIntPtr(p90)
|
||||
duration.P95 = floatToIntPtr(p95)
|
||||
duration.P99 = floatToIntPtr(p99)
|
||||
duration.Avg = floatToIntPtr(avg)
|
||||
if max.Valid {
|
||||
v := int(max.Int64)
|
||||
duration.Max = &v
|
||||
}
|
||||
var dP50, dP90, dP95, dP99 sql.NullFloat64
|
||||
var dAvg sql.NullFloat64
|
||||
var dMax sql.NullInt64
|
||||
var tP50, tP90, tP95, tP99 sql.NullFloat64
|
||||
var tAvg sql.NullFloat64
|
||||
var tMax sql.NullInt64
|
||||
if err := r.db.QueryRowContext(ctx, q, args...).Scan(
|
||||
&dP50, &dP90, &dP95, &dP99, &dAvg, &dMax,
|
||||
&tP50, &tP90, &tP95, &tP99, &tAvg, &tMax,
|
||||
); err != nil {
|
||||
return service.OpsPercentiles{}, service.OpsPercentiles{}, err
|
||||
}
|
||||
|
||||
{
|
||||
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
|
||||
q := `
|
||||
SELECT
|
||||
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) AS p50,
|
||||
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) AS p90,
|
||||
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) AS p95,
|
||||
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) AS p99,
|
||||
AVG(first_token_ms) AS avg_ms,
|
||||
MAX(first_token_ms) AS max_ms
|
||||
FROM usage_logs ul
|
||||
` + join + `
|
||||
` + where + `
|
||||
AND first_token_ms IS NOT NULL`
|
||||
duration.P50 = floatToIntPtr(dP50)
|
||||
duration.P90 = floatToIntPtr(dP90)
|
||||
duration.P95 = floatToIntPtr(dP95)
|
||||
duration.P99 = floatToIntPtr(dP99)
|
||||
duration.Avg = floatToIntPtr(dAvg)
|
||||
if dMax.Valid {
|
||||
v := int(dMax.Int64)
|
||||
duration.Max = &v
|
||||
}
|
||||
|
||||
var p50, p90, p95, p99 sql.NullFloat64
|
||||
var avg sql.NullFloat64
|
||||
var max sql.NullInt64
|
||||
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
|
||||
return service.OpsPercentiles{}, service.OpsPercentiles{}, err
|
||||
}
|
||||
ttft.P50 = floatToIntPtr(p50)
|
||||
ttft.P90 = floatToIntPtr(p90)
|
||||
ttft.P95 = floatToIntPtr(p95)
|
||||
ttft.P99 = floatToIntPtr(p99)
|
||||
ttft.Avg = floatToIntPtr(avg)
|
||||
if max.Valid {
|
||||
v := int(max.Int64)
|
||||
ttft.Max = &v
|
||||
}
|
||||
ttft.P50 = floatToIntPtr(tP50)
|
||||
ttft.P90 = floatToIntPtr(tP90)
|
||||
ttft.P95 = floatToIntPtr(tP95)
|
||||
ttft.P99 = floatToIntPtr(tP99)
|
||||
ttft.Avg = floatToIntPtr(tAvg)
|
||||
if tMax.Valid {
|
||||
v := int(tMax.Int64)
|
||||
ttft.Max = &v
|
||||
}
|
||||
|
||||
return duration, ttft, nil
|
||||
@@ -854,20 +902,23 @@ func (r *opsRepository) queryCurrentRates(ctx context.Context, filter *service.O
|
||||
return qpsCurrent, tpsCurrent, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) queryPeakQPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) {
|
||||
func (r *opsRepository) queryPeakRates(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (qpsPeak float64, tpsPeak float64, err error) {
|
||||
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
|
||||
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
|
||||
|
||||
q := `
|
||||
WITH usage_buckets AS (
|
||||
SELECT date_trunc('minute', ul.created_at) AS bucket, COUNT(*) AS cnt
|
||||
SELECT
|
||||
date_trunc('minute', ul.created_at) AS bucket,
|
||||
COUNT(*) AS req_cnt,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_cnt
|
||||
FROM usage_logs ul
|
||||
` + usageJoin + `
|
||||
` + usageWhere + `
|
||||
GROUP BY 1
|
||||
),
|
||||
error_buckets AS (
|
||||
SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS cnt
|
||||
SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS err_cnt
|
||||
FROM ops_error_logs
|
||||
` + errorWhere + `
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
@@ -875,47 +926,33 @@ error_buckets AS (
|
||||
),
|
||||
combined AS (
|
||||
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
|
||||
COALESCE(u.cnt, 0) + COALESCE(e.cnt, 0) AS total
|
||||
COALESCE(u.req_cnt, 0) + COALESCE(e.err_cnt, 0) AS total_req,
|
||||
COALESCE(u.token_cnt, 0) AS total_tokens
|
||||
FROM usage_buckets u
|
||||
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
|
||||
)
|
||||
SELECT COALESCE(MAX(total), 0) FROM combined`
|
||||
SELECT
|
||||
COALESCE(MAX(total_req), 0) AS max_req_per_min,
|
||||
COALESCE(MAX(total_tokens), 0) AS max_tokens_per_min
|
||||
FROM combined`
|
||||
|
||||
args := append(usageArgs, errorArgs...)
|
||||
|
||||
var maxPerMinute sql.NullInt64
|
||||
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil {
|
||||
return 0, err
|
||||
var maxReqPerMinute, maxTokensPerMinute sql.NullInt64
|
||||
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxReqPerMinute, &maxTokensPerMinute); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 {
|
||||
return 0, nil
|
||||
if maxReqPerMinute.Valid && maxReqPerMinute.Int64 > 0 {
|
||||
qpsPeak = roundTo1DP(float64(maxReqPerMinute.Int64) / 60.0)
|
||||
}
|
||||
return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil
|
||||
if maxTokensPerMinute.Valid && maxTokensPerMinute.Int64 > 0 {
|
||||
tpsPeak = roundTo1DP(float64(maxTokensPerMinute.Int64) / 60.0)
|
||||
}
|
||||
return qpsPeak, tpsPeak, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) queryPeakTPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) {
|
||||
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
|
||||
|
||||
q := `
|
||||
SELECT COALESCE(MAX(tokens_per_min), 0)
|
||||
FROM (
|
||||
SELECT
|
||||
date_trunc('minute', ul.created_at) AS bucket,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS tokens_per_min
|
||||
FROM usage_logs ul
|
||||
` + join + `
|
||||
` + where + `
|
||||
GROUP BY 1
|
||||
) t`
|
||||
|
||||
var maxPerMinute sql.NullInt64
|
||||
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil
|
||||
func isQueryTimeoutErr(err error) bool {
|
||||
return errors.Is(err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
func buildUsageWhere(filter *service.OpsDashboardFilter, start, end time.Time, startIndex int) (join string, where string, args []any, nextIndex int) {
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsQueryTimeoutErr(t *testing.T) {
|
||||
if !isQueryTimeoutErr(context.DeadlineExceeded) {
|
||||
t.Fatalf("context.DeadlineExceeded should be treated as query timeout")
|
||||
}
|
||||
if !isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.DeadlineExceeded)) {
|
||||
t.Fatalf("wrapped context.DeadlineExceeded should be treated as query timeout")
|
||||
}
|
||||
if isQueryTimeoutErr(context.Canceled) {
|
||||
t.Fatalf("context.Canceled should not be treated as query timeout")
|
||||
}
|
||||
if isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.Canceled)) {
|
||||
t.Fatalf("wrapped context.Canceled should not be treated as query timeout")
|
||||
}
|
||||
}
|
||||
419
backend/internal/repository/sora_generation_repo.go
Normal file
419
backend/internal/repository/sora_generation_repo.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
|
||||
// 使用原生 SQL 操作 sora_generations 表。
|
||||
type soraGenerationRepository struct {
|
||||
sql *sql.DB
|
||||
}
|
||||
|
||||
// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
|
||||
func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository {
|
||||
return &soraGenerationRepository{sql: sqlDB}
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error {
|
||||
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
|
||||
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
|
||||
|
||||
err := r.sql.QueryRowContext(ctx, `
|
||||
INSERT INTO sora_generations (
|
||||
user_id, api_key_id, model, prompt, media_type,
|
||||
status, media_url, media_urls, file_size_bytes,
|
||||
storage_type, s3_object_keys, upstream_task_id, error_message
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
RETURNING id, created_at
|
||||
`,
|
||||
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
|
||||
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
|
||||
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
|
||||
).Scan(&gen.ID, &gen.CreatedAt)
|
||||
return err
|
||||
}
|
||||
|
||||
// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
|
||||
func (r *soraGenerationRepository) CreatePendingWithLimit(
|
||||
ctx context.Context,
|
||||
gen *service.SoraGeneration,
|
||||
activeStatuses []string,
|
||||
maxActive int64,
|
||||
) error {
|
||||
if gen == nil {
|
||||
return fmt.Errorf("generation is nil")
|
||||
}
|
||||
if maxActive <= 0 {
|
||||
return r.Create(ctx, gen)
|
||||
}
|
||||
if len(activeStatuses) == 0 {
|
||||
activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating}
|
||||
}
|
||||
|
||||
tx, err := r.sql.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
|
||||
if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
placeholders := make([]string, len(activeStatuses))
|
||||
args := make([]any, 0, 1+len(activeStatuses))
|
||||
args = append(args, gen.UserID)
|
||||
for i, s := range activeStatuses {
|
||||
placeholders[i] = fmt.Sprintf("$%d", i+2)
|
||||
args = append(args, s)
|
||||
}
|
||||
countQuery := fmt.Sprintf(
|
||||
`SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`,
|
||||
strings.Join(placeholders, ","),
|
||||
)
|
||||
var activeCount int64
|
||||
if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil {
|
||||
return err
|
||||
}
|
||||
if activeCount >= maxActive {
|
||||
return service.ErrSoraGenerationConcurrencyLimit
|
||||
}
|
||||
|
||||
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
|
||||
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
|
||||
if err := tx.QueryRowContext(ctx, `
|
||||
INSERT INTO sora_generations (
|
||||
user_id, api_key_id, model, prompt, media_type,
|
||||
status, media_url, media_urls, file_size_bytes,
|
||||
storage_type, s3_object_keys, upstream_task_id, error_message
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
RETURNING id, created_at
|
||||
`,
|
||||
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
|
||||
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
|
||||
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
|
||||
).Scan(&gen.ID, &gen.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) {
|
||||
gen := &service.SoraGeneration{}
|
||||
var mediaURLsJSON, s3KeysJSON []byte
|
||||
var completedAt sql.NullTime
|
||||
var apiKeyID sql.NullInt64
|
||||
|
||||
err := r.sql.QueryRowContext(ctx, `
|
||||
SELECT id, user_id, api_key_id, model, prompt, media_type,
|
||||
status, media_url, media_urls, file_size_bytes,
|
||||
storage_type, s3_object_keys, upstream_task_id, error_message,
|
||||
created_at, completed_at
|
||||
FROM sora_generations WHERE id = $1
|
||||
`, id).Scan(
|
||||
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
|
||||
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
|
||||
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
|
||||
&gen.CreatedAt, &completedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("生成记录不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if apiKeyID.Valid {
|
||||
gen.APIKeyID = &apiKeyID.Int64
|
||||
}
|
||||
if completedAt.Valid {
|
||||
gen.CompletedAt = &completedAt.Time
|
||||
}
|
||||
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
|
||||
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
|
||||
return gen, nil
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error {
|
||||
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
|
||||
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
|
||||
|
||||
var completedAt *time.Time
|
||||
if gen.CompletedAt != nil {
|
||||
completedAt = gen.CompletedAt
|
||||
}
|
||||
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations SET
|
||||
status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
|
||||
storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
|
||||
error_message = $9, completed_at = $10
|
||||
WHERE id = $1
|
||||
`,
|
||||
gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
|
||||
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID,
|
||||
gen.ErrorMessage, completedAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
|
||||
func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) {
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations
|
||||
SET status = $2, upstream_task_id = $3
|
||||
WHERE id = $1 AND status = $4
|
||||
`,
|
||||
id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
|
||||
func (r *soraGenerationRepository) UpdateCompletedIfActive(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
mediaURL string,
|
||||
mediaURLs []string,
|
||||
storageType string,
|
||||
s3Keys []string,
|
||||
fileSizeBytes int64,
|
||||
completedAt time.Time,
|
||||
) (bool, error) {
|
||||
mediaURLsJSON, _ := json.Marshal(mediaURLs)
|
||||
s3KeysJSON, _ := json.Marshal(s3Keys)
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations
|
||||
SET status = $2,
|
||||
media_url = $3,
|
||||
media_urls = $4,
|
||||
file_size_bytes = $5,
|
||||
storage_type = $6,
|
||||
s3_object_keys = $7,
|
||||
error_message = '',
|
||||
completed_at = $8
|
||||
WHERE id = $1 AND status IN ($9, $10)
|
||||
`,
|
||||
id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes,
|
||||
storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
|
||||
func (r *soraGenerationRepository) UpdateFailedIfActive(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
errMsg string,
|
||||
completedAt time.Time,
|
||||
) (bool, error) {
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations
|
||||
SET status = $2,
|
||||
error_message = $3,
|
||||
completed_at = $4
|
||||
WHERE id = $1 AND status IN ($5, $6)
|
||||
`,
|
||||
id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
|
||||
func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) {
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations
|
||||
SET status = $2, completed_at = $3
|
||||
WHERE id = $1 AND status IN ($4, $5)
|
||||
`,
|
||||
id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。
|
||||
func (r *soraGenerationRepository) UpdateStorageIfCompleted(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
mediaURL string,
|
||||
mediaURLs []string,
|
||||
storageType string,
|
||||
s3Keys []string,
|
||||
fileSizeBytes int64,
|
||||
) (bool, error) {
|
||||
mediaURLsJSON, _ := json.Marshal(mediaURLs)
|
||||
s3KeysJSON, _ := json.Marshal(s3Keys)
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations
|
||||
SET media_url = $2,
|
||||
media_urls = $3,
|
||||
file_size_bytes = $4,
|
||||
storage_type = $5,
|
||||
s3_object_keys = $6
|
||||
WHERE id = $1 AND status = $7
|
||||
`,
|
||||
id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
|
||||
// 构建 WHERE 条件
|
||||
conditions := []string{"user_id = $1"}
|
||||
args := []any{params.UserID}
|
||||
argIdx := 2
|
||||
|
||||
if params.Status != "" {
|
||||
// 支持逗号分隔的多状态
|
||||
statuses := strings.Split(params.Status, ",")
|
||||
placeholders := make([]string, len(statuses))
|
||||
for i, s := range statuses {
|
||||
placeholders[i] = fmt.Sprintf("$%d", argIdx)
|
||||
args = append(args, strings.TrimSpace(s))
|
||||
argIdx++
|
||||
}
|
||||
conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
|
||||
}
|
||||
if params.StorageType != "" {
|
||||
storageTypes := strings.Split(params.StorageType, ",")
|
||||
placeholders := make([]string, len(storageTypes))
|
||||
for i, s := range storageTypes {
|
||||
placeholders[i] = fmt.Sprintf("$%d", argIdx)
|
||||
args = append(args, strings.TrimSpace(s))
|
||||
argIdx++
|
||||
}
|
||||
conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ",")))
|
||||
}
|
||||
if params.MediaType != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx))
|
||||
args = append(args, params.MediaType)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
whereClause := "WHERE " + strings.Join(conditions, " AND ")
|
||||
|
||||
// 计数
|
||||
var total int64
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause)
|
||||
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (params.Page - 1) * params.PageSize
|
||||
listQuery := fmt.Sprintf(`
|
||||
SELECT id, user_id, api_key_id, model, prompt, media_type,
|
||||
status, media_url, media_urls, file_size_bytes,
|
||||
storage_type, s3_object_keys, upstream_task_id, error_message,
|
||||
created_at, completed_at
|
||||
FROM sora_generations %s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, whereClause, argIdx, argIdx+1)
|
||||
args = append(args, params.PageSize, offset)
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, listQuery, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
var results []*service.SoraGeneration
|
||||
for rows.Next() {
|
||||
gen := &service.SoraGeneration{}
|
||||
var mediaURLsJSON, s3KeysJSON []byte
|
||||
var completedAt sql.NullTime
|
||||
var apiKeyID sql.NullInt64
|
||||
|
||||
if err := rows.Scan(
|
||||
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
|
||||
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
|
||||
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
|
||||
&gen.CreatedAt, &completedAt,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if apiKeyID.Valid {
|
||||
gen.APIKeyID = &apiKeyID.Int64
|
||||
}
|
||||
if completedAt.Valid {
|
||||
gen.CompletedAt = &completedAt.Time
|
||||
}
|
||||
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
|
||||
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
|
||||
results = append(results, gen)
|
||||
}
|
||||
|
||||
return results, total, rows.Err()
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) {
|
||||
if len(statuses) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
placeholders := make([]string, len(statuses))
|
||||
args := []any{userID}
|
||||
for i, s := range statuses {
|
||||
placeholders[i] = fmt.Sprintf("$%d", i+2)
|
||||
args = append(args, s)
|
||||
}
|
||||
|
||||
var count int64
|
||||
query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ","))
|
||||
err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
@@ -362,7 +362,12 @@ func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any)
|
||||
idx++
|
||||
}
|
||||
}
|
||||
if filters.Stream != nil {
|
||||
if filters.RequestType != nil {
|
||||
condition, conditionArgs := buildRequestTypeFilterCondition(idx, *filters.RequestType)
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, conditionArgs...)
|
||||
idx += len(conditionArgs)
|
||||
} else if filters.Stream != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("stream = $%d", idx))
|
||||
args = append(args, *filters.Stream)
|
||||
idx++
|
||||
|
||||
@@ -466,6 +466,38 @@ func TestBuildUsageCleanupWhere(t *testing.T) {
|
||||
require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args)
|
||||
}
|
||||
|
||||
func TestBuildUsageCleanupWhereRequestTypePriority(t *testing.T) {
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
requestType := int16(service.RequestTypeWSV2)
|
||||
stream := false
|
||||
|
||||
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
RequestType: &requestType,
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))", where)
|
||||
require.Equal(t, []any{start, end, requestType}, args)
|
||||
}
|
||||
|
||||
func TestBuildUsageCleanupWhereRequestTypeLegacyFallback(t *testing.T) {
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
requestType := int16(service.RequestTypeStream)
|
||||
|
||||
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
RequestType: &requestType,
|
||||
})
|
||||
|
||||
require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))", where)
|
||||
require.Equal(t, []any{start, end, requestType}, args)
|
||||
}
|
||||
|
||||
func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) {
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
|
||||
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
||||
var dateFormatWhitelist = map[string]string{
|
||||
@@ -98,6 +98,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
log.RequestID = requestID
|
||||
|
||||
rateMultiplier := log.RateMultiplier
|
||||
log.SyncRequestTypeAndLegacyFields()
|
||||
requestType := int16(log.RequestType)
|
||||
|
||||
query := `
|
||||
INSERT INTO usage_logs (
|
||||
@@ -123,7 +125,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
request_type,
|
||||
stream,
|
||||
openai_ws_mode,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
@@ -140,7 +144,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -184,7 +188,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
rateMultiplier,
|
||||
log.AccountRateMultiplier,
|
||||
log.BillingType,
|
||||
requestType,
|
||||
log.Stream,
|
||||
log.OpenAIWSMode,
|
||||
duration,
|
||||
firstToken,
|
||||
userAgent,
|
||||
@@ -492,25 +498,46 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error {
|
||||
totalStatsQuery := `
|
||||
todayEnd := todayUTC.Add(24 * time.Hour)
|
||||
combinedStatsQuery := `
|
||||
WITH scoped AS (
|
||||
SELECT
|
||||
created_at,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
COALESCE(duration_ms, 0) AS duration_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
|
||||
AND created_at < GREATEST($2::timestamptz, $4::timestamptz)
|
||||
)
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
COUNT(*) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz) AS total_requests,
|
||||
COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_input_tokens,
|
||||
COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost,
|
||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost,
|
||||
COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms,
|
||||
COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests,
|
||||
COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens,
|
||||
COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost,
|
||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost
|
||||
FROM scoped
|
||||
`
|
||||
var totalDurationMs int64
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
totalStatsQuery,
|
||||
[]any{startUTC, endUTC},
|
||||
combinedStatsQuery,
|
||||
[]any{startUTC, endUTC, todayUTC, todayEnd},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
@@ -519,32 +546,6 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&totalDurationMs,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
todayEnd := todayUTC.Add(24 * time.Hour)
|
||||
todayStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as today_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as today_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as today_actual_cost
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
todayStatsQuery,
|
||||
[]any{todayUTC, todayEnd},
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
@@ -555,25 +556,28 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||
|
||||
activeUsersQuery := `
|
||||
SELECT COUNT(DISTINCT user_id) as active_users
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`
|
||||
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil {
|
||||
return err
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||
|
||||
hourStart := now.UTC().Truncate(time.Hour)
|
||||
hourEnd := hourStart.Add(time.Hour)
|
||||
hourlyActiveQuery := `
|
||||
SELECT COUNT(DISTINCT user_id) as active_users
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
activeUsersQuery := `
|
||||
WITH scoped AS (
|
||||
SELECT user_id, created_at
|
||||
FROM usage_logs
|
||||
WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
|
||||
AND created_at < GREATEST($2::timestamptz, $4::timestamptz)
|
||||
)
|
||||
SELECT
|
||||
COUNT(DISTINCT CASE WHEN created_at >= $1::timestamptz AND created_at < $2::timestamptz THEN user_id END) AS active_users,
|
||||
COUNT(DISTINCT CASE WHEN created_at >= $3::timestamptz AND created_at < $4::timestamptz THEN user_id END) AS hourly_active_users
|
||||
FROM scoped
|
||||
`
|
||||
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil {
|
||||
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd, hourStart, hourEnd}, &stats.ActiveUsers, &stats.HourlyActiveUsers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -968,6 +972,61 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetGeminiUsageTotalsBatch 批量聚合 Gemini 账号在窗口内的 Pro/Flash 请求与用量。
|
||||
// 模型分类规则与 service.geminiModelClassFromName 一致:model 包含 flash/lite 视为 flash,其余视为 pro。
|
||||
func (r *usageLogRepository) GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]service.GeminiUsageTotals, error) {
|
||||
result := make(map[int64]service.GeminiUsageTotals, len(accountIDs))
|
||||
if len(accountIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
account_id,
|
||||
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 1 ELSE 0 END), 0) AS flash_requests,
|
||||
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE 1 END), 0) AS pro_requests,
|
||||
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) ELSE 0 END), 0) AS flash_tokens,
|
||||
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) END), 0) AS pro_tokens,
|
||||
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN actual_cost ELSE 0 END), 0) AS flash_cost,
|
||||
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE actual_cost END), 0) AS pro_cost
|
||||
FROM usage_logs
|
||||
WHERE account_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
||||
GROUP BY account_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
for rows.Next() {
|
||||
var accountID int64
|
||||
var totals service.GeminiUsageTotals
|
||||
if err := rows.Scan(
|
||||
&accountID,
|
||||
&totals.FlashRequests,
|
||||
&totals.ProRequests,
|
||||
&totals.FlashTokens,
|
||||
&totals.ProTokens,
|
||||
&totals.FlashCost,
|
||||
&totals.ProCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[accountID] = totals
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, accountID := range accountIDs {
|
||||
if _, ok := result[accountID]; !ok {
|
||||
result[accountID] = service.GeminiUsageTotals{}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// TrendDataPoint represents a single point in trend data
|
||||
type TrendDataPoint = usagestats.TrendDataPoint
|
||||
|
||||
@@ -1399,10 +1458,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
|
||||
args = append(args, filters.Model)
|
||||
}
|
||||
if filters.Stream != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
|
||||
args = append(args, *filters.Stream)
|
||||
}
|
||||
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
|
||||
if filters.BillingType != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||||
args = append(args, int16(*filters.BillingType))
|
||||
@@ -1598,7 +1654,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
||||
}
|
||||
|
||||
// GetUsageTrendWithFilters returns usage trend data with optional filters
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
@@ -1636,10 +1692,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||
args = append(args, model)
|
||||
}
|
||||
if stream != nil {
|
||||
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||
args = append(args, *stream)
|
||||
}
|
||||
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
args = append(args, int16(*billingType))
|
||||
@@ -1667,7 +1720,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
}
|
||||
|
||||
// GetModelStatsWithFilters returns model statistics with optional filters
|
||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) {
|
||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
@@ -1704,10 +1757,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
||||
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||
args = append(args, groupID)
|
||||
}
|
||||
if stream != nil {
|
||||
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||
args = append(args, *stream)
|
||||
}
|
||||
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
|
||||
if billingType != nil {
|
||||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||||
args = append(args, int16(*billingType))
|
||||
@@ -1794,10 +1844,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
||||
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
|
||||
args = append(args, filters.Model)
|
||||
}
|
||||
if filters.Stream != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
|
||||
args = append(args, *filters.Stream)
|
||||
}
|
||||
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
|
||||
if filters.BillingType != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||||
args = append(args, int16(*filters.BillingType))
|
||||
@@ -2017,7 +2064,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
}
|
||||
}
|
||||
|
||||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil)
|
||||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil, nil)
|
||||
if err != nil {
|
||||
models = []ModelStat{}
|
||||
}
|
||||
@@ -2267,7 +2314,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
rateMultiplier float64
|
||||
accountRateMultiplier sql.NullFloat64
|
||||
billingType int16
|
||||
requestTypeRaw int16
|
||||
stream bool
|
||||
openaiWSMode bool
|
||||
durationMs sql.NullInt64
|
||||
firstTokenMs sql.NullInt64
|
||||
userAgent sql.NullString
|
||||
@@ -2304,7 +2353,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&rateMultiplier,
|
||||
&accountRateMultiplier,
|
||||
&billingType,
|
||||
&requestTypeRaw,
|
||||
&stream,
|
||||
&openaiWSMode,
|
||||
&durationMs,
|
||||
&firstTokenMs,
|
||||
&userAgent,
|
||||
@@ -2340,11 +2391,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
RateMultiplier: rateMultiplier,
|
||||
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
|
||||
BillingType: int8(billingType),
|
||||
Stream: stream,
|
||||
RequestType: service.RequestTypeFromInt16(requestTypeRaw),
|
||||
ImageCount: imageCount,
|
||||
CacheTTLOverridden: cacheTTLOverridden,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
// 先回填 legacy 字段,再基于 legacy + request_type 计算最终请求类型,保证历史数据兼容。
|
||||
log.Stream = stream
|
||||
log.OpenAIWSMode = openaiWSMode
|
||||
log.RequestType = log.EffectiveRequestType()
|
||||
log.Stream, log.OpenAIWSMode = service.ApplyLegacyRequestFields(log.RequestType, stream, openaiWSMode)
|
||||
|
||||
if requestID.Valid {
|
||||
log.RequestID = requestID.String
|
||||
@@ -2438,6 +2494,50 @@ func buildWhere(conditions []string) string {
|
||||
return "WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
func appendRequestTypeOrStreamWhereCondition(conditions []string, args []any, requestType *int16, stream *bool) ([]string, []any) {
|
||||
if requestType != nil {
|
||||
condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType)
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, conditionArgs...)
|
||||
return conditions, args
|
||||
}
|
||||
if stream != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
|
||||
args = append(args, *stream)
|
||||
}
|
||||
return conditions, args
|
||||
}
|
||||
|
||||
func appendRequestTypeOrStreamQueryFilter(query string, args []any, requestType *int16, stream *bool) (string, []any) {
|
||||
if requestType != nil {
|
||||
condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType)
|
||||
query += " AND " + condition
|
||||
args = append(args, conditionArgs...)
|
||||
return query, args
|
||||
}
|
||||
if stream != nil {
|
||||
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||
args = append(args, *stream)
|
||||
}
|
||||
return query, args
|
||||
}
|
||||
|
||||
// buildRequestTypeFilterCondition 在 request_type 过滤时兼容 legacy 字段,避免历史数据漏查。
|
||||
func buildRequestTypeFilterCondition(startArgIndex int, requestType int16) (string, []any) {
|
||||
normalized := service.RequestTypeFromInt16(requestType)
|
||||
requestTypeArg := int16(normalized)
|
||||
switch normalized {
|
||||
case service.RequestTypeSync:
|
||||
return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = FALSE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
|
||||
case service.RequestTypeStream:
|
||||
return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = TRUE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
|
||||
case service.RequestTypeWSV2:
|
||||
return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND openai_ws_mode = TRUE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
|
||||
default:
|
||||
return fmt.Sprintf("request_type = $%d", startArgIndex), []any{requestTypeArg}
|
||||
}
|
||||
}
|
||||
|
||||
func nullInt64(v *int64) sql.NullInt64 {
|
||||
if v == nil {
|
||||
return sql.NullInt64{}
|
||||
|
||||
@@ -130,6 +130,62 @@ func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() {
|
||||
s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID_ReturnsOpenAIWSMode() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-ws@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-ws", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-ws"})
|
||||
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.New().String(),
|
||||
Model: "gpt-5.3-codex",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 1.0,
|
||||
ActualCost: 1.0,
|
||||
OpenAIWSMode: true,
|
||||
CreatedAt: timezone.Today().Add(3 * time.Hour),
|
||||
}
|
||||
_, err := s.repo.Create(s.ctx, log)
|
||||
s.Require().NoError(err)
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, log.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(got.OpenAIWSMode)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID_ReturnsRequestTypeAndLegacyFallback() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-request-type@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-request-type", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-request-type"})
|
||||
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.New().String(),
|
||||
Model: "gpt-5.3-codex",
|
||||
RequestType: service.RequestTypeWSV2,
|
||||
Stream: true,
|
||||
OpenAIWSMode: false,
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 1.0,
|
||||
ActualCost: 1.0,
|
||||
CreatedAt: timezone.Today().Add(4 * time.Hour),
|
||||
}
|
||||
_, err := s.repo.Create(s.ctx, log)
|
||||
s.Require().NoError(err)
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, log.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.RequestTypeWSV2, got.RequestType)
|
||||
s.Require().True(got.Stream)
|
||||
s.Require().True(got.OpenAIWSMode)
|
||||
}
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestDelete() {
|
||||
@@ -944,17 +1000,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||
endTime := base.Add(48 * time.Hour)
|
||||
|
||||
// Test with user filter
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil)
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil, nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
|
||||
s.Require().Len(trend, 2)
|
||||
|
||||
// Test with apiKey filter
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil)
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil, nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
|
||||
s.Require().Len(trend, 2)
|
||||
|
||||
// Test with both filters
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil)
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil, nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
|
||||
s.Require().Len(trend, 2)
|
||||
}
|
||||
@@ -971,7 +1027,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(3 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil)
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil, nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
|
||||
s.Require().Len(trend, 2)
|
||||
}
|
||||
@@ -1017,17 +1073,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
|
||||
// Test with user filter
|
||||
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil)
|
||||
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil, nil)
|
||||
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
|
||||
s.Require().Len(stats, 2)
|
||||
|
||||
// Test with apiKey filter
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil)
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil, nil)
|
||||
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
|
||||
s.Require().Len(stats, 2)
|
||||
|
||||
// Test with account filter
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil)
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil, nil)
|
||||
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
|
||||
s.Require().Len(stats, 2)
|
||||
}
|
||||
|
||||
327
backend/internal/repository/usage_log_repo_request_type_test.go
Normal file
327
backend/internal/repository/usage_log_repo_request_type_test.go
Normal file
@@ -0,0 +1,327 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
|
||||
createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
log := &service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-1",
|
||||
Model: "gpt-5",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 1,
|
||||
ActualCost: 1,
|
||||
BillingType: service.BillingTypeBalance,
|
||||
RequestType: service.RequestTypeWSV2,
|
||||
Stream: false,
|
||||
OpenAIWSMode: false,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("INSERT INTO usage_logs").
|
||||
WithArgs(
|
||||
log.UserID,
|
||||
log.APIKeyID,
|
||||
log.AccountID,
|
||||
log.RequestID,
|
||||
log.Model,
|
||||
sqlmock.AnyArg(), // group_id
|
||||
sqlmock.AnyArg(), // subscription_id
|
||||
log.InputTokens,
|
||||
log.OutputTokens,
|
||||
log.CacheCreationTokens,
|
||||
log.CacheReadTokens,
|
||||
log.CacheCreation5mTokens,
|
||||
log.CacheCreation1hTokens,
|
||||
log.InputCost,
|
||||
log.OutputCost,
|
||||
log.CacheCreationCost,
|
||||
log.CacheReadCost,
|
||||
log.TotalCost,
|
||||
log.ActualCost,
|
||||
log.RateMultiplier,
|
||||
log.AccountRateMultiplier,
|
||||
log.BillingType,
|
||||
int16(service.RequestTypeWSV2),
|
||||
true,
|
||||
true,
|
||||
sqlmock.AnyArg(), // duration_ms
|
||||
sqlmock.AnyArg(), // first_token_ms
|
||||
sqlmock.AnyArg(), // user_agent
|
||||
sqlmock.AnyArg(), // ip_address
|
||||
log.ImageCount,
|
||||
sqlmock.AnyArg(), // image_size
|
||||
sqlmock.AnyArg(), // media_type
|
||||
sqlmock.AnyArg(), // reasoning_effort
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
|
||||
|
||||
inserted, err := repo.Create(context.Background(), log)
|
||||
require.NoError(t, err)
|
||||
require.True(t, inserted)
|
||||
require.Equal(t, int64(99), log.ID)
|
||||
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.True(t, log.OpenAIWSMode)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
|
||||
requestType := int16(service.RequestTypeWSV2)
|
||||
stream := false
|
||||
filters := usagestats.UsageLogFilters{
|
||||
RequestType: &requestType,
|
||||
Stream: &stream,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
|
||||
WithArgs(requestType).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0)))
|
||||
mock.ExpectQuery("SELECT .* FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\) ORDER BY id DESC LIMIT \\$2 OFFSET \\$3").
|
||||
WithArgs(requestType, 20, 0).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
|
||||
logs, page, err := repo.ListWithFilters(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}, filters)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, logs)
|
||||
require.NotNil(t, page)
|
||||
require.Equal(t, int64(0), page.Total)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
|
||||
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
requestType := int16(service.RequestTypeStream)
|
||||
stream := true
|
||||
|
||||
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)").
|
||||
WithArgs(start, end, requestType).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
|
||||
trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, trend)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
|
||||
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
requestType := int16(service.RequestTypeWSV2)
|
||||
stream := false
|
||||
|
||||
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
|
||||
WithArgs(start, end, requestType).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"}))
|
||||
|
||||
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, stats)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &usageLogRepository{sql: db}
|
||||
|
||||
requestType := int16(service.RequestTypeSync)
|
||||
stream := true
|
||||
filters := usagestats.UsageLogFilters{
|
||||
RequestType: &requestType,
|
||||
Stream: &stream,
|
||||
}
|
||||
|
||||
mock.ExpectQuery("FROM usage_logs\\s+WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE\\)\\)").
|
||||
WithArgs(requestType).
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"total_requests",
|
||||
"total_input_tokens",
|
||||
"total_output_tokens",
|
||||
"total_cache_tokens",
|
||||
"total_cost",
|
||||
"total_actual_cost",
|
||||
"total_account_cost",
|
||||
"avg_duration_ms",
|
||||
}).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0))
|
||||
|
||||
stats, err := repo.GetStatsWithFilters(context.Background(), filters)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), stats.TotalRequests)
|
||||
require.Equal(t, int64(9), stats.TotalTokens)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request int16
|
||||
wantWhere string
|
||||
wantArg int16
|
||||
}{
|
||||
{
|
||||
name: "sync_with_legacy_fallback",
|
||||
request: int16(service.RequestTypeSync),
|
||||
wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE))",
|
||||
wantArg: int16(service.RequestTypeSync),
|
||||
},
|
||||
{
|
||||
name: "stream_with_legacy_fallback",
|
||||
request: int16(service.RequestTypeStream),
|
||||
wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))",
|
||||
wantArg: int16(service.RequestTypeStream),
|
||||
},
|
||||
{
|
||||
name: "ws_v2_with_legacy_fallback",
|
||||
request: int16(service.RequestTypeWSV2),
|
||||
wantWhere: "(request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))",
|
||||
wantArg: int16(service.RequestTypeWSV2),
|
||||
},
|
||||
{
|
||||
name: "invalid_request_type_normalized_to_unknown",
|
||||
request: int16(99),
|
||||
wantWhere: "request_type = $3",
|
||||
wantArg: int16(service.RequestTypeUnknown),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
where, args := buildRequestTypeFilterCondition(3, tt.request)
|
||||
require.Equal(t, tt.wantWhere, where)
|
||||
require.Equal(t, []any{tt.wantArg}, args)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type usageLogScannerStub struct {
|
||||
values []any
|
||||
}
|
||||
|
||||
func (s usageLogScannerStub) Scan(dest ...any) error {
|
||||
if len(dest) != len(s.values) {
|
||||
return fmt.Errorf("scan arg count mismatch: got %d want %d", len(dest), len(s.values))
|
||||
}
|
||||
for i := range dest {
|
||||
dv := reflect.ValueOf(dest[i])
|
||||
if dv.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("dest[%d] is not pointer", i)
|
||||
}
|
||||
dv.Elem().Set(reflect.ValueOf(s.values[i]))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
||||
t.Run("request_type_ws_v2_overrides_legacy", func(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
log, err := scanUsageLog(usageLogScannerStub{values: []any{
|
||||
int64(1), // id
|
||||
int64(10), // user_id
|
||||
int64(20), // api_key_id
|
||||
int64(30), // account_id
|
||||
sql.NullString{Valid: true, String: "req-1"},
|
||||
"gpt-5", // model
|
||||
sql.NullInt64{}, // group_id
|
||||
sql.NullInt64{}, // subscription_id
|
||||
1, // input_tokens
|
||||
2, // output_tokens
|
||||
3, // cache_creation_tokens
|
||||
4, // cache_read_tokens
|
||||
5, // cache_creation_5m_tokens
|
||||
6, // cache_creation_1h_tokens
|
||||
0.1, // input_cost
|
||||
0.2, // output_cost
|
||||
0.3, // cache_creation_cost
|
||||
0.4, // cache_read_cost
|
||||
1.0, // total_cost
|
||||
0.9, // actual_cost
|
||||
1.0, // rate_multiplier
|
||||
sql.NullFloat64{}, // account_rate_multiplier
|
||||
int16(service.BillingTypeBalance),
|
||||
int16(service.RequestTypeWSV2),
|
||||
false, // legacy stream
|
||||
false, // legacy openai ws
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.True(t, log.OpenAIWSMode)
|
||||
})
|
||||
|
||||
t.Run("request_type_unknown_falls_back_to_legacy", func(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
log, err := scanUsageLog(usageLogScannerStub{values: []any{
|
||||
int64(2),
|
||||
int64(11),
|
||||
int64(21),
|
||||
int64(31),
|
||||
sql.NullString{Valid: true, String: "req-2"},
|
||||
"gpt-5",
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
1, 2, 3, 4, 5, 6,
|
||||
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
|
||||
1.0,
|
||||
sql.NullFloat64{},
|
||||
int16(service.BillingTypeBalance),
|
||||
int16(service.RequestTypeUnknown),
|
||||
true,
|
||||
false,
|
||||
sql.NullInt64{},
|
||||
sql.NullInt64{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
0,
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
sql.NullString{},
|
||||
false,
|
||||
now,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, service.RequestTypeStream, log.RequestType)
|
||||
require.True(t, log.Stream)
|
||||
require.False(t, log.OpenAIWSMode)
|
||||
})
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type userGroupRateRepository struct {
|
||||
@@ -41,6 +42,59 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByUserIDs 批量获取多个用户的专属分组倍率。
|
||||
// 返回结构:map[userID]map[groupID]rate
|
||||
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
|
||||
result := make(map[int64]map[int64]float64, len(userIDs))
|
||||
if len(userIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
uniqueIDs := make([]int64, 0, len(userIDs))
|
||||
seen := make(map[int64]struct{}, len(userIDs))
|
||||
for _, userID := range userIDs {
|
||||
if userID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[userID]; exists {
|
||||
continue
|
||||
}
|
||||
seen[userID] = struct{}{}
|
||||
uniqueIDs = append(uniqueIDs, userID)
|
||||
result[userID] = make(map[int64]float64)
|
||||
}
|
||||
if len(uniqueIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT user_id, group_id, rate_multiplier
|
||||
FROM user_group_rate_multipliers
|
||||
WHERE user_id = ANY($1)
|
||||
`, pq.Array(uniqueIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
for rows.Next() {
|
||||
var userID int64
|
||||
var groupID int64
|
||||
var rate float64
|
||||
if err := rows.Scan(&userID, &groupID, &rate); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := result[userID]; !ok {
|
||||
result[userID] = make(map[int64]float64)
|
||||
}
|
||||
result[userID][groupID] = rate
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||
@@ -65,33 +119,43 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
|
||||
|
||||
// 分离需要删除和需要 upsert 的记录
|
||||
var toDelete []int64
|
||||
toUpsert := make(map[int64]float64)
|
||||
upsertGroupIDs := make([]int64, 0, len(rates))
|
||||
upsertRates := make([]float64, 0, len(rates))
|
||||
for groupID, rate := range rates {
|
||||
if rate == nil {
|
||||
toDelete = append(toDelete, groupID)
|
||||
} else {
|
||||
toUpsert[groupID] = *rate
|
||||
upsertGroupIDs = append(upsertGroupIDs, groupID)
|
||||
upsertRates = append(upsertRates, *rate)
|
||||
}
|
||||
}
|
||||
|
||||
// 删除指定的记录
|
||||
for _, groupID := range toDelete {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`,
|
||||
userID, groupID)
|
||||
if err != nil {
|
||||
if len(toDelete) > 0 {
|
||||
if _, err := r.sql.ExecContext(ctx,
|
||||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`,
|
||||
userID, pq.Array(toDelete)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Upsert 记录
|
||||
now := time.Now()
|
||||
for groupID, rate := range toUpsert {
|
||||
if len(upsertGroupIDs) > 0 {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $4)
|
||||
ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4
|
||||
`, userID, groupID, rate, now)
|
||||
SELECT
|
||||
$1::bigint,
|
||||
data.group_id,
|
||||
data.rate_multiplier,
|
||||
$2::timestamptz,
|
||||
$2::timestamptz
|
||||
FROM unnest($3::bigint[], $4::double precision[]) AS data(group_id, rate_multiplier)
|
||||
ON CONFLICT (user_id, group_id)
|
||||
DO UPDATE SET
|
||||
rate_multiplier = EXCLUDED.rate_multiplier,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
`, userID, now, pq.Array(upsertGroupIDs), pq.Array(upsertRates))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -61,6 +61,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
SetBalance(userIn.Balance).
|
||||
SetConcurrency(userIn.Concurrency).
|
||||
SetStatus(userIn.Status).
|
||||
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
@@ -143,6 +144,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
SetBalance(userIn.Balance).
|
||||
SetConcurrency(userIn.Concurrency).
|
||||
SetStatus(userIn.Status).
|
||||
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
|
||||
SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||||
@@ -363,6 +366,65 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。
|
||||
func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) {
|
||||
if deltaBytes <= 0 {
|
||||
user, err := r.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return user.SoraStorageUsedBytes, nil
|
||||
}
|
||||
var newUsed int64
|
||||
err := scanSingleRow(ctx, r.sql, `
|
||||
UPDATE users
|
||||
SET sora_storage_used_bytes = sora_storage_used_bytes + $2
|
||||
WHERE id = $1
|
||||
AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3)
|
||||
RETURNING sora_storage_used_bytes
|
||||
`, []any{userID, deltaBytes, effectiveQuota}, &newUsed)
|
||||
if err == nil {
|
||||
return newUsed, nil
|
||||
}
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// 区分用户不存在和配额冲突
|
||||
exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx)
|
||||
if existsErr != nil {
|
||||
return 0, existsErr
|
||||
}
|
||||
if !exists {
|
||||
return 0, service.ErrUserNotFound
|
||||
}
|
||||
return 0, service.ErrSoraStorageQuotaExceeded
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。
|
||||
func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) {
|
||||
if deltaBytes <= 0 {
|
||||
user, err := r.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return user.SoraStorageUsedBytes, nil
|
||||
}
|
||||
var newUsed int64
|
||||
err := scanSingleRow(ctx, r.sql, `
|
||||
UPDATE users
|
||||
SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0)
|
||||
WHERE id = $1
|
||||
RETURNING sora_storage_used_bytes
|
||||
`, []any{userID, deltaBytes}, &newUsed)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, service.ErrUserNotFound
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return newUsed, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
|
||||
}
|
||||
|
||||
@@ -186,11 +186,12 @@ func TestAPIContracts(t *testing.T) {
|
||||
"image_price_1k": null,
|
||||
"image_price_2k": null,
|
||||
"image_price_4k": null,
|
||||
"sora_image_price_360": null,
|
||||
"sora_image_price_540": null,
|
||||
"sora_video_price_per_request": null,
|
||||
"sora_video_price_per_request_hd": null,
|
||||
"claude_code_only": false,
|
||||
"sora_image_price_360": null,
|
||||
"sora_image_price_540": null,
|
||||
"sora_storage_quota_bytes": 0,
|
||||
"sora_video_price_per_request": null,
|
||||
"sora_video_price_per_request_hd": null,
|
||||
"claude_code_only": false,
|
||||
"fallback_group_id": null,
|
||||
"fallback_group_id_on_invalid_request": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
@@ -384,10 +385,12 @@ func TestAPIContracts(t *testing.T) {
|
||||
"user_id": 1,
|
||||
"api_key_id": 100,
|
||||
"account_id": 200,
|
||||
"request_id": "req_123",
|
||||
"model": "claude-3",
|
||||
"group_id": null,
|
||||
"subscription_id": null,
|
||||
"request_id": "req_123",
|
||||
"model": "claude-3",
|
||||
"request_type": "stream",
|
||||
"openai_ws_mode": false,
|
||||
"group_id": null,
|
||||
"subscription_id": null,
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"cache_creation_tokens": 1,
|
||||
@@ -500,11 +503,12 @@ func TestAPIContracts(t *testing.T) {
|
||||
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
|
||||
"fallback_model_antigravity": "gemini-2.5-pro",
|
||||
"fallback_model_gemini": "gemini-2.5-pro",
|
||||
"fallback_model_openai": "gpt-4o",
|
||||
"enable_identity_patch": true,
|
||||
"identity_patch_prompt": "",
|
||||
"invitation_code_enabled": false,
|
||||
"home_content": "",
|
||||
"fallback_model_openai": "gpt-4o",
|
||||
"enable_identity_patch": true,
|
||||
"identity_patch_prompt": "",
|
||||
"sora_client_enabled": false,
|
||||
"invitation_code_enabled": false,
|
||||
"home_content": "",
|
||||
"hide_ccs_import_button": false,
|
||||
"purchase_subscription_enabled": false,
|
||||
"purchase_subscription_url": ""
|
||||
@@ -619,7 +623,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil)
|
||||
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
jwtAuth := func(c *gin.Context) {
|
||||
@@ -1555,11 +1559,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
|
||||
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
|
||||
clientIP := ip.GetTrustedClientIP(c)
|
||||
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
|
||||
allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist)
|
||||
if !allowed {
|
||||
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
|
||||
return
|
||||
|
||||
@@ -80,17 +80,25 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
abortWithGoogleError(c, 403, "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
|
||||
abortWithGoogleError(c, 403, err.Error())
|
||||
return
|
||||
}
|
||||
_ = subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription)
|
||||
_ = subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription)
|
||||
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
|
||||
abortWithGoogleError(c, 429, err.Error())
|
||||
|
||||
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
|
||||
if err != nil {
|
||||
status := 403
|
||||
if errors.Is(err, service.ErrDailyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
|
||||
status = 429
|
||||
}
|
||||
abortWithGoogleError(c, status, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
|
||||
if needsMaintenance {
|
||||
maintenanceCopy := *subscription
|
||||
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
||||
}
|
||||
} else {
|
||||
if apiKey.User.Balance <= 0 {
|
||||
abortWithGoogleError(c, 403, "Insufficient account balance")
|
||||
|
||||
@@ -23,6 +23,15 @@ type fakeAPIKeyRepo struct {
|
||||
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
|
||||
}
|
||||
|
||||
type fakeGoogleSubscriptionRepo struct {
|
||||
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
||||
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
||||
activateWindow func(ctx context.Context, id int64, start time.Time) error
|
||||
resetDaily func(ctx context.Context, id int64, start time.Time) error
|
||||
resetWeekly func(ctx context.Context, id int64, start time.Time) error
|
||||
resetMonthly func(ctx context.Context, id int64, start time.Time) error
|
||||
}
|
||||
|
||||
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
@@ -87,6 +96,85 @@ func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt tim
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
if f.getActive != nil {
|
||||
return f.getActive(ctx, userID, groupID)
|
||||
}
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
|
||||
if f.updateStatus != nil {
|
||||
return f.updateStatus(ctx, subscriptionID, status)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
|
||||
if f.activateWindow != nil {
|
||||
return f.activateWindow(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, start time.Time) error {
|
||||
if f.resetDaily != nil {
|
||||
return f.resetDaily(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, start time.Time) error {
|
||||
if f.resetWeekly != nil {
|
||||
return f.resetWeekly(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, start time.Time) error {
|
||||
if f.resetMonthly != nil {
|
||||
return f.resetMonthly(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeGoogleSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type googleErrorResponse struct {
|
||||
Error struct {
|
||||
Code int `json:"code"`
|
||||
@@ -505,3 +593,85 @@ func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testi
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 1, touchCalls)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_SubscriptionLimitExceededReturns429(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
limit := 1.0
|
||||
group := &service.Group{
|
||||
ID: 77,
|
||||
Name: "gemini-sub",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformGemini,
|
||||
Hydrated: true,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
DailyLimitUSD: &limit,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 999,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 501,
|
||||
UserID: user.ID,
|
||||
Key: "google-sub-limit",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
})
|
||||
|
||||
now := time.Now()
|
||||
sub := &service.UserSubscription{
|
||||
ID: 601,
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
DailyWindowStart: &now,
|
||||
DailyUsageUSD: 10,
|
||||
}
|
||||
subscriptionService := service.NewSubscriptionService(nil, fakeGoogleSubscriptionRepo{
|
||||
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
if userID != user.ID || groupID != group.ID {
|
||||
return nil, service.ErrSubscriptionNotFound
|
||||
}
|
||||
clone := *sub
|
||||
return &clone, nil
|
||||
},
|
||||
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
|
||||
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
}, nil, nil, &config.Config{RunMode: config.RunModeStandard})
|
||||
|
||||
r := gin.New()
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, &config.Config{RunMode: config.RunModeStandard}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("x-goog-api-key", apiKey.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusTooManyRequests, resp.Error.Code)
|
||||
require.Equal(t, "RESOURCE_EXHAUSTED", resp.Error.Status)
|
||||
require.Contains(t, resp.Error.Message, "daily usage limit exceeded")
|
||||
}
|
||||
|
||||
@@ -54,6 +54,10 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
if isAPIRoutePath(c) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.Enabled {
|
||||
// Generate nonce for this request
|
||||
@@ -73,6 +77,18 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func isAPIRoutePath(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil || c.Request.URL == nil {
|
||||
return false
|
||||
}
|
||||
path := c.Request.URL.Path
|
||||
return strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/v1beta/") ||
|
||||
strings.HasPrefix(path, "/antigravity/") ||
|
||||
strings.HasPrefix(path, "/sora/") ||
|
||||
strings.HasPrefix(path, "/responses")
|
||||
}
|
||||
|
||||
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
|
||||
// This allows the application to work correctly even if the config file has an older CSP policy.
|
||||
func enhanceCSPPolicy(policy string) string {
|
||||
|
||||
@@ -131,6 +131,26 @@ func TestSecurityHeaders(t *testing.T) {
|
||||
assert.Contains(t, csp, CloudflareInsightsDomain)
|
||||
})
|
||||
|
||||
t.Run("api_route_skips_csp_nonce_generation", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
Policy: "default-src 'self'; script-src 'self' __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
|
||||
assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
|
||||
assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
|
||||
assert.Empty(t, w.Header().Get("Content-Security-Policy"))
|
||||
assert.Empty(t, GetNonceFromContext(c))
|
||||
})
|
||||
|
||||
t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{
|
||||
Enabled: true,
|
||||
|
||||
@@ -75,6 +75,7 @@ func registerRoutes(
|
||||
// 注册各模块路由
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg)
|
||||
}
|
||||
|
||||
@@ -55,6 +55,9 @@ func RegisterAdminRoutes(
|
||||
// 系统设置
|
||||
registerSettingsRoutes(admin, h)
|
||||
|
||||
// 数据管理
|
||||
registerDataManagementRoutes(admin, h)
|
||||
|
||||
// 运维监控(Ops)
|
||||
registerOpsRoutes(admin, h)
|
||||
|
||||
@@ -231,6 +234,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
|
||||
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
|
||||
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
|
||||
accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats)
|
||||
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
|
||||
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
|
||||
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
|
||||
@@ -370,6 +374,38 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
// 流超时处理配置
|
||||
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
|
||||
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
|
||||
// Sora S3 存储配置
|
||||
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
|
||||
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
|
||||
adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection)
|
||||
adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles)
|
||||
adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile)
|
||||
adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile)
|
||||
adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile)
|
||||
adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile)
|
||||
}
|
||||
}
|
||||
|
||||
func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
dataManagement := admin.Group("/data-management")
|
||||
{
|
||||
dataManagement.GET("/agent/health", h.Admin.DataManagement.GetAgentHealth)
|
||||
dataManagement.GET("/config", h.Admin.DataManagement.GetConfig)
|
||||
dataManagement.PUT("/config", h.Admin.DataManagement.UpdateConfig)
|
||||
dataManagement.GET("/sources/:source_type/profiles", h.Admin.DataManagement.ListSourceProfiles)
|
||||
dataManagement.POST("/sources/:source_type/profiles", h.Admin.DataManagement.CreateSourceProfile)
|
||||
dataManagement.PUT("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.UpdateSourceProfile)
|
||||
dataManagement.DELETE("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.DeleteSourceProfile)
|
||||
dataManagement.POST("/sources/:source_type/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveSourceProfile)
|
||||
dataManagement.POST("/s3/test", h.Admin.DataManagement.TestS3)
|
||||
dataManagement.GET("/s3/profiles", h.Admin.DataManagement.ListS3Profiles)
|
||||
dataManagement.POST("/s3/profiles", h.Admin.DataManagement.CreateS3Profile)
|
||||
dataManagement.PUT("/s3/profiles/:profile_id", h.Admin.DataManagement.UpdateS3Profile)
|
||||
dataManagement.DELETE("/s3/profiles/:profile_id", h.Admin.DataManagement.DeleteS3Profile)
|
||||
dataManagement.POST("/s3/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveS3Profile)
|
||||
dataManagement.POST("/backups", h.Admin.DataManagement.CreateBackupJob)
|
||||
dataManagement.GET("/backups", h.Admin.DataManagement.ListBackupJobs)
|
||||
dataManagement.GET("/backups/:job_id", h.Admin.DataManagement.GetBackupJob)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ func RegisterGatewayRoutes(
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
// OpenAI Responses API
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
|
||||
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
|
||||
gateway.POST("/chat/completions", func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
@@ -69,6 +70,7 @@ func RegisterGatewayRoutes(
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.ResponsesWebSocket)
|
||||
|
||||
// Antigravity 模型列表
|
||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
|
||||
|
||||
33
backend/internal/server/routes/sora_client.go
Normal file
33
backend/internal/server/routes/sora_client.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。
|
||||
func RegisterSoraClientRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
) {
|
||||
if h.SoraClient == nil {
|
||||
return
|
||||
}
|
||||
|
||||
authenticated := v1.Group("/sora")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
{
|
||||
authenticated.POST("/generate", h.SoraClient.Generate)
|
||||
authenticated.GET("/generations", h.SoraClient.ListGenerations)
|
||||
authenticated.GET("/generations/:id", h.SoraClient.GetGeneration)
|
||||
authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration)
|
||||
authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration)
|
||||
authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage)
|
||||
authenticated.GET("/quota", h.SoraClient.GetQuota)
|
||||
authenticated.GET("/models", h.SoraClient.GetModels)
|
||||
authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus)
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,8 @@ package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"hash/fnv"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -50,6 +52,14 @@ type Account struct {
|
||||
AccountGroups []AccountGroup
|
||||
GroupIDs []int64
|
||||
Groups []*Group
|
||||
|
||||
// model_mapping 热路径缓存(非持久化字段)
|
||||
modelMappingCache map[string]string
|
||||
modelMappingCacheReady bool
|
||||
modelMappingCacheCredentialsPtr uintptr
|
||||
modelMappingCacheRawPtr uintptr
|
||||
modelMappingCacheRawLen int
|
||||
modelMappingCacheRawSig uint64
|
||||
}
|
||||
|
||||
type TempUnschedulableRule struct {
|
||||
@@ -349,6 +359,39 @@ func parseTempUnschedInt(value any) int {
|
||||
}
|
||||
|
||||
func (a *Account) GetModelMapping() map[string]string {
|
||||
credentialsPtr := mapPtr(a.Credentials)
|
||||
rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
|
||||
rawPtr := mapPtr(rawMapping)
|
||||
rawLen := len(rawMapping)
|
||||
rawSig := uint64(0)
|
||||
rawSigReady := false
|
||||
|
||||
if a.modelMappingCacheReady &&
|
||||
a.modelMappingCacheCredentialsPtr == credentialsPtr &&
|
||||
a.modelMappingCacheRawPtr == rawPtr &&
|
||||
a.modelMappingCacheRawLen == rawLen {
|
||||
rawSig = modelMappingSignature(rawMapping)
|
||||
rawSigReady = true
|
||||
if a.modelMappingCacheRawSig == rawSig {
|
||||
return a.modelMappingCache
|
||||
}
|
||||
}
|
||||
|
||||
mapping := a.resolveModelMapping(rawMapping)
|
||||
if !rawSigReady {
|
||||
rawSig = modelMappingSignature(rawMapping)
|
||||
}
|
||||
|
||||
a.modelMappingCache = mapping
|
||||
a.modelMappingCacheReady = true
|
||||
a.modelMappingCacheCredentialsPtr = credentialsPtr
|
||||
a.modelMappingCacheRawPtr = rawPtr
|
||||
a.modelMappingCacheRawLen = rawLen
|
||||
a.modelMappingCacheRawSig = rawSig
|
||||
return mapping
|
||||
}
|
||||
|
||||
func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
|
||||
if a.Credentials == nil {
|
||||
// Antigravity 平台使用默认映射
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
@@ -356,32 +399,31 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
raw, ok := a.Credentials["model_mapping"]
|
||||
if !ok || raw == nil {
|
||||
if len(rawMapping) == 0 {
|
||||
// Antigravity 平台使用默认映射
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if m, ok := raw.(map[string]any); ok {
|
||||
result := make(map[string]string)
|
||||
for k, v := range m {
|
||||
if s, ok := v.(string); ok {
|
||||
result[k] = s
|
||||
}
|
||||
}
|
||||
if len(result) > 0 {
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
ensureAntigravityDefaultPassthroughs(result, []string{
|
||||
"gemini-3-flash",
|
||||
"gemini-3.1-pro-high",
|
||||
"gemini-3.1-pro-low",
|
||||
})
|
||||
}
|
||||
return result
|
||||
|
||||
result := make(map[string]string)
|
||||
for k, v := range rawMapping {
|
||||
if s, ok := v.(string); ok {
|
||||
result[k] = s
|
||||
}
|
||||
}
|
||||
if len(result) > 0 {
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
ensureAntigravityDefaultPassthroughs(result, []string{
|
||||
"gemini-3-flash",
|
||||
"gemini-3.1-pro-high",
|
||||
"gemini-3.1-pro-low",
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Antigravity 平台使用默认映射
|
||||
if a.Platform == domain.PlatformAntigravity {
|
||||
return domain.DefaultAntigravityModelMapping
|
||||
@@ -389,6 +431,37 @@ func (a *Account) GetModelMapping() map[string]string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapPtr(m map[string]any) uintptr {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
return reflect.ValueOf(m).Pointer()
|
||||
}
|
||||
|
||||
func modelMappingSignature(rawMapping map[string]any) uint64 {
|
||||
if len(rawMapping) == 0 {
|
||||
return 0
|
||||
}
|
||||
keys := make([]string, 0, len(rawMapping))
|
||||
for k := range rawMapping {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
h := fnv.New64a()
|
||||
for _, k := range keys {
|
||||
_, _ = h.Write([]byte(k))
|
||||
_, _ = h.Write([]byte{0})
|
||||
if v, ok := rawMapping[k].(string); ok {
|
||||
_, _ = h.Write([]byte(v))
|
||||
} else {
|
||||
_, _ = h.Write([]byte{1})
|
||||
}
|
||||
_, _ = h.Write([]byte{0xff})
|
||||
}
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) {
|
||||
if mapping == nil || model == "" {
|
||||
return
|
||||
@@ -742,6 +815,159 @@ func (a *Account) IsOpenAIPassthroughEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsOpenAIResponsesWebSocketV2Enabled 返回 OpenAI 账号是否开启 Responses WebSocket v2。
|
||||
//
|
||||
// 分类型新字段:
|
||||
// - OAuth 账号:accounts.extra.openai_oauth_responses_websockets_v2_enabled
|
||||
// - API Key 账号:accounts.extra.openai_apikey_responses_websockets_v2_enabled
|
||||
//
|
||||
// 兼容字段:
|
||||
// - accounts.extra.responses_websockets_v2_enabled
|
||||
// - accounts.extra.openai_ws_enabled(历史开关)
|
||||
//
|
||||
// 优先级:
|
||||
// 1. 按账号类型读取分类型字段
|
||||
// 2. 分类型字段缺失时,回退兼容字段
|
||||
func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
|
||||
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if a.IsOpenAIOAuth() {
|
||||
if enabled, ok := a.Extra["openai_oauth_responses_websockets_v2_enabled"].(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
if a.IsOpenAIApiKey() {
|
||||
if enabled, ok := a.Extra["openai_apikey_responses_websockets_v2_enabled"].(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
if enabled, ok := a.Extra["responses_websockets_v2_enabled"].(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
if enabled, ok := a.Extra["openai_ws_enabled"].(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
OpenAIWSIngressModeOff = "off"
|
||||
OpenAIWSIngressModeShared = "shared"
|
||||
OpenAIWSIngressModeDedicated = "dedicated"
|
||||
)
|
||||
|
||||
func normalizeOpenAIWSIngressMode(mode string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case OpenAIWSIngressModeOff:
|
||||
return OpenAIWSIngressModeOff
|
||||
case OpenAIWSIngressModeShared:
|
||||
return OpenAIWSIngressModeShared
|
||||
case OpenAIWSIngressModeDedicated:
|
||||
return OpenAIWSIngressModeDedicated
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpenAIWSIngressDefaultMode(mode string) string {
|
||||
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
|
||||
return normalized
|
||||
}
|
||||
return OpenAIWSIngressModeShared
|
||||
}
|
||||
|
||||
// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
|
||||
//
|
||||
// 优先级:
|
||||
// 1. 分类型 mode 新字段(string)
|
||||
// 2. 分类型 enabled 旧字段(bool)
|
||||
// 3. 兼容 enabled 旧字段(bool)
|
||||
// 4. defaultMode(非法时回退 shared)
|
||||
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
|
||||
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
|
||||
if a == nil || !a.IsOpenAI() {
|
||||
return OpenAIWSIngressModeOff
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return resolvedDefault
|
||||
}
|
||||
|
||||
resolveModeString := func(key string) (string, bool) {
|
||||
raw, ok := a.Extra[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
mode, ok := raw.(string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
normalized := normalizeOpenAIWSIngressMode(mode)
|
||||
if normalized == "" {
|
||||
return "", false
|
||||
}
|
||||
return normalized, true
|
||||
}
|
||||
resolveBoolMode := func(key string) (string, bool) {
|
||||
raw, ok := a.Extra[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
enabled, ok := raw.(bool)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if enabled {
|
||||
return OpenAIWSIngressModeShared, true
|
||||
}
|
||||
return OpenAIWSIngressModeOff, true
|
||||
}
|
||||
|
||||
if a.IsOpenAIOAuth() {
|
||||
if mode, ok := resolveModeString("openai_oauth_responses_websockets_v2_mode"); ok {
|
||||
return mode
|
||||
}
|
||||
if mode, ok := resolveBoolMode("openai_oauth_responses_websockets_v2_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
}
|
||||
if a.IsOpenAIApiKey() {
|
||||
if mode, ok := resolveModeString("openai_apikey_responses_websockets_v2_mode"); ok {
|
||||
return mode
|
||||
}
|
||||
if mode, ok := resolveBoolMode("openai_apikey_responses_websockets_v2_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
}
|
||||
if mode, ok := resolveBoolMode("responses_websockets_v2_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
|
||||
return mode
|
||||
}
|
||||
return resolvedDefault
|
||||
}
|
||||
|
||||
// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
|
||||
// 字段:accounts.extra.openai_ws_force_http。
|
||||
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
|
||||
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
enabled, ok := a.Extra["openai_ws_force_http"].(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// IsOpenAIWSAllowStoreRecoveryEnabled 返回账号级 store 恢复开关。
|
||||
// 字段:accounts.extra.openai_ws_allow_store_recovery。
|
||||
func (a *Account) IsOpenAIWSAllowStoreRecoveryEnabled() bool {
|
||||
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
enabled, ok := a.Extra["openai_ws_allow_store_recovery"].(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。
|
||||
func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
|
||||
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user