Merge tag 'v0.1.90' into merge/upstream-v0.1.90

注册邮箱域名白名单策略上线,后台大数据场景性能大幅优化。

- 注册邮箱域名白名单:支持管理员配置允许注册的邮箱域名策略
- Keys 页面表单筛选:用户 /keys 页面支持按条件筛选 API Key
- Settings 页面分 Tab 拆分:管理后台设置页面按功能模块分 Tab 展示

- 后台大数据场景加载性能优化:仪表盘/用户/账号/Ops 页面大数据集加载显著提速
- Usage 大表分页优化:默认避免全量 COUNT(*),大幅降低分页查询耗时
- 消除重复的 normalizeAccountIDList,补充新增组件的单元测试
- 清理无用文件和过时文档,精简项目结构
- EmailVerifyView 硬编码英文字符串替换为 i18n 调用

- 修复 Anthropic 平台无限流重置时间的 429 误标记账号限流问题
- 修复自定义菜单页面管理员视角菜单不生效问题
- 修复 Ops 错误详情弹窗未展示真实上游 payload 的问题
- 修复充值/订阅菜单 icon 显示问题

# Conflicts:
#	.gitignore
#	backend/cmd/server/VERSION
#	backend/ent/group.go
#	backend/ent/runtime/runtime.go
#	backend/ent/schema/group.go
#	backend/go.sum
#	backend/internal/handler/admin/account_handler.go
#	backend/internal/handler/admin/dashboard_handler.go
#	backend/internal/pkg/usagestats/usage_log_types.go
#	backend/internal/repository/group_repo.go
#	backend/internal/repository/usage_log_repo.go
#	backend/internal/server/middleware/security_headers.go
#	backend/internal/server/router.go
#	backend/internal/service/account_usage_service.go
#	backend/internal/service/admin_service_bulk_update_test.go
#	backend/internal/service/dashboard_service.go
#	backend/internal/service/gateway_service.go
#	frontend/src/api/admin/dashboard.ts
#	frontend/src/components/account/BulkEditAccountModal.vue
#	frontend/src/components/charts/GroupDistributionChart.vue
#	frontend/src/components/layout/AppSidebar.vue
#	frontend/src/i18n/locales/en.ts
#	frontend/src/i18n/locales/zh.ts
#	frontend/src/views/admin/GroupsView.vue
#	frontend/src/views/admin/SettingsView.vue
#	frontend/src/views/admin/UsageView.vue
#	frontend/src/views/user/PurchaseSubscriptionView.vue
This commit is contained in:
erio
2026-03-04 19:58:38 +08:00
461 changed files with 63392 additions and 6617 deletions

View File

@@ -30,6 +30,14 @@ const (
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// UMQ用户消息队列模式常量
const (
// UMQModeSerialize: 账号级串行锁 + RPM 自适应延迟
UMQModeSerialize = "serialize"
// UMQModeThrottle: 仅 RPM 自适应前置延迟,不阻塞并发
UMQModeThrottle = "throttle"
)
// 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
const (
@@ -265,8 +273,13 @@ type CSPConfig struct {
}
type ProxyFallbackConfig struct {
// AllowDirectOnError 当代理初始化失败时是否允许回退直连。
// 默认 false避免因代理配置错误导致 IP 泄露/关联。
// AllowDirectOnError 当辅助服务的代理初始化失败时是否允许回退直连。
// 仅影响以下非 AI 账号连接的辅助服务:
// - GitHub Release 更新检查
// - 定价数据拉取
// 不影响 AI 账号网关连接Claude/OpenAI/Gemini/Antigravity
// 这些关键路径的代理失败始终返回错误,不会回退直连。
// 默认 false避免因代理配置错误导致服务器真实 IP 泄露。
AllowDirectOnError bool `mapstructure:"allow_direct_on_error"`
}
@@ -364,6 +377,8 @@ type GatewayConfig struct {
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
// OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
@@ -448,6 +463,147 @@ type GatewayConfig struct {
UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"`
// ModelsListCacheTTLSeconds: /v1/models 模型列表短缓存 TTL
ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"`
// UserMessageQueue: 用户消息串行队列配置
// 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟
UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"`
}
// UserMessageQueueConfig 用户消息串行队列配置
// 用于 Anthropic OAuth/SetupToken 账号的用户消息串行化发送
type UserMessageQueueConfig struct {
// Mode: 模式选择
// "serialize" = 账号级串行锁 + RPM 自适应延迟
// "throttle" = 仅 RPM 自适应前置延迟,不阻塞并发
// "" = 禁用(默认)
Mode string `mapstructure:"mode"`
// Enabled: 已废弃,仅向后兼容(等同于 mode: "serialize"
Enabled bool `mapstructure:"enabled"`
// LockTTLMs: 串行锁 TTL毫秒应大于最长请求时间
LockTTLMs int `mapstructure:"lock_ttl_ms"`
// WaitTimeoutMs: 等待获取锁的超时时间(毫秒)
WaitTimeoutMs int `mapstructure:"wait_timeout_ms"`
// MinDelayMs: RPM 自适应延迟下限(毫秒)
MinDelayMs int `mapstructure:"min_delay_ms"`
// MaxDelayMs: RPM 自适应延迟上限(毫秒)
MaxDelayMs int `mapstructure:"max_delay_ms"`
// CleanupIntervalSeconds: 孤儿锁清理间隔0 表示禁用
CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"`
}
// WaitTimeout 返回等待超时的 time.Duration
func (c *UserMessageQueueConfig) WaitTimeout() time.Duration {
if c.WaitTimeoutMs <= 0 {
return 30 * time.Second
}
return time.Duration(c.WaitTimeoutMs) * time.Millisecond
}
// GetEffectiveMode 返回生效的模式
// 注意Mode 字段已在 load() 中做过白名单校验和规范化,此处无需重复验证
func (c *UserMessageQueueConfig) GetEffectiveMode() string {
if c.Mode == UMQModeSerialize || c.Mode == UMQModeThrottle {
return c.Mode
}
if c.Enabled {
return UMQModeSerialize // 向后兼容
}
return ""
}
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
type GatewayOpenAIWSConfig struct {
// ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false关闭时保持 legacy 行为)
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
// IngressModeDefault: ingress 默认模式off/shared/dedicated
IngressModeDefault string `mapstructure:"ingress_mode_default"`
// Enabled: 全局总开关(默认 true
Enabled bool `mapstructure:"enabled"`
// OAuthEnabled: 是否允许 OpenAI OAuth 账号使用 WS
OAuthEnabled bool `mapstructure:"oauth_enabled"`
// APIKeyEnabled: 是否允许 OpenAI API Key 账号使用 WS
APIKeyEnabled bool `mapstructure:"apikey_enabled"`
// ForceHTTP: 全局强制 HTTP用于紧急回滚
ForceHTTP bool `mapstructure:"force_http"`
// AllowStoreRecovery: 允许在 WSv2 下按策略恢复 store=true默认 false
AllowStoreRecovery bool `mapstructure:"allow_store_recovery"`
// IngressPreviousResponseRecoveryEnabled: ingress 模式收到 previous_response_not_found 时,是否允许自动去掉 previous_response_id 重试一次(默认 true
IngressPreviousResponseRecoveryEnabled bool `mapstructure:"ingress_previous_response_recovery_enabled"`
// StoreDisabledConnMode: store=false 且无可复用会话连接时的建连策略strict/adaptive/off
// - strict: 强制新建连接(隔离优先)
// - adaptive: 仅在高风险失败后强制新建连接(性能与隔离折中)
// - off: 不强制新建连接(复用优先)
StoreDisabledConnMode string `mapstructure:"store_disabled_conn_mode"`
// StoreDisabledForceNewConn: store=false 且无可复用粘连连接时是否强制新建连接(默认 true保障会话隔离
// 兼容旧配置;当 StoreDisabledConnMode 为空时才生效。
StoreDisabledForceNewConn bool `mapstructure:"store_disabled_force_new_conn"`
// PrewarmGenerateEnabled: 是否启用 WSv2 generate=false 预热(默认 false
PrewarmGenerateEnabled bool `mapstructure:"prewarm_generate_enabled"`
// Feature 开关v2 优先于 v1
ResponsesWebsockets bool `mapstructure:"responses_websockets"`
ResponsesWebsocketsV2 bool `mapstructure:"responses_websockets_v2"`
// 连接池参数
MaxConnsPerAccount int `mapstructure:"max_conns_per_account"`
MinIdlePerAccount int `mapstructure:"min_idle_per_account"`
MaxIdlePerAccount int `mapstructure:"max_idle_per_account"`
// DynamicMaxConnsByAccountConcurrencyEnabled: 是否按账号并发动态计算连接池上限
DynamicMaxConnsByAccountConcurrencyEnabled bool `mapstructure:"dynamic_max_conns_by_account_concurrency_enabled"`
// OAuthMaxConnsFactor: OAuth 账号连接池系数effective=ceil(concurrency*factor)
OAuthMaxConnsFactor float64 `mapstructure:"oauth_max_conns_factor"`
// APIKeyMaxConnsFactor: API Key 账号连接池系数effective=ceil(concurrency*factor)
APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"`
DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"`
QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"`
// EventFlushBatchSize: WS 流式写出批量 flush 阈值(事件条数)
EventFlushBatchSize int `mapstructure:"event_flush_batch_size"`
// EventFlushIntervalMS: WS 流式写出最大等待时间毫秒0 表示仅按 batch 触发
EventFlushIntervalMS int `mapstructure:"event_flush_interval_ms"`
// PrewarmCooldownMS: 连接池预热触发冷却时间(毫秒)
PrewarmCooldownMS int `mapstructure:"prewarm_cooldown_ms"`
// FallbackCooldownSeconds: WS 回退冷却窗口,避免 WS/HTTP 抖动0 表示关闭冷却
FallbackCooldownSeconds int `mapstructure:"fallback_cooldown_seconds"`
// RetryBackoffInitialMS: WS 重试初始退避(毫秒);<=0 表示关闭退避
RetryBackoffInitialMS int `mapstructure:"retry_backoff_initial_ms"`
// RetryBackoffMaxMS: WS 重试最大退避(毫秒)
RetryBackoffMaxMS int `mapstructure:"retry_backoff_max_ms"`
// RetryJitterRatio: WS 重试退避抖动比例0-1
RetryJitterRatio float64 `mapstructure:"retry_jitter_ratio"`
// RetryTotalBudgetMS: WS 单次请求重试总预算毫秒0 表示关闭预算限制
RetryTotalBudgetMS int `mapstructure:"retry_total_budget_ms"`
// PayloadLogSampleRate: payload_schema 日志采样率0-1
PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"`
// 账号调度与粘连参数
LBTopK int `mapstructure:"lb_top_k"`
// StickySessionTTLSeconds: session_hash -> account_id 粘连 TTL
StickySessionTTLSeconds int `mapstructure:"sticky_session_ttl_seconds"`
// SessionHashReadOldFallback: 会话哈希迁移期是否允许“新 key 未命中时回退读旧 SHA-256 key”
SessionHashReadOldFallback bool `mapstructure:"session_hash_read_old_fallback"`
// SessionHashDualWriteOld: 会话哈希迁移期是否双写旧 SHA-256 key短 TTL
SessionHashDualWriteOld bool `mapstructure:"session_hash_dual_write_old"`
// MetadataBridgeEnabled: RequestMetadata 迁移期是否保留旧 ctxkey.* 兼容桥接
MetadataBridgeEnabled bool `mapstructure:"metadata_bridge_enabled"`
// StickyResponseIDTTLSeconds: response_id -> account_id 粘连 TTL
StickyResponseIDTTLSeconds int `mapstructure:"sticky_response_id_ttl_seconds"`
// StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退)
StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"`
SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"`
}
// GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重。
type GatewayOpenAIWSSchedulerScoreWeights struct {
Priority float64 `mapstructure:"priority"`
Load float64 `mapstructure:"load"`
Queue float64 `mapstructure:"queue"`
ErrorRate float64 `mapstructure:"error_rate"`
TTFT float64 `mapstructure:"ttft"`
}
// GatewayUsageRecordConfig 使用量记录异步队列配置
@@ -716,7 +872,8 @@ type DefaultConfig struct {
}
type RateLimitConfig struct {
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
OAuth401CooldownMinutes int `mapstructure:"oauth_401_cooldown_minutes"` // OAuth 401临时不可调度冷却(分钟)
}
// APIKeyAuthCacheConfig API Key 认证缓存配置
@@ -886,6 +1043,20 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
// 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。
// 新键未配置(<=0时回退旧键新键优先。
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
}
// Normalize UMQ mode: 白名单校验,非法值在加载时一次性 warn 并清空
if m := cfg.Gateway.UserMessageQueue.Mode; m != "" && m != UMQModeSerialize && m != UMQModeThrottle {
slog.Warn("invalid user_message_queue mode, disabling",
"mode", m,
"valid_modes", []string{UMQModeSerialize, UMQModeThrottle})
cfg.Gateway.UserMessageQueue.Mode = ""
}
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
if cfg.Totp.EncryptionKey == "" {
@@ -945,7 +1116,7 @@ func setDefaults() {
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.trusted_proxies", []string{})
viper.SetDefault("server.max_request_body_size", int64(100*1024*1024))
viper.SetDefault("server.max_request_body_size", int64(256*1024*1024))
// H2C 默认配置
viper.SetDefault("server.h2c.enabled", false)
viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流
@@ -1002,6 +1173,9 @@ func setDefaults() {
viper.SetDefault("security.csp.policy", DefaultCSPPolicy)
viper.SetDefault("security.proxy_probe.insecure_skip_verify", false)
// Security - disable direct fallback on proxy error
viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false)
// Billing
viper.SetDefault("billing.circuit_breaker.enabled", true)
viper.SetDefault("billing.circuit_breaker.failure_threshold", 5)
@@ -1053,7 +1227,7 @@ func setDefaults() {
// Ops (vNext)
viper.SetDefault("ops.enabled", true)
viper.SetDefault("ops.use_preaggregated_tables", false)
viper.SetDefault("ops.use_preaggregated_tables", true)
viper.SetDefault("ops.cleanup.enabled", true)
viper.SetDefault("ops.cleanup.schedule", "0 2 * * *")
// Retention days: vNext defaults to 30 days across ops datasets.
@@ -1087,10 +1261,11 @@ func setDefaults() {
// RateLimit
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10)
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据的配置
viper.SetDefault("pricing.remote_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.json")
viper.SetDefault("pricing.hash_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.sha256")
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit避免分支漂移
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json")
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256")
viper.SetDefault("pricing.data_dir", "./data")
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
viper.SetDefault("pricing.update_interval_hours", 24)
@@ -1157,9 +1332,55 @@ func setDefaults() {
viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.force_codex_cli", false)
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
// OpenAI Responses WebSocket默认开启可通过 force_http 紧急回滚)
viper.SetDefault("gateway.openai_ws.enabled", true)
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
viper.SetDefault("gateway.openai_ws.force_http", false)
viper.SetDefault("gateway.openai_ws.allow_store_recovery", false)
viper.SetDefault("gateway.openai_ws.ingress_previous_response_recovery_enabled", true)
viper.SetDefault("gateway.openai_ws.store_disabled_conn_mode", "strict")
viper.SetDefault("gateway.openai_ws.store_disabled_force_new_conn", true)
viper.SetDefault("gateway.openai_ws.prewarm_generate_enabled", false)
viper.SetDefault("gateway.openai_ws.responses_websockets", false)
viper.SetDefault("gateway.openai_ws.responses_websockets_v2", true)
viper.SetDefault("gateway.openai_ws.max_conns_per_account", 128)
viper.SetDefault("gateway.openai_ws.min_idle_per_account", 4)
viper.SetDefault("gateway.openai_ws.max_idle_per_account", 12)
viper.SetDefault("gateway.openai_ws.dynamic_max_conns_by_account_concurrency_enabled", true)
viper.SetDefault("gateway.openai_ws.oauth_max_conns_factor", 1.0)
viper.SetDefault("gateway.openai_ws.apikey_max_conns_factor", 1.0)
viper.SetDefault("gateway.openai_ws.dial_timeout_seconds", 10)
viper.SetDefault("gateway.openai_ws.read_timeout_seconds", 900)
viper.SetDefault("gateway.openai_ws.write_timeout_seconds", 120)
viper.SetDefault("gateway.openai_ws.pool_target_utilization", 0.7)
viper.SetDefault("gateway.openai_ws.queue_limit_per_conn", 64)
viper.SetDefault("gateway.openai_ws.event_flush_batch_size", 1)
viper.SetDefault("gateway.openai_ws.event_flush_interval_ms", 10)
viper.SetDefault("gateway.openai_ws.prewarm_cooldown_ms", 300)
viper.SetDefault("gateway.openai_ws.fallback_cooldown_seconds", 30)
viper.SetDefault("gateway.openai_ws.retry_backoff_initial_ms", 120)
viper.SetDefault("gateway.openai_ws.retry_backoff_max_ms", 2000)
viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2)
viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000)
viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2)
viper.SetDefault("gateway.openai_ws.lb_top_k", 7)
viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600)
viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true)
viper.SetDefault("gateway.openai_ws.session_hash_dual_write_old", true)
viper.SetDefault("gateway.openai_ws.metadata_bridge_enabled", true)
viper.SetDefault("gateway.openai_ws.sticky_response_id_ttl_seconds", 3600)
viper.SetDefault("gateway.openai_ws.sticky_previous_response_ttl_seconds", 3600)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.priority", 1.0)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.load", 1.0)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
@@ -1215,6 +1436,14 @@ func setDefaults() {
viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30)
viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15)
// TLS指纹伪装配置默认关闭需要账号级别单独启用
// 用户消息串行队列默认值
viper.SetDefault("gateway.user_message_queue.enabled", false)
viper.SetDefault("gateway.user_message_queue.lock_ttl_ms", 120000)
viper.SetDefault("gateway.user_message_queue.wait_timeout_ms", 30000)
viper.SetDefault("gateway.user_message_queue.min_delay_ms", 200)
viper.SetDefault("gateway.user_message_queue.max_delay_ms", 2000)
viper.SetDefault("gateway.user_message_queue.cleanup_interval_seconds", 60)
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
viper.SetDefault("concurrency.ping_interval", 10)
@@ -1266,9 +1495,6 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
// Security - proxy fallback
viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false)
// Subscription Maintenance (bounded queue + worker pool)
viper.SetDefault("subscription_maintenance.worker_count", 2)
viper.SetDefault("subscription_maintenance.queue_size", 1024)
@@ -1747,6 +1973,118 @@ func (c *Config) Validate() error {
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
}
// 兼容旧键 sticky_previous_response_ttl_seconds
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
}
if c.Gateway.OpenAIWS.MaxConnsPerAccount <= 0 {
return fmt.Errorf("gateway.openai_ws.max_conns_per_account must be positive")
}
if c.Gateway.OpenAIWS.MinIdlePerAccount < 0 {
return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be non-negative")
}
if c.Gateway.OpenAIWS.MaxIdlePerAccount < 0 {
return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be non-negative")
}
if c.Gateway.OpenAIWS.MinIdlePerAccount > c.Gateway.OpenAIWS.MaxIdlePerAccount {
return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account")
}
if c.Gateway.OpenAIWS.MaxIdlePerAccount > c.Gateway.OpenAIWS.MaxConnsPerAccount {
return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account")
}
if c.Gateway.OpenAIWS.OAuthMaxConnsFactor <= 0 {
return fmt.Errorf("gateway.openai_ws.oauth_max_conns_factor must be positive")
}
if c.Gateway.OpenAIWS.APIKeyMaxConnsFactor <= 0 {
return fmt.Errorf("gateway.openai_ws.apikey_max_conns_factor must be positive")
}
if c.Gateway.OpenAIWS.DialTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.openai_ws.dial_timeout_seconds must be positive")
}
if c.Gateway.OpenAIWS.ReadTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.openai_ws.read_timeout_seconds must be positive")
}
if c.Gateway.OpenAIWS.WriteTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.openai_ws.write_timeout_seconds must be positive")
}
if c.Gateway.OpenAIWS.PoolTargetUtilization <= 0 || c.Gateway.OpenAIWS.PoolTargetUtilization > 1 {
return fmt.Errorf("gateway.openai_ws.pool_target_utilization must be within (0,1]")
}
if c.Gateway.OpenAIWS.QueueLimitPerConn <= 0 {
return fmt.Errorf("gateway.openai_ws.queue_limit_per_conn must be positive")
}
if c.Gateway.OpenAIWS.EventFlushBatchSize <= 0 {
return fmt.Errorf("gateway.openai_ws.event_flush_batch_size must be positive")
}
if c.Gateway.OpenAIWS.EventFlushIntervalMS < 0 {
return fmt.Errorf("gateway.openai_ws.event_flush_interval_ms must be non-negative")
}
if c.Gateway.OpenAIWS.PrewarmCooldownMS < 0 {
return fmt.Errorf("gateway.openai_ws.prewarm_cooldown_ms must be non-negative")
}
if c.Gateway.OpenAIWS.FallbackCooldownSeconds < 0 {
return fmt.Errorf("gateway.openai_ws.fallback_cooldown_seconds must be non-negative")
}
if c.Gateway.OpenAIWS.RetryBackoffInitialMS < 0 {
return fmt.Errorf("gateway.openai_ws.retry_backoff_initial_ms must be non-negative")
}
if c.Gateway.OpenAIWS.RetryBackoffMaxMS < 0 {
return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be non-negative")
}
if c.Gateway.OpenAIWS.RetryBackoffInitialMS > 0 && c.Gateway.OpenAIWS.RetryBackoffMaxMS > 0 &&
c.Gateway.OpenAIWS.RetryBackoffMaxMS < c.Gateway.OpenAIWS.RetryBackoffInitialMS {
return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be >= retry_backoff_initial_ms")
}
if c.Gateway.OpenAIWS.RetryJitterRatio < 0 || c.Gateway.OpenAIWS.RetryJitterRatio > 1 {
return fmt.Errorf("gateway.openai_ws.retry_jitter_ratio must be within [0,1]")
}
if c.Gateway.OpenAIWS.RetryTotalBudgetMS < 0 {
return fmt.Errorf("gateway.openai_ws.retry_total_budget_ms must be non-negative")
}
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
switch mode {
case "off", "shared", "dedicated":
default:
return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
}
}
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
switch mode {
case "strict", "adaptive", "off":
default:
return fmt.Errorf("gateway.openai_ws.store_disabled_conn_mode must be one of strict|adaptive|off")
}
}
if c.Gateway.OpenAIWS.PayloadLogSampleRate < 0 || c.Gateway.OpenAIWS.PayloadLogSampleRate > 1 {
return fmt.Errorf("gateway.openai_ws.payload_log_sample_rate must be within [0,1]")
}
if c.Gateway.OpenAIWS.LBTopK <= 0 {
return fmt.Errorf("gateway.openai_ws.lb_top_k must be positive")
}
if c.Gateway.OpenAIWS.StickySessionTTLSeconds <= 0 {
return fmt.Errorf("gateway.openai_ws.sticky_session_ttl_seconds must be positive")
}
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 {
return fmt.Errorf("gateway.openai_ws.sticky_response_id_ttl_seconds must be positive")
}
if c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds < 0 {
return fmt.Errorf("gateway.openai_ws.sticky_previous_response_ttl_seconds must be non-negative")
}
if c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 ||
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 {
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative")
}
weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority +
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load +
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue +
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate +
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT
if weightSum <= 0 {
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights must not all be zero")
}
if c.Gateway.MaxLineSize < 0 {
return fmt.Errorf("gateway.max_line_size must be non-negative")
}

View File

@@ -6,6 +6,7 @@ import (
"time"
"github.com/spf13/viper"
"github.com/stretchr/testify/require"
)
func resetViperWithJWTSecret(t *testing.T) {
@@ -75,6 +76,103 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
}
}
func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.Gateway.OpenAIWS.Enabled {
t.Fatalf("Gateway.OpenAIWS.Enabled = false, want true")
}
if !cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 {
t.Fatalf("Gateway.OpenAIWS.ResponsesWebsocketsV2 = false, want true")
}
if cfg.Gateway.OpenAIWS.ResponsesWebsockets {
t.Fatalf("Gateway.OpenAIWS.ResponsesWebsockets = true, want false")
}
if !cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled {
t.Fatalf("Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false, want true")
}
if cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor != 1.0 {
t.Fatalf("Gateway.OpenAIWS.OAuthMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor)
}
if cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor != 1.0 {
t.Fatalf("Gateway.OpenAIWS.APIKeyMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor)
}
if cfg.Gateway.OpenAIWS.StickySessionTTLSeconds != 3600 {
t.Fatalf("Gateway.OpenAIWS.StickySessionTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickySessionTTLSeconds)
}
if !cfg.Gateway.OpenAIWS.SessionHashReadOldFallback {
t.Fatalf("Gateway.OpenAIWS.SessionHashReadOldFallback = false, want true")
}
if !cfg.Gateway.OpenAIWS.SessionHashDualWriteOld {
t.Fatalf("Gateway.OpenAIWS.SessionHashDualWriteOld = false, want true")
}
if !cfg.Gateway.OpenAIWS.MetadataBridgeEnabled {
t.Fatalf("Gateway.OpenAIWS.MetadataBridgeEnabled = false, want true")
}
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 3600 {
t.Fatalf("Gateway.OpenAIWS.StickyResponseIDTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
}
if cfg.Gateway.OpenAIWS.FallbackCooldownSeconds != 30 {
t.Fatalf("Gateway.OpenAIWS.FallbackCooldownSeconds = %d, want 30", cfg.Gateway.OpenAIWS.FallbackCooldownSeconds)
}
if cfg.Gateway.OpenAIWS.EventFlushBatchSize != 1 {
t.Fatalf("Gateway.OpenAIWS.EventFlushBatchSize = %d, want 1", cfg.Gateway.OpenAIWS.EventFlushBatchSize)
}
if cfg.Gateway.OpenAIWS.EventFlushIntervalMS != 10 {
t.Fatalf("Gateway.OpenAIWS.EventFlushIntervalMS = %d, want 10", cfg.Gateway.OpenAIWS.EventFlushIntervalMS)
}
if cfg.Gateway.OpenAIWS.PrewarmCooldownMS != 300 {
t.Fatalf("Gateway.OpenAIWS.PrewarmCooldownMS = %d, want 300", cfg.Gateway.OpenAIWS.PrewarmCooldownMS)
}
if cfg.Gateway.OpenAIWS.RetryBackoffInitialMS != 120 {
t.Fatalf("Gateway.OpenAIWS.RetryBackoffInitialMS = %d, want 120", cfg.Gateway.OpenAIWS.RetryBackoffInitialMS)
}
if cfg.Gateway.OpenAIWS.RetryBackoffMaxMS != 2000 {
t.Fatalf("Gateway.OpenAIWS.RetryBackoffMaxMS = %d, want 2000", cfg.Gateway.OpenAIWS.RetryBackoffMaxMS)
}
if cfg.Gateway.OpenAIWS.RetryJitterRatio != 0.2 {
t.Fatalf("Gateway.OpenAIWS.RetryJitterRatio = %v, want 0.2", cfg.Gateway.OpenAIWS.RetryJitterRatio)
}
if cfg.Gateway.OpenAIWS.RetryTotalBudgetMS != 5000 {
t.Fatalf("Gateway.OpenAIWS.RetryTotalBudgetMS = %d, want 5000", cfg.Gateway.OpenAIWS.RetryTotalBudgetMS)
}
if cfg.Gateway.OpenAIWS.PayloadLogSampleRate != 0.2 {
t.Fatalf("Gateway.OpenAIWS.PayloadLogSampleRate = %v, want 0.2", cfg.Gateway.OpenAIWS.PayloadLogSampleRate)
}
if !cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn {
t.Fatalf("Gateway.OpenAIWS.StoreDisabledForceNewConn = false, want true")
}
if cfg.Gateway.OpenAIWS.StoreDisabledConnMode != "strict" {
t.Fatalf("Gateway.OpenAIWS.StoreDisabledConnMode = %q, want %q", cfg.Gateway.OpenAIWS.StoreDisabledConnMode, "strict")
}
if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
}
if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
}
}
func TestLoadOpenAIWSStickyTTLCompatibility(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("GATEWAY_OPENAI_WS_STICKY_RESPONSE_ID_TTL_SECONDS", "0")
t.Setenv("GATEWAY_OPENAI_WS_STICKY_PREVIOUS_RESPONSE_TTL_SECONDS", "7200")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 7200 {
t.Fatalf("StickyResponseIDTTLSeconds = %d, want 7200", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
}
}
func TestLoadDefaultIdempotencyConfig(t *testing.T) {
resetViperWithJWTSecret(t)
@@ -993,6 +1091,16 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 },
wantErr: "gateway.stream_keepalive_interval",
},
{
name: "gateway openai ws oauth max conns factor",
mutate: func(c *Config) { c.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0 },
wantErr: "gateway.openai_ws.oauth_max_conns_factor",
},
{
name: "gateway openai ws apikey max conns factor",
mutate: func(c *Config) { c.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0 },
wantErr: "gateway.openai_ws.apikey_max_conns_factor",
},
{
name: "gateway stream data interval range",
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },
@@ -1174,6 +1282,165 @@ func TestValidateConfigErrors(t *testing.T) {
}
}
func TestValidateConfig_OpenAIWSRules(t *testing.T) {
buildValid := func(t *testing.T) *Config {
t.Helper()
resetViperWithJWTSecret(t)
cfg, err := Load()
require.NoError(t, err)
return cfg
}
t.Run("sticky response id ttl 兼容旧键回填", func(t *testing.T) {
cfg := buildValid(t)
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0
cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 7200
require.NoError(t, cfg.Validate())
require.Equal(t, 7200, cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
})
cases := []struct {
name string
mutate func(*Config)
wantErr string
}{
{
name: "max_conns_per_account 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxConnsPerAccount = 0 },
wantErr: "gateway.openai_ws.max_conns_per_account",
},
{
name: "min_idle_per_account 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.MinIdlePerAccount = -1 },
wantErr: "gateway.openai_ws.min_idle_per_account",
},
{
name: "max_idle_per_account 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxIdlePerAccount = -1 },
wantErr: "gateway.openai_ws.max_idle_per_account",
},
{
name: "min_idle_per_account 不能大于 max_idle_per_account",
mutate: func(c *Config) {
c.Gateway.OpenAIWS.MinIdlePerAccount = 3
c.Gateway.OpenAIWS.MaxIdlePerAccount = 2
},
wantErr: "gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account",
},
{
name: "max_idle_per_account 不能大于 max_conns_per_account",
mutate: func(c *Config) {
c.Gateway.OpenAIWS.MaxConnsPerAccount = 2
c.Gateway.OpenAIWS.MinIdlePerAccount = 1
c.Gateway.OpenAIWS.MaxIdlePerAccount = 3
},
wantErr: "gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account",
},
{
name: "dial_timeout_seconds 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.DialTimeoutSeconds = 0 },
wantErr: "gateway.openai_ws.dial_timeout_seconds",
},
{
name: "read_timeout_seconds 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.ReadTimeoutSeconds = 0 },
wantErr: "gateway.openai_ws.read_timeout_seconds",
},
{
name: "write_timeout_seconds 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.WriteTimeoutSeconds = 0 },
wantErr: "gateway.openai_ws.write_timeout_seconds",
},
{
name: "pool_target_utilization 必须在 (0,1]",
mutate: func(c *Config) { c.Gateway.OpenAIWS.PoolTargetUtilization = 0 },
wantErr: "gateway.openai_ws.pool_target_utilization",
},
{
name: "queue_limit_per_conn 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.QueueLimitPerConn = 0 },
wantErr: "gateway.openai_ws.queue_limit_per_conn",
},
{
name: "fallback_cooldown_seconds 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.FallbackCooldownSeconds = -1 },
wantErr: "gateway.openai_ws.fallback_cooldown_seconds",
},
{
name: "store_disabled_conn_mode 必须为 strict|adaptive|off",
mutate: func(c *Config) { c.Gateway.OpenAIWS.StoreDisabledConnMode = "invalid" },
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
},
{
name: "ingress_mode_default 必须为 off|shared|dedicated",
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
wantErr: "gateway.openai_ws.ingress_mode_default",
},
{
name: "payload_log_sample_rate 必须在 [0,1] 范围内",
mutate: func(c *Config) { c.Gateway.OpenAIWS.PayloadLogSampleRate = 1.2 },
wantErr: "gateway.openai_ws.payload_log_sample_rate",
},
{
name: "retry_total_budget_ms 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.RetryTotalBudgetMS = -1 },
wantErr: "gateway.openai_ws.retry_total_budget_ms",
},
{
name: "lb_top_k 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.LBTopK = 0 },
wantErr: "gateway.openai_ws.lb_top_k",
},
{
name: "sticky_session_ttl_seconds 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.StickySessionTTLSeconds = 0 },
wantErr: "gateway.openai_ws.sticky_session_ttl_seconds",
},
{
name: "sticky_response_id_ttl_seconds 必须为正数",
mutate: func(c *Config) {
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0
c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 0
},
wantErr: "gateway.openai_ws.sticky_response_id_ttl_seconds",
},
{
name: "sticky_previous_response_ttl_seconds 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = -1 },
wantErr: "gateway.openai_ws.sticky_previous_response_ttl_seconds",
},
{
name: "scheduler_score_weights 不能为负数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = -0.1 },
wantErr: "gateway.openai_ws.scheduler_score_weights.* must be non-negative",
},
{
name: "scheduler_score_weights 不能全为 0",
mutate: func(c *Config) {
c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0
c.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0
c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0
c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0
c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0
},
wantErr: "gateway.openai_ws.scheduler_score_weights must not all be zero",
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
cfg := buildValid(t)
tc.mutate(cfg)
err := cfg.Validate()
require.Error(t, err)
require.Contains(t, err.Error(), tc.wantErr)
})
}
}
func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()

View File

@@ -104,6 +104,9 @@ var DefaultAntigravityModelMapping = map[string]string{
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
// Gemini 3.1 image preview 映射
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
// Gemini 3 image 兼容映射(向 3.1 image 迁移)
"gemini-3-pro-image": "gemini-3.1-flash-image",
"gemini-3-pro-image-preview": "gemini-3.1-flash-image",
// 其他官方模型
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview",

View File

@@ -0,0 +1,24 @@
package domain
import "testing"
func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) {
t.Parallel()
cases := map[string]string{
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
"gemini-3-pro-image": "gemini-3.1-flash-image",
"gemini-3-pro-image-preview": "gemini-3.1-flash-image",
}
for from, want := range cases {
got, ok := DefaultAntigravityModelMapping[from]
if !ok {
t.Fatalf("expected mapping for %q to exist", from)
}
if got != want {
t.Fatalf("unexpected mapping for %q: got %q want %q", from, got, want)
}
}
}

View File

@@ -64,6 +64,7 @@ func setupAccountDataRouter() (*gin.Engine, *stubAdminService) {
nil,
nil,
nil,
nil,
)
router.GET("/api/v1/admin/accounts/data", h.ExportData)

View File

@@ -53,6 +53,7 @@ type AccountHandler struct {
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
sessionLimitCache service.SessionLimitCache
rpmCache service.RPMCache
tokenCacheInvalidator service.TokenCacheInvalidator
}
@@ -69,6 +70,7 @@ func NewAccountHandler(
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
sessionLimitCache service.SessionLimitCache,
rpmCache service.RPMCache,
tokenCacheInvalidator service.TokenCacheInvalidator,
) *AccountHandler {
return &AccountHandler{
@@ -83,6 +85,7 @@ func NewAccountHandler(
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
tokenCacheInvalidator: tokenCacheInvalidator,
}
}
@@ -154,6 +157,7 @@ type AccountWithConcurrency struct {
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
}
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
@@ -189,6 +193,12 @@ func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, ac
}
}
}
if h.rpmCache != nil && account.GetBaseRPM() > 0 {
if rpm, err := h.rpmCache.GetRPM(ctx, account.ID); err == nil {
item.CurrentRPM = &rpm
}
}
}
return item
@@ -207,6 +217,7 @@ func (h *AccountHandler) List(c *gin.Context) {
if len(search) > 100 {
search = search[:100]
}
lite := parseBoolQueryWithDefault(c.Query("lite"), false)
var groupID int64
if groupIDStr := c.Query("group"); groupIDStr != "" {
@@ -225,67 +236,81 @@ func (h *AccountHandler) List(c *gin.Context) {
accountIDs[i] = acc.ID
}
concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs)
if err != nil {
// Log error but don't fail the request, just use 0 for all
concurrencyCounts = make(map[int64]int)
}
// 识别需要查询窗口费用和会话数的账号Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
for i := range accounts {
acc := &accounts[i]
if acc.IsAnthropicOAuthOrSetupToken() {
if acc.GetWindowCostLimit() > 0 {
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
}
if acc.GetMaxSessions() > 0 {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
}
}
}
// 并行获取窗口费用和活跃会话数
concurrencyCounts := make(map[int64]int)
var windowCosts map[int64]float64
var activeSessions map[int64]int
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
if activeSessions == nil {
activeSessions = make(map[int64]int)
var rpmCounts map[int64]int
if !lite {
// Get current concurrency counts for all accounts
if h.concurrencyService != nil {
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
concurrencyCounts = cc
}
}
}
// 获取窗口费用(并行查询)
if len(windowCostAccountIDs) > 0 {
windowCosts = make(map[int64]float64)
var mu sync.Mutex
g, gctx := errgroup.WithContext(c.Request.Context())
g.SetLimit(10) // 限制并发数
// 识别需要查询窗口费用、会话数和 RPM 的账号Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
rpmAccountIDs := make([]int64, 0)
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
for i := range accounts {
acc := &accounts[i]
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
continue
}
accCopy := acc // 闭包捕获
g.Go(func() error {
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime := accCopy.GetCurrentWindowStartTime()
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
if err == nil && stats != nil {
mu.Lock()
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
mu.Unlock()
if acc.IsAnthropicOAuthOrSetupToken() {
if acc.GetWindowCostLimit() > 0 {
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
}
return nil // 不返回错误,允许部分失败
})
if acc.GetMaxSessions() > 0 {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
}
if acc.GetBaseRPM() > 0 {
rpmAccountIDs = append(rpmAccountIDs, acc.ID)
}
}
}
// 获取 RPM 计数(批量查询)
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
if rpmCounts == nil {
rpmCounts = make(map[int64]int)
}
}
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
if activeSessions == nil {
activeSessions = make(map[int64]int)
}
}
// 获取窗口费用(并行查询)
if len(windowCostAccountIDs) > 0 {
windowCosts = make(map[int64]float64)
var mu sync.Mutex
g, gctx := errgroup.WithContext(c.Request.Context())
g.SetLimit(10) // 限制并发数
for i := range accounts {
acc := &accounts[i]
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
continue
}
accCopy := acc // 闭包捕获
g.Go(func() error {
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime := accCopy.GetCurrentWindowStartTime()
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
if err == nil && stats != nil {
mu.Lock()
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
mu.Unlock()
}
return nil // 不返回错误,允许部分失败
})
}
_ = g.Wait()
}
_ = g.Wait()
}
// Build response with concurrency info
@@ -311,10 +336,17 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
// 添加 RPM 计数(仅当启用时)
if rpmCounts != nil {
if rpm, ok := rpmCounts[acc.ID]; ok {
item.CurrentRPM = &rpm
}
}
result[i] = item
}
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search)
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search, lite)
if etag != "" {
c.Header("ETag", etag)
c.Header("Vary", "If-None-Match")
@@ -332,6 +364,7 @@ func buildAccountsListETag(
total int64,
page, pageSize int,
platform, accountType, status, search string,
lite bool,
) string {
payload := struct {
Total int64 `json:"total"`
@@ -341,6 +374,7 @@ func buildAccountsListETag(
AccountType string `json:"type"`
Status string `json:"status"`
Search string `json:"search"`
Lite bool `json:"lite"`
Items []AccountWithConcurrency `json:"items"`
}{
Total: total,
@@ -350,6 +384,7 @@ func buildAccountsListETag(
AccountType: accountType,
Status: status,
Search: search,
Lite: lite,
Items: items,
}
raw, err := json.Marshal(payload)
@@ -453,6 +488,8 @@ func (h *AccountHandler) Create(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -522,6 +559,8 @@ func (h *AccountHandler) Update(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -904,6 +943,9 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
continue
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(item.Extra)
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
@@ -1048,6 +1090,8 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -1351,6 +1395,57 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
response.Success(c, stats)
}
// BatchTodayStatsRequest 批量今日统计请求体。
type BatchTodayStatsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required"`
}
// GetBatchTodayStats 批量获取多个账号的今日统计。
// POST /api/v1/admin/accounts/today-stats/batch
func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) {
var req BatchTodayStatsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
accountIDs := normalizeInt64IDList(req.AccountIDs)
if len(accountIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
cacheKey := buildAccountTodayStatsBatchCacheKey(accountIDs)
if cached, ok := accountTodayStatsBatchCache.Get(cacheKey); ok {
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), accountIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
payload := gin.H{"stats": stats}
cached := accountTodayStatsBatchCache.Set(cacheKey, payload)
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
}
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
// SetSchedulableRequest represents the request body for setting schedulable status
type SetSchedulableRequest struct {
Schedulable bool `json:"schedulable"`
@@ -1692,3 +1787,22 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) {
response.Success(c, domain.DefaultAntigravityModelMapping)
}
// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。
// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。
func sanitizeExtraBaseRPM(extra map[string]any) {
if extra == nil {
return
}
raw, ok := extra["base_rpm"]
if !ok {
return
}
v := service.ParseExtraInt(raw)
if v < 0 {
v = 0
} else if v > 10000 {
v = 10000
}
extra["base_rpm"] = v
}

View File

@@ -15,7 +15,7 @@ import (
func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
router.POST("/api/v1/admin/accounts", accountHandler.Create)
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)

View File

@@ -28,6 +28,7 @@ func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testi
nil,
nil,
nil,
nil,
)
router := gin.New()

View File

@@ -0,0 +1,25 @@
package admin
import (
"strconv"
"strings"
"time"
)
var accountTodayStatsBatchCache = newSnapshotCache(30 * time.Second)
func buildAccountTodayStatsBatchCacheKey(accountIDs []int64) string {
if len(accountIDs) == 0 {
return "accounts_today_stats_empty"
}
var b strings.Builder
b.Grow(len(accountIDs) * 6)
_, _ = b.WriteString("accounts_today_stats:")
for i, id := range accountIDs {
if i > 0 {
_ = b.WriteByte(',')
}
_, _ = b.WriteString(strconv.FormatInt(id, 10))
}
return b.String()
}

View File

@@ -407,5 +407,23 @@ func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates []
return nil
}
func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
for i := range s.apiKeys {
if s.apiKeys[i].ID == keyID {
k := s.apiKeys[i]
if groupID != nil {
if *groupID == 0 {
k.GroupID = nil
} else {
gid := *groupID
k.GroupID = &gid
}
}
return &service.AdminUpdateAPIKeyGroupIDResult{APIKey: &k}, nil
}
}
return nil, service.ErrAPIKeyNotFound
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -0,0 +1,63 @@
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AdminAPIKeyHandler handles admin API key management
type AdminAPIKeyHandler struct {
adminService service.AdminService
}
// NewAdminAPIKeyHandler creates a new admin API key handler
func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandler {
return &AdminAPIKeyHandler{
adminService: adminService,
}
}
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
type AdminUpdateAPIKeyGroupRequest struct {
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
}
// UpdateGroup handles updating an API key's group binding
// PUT /api/v1/admin/api-keys/:id
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid API key ID")
return
}
var req AdminUpdateAPIKeyGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
resp := struct {
APIKey *dto.APIKey `json:"api_key"`
AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
GrantedGroupID *int64 `json:"granted_group_id,omitempty"`
GrantedGroupName string `json:"granted_group_name,omitempty"`
}{
APIKey: dto.APIKeyFromService(result.APIKey),
AutoGrantedGroupAccess: result.AutoGrantedGroupAccess,
GrantedGroupID: result.GrantedGroupID,
GrantedGroupName: result.GrantedGroupName,
}
response.Success(c, resp)
}

View File

@@ -0,0 +1,202 @@
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAPIKeyHandler(adminSvc service.AdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
h := NewAdminAPIKeyHandler(adminSvc)
router.PUT("/api/v1/admin/api-keys/:id", h.UpdateGroup)
return router
}
func TestAdminAPIKeyHandler_UpdateGroup_InvalidID(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/abc", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "Invalid API key ID")
}
func TestAdminAPIKeyHandler_UpdateGroup_InvalidJSON(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{bad json`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "Invalid request")
}
func TestAdminAPIKeyHandler_UpdateGroup_KeyNotFound(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/999", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
// ErrAPIKeyNotFound maps to 404
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestAdminAPIKeyHandler_UpdateGroup_BindGroup(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Code int `json:"code"`
Data json.RawMessage `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
var data struct {
APIKey struct {
ID int64 `json:"id"`
GroupID *int64 `json:"group_id"`
} `json:"api_key"`
AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
}
require.NoError(t, json.Unmarshal(resp.Data, &data))
require.Equal(t, int64(10), data.APIKey.ID)
require.NotNil(t, data.APIKey.GroupID)
require.Equal(t, int64(2), *data.APIKey.GroupID)
}
func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
svc := newStubAdminService()
gid := int64(2)
svc.apiKeys[0].GroupID = &gid
router := setupAPIKeyHandler(svc)
body := `{"group_id": 0}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data struct {
APIKey struct {
GroupID *int64 `json:"group_id"`
} `json:"api_key"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Nil(t, resp.Data.APIKey.GroupID)
}
func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: errors.New("internal failure"),
}
router := setupAPIKeyHandler(svc)
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// H2: empty body → group_id is nil → no-op, returns original key
func TestAdminAPIKeyHandler_UpdateGroup_EmptyBody_NoChange(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Code int `json:"code"`
Data struct {
APIKey struct {
ID int64 `json:"id"`
} `json:"api_key"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, int64(10), resp.Data.APIKey.ID)
}
// M2: service returns GROUP_NOT_ACTIVE → handler maps to 400
func TestAdminAPIKeyHandler_UpdateGroup_GroupNotActive(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active"),
}
router := setupAPIKeyHandler(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": 5}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "GROUP_NOT_ACTIVE")
}
// M2: service returns INVALID_GROUP_ID → handler maps to 400
func TestAdminAPIKeyHandler_UpdateGroup_NegativeGroupID(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative"),
}
router := setupAPIKeyHandler(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": -5}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "INVALID_GROUP_ID")
}
// failingUpdateGroupService overrides AdminUpdateAPIKeyGroupID to return an error.
type failingUpdateGroupService struct {
*stubAdminService
err error
}
func (f *failingUpdateGroupService) AdminUpdateAPIKeyGroupID(_ context.Context, _ int64, _ *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
return nil, f.err
}

View File

@@ -36,7 +36,7 @@ func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
return router, handler
}

View File

@@ -1,8 +1,10 @@
package admin
import (
"encoding/json"
"errors"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -186,7 +188,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
@@ -194,6 +196,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var model string
var requestType *int16
var stream *bool
var billingType *int8
@@ -220,9 +223,20 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
if modelStr := c.Query("model"); modelStr != "" {
model = modelStr
}
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
@@ -235,7 +249,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
}
}
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
@@ -251,12 +265,13 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var requestType *int16
var stream *bool
var billingType *int8
@@ -280,9 +295,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
groupID = id
}
}
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
@@ -295,7 +321,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
}
}
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
@@ -310,11 +336,12 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
// GetGroupStats handles getting group usage statistics
// GET /api/v1/admin/dashboard/groups
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
var userID, apiKeyID, accountID, groupID int64
var requestType *int16
var stream *bool
var billingType *int8
@@ -338,9 +365,20 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
groupID = id
}
}
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
@@ -353,7 +391,7 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
}
}
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get group statistics")
return
@@ -423,6 +461,9 @@ type BatchUsersUsageRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
}
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
// GetBatchUsersUsage handles getting usage stats for multiple users
// POST /api/v1/admin/dashboard/users-usage
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
@@ -432,18 +473,34 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return
}
if len(req.UserIDs) == 0 {
userIDs := normalizeInt64IDList(req.UserIDs)
if len(userIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
keyRaw, _ := json.Marshal(struct {
UserIDs []int64 `json:"user_ids"`
}{
UserIDs: userIDs,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardBatchUsersUsageCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), userIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get user usage stats")
return
}
response.Success(c, gin.H{"stats": stats})
payload := gin.H{"stats": stats}
dashboardBatchUsersUsageCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
@@ -460,16 +517,32 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
return
}
if len(req.APIKeyIDs) == 0 {
apiKeyIDs := normalizeInt64IDList(req.APIKeyIDs)
if len(apiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
keyRaw, _ := json.Marshal(struct {
APIKeyIDs []int64 `json:"api_key_ids"`
}{
APIKeyIDs: apiKeyIDs,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardBatchAPIKeysUsageCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), apiKeyIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return
}
response.Success(c, gin.H{"stats": stats})
payload := gin.H{"stats": stats}
dashboardBatchAPIKeysUsageCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}

View File

@@ -0,0 +1,132 @@
package admin
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type dashboardUsageRepoCapture struct {
service.UsageLogRepository
trendRequestType *int16
trendStream *bool
modelRequestType *int16
modelStream *bool
}
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
userID, apiKeyID, accountID, groupID int64,
model string,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.TrendDataPoint, error) {
s.trendRequestType = requestType
s.trendStream = stream
return []usagestats.TrendDataPoint{}, nil
}
func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
ctx context.Context,
startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.ModelStat, error) {
s.modelRequestType = requestType
s.modelStream = stream
return []usagestats.ModelStat{}, nil
}
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
handler := NewDashboardHandler(dashboardSvc, nil)
router := gin.New()
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
router.GET("/admin/dashboard/models", handler.GetModelStats)
return router
}
func TestDashboardTrendRequestTypePriority(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=ws_v2&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.trendRequestType)
require.Equal(t, int16(service.RequestTypeWSV2), *repo.trendRequestType)
require.Nil(t, repo.trendStream)
}
func TestDashboardTrendInvalidRequestType(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardTrendInvalidStream(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsRequestTypePriority(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=sync&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.modelRequestType)
require.Equal(t, int16(service.RequestTypeSync), *repo.modelRequestType)
require.Nil(t, repo.modelStream)
}
func TestDashboardModelStatsInvalidRequestType(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsInvalidStream(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}

View File

@@ -0,0 +1,292 @@
package admin
import (
"encoding/json"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
var dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
type dashboardSnapshotV2Stats struct {
usagestats.DashboardStats
Uptime int64 `json:"uptime"`
}
type dashboardSnapshotV2Response struct {
GeneratedAt string `json:"generated_at"`
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
Granularity string `json:"granularity"`
Stats *dashboardSnapshotV2Stats `json:"stats,omitempty"`
Trend []usagestats.TrendDataPoint `json:"trend,omitempty"`
Models []usagestats.ModelStat `json:"models,omitempty"`
Groups []usagestats.GroupStat `json:"groups,omitempty"`
UsersTrend []usagestats.UserUsageTrendPoint `json:"users_trend,omitempty"`
}
type dashboardSnapshotV2Filters struct {
UserID int64
APIKeyID int64
AccountID int64
GroupID int64
Model string
RequestType *int16
Stream *bool
BillingType *int8
}
type dashboardSnapshotV2CacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Granularity string `json:"granularity"`
UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
Model string `json:"model"`
RequestType *int16 `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
IncludeStats bool `json:"include_stats"`
IncludeTrend bool `json:"include_trend"`
IncludeModels bool `json:"include_models"`
IncludeGroups bool `json:"include_groups"`
IncludeUsersTrend bool `json:"include_users_trend"`
UsersTrendLimit int `json:"users_trend_limit"`
}
func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := strings.TrimSpace(c.DefaultQuery("granularity", "day"))
if granularity != "hour" {
granularity = "day"
}
includeStats := parseBoolQueryWithDefault(c.Query("include_stats"), true)
includeTrend := parseBoolQueryWithDefault(c.Query("include_trend"), true)
includeModels := parseBoolQueryWithDefault(c.Query("include_model_stats"), true)
includeGroups := parseBoolQueryWithDefault(c.Query("include_group_stats"), false)
includeUsersTrend := parseBoolQueryWithDefault(c.Query("include_users_trend"), false)
usersTrendLimit := 12
if raw := strings.TrimSpace(c.Query("users_trend_limit")); raw != "" {
if parsed, err := strconv.Atoi(raw); err == nil && parsed > 0 && parsed <= 50 {
usersTrendLimit = parsed
}
}
filters, err := parseDashboardSnapshotV2Filters(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
keyRaw, _ := json.Marshal(dashboardSnapshotV2CacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Granularity: granularity,
UserID: filters.UserID,
APIKeyID: filters.APIKeyID,
AccountID: filters.AccountID,
GroupID: filters.GroupID,
Model: filters.Model,
RequestType: filters.RequestType,
Stream: filters.Stream,
BillingType: filters.BillingType,
IncludeStats: includeStats,
IncludeTrend: includeTrend,
IncludeModels: includeModels,
IncludeGroups: includeGroups,
IncludeUsersTrend: includeUsersTrend,
UsersTrendLimit: usersTrendLimit,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
resp := &dashboardSnapshotV2Response{
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
StartDate: startTime.Format("2006-01-02"),
EndDate: endTime.Add(-24 * time.Hour).Format("2006-01-02"),
Granularity: granularity,
}
if includeStats {
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
if err != nil {
response.Error(c, 500, "Failed to get dashboard statistics")
return
}
resp.Stats = &dashboardSnapshotV2Stats{
DashboardStats: *stats,
Uptime: int64(time.Since(h.startTime).Seconds()),
}
}
if includeTrend {
trend, err := h.dashboardService.GetUsageTrendWithFilters(
c.Request.Context(),
startTime,
endTime,
granularity,
filters.UserID,
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
filters.Model,
filters.RequestType,
filters.Stream,
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
}
resp.Trend = trend
}
if includeModels {
models, err := h.dashboardService.GetModelStatsWithFilters(
c.Request.Context(),
startTime,
endTime,
filters.UserID,
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
filters.RequestType,
filters.Stream,
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
}
resp.Models = models
}
if includeGroups {
groups, err := h.dashboardService.GetGroupStatsWithFilters(
c.Request.Context(),
startTime,
endTime,
filters.UserID,
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
filters.RequestType,
filters.Stream,
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get group statistics")
return
}
resp.Groups = groups
}
if includeUsersTrend {
usersTrend, err := h.dashboardService.GetUserUsageTrend(
c.Request.Context(),
startTime,
endTime,
granularity,
usersTrendLimit,
)
if err != nil {
response.Error(c, 500, "Failed to get user usage trend")
return
}
resp.UsersTrend = usersTrend
}
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
}
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, resp)
}
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
filters := &dashboardSnapshotV2Filters{
Model: strings.TrimSpace(c.Query("model")),
}
if userIDStr := strings.TrimSpace(c.Query("user_id")); userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.UserID = id
}
if apiKeyIDStr := strings.TrimSpace(c.Query("api_key_id")); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.APIKeyID = id
}
if accountIDStr := strings.TrimSpace(c.Query("account_id")); accountIDStr != "" {
id, err := strconv.ParseInt(accountIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.AccountID = id
}
if groupIDStr := strings.TrimSpace(c.Query("group_id")); groupIDStr != "" {
id, err := strconv.ParseInt(groupIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.GroupID = id
}
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
return nil, err
}
value := int16(parsed)
filters.RequestType = &value
} else if streamStr := strings.TrimSpace(c.Query("stream")); streamStr != "" {
streamVal, err := strconv.ParseBool(streamStr)
if err != nil {
return nil, err
}
filters.Stream = &streamVal
}
if billingTypeStr := strings.TrimSpace(c.Query("billing_type")); billingTypeStr != "" {
v, err := strconv.ParseInt(billingTypeStr, 10, 8)
if err != nil {
return nil, err
}
bt := int8(v)
filters.BillingType = &bt
}
return filters, nil
}

View File

@@ -0,0 +1,545 @@
package admin
import (
"context"
"strconv"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type DataManagementHandler struct {
dataManagementService dataManagementService
}
func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler {
return &DataManagementHandler{dataManagementService: dataManagementService}
}
type dataManagementService interface {
GetConfig(ctx context.Context) (service.DataManagementConfig, error)
UpdateConfig(ctx context.Context, cfg service.DataManagementConfig) (service.DataManagementConfig, error)
ValidateS3(ctx context.Context, cfg service.DataManagementS3Config) (service.DataManagementTestS3Result, error)
CreateBackupJob(ctx context.Context, input service.DataManagementCreateBackupJobInput) (service.DataManagementBackupJob, error)
ListSourceProfiles(ctx context.Context, sourceType string) ([]service.DataManagementSourceProfile, error)
CreateSourceProfile(ctx context.Context, input service.DataManagementCreateSourceProfileInput) (service.DataManagementSourceProfile, error)
UpdateSourceProfile(ctx context.Context, input service.DataManagementUpdateSourceProfileInput) (service.DataManagementSourceProfile, error)
DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error
SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (service.DataManagementSourceProfile, error)
ListS3Profiles(ctx context.Context) ([]service.DataManagementS3Profile, error)
CreateS3Profile(ctx context.Context, input service.DataManagementCreateS3ProfileInput) (service.DataManagementS3Profile, error)
UpdateS3Profile(ctx context.Context, input service.DataManagementUpdateS3ProfileInput) (service.DataManagementS3Profile, error)
DeleteS3Profile(ctx context.Context, profileID string) error
SetActiveS3Profile(ctx context.Context, profileID string) (service.DataManagementS3Profile, error)
ListBackupJobs(ctx context.Context, input service.DataManagementListBackupJobsInput) (service.DataManagementListBackupJobsResult, error)
GetBackupJob(ctx context.Context, jobID string) (service.DataManagementBackupJob, error)
EnsureAgentEnabled(ctx context.Context) error
GetAgentHealth(ctx context.Context) service.DataManagementAgentHealth
}
type TestS3ConnectionRequest struct {
Endpoint string `json:"endpoint"`
Region string `json:"region" binding:"required"`
Bucket string `json:"bucket" binding:"required"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
type CreateBackupJobRequest struct {
BackupType string `json:"backup_type" binding:"required,oneof=postgres redis full"`
UploadToS3 bool `json:"upload_to_s3"`
S3ProfileID string `json:"s3_profile_id"`
PostgresID string `json:"postgres_profile_id"`
RedisID string `json:"redis_profile_id"`
IdempotencyKey string `json:"idempotency_key"`
}
type CreateSourceProfileRequest struct {
ProfileID string `json:"profile_id" binding:"required"`
Name string `json:"name" binding:"required"`
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
SetActive bool `json:"set_active"`
}
type UpdateSourceProfileRequest struct {
Name string `json:"name" binding:"required"`
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
}
type CreateS3ProfileRequest struct {
ProfileID string `json:"profile_id" binding:"required"`
Name string `json:"name" binding:"required"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
SetActive bool `json:"set_active"`
}
type UpdateS3ProfileRequest struct {
Name string `json:"name" binding:"required"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
func (h *DataManagementHandler) GetAgentHealth(c *gin.Context) {
health := h.getAgentHealth(c)
payload := gin.H{
"enabled": health.Enabled,
"reason": health.Reason,
"socket_path": health.SocketPath,
}
if health.Agent != nil {
payload["agent"] = gin.H{
"status": health.Agent.Status,
"version": health.Agent.Version,
"uptime_seconds": health.Agent.UptimeSeconds,
}
}
response.Success(c, payload)
}
func (h *DataManagementHandler) GetConfig(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
cfg, err := h.dataManagementService.GetConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *DataManagementHandler) UpdateConfig(c *gin.Context) {
var req service.DataManagementConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *DataManagementHandler) TestS3(c *gin.Context) {
var req TestS3ConnectionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{
Enabled: true,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"ok": result.OK, "message": result.Message})
}
func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) {
var req CreateBackupJobRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
req.IdempotencyKey = normalizeBackupIdempotencyKey(c.GetHeader("X-Idempotency-Key"), req.IdempotencyKey)
if !h.requireAgentEnabled(c) {
return
}
triggeredBy := "admin:unknown"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{
BackupType: req.BackupType,
UploadToS3: req.UploadToS3,
S3ProfileID: req.S3ProfileID,
PostgresID: req.PostgresID,
RedisID: req.RedisID,
TriggeredBy: triggeredBy,
IdempotencyKey: req.IdempotencyKey,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status})
}
func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType == "" {
response.BadRequest(c, "Invalid source_type")
return
}
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
if !h.requireAgentEnabled(c) {
return
}
items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"items": items})
}
func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
var req CreateSourceProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{
SourceType: sourceType,
ProfileID: req.ProfileID,
Name: req.Name,
Config: req.Config,
SetActive: req.SetActive,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
var req UpdateSourceProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{
SourceType: sourceType,
ProfileID: profileID,
Name: req.Name,
Config: req.Config,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
items, err := h.dataManagementService.ListS3Profiles(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"items": items})
}
func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) {
var req CreateS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{
ProfileID: req.ProfileID,
Name: req.Name,
SetActive: req.SetActive,
S3: service.DataManagementS3Config{
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
},
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) {
var req UpdateS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{
ProfileID: profileID,
Name: req.Name,
S3: service.DataManagementS3Config{
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
},
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
pageSize := int32(20)
if raw := strings.TrimSpace(c.Query("page_size")); raw != "" {
v, err := strconv.Atoi(raw)
if err != nil || v <= 0 {
response.BadRequest(c, "Invalid page_size")
return
}
pageSize = int32(v)
}
result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{
PageSize: pageSize,
PageToken: c.Query("page_token"),
Status: c.Query("status"),
BackupType: c.Query("backup_type"),
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
func (h *DataManagementHandler) GetBackupJob(c *gin.Context) {
jobID := strings.TrimSpace(c.Param("job_id"))
if jobID == "" {
response.BadRequest(c, "Invalid backup job ID")
return
}
if !h.requireAgentEnabled(c) {
return
}
job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, job)
}
func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool {
if h.dataManagementService == nil {
err := infraerrors.ServiceUnavailable(
service.DataManagementAgentUnavailableReason,
"data management agent service is not configured",
).WithMetadata(map[string]string{"socket_path": service.DefaultDataManagementAgentSocketPath})
response.ErrorFrom(c, err)
return false
}
if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return false
}
return true
}
func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth {
if h.dataManagementService == nil {
return service.DataManagementAgentHealth{
Enabled: false,
Reason: service.DataManagementAgentUnavailableReason,
SocketPath: service.DefaultDataManagementAgentSocketPath,
}
}
return h.dataManagementService.GetAgentHealth(c.Request.Context())
}
func normalizeBackupIdempotencyKey(headerValue, bodyValue string) string {
headerKey := strings.TrimSpace(headerValue)
if headerKey != "" {
return headerKey
}
return strings.TrimSpace(bodyValue)
}

View File

@@ -0,0 +1,78 @@
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type apiEnvelope struct {
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason"`
Data json.RawMessage `json:"data"`
}
func TestDataManagementHandler_AgentHealthAlways200(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
h := NewDataManagementHandler(svc)
r := gin.New()
r.GET("/api/v1/admin/data-management/agent/health", h.GetAgentHealth)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/agent/health", nil)
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var envelope apiEnvelope
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
require.Equal(t, 0, envelope.Code)
var data struct {
Enabled bool `json:"enabled"`
Reason string `json:"reason"`
SocketPath string `json:"socket_path"`
}
require.NoError(t, json.Unmarshal(envelope.Data, &data))
require.False(t, data.Enabled)
require.Equal(t, service.DataManagementDeprecatedReason, data.Reason)
require.Equal(t, svc.SocketPath(), data.SocketPath)
}
func TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
h := NewDataManagementHandler(svc)
r := gin.New()
r.GET("/api/v1/admin/data-management/config", h.GetConfig)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/config", nil)
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
var envelope apiEnvelope
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
require.Equal(t, http.StatusServiceUnavailable, envelope.Code)
require.Equal(t, service.DataManagementDeprecatedReason, envelope.Reason)
}
func TestNormalizeBackupIdempotencyKey(t *testing.T) {
require.Equal(t, "from-header", normalizeBackupIdempotencyKey("from-header", "from-body"))
require.Equal(t, "from-body", normalizeBackupIdempotencyKey(" ", " from-body "))
require.Equal(t, "", normalizeBackupIdempotencyKey("", ""))
}

View File

@@ -52,6 +52,8 @@ type CreateGroupRequest struct {
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -86,6 +88,8 @@ type UpdateGroupRequest struct {
SimulateClaudeMaxEnabled *bool `json:"simulate_claude_max_enabled"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -201,6 +205,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
MCPXMLInject: req.MCPXMLInject,
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
@@ -252,6 +257,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
MCPXMLInject: req.MCPXMLInject,
SimulateClaudeMaxEnabled: req.SimulateClaudeMaxEnabled,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {

View File

@@ -0,0 +1,25 @@
package admin
import "sort"
func normalizeInt64IDList(ids []int64) []int64 {
if len(ids) == 0 {
return nil
}
out := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
sort.Slice(out, func(i, j int) bool { return out[i] < out[j] })
return out
}

View File

@@ -0,0 +1,57 @@
//go:build unit
package admin
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNormalizeInt64IDList(t *testing.T) {
tests := []struct {
name string
in []int64
want []int64
}{
{"nil input", nil, nil},
{"empty input", []int64{}, nil},
{"single element", []int64{5}, []int64{5}},
{"already sorted unique", []int64{1, 2, 3}, []int64{1, 2, 3}},
{"duplicates removed", []int64{3, 1, 3, 2, 1}, []int64{1, 2, 3}},
{"zero filtered", []int64{0, 1, 2}, []int64{1, 2}},
{"negative filtered", []int64{-5, -1, 3}, []int64{3}},
{"all invalid", []int64{0, -1, -2}, []int64{}},
{"sorted output", []int64{9, 3, 7, 1}, []int64{1, 3, 7, 9}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := normalizeInt64IDList(tc.in)
if tc.want == nil {
require.Nil(t, got)
} else {
require.Equal(t, tc.want, got)
}
})
}
}
func TestBuildAccountTodayStatsBatchCacheKey(t *testing.T) {
tests := []struct {
name string
ids []int64
want string
}{
{"empty", nil, "accounts_today_stats_empty"},
{"single", []int64{42}, "accounts_today_stats:42"},
{"multiple", []int64{1, 2, 3}, "accounts_today_stats:1,2,3"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := buildAccountTodayStatsBatchCacheKey(tc.ids)
require.Equal(t, tc.want, got)
})
}
}

View File

@@ -5,6 +5,7 @@ import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -47,7 +48,12 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
req = OpenAIGenerateAuthURLRequest{}
}
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
result, err := h.openaiOAuthService.GenerateAuthURL(
c.Request.Context(),
req.ProxyID,
req.RedirectURI,
oauthPlatformFromPath(c),
)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -123,7 +129,14 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
}
}
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
// 未指定 client_id 时,根据请求路径平台自动设置默认值,避免 repository 层盲猜
clientID := strings.TrimSpace(req.ClientID)
if clientID == "" {
platform := oauthPlatformFromPath(c)
clientID, _ = openai.OAuthClientConfigByPlatform(platform)
}
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, clientID)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -0,0 +1,145 @@
package admin
import (
"encoding/json"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
)
var opsDashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
type opsDashboardSnapshotV2Response struct {
GeneratedAt string `json:"generated_at"`
Overview *service.OpsDashboardOverview `json:"overview"`
ThroughputTrend *service.OpsThroughputTrendResponse `json:"throughput_trend"`
ErrorTrend *service.OpsErrorTrendResponse `json:"error_trend"`
}
type opsDashboardSnapshotV2CacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id"`
QueryMode service.OpsQueryMode `json:"mode"`
BucketSecond int `json:"bucket_second"`
}
// GetDashboardSnapshotV2 returns ops dashboard core snapshot in one request.
// GET /api/v1/admin/ops/dashboard/snapshot-v2
func (h *OpsHandler) GetDashboardSnapshotV2(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: strings.TrimSpace(c.Query("platform")),
QueryMode: parseOpsQueryMode(c),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime))
keyRaw, _ := json.Marshal(opsDashboardSnapshotV2CacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Platform: filter.Platform,
GroupID: filter.GroupID,
QueryMode: filter.QueryMode,
BucketSecond: bucketSeconds,
})
cacheKey := string(keyRaw)
if cached, ok := opsDashboardSnapshotV2Cache.Get(cacheKey); ok {
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
var (
overview *service.OpsDashboardOverview
trend *service.OpsThroughputTrendResponse
errTrend *service.OpsErrorTrendResponse
)
g, gctx := errgroup.WithContext(c.Request.Context())
g.Go(func() error {
f := *filter
result, err := h.opsService.GetDashboardOverview(gctx, &f)
if err != nil {
return err
}
overview = result
return nil
})
g.Go(func() error {
f := *filter
result, err := h.opsService.GetThroughputTrend(gctx, &f, bucketSeconds)
if err != nil {
return err
}
trend = result
return nil
})
g.Go(func() error {
f := *filter
result, err := h.opsService.GetErrorTrend(gctx, &f, bucketSeconds)
if err != nil {
return err
}
errTrend = result
return nil
})
if err := g.Wait(); err != nil {
response.ErrorFrom(c, err)
return
}
resp := &opsDashboardSnapshotV2Response{
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
Overview: overview,
ThroughputTrend: trend,
ErrorTrend: errTrend,
}
cached := opsDashboardSnapshotV2Cache.Set(cacheKey, resp)
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
}
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, resp)
}

View File

@@ -62,7 +62,8 @@ const (
)
var wsConnCount atomic.Int32
var wsConnCountByIP sync.Map // map[string]*atomic.Int32
var wsConnCountByIPMu sync.Mutex
var wsConnCountByIP = make(map[string]int32)
const qpsWSIdleStopDelay = 30 * time.Second
@@ -389,42 +390,31 @@ func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool {
if strings.TrimSpace(clientIP) == "" || limit <= 0 {
return true
}
v, _ := wsConnCountByIP.LoadOrStore(clientIP, &atomic.Int32{})
counter, ok := v.(*atomic.Int32)
if !ok {
wsConnCountByIPMu.Lock()
defer wsConnCountByIPMu.Unlock()
current := wsConnCountByIP[clientIP]
if current >= limit {
return false
}
for {
current := counter.Load()
if current >= limit {
return false
}
if counter.CompareAndSwap(current, current+1) {
return true
}
}
wsConnCountByIP[clientIP] = current + 1
return true
}
func releaseOpsWSIPSlot(clientIP string) {
if strings.TrimSpace(clientIP) == "" {
return
}
v, ok := wsConnCountByIP.Load(clientIP)
wsConnCountByIPMu.Lock()
defer wsConnCountByIPMu.Unlock()
current, ok := wsConnCountByIP[clientIP]
if !ok {
return
}
counter, ok := v.(*atomic.Int32)
if !ok {
if current <= 1 {
delete(wsConnCountByIP, clientIP)
return
}
next := counter.Add(-1)
if next <= 0 {
// Best-effort cleanup; safe even if a new slot was acquired concurrently.
wsConnCountByIP.Delete(clientIP)
}
wsConnCountByIP[clientIP] = current - 1
}
func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {

View File

@@ -64,9 +64,9 @@ func (h *ProxyHandler) List(c *gin.Context) {
return
}
out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
@@ -83,9 +83,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i]))
}
response.Success(c, out)
return
@@ -97,9 +97,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
return
}
out := make([]dto.Proxy, 0, len(proxies))
out := make([]dto.AdminProxy, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyFromService(&proxies[i]))
out = append(out, *dto.ProxyFromServiceAdmin(&proxies[i]))
}
response.Success(c, out)
}
@@ -119,7 +119,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.ProxyFromService(proxy))
response.Success(c, dto.ProxyFromServiceAdmin(proxy))
}
// Create handles creating a new proxy
@@ -143,7 +143,7 @@ func (h *ProxyHandler) Create(c *gin.Context) {
if err != nil {
return nil, err
}
return dto.ProxyFromService(proxy), nil
return dto.ProxyFromServiceAdmin(proxy), nil
})
}
@@ -176,7 +176,7 @@ func (h *ProxyHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.ProxyFromService(proxy))
response.Success(c, dto.ProxyFromServiceAdmin(proxy))
}
// Delete handles deleting a proxy

View File

@@ -1,7 +1,13 @@
package admin
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"log"
"net/http"
"regexp"
"strings"
"time"
@@ -14,21 +20,38 @@ import (
"github.com/gin-gonic/gin"
)
// semverPattern 预编译 semver 格式校验正则
var semverPattern = regexp.MustCompile(`^\d+\.\d+\.\d+$`)
// menuItemIDPattern validates custom menu item IDs: alphanumeric, hyphens, underscores only.
var menuItemIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
// generateMenuItemID generates a short random hex ID for a custom menu item.
func generateMenuItemID() (string, error) {
b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("generate menu item ID: %w", err)
}
return hex.EncodeToString(b), nil
}
// SettingHandler 系统设置处理器
type SettingHandler struct {
settingService *service.SettingService
emailService *service.EmailService
turnstileService *service.TurnstileService
opsService *service.OpsService
soraS3Storage *service.SoraS3Storage
}
// NewSettingHandler 创建系统设置处理器
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler {
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler {
return &SettingHandler{
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
opsService: opsService,
soraS3Storage: soraS3Storage,
}
}
@@ -43,10 +66,18 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
defaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(settings.DefaultSubscriptions))
for _, sub := range settings.DefaultSubscriptions {
defaultSubscriptions = append(defaultSubscriptions, dto.DefaultSubscriptionSetting{
GroupID: sub.GroupID,
ValidityDays: sub.ValidityDays,
})
}
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
@@ -76,8 +107,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI,
@@ -89,18 +123,21 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: settings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
})
}
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
// 邮件服务设置
SMTPHost string `json:"smtp_host"`
@@ -123,20 +160,23 @@ type UpdateSettingsRequest struct {
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -154,6 +194,11 @@ type UpdateSettingsRequest struct {
OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"`
OpsQueryModeDefault *string `json:"ops_query_mode_default"`
OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"`
MinClaudeCodeVersion string `json:"min_claude_code_version"`
// 分组隔离
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
}
// UpdateSettings 更新系统设置
@@ -181,6 +226,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.SMTPPort <= 0 {
req.SMTPPort = 587
}
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
// Turnstile 参数验证
if req.TurnstileEnabled {
@@ -276,6 +322,84 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// 自定义菜单项验证
const (
maxCustomMenuItems = 20
maxMenuItemLabelLen = 50
maxMenuItemURLLen = 2048
maxMenuItemIconSVGLen = 10 * 1024 // 10KB
maxMenuItemIDLen = 32
)
customMenuJSON := previousSettings.CustomMenuItems
if req.CustomMenuItems != nil {
items := *req.CustomMenuItems
if len(items) > maxCustomMenuItems {
response.BadRequest(c, "Too many custom menu items (max 20)")
return
}
for i, item := range items {
if strings.TrimSpace(item.Label) == "" {
response.BadRequest(c, "Custom menu item label is required")
return
}
if len(item.Label) > maxMenuItemLabelLen {
response.BadRequest(c, "Custom menu item label is too long (max 50 characters)")
return
}
if strings.TrimSpace(item.URL) == "" {
response.BadRequest(c, "Custom menu item URL is required")
return
}
if len(item.URL) > maxMenuItemURLLen {
response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)")
return
}
if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil {
response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL")
return
}
if item.Visibility != "user" && item.Visibility != "admin" {
response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'")
return
}
if len(item.IconSVG) > maxMenuItemIconSVGLen {
response.BadRequest(c, "Custom menu item icon SVG is too large (max 10KB)")
return
}
// Auto-generate ID if missing
if strings.TrimSpace(item.ID) == "" {
id, err := generateMenuItemID()
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to generate menu item ID")
return
}
items[i].ID = id
} else if len(item.ID) > maxMenuItemIDLen {
response.BadRequest(c, "Custom menu item ID is too long (max 32 characters)")
return
} else if !menuItemIDPattern.MatchString(item.ID) {
response.BadRequest(c, "Custom menu item ID contains invalid characters (only a-z, A-Z, 0-9, - and _ are allowed)")
return
}
}
// ID uniqueness check
seen := make(map[string]struct{}, len(items))
for _, item := range items {
if _, exists := seen[item.ID]; exists {
response.BadRequest(c, "Duplicate custom menu item ID: "+item.ID)
return
}
seen[item.ID] = struct{}{}
}
menuBytes, err := json.Marshal(items)
if err != nil {
response.BadRequest(c, "Failed to serialize custom menu items")
return
}
customMenuJSON = string(menuBytes)
}
// Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds
@@ -287,47 +411,68 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
req.OpsMetricsIntervalSeconds = &v
}
defaultSubscriptions := make([]service.DefaultSubscriptionSetting, 0, len(req.DefaultSubscriptions))
for _, sub := range req.DefaultSubscriptions {
defaultSubscriptions = append(defaultSubscriptions, service.DefaultSubscriptionSetting{
GroupID: sub.GroupID,
ValidityDays: sub.ValidityDays,
})
}
// 验证最低版本号格式(空字符串=禁用,或合法 semver
if req.MinClaudeCodeVersion != "" {
if !semverPattern.MatchString(req.MinClaudeCodeVersion) {
response.Error(c, http.StatusBadRequest, "min_claude_code_version must be empty or a valid semver (e.g. 2.1.63)")
return
}
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
SoraClientEnabled: req.SoraClientEnabled,
CustomMenuItems: customMenuJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@@ -367,10 +512,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
for _, sub := range updatedSettings.DefaultSubscriptions {
updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
GroupID: sub.GroupID,
ValidityDays: sub.ValidityDays,
})
}
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
@@ -400,8 +553,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
HideCcsImportButton: updatedSettings.HideCcsImportButton,
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
SoraClientEnabled: updatedSettings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
@@ -413,6 +569,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
})
}
@@ -444,6 +602,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled")
}
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
changed = append(changed, "registration_email_suffix_whitelist")
}
if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled")
}
@@ -522,6 +683,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance")
}
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
changed = append(changed, "default_subscriptions")
}
if before.EnableModelFallback != after.EnableModelFallback {
changed = append(changed, "enable_model_fallback")
}
@@ -555,9 +719,65 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.OpsMetricsIntervalSeconds != after.OpsMetricsIntervalSeconds {
changed = append(changed, "ops_metrics_interval_seconds")
}
if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion {
changed = append(changed, "min_claude_code_version")
}
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
changed = append(changed, "allow_ungrouped_key_scheduling")
}
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
changed = append(changed, "purchase_subscription_enabled")
}
if before.PurchaseSubscriptionURL != after.PurchaseSubscriptionURL {
changed = append(changed, "purchase_subscription_url")
}
if before.CustomMenuItems != after.CustomMenuItems {
changed = append(changed, "custom_menu_items")
}
return changed
}
func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto.DefaultSubscriptionSetting {
if len(input) == 0 {
return nil
}
normalized := make([]dto.DefaultSubscriptionSetting, 0, len(input))
for _, item := range input {
if item.GroupID <= 0 || item.ValidityDays <= 0 {
continue
}
if item.ValidityDays > service.MaxValidityDays {
item.ValidityDays = service.MaxValidityDays
}
normalized = append(normalized, item)
}
return normalized
}
func equalStringSlice(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].GroupID != b[i].GroupID || a[i].ValidityDays != b[i].ValidityDays {
return false
}
}
return true
}
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"`
@@ -750,6 +970,384 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
})
}
func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings {
if settings == nil {
return dto.SoraS3Settings{}
}
return dto.SoraS3Settings{
Enabled: settings.Enabled,
Endpoint: settings.Endpoint,
Region: settings.Region,
Bucket: settings.Bucket,
AccessKeyID: settings.AccessKeyID,
SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured,
Prefix: settings.Prefix,
ForcePathStyle: settings.ForcePathStyle,
CDNURL: settings.CDNURL,
DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes,
}
}
func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile {
return dto.SoraS3Profile{
ProfileID: profile.ProfileID,
Name: profile.Name,
IsActive: profile.IsActive,
Enabled: profile.Enabled,
Endpoint: profile.Endpoint,
Region: profile.Region,
Bucket: profile.Bucket,
AccessKeyID: profile.AccessKeyID,
SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured,
Prefix: profile.Prefix,
ForcePathStyle: profile.ForcePathStyle,
CDNURL: profile.CDNURL,
DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes,
UpdatedAt: profile.UpdatedAt,
}
}
func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error {
if !enabled {
return nil
}
if strings.TrimSpace(endpoint) == "" {
return fmt.Errorf("S3 Endpoint is required when enabled")
}
if strings.TrimSpace(bucket) == "" {
return fmt.Errorf("S3 Bucket is required when enabled")
}
if strings.TrimSpace(accessKeyID) == "" {
return fmt.Errorf("S3 Access Key ID is required when enabled")
}
if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret {
return nil
}
return fmt.Errorf("S3 Secret Access Key is required when enabled")
}
func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile {
for idx := range items {
if items[idx].ProfileID == profileID {
return &items[idx]
}
}
return nil
}
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口)
// GET /api/v1/admin/settings/sora-s3
func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) {
settings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, toSoraS3SettingsDTO(settings))
}
// ListSoraS3Profiles 获取 Sora S3 多配置
// GET /api/v1/admin/settings/sora-s3/profiles
func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) {
result, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
items := make([]dto.SoraS3Profile, 0, len(result.Items))
for idx := range result.Items {
items = append(items, toSoraS3ProfileDTO(result.Items[idx]))
}
response.Success(c, dto.ListSoraS3ProfilesResponse{
ActiveProfileID: result.ActiveProfileID,
Items: items,
})
}
// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口)
type UpdateSoraS3SettingsRequest struct {
ProfileID string `json:"profile_id"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
}
type CreateSoraS3ProfileRequest struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
SetActive bool `json:"set_active"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
}
type UpdateSoraS3ProfileRequest struct {
Name string `json:"name"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
}
// CreateSoraS3Profile 创建 Sora S3 配置
// POST /api/v1/admin/settings/sora-s3/profiles
func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) {
var req CreateSoraS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.DefaultStorageQuotaBytes < 0 {
req.DefaultStorageQuotaBytes = 0
}
if strings.TrimSpace(req.Name) == "" {
response.BadRequest(c, "Name is required")
return
}
if strings.TrimSpace(req.ProfileID) == "" {
response.BadRequest(c, "Profile ID is required")
return
}
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil {
response.BadRequest(c, err.Error())
return
}
created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{
ProfileID: req.ProfileID,
Name: req.Name,
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
CDNURL: req.CDNURL,
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
}, req.SetActive)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, toSoraS3ProfileDTO(*created))
}
// UpdateSoraS3Profile 更新 Sora S3 配置
// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id
func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Profile ID is required")
return
}
var req UpdateSoraS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.DefaultStorageQuotaBytes < 0 {
req.DefaultStorageQuotaBytes = 0
}
if strings.TrimSpace(req.Name) == "" {
response.BadRequest(c, "Name is required")
return
}
existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
existing := findSoraS3ProfileByID(existingList.Items, profileID)
if existing == nil {
response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound)
return
}
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
response.BadRequest(c, err.Error())
return
}
updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{
Name: req.Name,
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
CDNURL: req.CDNURL,
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
})
if updateErr != nil {
response.ErrorFrom(c, updateErr)
return
}
response.Success(c, toSoraS3ProfileDTO(*updated))
}
// DeleteSoraS3Profile 删除 Sora S3 配置
// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id
func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Profile ID is required")
return
}
if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
// SetActiveSoraS3Profile 切换激活 Sora S3 配置
// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate
func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Profile ID is required")
return
}
active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, toSoraS3ProfileDTO(*active))
}
// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口)
// PUT /api/v1/admin/settings/sora-s3
func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) {
var req UpdateSoraS3SettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
if req.DefaultStorageQuotaBytes < 0 {
req.DefaultStorageQuotaBytes = 0
}
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
response.BadRequest(c, err.Error())
return
}
settings := &service.SoraS3Settings{
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
CDNURL: req.CDNURL,
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
}
if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil {
response.ErrorFrom(c, err)
return
}
updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, toSoraS3SettingsDTO(updatedSettings))
}
// TestSoraS3Connection 测试 Sora S3 连接HeadBucket
// POST /api/v1/admin/settings/sora-s3/test
func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
if h.soraS3Storage == nil {
response.Error(c, 500, "S3 存储服务未初始化")
return
}
var req UpdateSoraS3SettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !req.Enabled {
response.BadRequest(c, "S3 未启用,无法测试连接")
return
}
if req.SecretAccessKey == "" {
if req.ProfileID != "" {
profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
if err == nil {
profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID)
if profile != nil {
req.SecretAccessKey = profile.SecretAccessKey
}
}
}
if req.SecretAccessKey == "" {
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
if err == nil {
req.SecretAccessKey = existing.SecretAccessKey
}
}
}
testCfg := &service.SoraS3Settings{
Enabled: true,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
CDNURL: req.CDNURL,
}
if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil {
response.Error(c, 400, "S3 连接测试失败: "+err.Error())
return
}
response.Success(c, gin.H{"message": "S3 连接成功"})
}
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
type UpdateStreamTimeoutSettingsRequest struct {
Enabled bool `json:"enabled"`

View File

@@ -0,0 +1,95 @@
package admin
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"strings"
"sync"
"time"
)
type snapshotCacheEntry struct {
ETag string
Payload any
ExpiresAt time.Time
}
type snapshotCache struct {
mu sync.RWMutex
ttl time.Duration
items map[string]snapshotCacheEntry
}
func newSnapshotCache(ttl time.Duration) *snapshotCache {
if ttl <= 0 {
ttl = 30 * time.Second
}
return &snapshotCache{
ttl: ttl,
items: make(map[string]snapshotCacheEntry),
}
}
func (c *snapshotCache) Get(key string) (snapshotCacheEntry, bool) {
if c == nil || key == "" {
return snapshotCacheEntry{}, false
}
now := time.Now()
c.mu.RLock()
entry, ok := c.items[key]
c.mu.RUnlock()
if !ok {
return snapshotCacheEntry{}, false
}
if now.After(entry.ExpiresAt) {
c.mu.Lock()
delete(c.items, key)
c.mu.Unlock()
return snapshotCacheEntry{}, false
}
return entry, true
}
func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
if c == nil {
return snapshotCacheEntry{}
}
entry := snapshotCacheEntry{
ETag: buildETagFromAny(payload),
Payload: payload,
ExpiresAt: time.Now().Add(c.ttl),
}
if key == "" {
return entry
}
c.mu.Lock()
c.items[key] = entry
c.mu.Unlock()
return entry
}
func buildETagFromAny(payload any) string {
raw, err := json.Marshal(payload)
if err != nil {
return ""
}
sum := sha256.Sum256(raw)
return "\"" + hex.EncodeToString(sum[:]) + "\""
}
func parseBoolQueryWithDefault(raw string, def bool) bool {
value := strings.TrimSpace(strings.ToLower(raw))
if value == "" {
return def
}
switch value {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
default:
return def
}
}

View File

@@ -0,0 +1,128 @@
//go:build unit
package admin
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestSnapshotCache_SetAndGet(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
entry := c.Set("key1", map[string]string{"hello": "world"})
require.NotEmpty(t, entry.ETag)
require.NotNil(t, entry.Payload)
got, ok := c.Get("key1")
require.True(t, ok)
require.Equal(t, entry.ETag, got.ETag)
}
func TestSnapshotCache_Expiration(t *testing.T) {
c := newSnapshotCache(1 * time.Millisecond)
c.Set("key1", "value")
time.Sleep(5 * time.Millisecond)
_, ok := c.Get("key1")
require.False(t, ok, "expired entry should not be returned")
}
func TestSnapshotCache_GetEmptyKey(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
_, ok := c.Get("")
require.False(t, ok)
}
func TestSnapshotCache_GetMiss(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
_, ok := c.Get("nonexistent")
require.False(t, ok)
}
func TestSnapshotCache_NilReceiver(t *testing.T) {
var c *snapshotCache
_, ok := c.Get("key")
require.False(t, ok)
entry := c.Set("key", "value")
require.Empty(t, entry.ETag)
}
func TestSnapshotCache_SetEmptyKey(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
// Set with empty key should return entry but not store it
entry := c.Set("", "value")
require.NotEmpty(t, entry.ETag)
_, ok := c.Get("")
require.False(t, ok)
}
func TestSnapshotCache_DefaultTTL(t *testing.T) {
c := newSnapshotCache(0)
require.Equal(t, 30*time.Second, c.ttl)
c2 := newSnapshotCache(-1 * time.Second)
require.Equal(t, 30*time.Second, c2.ttl)
}
func TestSnapshotCache_ETagDeterministic(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
payload := map[string]int{"a": 1, "b": 2}
entry1 := c.Set("k1", payload)
entry2 := c.Set("k2", payload)
require.Equal(t, entry1.ETag, entry2.ETag, "same payload should produce same ETag")
}
func TestSnapshotCache_ETagFormat(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
entry := c.Set("k", "test")
// ETag should be quoted hex string: "abcdef..."
require.True(t, len(entry.ETag) > 2)
require.Equal(t, byte('"'), entry.ETag[0])
require.Equal(t, byte('"'), entry.ETag[len(entry.ETag)-1])
}
func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
// channels are not JSON-serializable
etag := buildETagFromAny(make(chan int))
require.Empty(t, etag)
}
func TestParseBoolQueryWithDefault(t *testing.T) {
tests := []struct {
name string
raw string
def bool
want bool
}{
{"empty returns default true", "", true, true},
{"empty returns default false", "", false, false},
{"1", "1", false, true},
{"true", "true", false, true},
{"TRUE", "TRUE", false, true},
{"yes", "yes", false, true},
{"on", "on", false, true},
{"0", "0", true, false},
{"false", "false", true, false},
{"FALSE", "FALSE", true, false},
{"no", "no", true, false},
{"off", "off", true, false},
{"whitespace trimmed", " true ", false, true},
{"unknown returns default true", "maybe", true, true},
{"unknown returns default false", "maybe", false, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := parseBoolQueryWithDefault(tc.raw, tc.def)
require.Equal(t, tc.want, got)
})
}
}

View File

@@ -225,6 +225,92 @@ func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskInvalidRequestType(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-01-02",
"timezone": "UTC",
"request_type": "invalid",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskRequestTypePriority(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 99)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-01-02",
"timezone": "UTC",
"request_type": "ws_v2",
"stream": false,
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.created, 1)
created := repo.created[0]
require.NotNil(t, created.Filters.RequestType)
require.Equal(t, int16(service.RequestTypeWSV2), *created.Filters.RequestType)
require.Nil(t, created.Filters.Stream)
}
func TestUsageHandlerCreateCleanupTaskWithLegacyStream(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 99)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-01-02",
"timezone": "UTC",
"stream": true,
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.created, 1)
created := repo.created[0]
require.Nil(t, created.Filters.RequestType)
require.NotNil(t, created.Filters.Stream)
require.True(t, *created.Filters.Stream)
}
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}

View File

@@ -51,6 +51,7 @@ type CreateUsageCleanupTaskRequest struct {
AccountID *int64 `json:"account_id"`
GroupID *int64 `json:"group_id"`
Model *string `json:"model"`
RequestType *string `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
Timezone string `json:"timezone"`
@@ -60,6 +61,15 @@ type CreateUsageCleanupTaskRequest struct {
// GET /api/v1/admin/usage
func (h *UsageHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
exactTotal := false
if exactTotalRaw := strings.TrimSpace(c.Query("exact_total")); exactTotalRaw != "" {
parsed, err := strconv.ParseBool(exactTotalRaw)
if err != nil {
response.BadRequest(c, "Invalid exact_total value, use true or false")
return
}
exactTotal = parsed
}
// Parse filters
var userID, apiKeyID, accountID, groupID int64
@@ -101,8 +111,17 @@ func (h *UsageHandler) List(c *gin.Context) {
model := c.Query("model")
var requestType *int16
var stream *bool
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
val, err := strconv.ParseBool(streamStr)
if err != nil {
response.BadRequest(c, "Invalid stream value, use true or false")
@@ -152,10 +171,12 @@ func (h *UsageHandler) List(c *gin.Context) {
AccountID: accountID,
GroupID: groupID,
Model: model,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
StartTime: startTime,
EndTime: endTime,
ExactTotal: exactTotal,
}
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
@@ -214,8 +235,17 @@ func (h *UsageHandler) Stats(c *gin.Context) {
model := c.Query("model")
var requestType *int16
var stream *bool
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
val, err := strconv.ParseBool(streamStr)
if err != nil {
response.BadRequest(c, "Invalid stream value, use true or false")
@@ -278,6 +308,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
AccountID: accountID,
GroupID: groupID,
Model: model,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
StartTime: &startTime,
@@ -432,6 +463,19 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
}
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
var requestType *int16
stream := req.Stream
if req.RequestType != nil {
parsed, err := service.ParseUsageRequestType(*req.RequestType)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
stream = nil
}
filters := service.UsageCleanupFilters{
StartTime: startTime,
EndTime: endTime,
@@ -440,7 +484,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
AccountID: req.AccountID,
GroupID: req.GroupID,
Model: req.Model,
Stream: req.Stream,
RequestType: requestType,
Stream: stream,
BillingType: req.BillingType,
}
@@ -464,9 +509,13 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
if filters.Model != nil {
model = *filters.Model
}
var stream any
var streamValue any
if filters.Stream != nil {
stream = *filters.Stream
streamValue = *filters.Stream
}
var requestTypeName any
if filters.RequestType != nil {
requestTypeName = service.RequestTypeFromInt16(*filters.RequestType).String()
}
var billingType any
if filters.BillingType != nil {
@@ -481,7 +530,7 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
Body: req,
}
executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v request_type=%v stream=%v billing_type=%v tz=%q",
subject.UserID,
filters.StartTime.Format(time.RFC3339),
filters.EndTime.Format(time.RFC3339),
@@ -490,7 +539,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
accountID,
groupID,
model,
stream,
requestTypeName,
streamValue,
billingType,
req.Timezone,
)

View File

@@ -0,0 +1,140 @@
package admin
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type adminUsageRepoCapture struct {
service.UsageLogRepository
listFilters usagestats.UsageLogFilters
statsFilters usagestats.UsageLogFilters
}
func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
s.listFilters = filters
return []service.UsageLog{}, &pagination.PaginationResult{
Total: 0,
Page: params.Page,
PageSize: params.PageSize,
Pages: 0,
}, nil
}
func (s *adminUsageRepoCapture) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
s.statsFilters = filters
return &usagestats.UsageStats{}, nil
}
func newAdminUsageRequestTypeTestRouter(repo *adminUsageRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
usageSvc := service.NewUsageService(repo, nil, nil, nil)
handler := NewUsageHandler(usageSvc, nil, nil, nil)
router := gin.New()
router.GET("/admin/usage", handler.List)
router.GET("/admin/usage/stats", handler.Stats)
return router
}
func TestAdminUsageListRequestTypePriority(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=ws_v2&stream=false", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.listFilters.RequestType)
require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType)
require.Nil(t, repo.listFilters.Stream)
}
func TestAdminUsageListInvalidRequestType(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageListInvalidStream(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageListExactTotalTrue(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=true", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.True(t, repo.listFilters.ExactTotal)
}
func TestAdminUsageListInvalidExactTotal(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=oops", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageStatsRequestTypePriority(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=stream&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.statsFilters.RequestType)
require.Equal(t, int16(service.RequestTypeStream), *repo.statsFilters.RequestType)
require.Nil(t, repo.statsFilters.Stream)
}
func TestAdminUsageStatsInvalidRequestType(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=oops", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageStatsInvalidStream(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?stream=oops", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}

View File

@@ -1,7 +1,9 @@
package admin
import (
"encoding/json"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -67,6 +69,8 @@ type BatchUserAttributesResponse struct {
Attributes map[int64]map[int64]string `json:"attributes"`
}
var userAttributesBatchCache = newSnapshotCache(30 * time.Second)
// AttributeDefinitionResponse represents attribute definition response
type AttributeDefinitionResponse struct {
ID int64 `json:"id"`
@@ -327,16 +331,32 @@ func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
return
}
if len(req.UserIDs) == 0 {
userIDs := normalizeInt64IDList(req.UserIDs)
if len(userIDs) == 0 {
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
return
}
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs)
keyRaw, _ := json.Marshal(struct {
UserIDs []int64 `json:"user_ids"`
}{
UserIDs: userIDs,
})
cacheKey := string(keyRaw)
if cached, ok := userAttributesBatchCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), userIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, BatchUserAttributesResponse{Attributes: attrs})
payload := BatchUserAttributesResponse{Attributes: attrs}
userAttributesBatchCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}

View File

@@ -34,13 +34,14 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi
// CreateUserRequest represents admin create user request
type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Username string `json:"username"`
Notes string `json:"notes"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
AllowedGroups []int64 `json:"allowed_groups"`
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Username string `json:"username"`
Notes string `json:"notes"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
AllowedGroups []int64 `json:"allowed_groups"`
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
}
// UpdateUserRequest represents admin update user request
@@ -56,7 +57,8 @@ type UpdateUserRequest struct {
AllowedGroups *[]int64 `json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
// map[groupID]*ratenil 表示删除该分组的专属倍率
GroupRates map[int64]*float64 `json:"group_rates"`
GroupRates map[int64]*float64 `json:"group_rates"`
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
}
// UpdateBalanceRequest represents balance update request
@@ -89,6 +91,10 @@ func (h *UserHandler) List(c *gin.Context) {
Search: search,
Attributes: parseAttributeFilters(c),
}
if raw, ok := c.GetQuery("include_subscriptions"); ok {
includeSubscriptions := parseBoolQueryWithDefault(raw, true)
filters.IncludeSubscriptions = &includeSubscriptions
}
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
if err != nil {
@@ -174,13 +180,14 @@ func (h *UserHandler) Create(c *gin.Context) {
}
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
Email: req.Email,
Password: req.Password,
Username: req.Username,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
AllowedGroups: req.AllowedGroups,
Email: req.Email,
Password: req.Password,
Username: req.Username,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
AllowedGroups: req.AllowedGroups,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -207,15 +214,16 @@ func (h *UserHandler) Update(c *gin.Context) {
// 使用指针类型直接传递nil 表示未提供该字段
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
Email: req.Email,
Password: req.Password,
Username: req.Username,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
Status: req.Status,
AllowedGroups: req.AllowedGroups,
GroupRates: req.GroupRates,
Email: req.Email,
Password: req.Password,
Username: req.Username,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
Status: req.Status,
AllowedGroups: req.AllowedGroups,
GroupRates: req.GroupRates,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
})
if err != nil {
response.ErrorFrom(c, err)

View File

@@ -4,6 +4,7 @@ package handler
import (
"context"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -36,6 +37,11 @@ type CreateAPIKeyRequest struct {
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
Quota *float64 `json:"quota"` // 配额限制 (USD)
ExpiresInDays *int `json:"expires_in_days"` // 过期天数
// Rate limit fields (0 = unlimited)
RateLimit5h *float64 `json:"rate_limit_5h"`
RateLimit1d *float64 `json:"rate_limit_1d"`
RateLimit7d *float64 `json:"rate_limit_7d"`
}
// UpdateAPIKeyRequest represents the update API key request payload
@@ -48,6 +54,12 @@ type UpdateAPIKeyRequest struct {
Quota *float64 `json:"quota"` // 配额限制 (USD), 0=无限制
ExpiresAt *string `json:"expires_at"` // 过期时间 (ISO 8601)
ResetQuota *bool `json:"reset_quota"` // 重置已用配额
// Rate limit fields (nil = no change, 0 = unlimited)
RateLimit5h *float64 `json:"rate_limit_5h"`
RateLimit1d *float64 `json:"rate_limit_1d"`
RateLimit7d *float64 `json:"rate_limit_7d"`
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // 重置限速用量
}
// List handles listing user's API keys with pagination
@@ -62,7 +74,23 @@ func (h *APIKeyHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params)
// Parse filter parameters
var filters service.APIKeyListFilters
if search := strings.TrimSpace(c.Query("search")); search != "" {
if len(search) > 100 {
search = search[:100]
}
filters.Search = search
}
filters.Status = c.Query("status")
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
gid, err := strconv.ParseInt(groupIDStr, 10, 64)
if err == nil {
filters.GroupID = &gid
}
}
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params, filters)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -131,6 +159,15 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
if req.Quota != nil {
svcReq.Quota = *req.Quota
}
if req.RateLimit5h != nil {
svcReq.RateLimit5h = *req.RateLimit5h
}
if req.RateLimit1d != nil {
svcReq.RateLimit1d = *req.RateLimit1d
}
if req.RateLimit7d != nil {
svcReq.RateLimit7d = *req.RateLimit7d
}
executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq)
@@ -163,10 +200,14 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
}
svcReq := service.UpdateAPIKeyRequest{
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
Quota: req.Quota,
ResetQuota: req.ResetQuota,
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
Quota: req.Quota,
ResetQuota: req.ResetQuota,
RateLimit5h: req.RateLimit5h,
RateLimit1d: req.RateLimit1d,
RateLimit7d: req.RateLimit7d,
ResetRateLimitUsage: req.ResetRateLimitUsage,
}
if req.Name != "" {
svcReq.Name = &req.Name

View File

@@ -59,9 +59,11 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
return nil
}
return &AdminUser{
User: *base,
Notes: u.Notes,
GroupRates: u.GroupRates,
User: *base,
Notes: u.Notes,
GroupRates: u.GroupRates,
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
}
}
@@ -70,22 +72,31 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
return nil
}
return &APIKey{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
IPWhitelist: k.IPWhitelist,
IPBlacklist: k.IPBlacklist,
LastUsedAt: k.LastUsedAt,
Quota: k.Quota,
QuotaUsed: k.QuotaUsed,
ExpiresAt: k.ExpiresAt,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
User: UserFromServiceShallow(k.User),
Group: GroupFromServiceShallow(k.Group),
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
IPWhitelist: k.IPWhitelist,
IPBlacklist: k.IPBlacklist,
LastUsedAt: k.LastUsedAt,
Quota: k.Quota,
QuotaUsed: k.QuotaUsed,
ExpiresAt: k.ExpiresAt,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
RateLimit5h: k.RateLimit5h,
RateLimit1d: k.RateLimit1d,
RateLimit7d: k.RateLimit7d,
Usage5h: k.Usage5h,
Usage1d: k.Usage1d,
Usage7d: k.Usage7d,
Window5hStart: k.Window5hStart,
Window1dStart: k.Window1dStart,
Window7dStart: k.Window7dStart,
User: UserFromServiceShallow(k.User),
Group: GroupFromServiceShallow(k.Group),
}
}
@@ -153,6 +164,7 @@ func groupFromServiceBase(g *service.Group) Group {
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
@@ -207,6 +219,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
out.SessionIdleTimeoutMin = &idleTimeout
}
if rpm := a.GetBaseRPM(); rpm > 0 {
out.BaseRPM = &rpm
strategy := a.GetRPMStrategy()
out.RPMStrategy = &strategy
buffer := a.GetRPMStickyBuffer()
out.RPMStickyBuffer = &buffer
}
// 用户消息队列模式
if mode := a.GetUserMsgQueueMode(); mode != "" {
out.UserMsgQueueMode = &mode
}
// TLS指纹伪装开关
if a.IsTLSFingerprintEnabled() {
enabled := true
@@ -284,7 +307,6 @@ func ProxyFromService(p *service.Proxy) *Proxy {
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
@@ -314,6 +336,51 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
}
}
// ProxyFromServiceAdmin converts a service Proxy to AdminProxy DTO for admin users.
// It includes the password field - user-facing endpoints must not use this.
func ProxyFromServiceAdmin(p *service.Proxy) *AdminProxy {
if p == nil {
return nil
}
base := ProxyFromService(p)
if base == nil {
return nil
}
return &AdminProxy{
Proxy: *base,
Password: p.Password,
}
}
// ProxyWithAccountCountFromServiceAdmin converts a service ProxyWithAccountCount to AdminProxyWithAccountCount DTO.
// It includes the password field - user-facing endpoints must not use this.
func ProxyWithAccountCountFromServiceAdmin(p *service.ProxyWithAccountCount) *AdminProxyWithAccountCount {
if p == nil {
return nil
}
admin := ProxyFromServiceAdmin(&p.Proxy)
if admin == nil {
return nil
}
return &AdminProxyWithAccountCount{
AdminProxy: *admin,
AccountCount: p.AccountCount,
LatencyMs: p.LatencyMs,
LatencyStatus: p.LatencyStatus,
LatencyMessage: p.LatencyMessage,
IPAddress: p.IPAddress,
Country: p.Country,
CountryCode: p.CountryCode,
Region: p.Region,
City: p.City,
QualityStatus: p.QualityStatus,
QualityScore: p.QualityScore,
QualityGrade: p.QualityGrade,
QualitySummary: p.QualitySummary,
QualityChecked: p.QualityChecked,
}
}
func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary {
if a == nil {
return nil
@@ -386,6 +453,8 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
// 普通用户 DTO严禁包含管理员字段例如 account_rate_multiplier、ip_address、account
requestType := l.EffectiveRequestType()
stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode)
return UsageLog{
ID: l.ID,
UserID: l.UserID,
@@ -410,7 +479,9 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
ActualCost: l.ActualCost,
RateMultiplier: l.RateMultiplier,
BillingType: l.BillingType,
Stream: l.Stream,
RequestType: requestType.String(),
Stream: stream,
OpenAIWSMode: openAIWSMode,
DurationMs: l.DurationMs,
FirstTokenMs: l.FirstTokenMs,
ImageCount: l.ImageCount,
@@ -465,6 +536,7 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa
AccountID: task.Filters.AccountID,
GroupID: task.Filters.GroupID,
Model: task.Filters.Model,
RequestType: requestTypeStringPtr(task.Filters.RequestType),
Stream: task.Filters.Stream,
BillingType: task.Filters.BillingType,
},
@@ -480,6 +552,14 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa
}
}
func requestTypeStringPtr(requestType *int16) *string {
if requestType == nil {
return nil
}
value := service.RequestTypeFromInt16(*requestType).String()
return &value
}
func SettingFromService(s *service.Setting) *Setting {
if s == nil {
return nil

View File

@@ -0,0 +1,73 @@
package dto
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestUsageLogFromService_IncludesOpenAIWSMode(t *testing.T) {
t.Parallel()
wsLog := &service.UsageLog{
RequestID: "req_1",
Model: "gpt-5.3-codex",
OpenAIWSMode: true,
}
httpLog := &service.UsageLog{
RequestID: "resp_1",
Model: "gpt-5.3-codex",
OpenAIWSMode: false,
}
require.True(t, UsageLogFromService(wsLog).OpenAIWSMode)
require.False(t, UsageLogFromService(httpLog).OpenAIWSMode)
require.True(t, UsageLogFromServiceAdmin(wsLog).OpenAIWSMode)
require.False(t, UsageLogFromServiceAdmin(httpLog).OpenAIWSMode)
}
func TestUsageLogFromService_PrefersRequestTypeForLegacyFields(t *testing.T) {
t.Parallel()
log := &service.UsageLog{
RequestID: "req_2",
Model: "gpt-5.3-codex",
RequestType: service.RequestTypeWSV2,
Stream: false,
OpenAIWSMode: false,
}
userDTO := UsageLogFromService(log)
adminDTO := UsageLogFromServiceAdmin(log)
require.Equal(t, "ws_v2", userDTO.RequestType)
require.True(t, userDTO.Stream)
require.True(t, userDTO.OpenAIWSMode)
require.Equal(t, "ws_v2", adminDTO.RequestType)
require.True(t, adminDTO.Stream)
require.True(t, adminDTO.OpenAIWSMode)
}
func TestUsageCleanupTaskFromService_RequestTypeMapping(t *testing.T) {
t.Parallel()
requestType := int16(service.RequestTypeStream)
task := &service.UsageCleanupTask{
ID: 1,
Status: service.UsageCleanupStatusPending,
Filters: service.UsageCleanupFilters{
RequestType: &requestType,
},
}
dtoTask := UsageCleanupTaskFromService(task)
require.NotNil(t, dtoTask)
require.NotNil(t, dtoTask.Filters.RequestType)
require.Equal(t, "stream", *dtoTask.Filters.RequestType)
}
func TestRequestTypeStringPtrNil(t *testing.T) {
t.Parallel()
require.Nil(t, requestTypeStringPtr(nil))
}

View File

@@ -1,14 +1,30 @@
package dto
import (
"encoding/json"
"strings"
)
// CustomMenuItem represents a user-configured custom menu entry.
type CustomMenuItem struct {
ID string `json:"id"`
Label string `json:"label"`
IconSVG string `json:"icon_svg"`
URL string `json:"url"`
Visibility string `json:"visibility"` // "user" or "admin"
SortOrder int `json:"sort_order"`
}
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
@@ -27,19 +43,22 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -57,29 +76,80 @@ type SystemSettings struct {
OpsRealtimeMonitoringEnabled bool `json:"ops_realtime_monitoring_enabled"`
OpsQueryModeDefault string `json:"ops_query_mode_default"`
OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"`
MinClaudeCodeVersion string `json:"min_claude_code_version"`
// 分组隔离
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
}
type DefaultSubscriptionSetting struct {
GroupID int64 `json:"group_id"`
ValidityDays int `json:"validity_days"`
}
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
SoraClientEnabled bool `json:"sora_client_enabled"`
Version string `json:"version"`
}
// SoraS3Settings Sora S3 存储配置 DTO响应用不含敏感字段
type SoraS3Settings struct {
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
}
// SoraS3Profile Sora S3 存储配置项 DTO响应用不含敏感字段
type SoraS3Profile struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
UpdatedAt string `json:"updated_at"`
}
// ListSoraS3ProfilesResponse Sora S3 配置列表响应
type ListSoraS3ProfilesResponse struct {
ActiveProfileID string `json:"active_profile_id"`
Items []SoraS3Profile `json:"items"`
}
// StreamTimeoutSettings 流超时处理配置 DTO
@@ -90,3 +160,29 @@ type StreamTimeoutSettings struct {
ThresholdCount int `json:"threshold_count"`
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
}
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
// Returns empty slice on empty/invalid input.
func ParseCustomMenuItems(raw string) []CustomMenuItem {
raw = strings.TrimSpace(raw)
if raw == "" || raw == "[]" {
return []CustomMenuItem{}
}
var items []CustomMenuItem
if err := json.Unmarshal([]byte(raw), &items); err != nil {
return []CustomMenuItem{}
}
return items
}
// ParseUserVisibleMenuItems parses custom menu items and filters out admin-only entries.
func ParseUserVisibleMenuItems(raw string) []CustomMenuItem {
items := ParseCustomMenuItems(raw)
filtered := make([]CustomMenuItem, 0, len(items))
for _, item := range items {
if item.Visibility != "admin" {
filtered = append(filtered, item)
}
}
return filtered
}

View File

@@ -26,7 +26,9 @@ type AdminUser struct {
Notes string `json:"notes"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"`
}
type APIKey struct {
@@ -45,6 +47,17 @@ type APIKey struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
// Rate limit fields
RateLimit5h float64 `json:"rate_limit_5h"`
RateLimit1d float64 `json:"rate_limit_1d"`
RateLimit7d float64 `json:"rate_limit_7d"`
Usage5h float64 `json:"usage_5h"`
Usage1d float64 `json:"usage_1d"`
Usage7d float64 `json:"usage_7d"`
Window5hStart *time.Time `json:"window_5h_start"`
Window1dStart *time.Time `json:"window_1d_start"`
Window7dStart *time.Time `json:"window_7d_start"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
@@ -80,6 +93,9 @@ type Group struct {
// 无效请求兜底分组
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// Sora 存储配额
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
@@ -150,6 +166,13 @@ type Account struct {
MaxSessions *int `json:"max_sessions,omitempty"`
SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
// RPM 限制(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
BaseRPM *int `json:"base_rpm,omitempty"`
RPMStrategy *string `json:"rpm_strategy,omitempty"`
RPMStickyBuffer *int `json:"rpm_sticky_buffer,omitempty"`
UserMsgQueueMode *string `json:"user_msg_queue_mode,omitempty"`
// TLS指纹伪装仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"`
@@ -212,6 +235,32 @@ type ProxyWithAccountCount struct {
QualityChecked *int64 `json:"quality_checked,omitempty"`
}
// AdminProxy 是管理员接口使用的 proxy DTO包含密码等敏感字段
// 注意:普通接口不得使用此 DTO。
type AdminProxy struct {
Proxy
Password string `json:"password,omitempty"`
}
// AdminProxyWithAccountCount 是管理员接口使用的带账号统计的 proxy DTO。
type AdminProxyWithAccountCount struct {
AdminProxy
AccountCount int64 `json:"account_count"`
LatencyMs *int64 `json:"latency_ms,omitempty"`
LatencyStatus string `json:"latency_status,omitempty"`
LatencyMessage string `json:"latency_message,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
Country string `json:"country,omitempty"`
CountryCode string `json:"country_code,omitempty"`
Region string `json:"region,omitempty"`
City string `json:"city,omitempty"`
QualityStatus string `json:"quality_status,omitempty"`
QualityScore *int `json:"quality_score,omitempty"`
QualityGrade string `json:"quality_grade,omitempty"`
QualitySummary string `json:"quality_summary,omitempty"`
QualityChecked *int64 `json:"quality_checked,omitempty"`
}
type ProxyAccountSummary struct {
ID int64 `json:"id"`
Name string `json:"name"`
@@ -280,10 +329,12 @@ type UsageLog struct {
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
BillingType int8 `json:"billing_type"`
Stream bool `json:"stream"`
DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"`
BillingType int8 `json:"billing_type"`
RequestType string `json:"request_type"`
Stream bool `json:"stream"`
OpenAIWSMode bool `json:"openai_ws_mode"`
DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"`
// 图片生成字段
ImageCount int `json:"image_count"`
@@ -326,6 +377,7 @@ type UsageCleanupFilters struct {
AccountID *int64 `json:"account_id,omitempty"`
GroupID *int64 `json:"group_id,omitempty"`
Model *string `json:"model,omitempty"`
RequestType *string `json:"request_type,omitempty"`
Stream *bool `json:"stream,omitempty"`
BillingType *int8 `json:"billing_type,omitempty"`
}

View File

@@ -2,11 +2,12 @@ package handler
import (
"context"
"log"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
"go.uber.org/zap"
)
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
@@ -78,8 +79,12 @@ func (s *FailoverState) HandleFailoverError(
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries {
s.SameAccountRetryCount[accountID]++
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries)
logger.FromContext(ctx).Warn("gateway.failover_same_account_retry",
zap.Int64("account_id", accountID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("same_account_retry_count", s.SameAccountRetryCount[accountID]),
zap.Int("same_account_retry_max", maxSameAccountRetries),
)
if !sleepWithContext(ctx, sameAccountRetryDelay) {
return FailoverCanceled
}
@@ -101,8 +106,12 @@ func (s *FailoverState) HandleFailoverError(
// 递增切换计数
s.SwitchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d",
accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches)
logger.FromContext(ctx).Warn("gateway.failover_switch_account",
zap.Int64("account_id", accountID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", s.SwitchCount),
zap.Int("max_switches", s.MaxSwitches),
)
// Antigravity 平台换号线性递增延时
if platform == service.PlatformAntigravity {
@@ -127,13 +136,18 @@ func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAc
s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable &&
s.SwitchCount <= s.MaxSwitches {
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)",
singleAccountBackoffDelay, s.SwitchCount)
logger.FromContext(ctx).Warn("gateway.failover_single_account_backoff",
zap.Duration("backoff_delay", singleAccountBackoffDelay),
zap.Int("switch_count", s.SwitchCount),
zap.Int("max_switches", s.MaxSwitches),
)
if !sleepWithContext(ctx, singleAccountBackoffDelay) {
return FailoverCanceled
}
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d",
s.SwitchCount, s.MaxSwitches)
logger.FromContext(ctx).Warn("gateway.failover_single_account_retry",
zap.Int("switch_count", s.SwitchCount),
zap.Int("max_switches", s.MaxSwitches),
)
s.FailedAccountIDs = make(map[int64]struct{})
return FailoverContinue
}

View File

@@ -6,9 +6,10 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -17,9 +18,11 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -27,6 +30,10 @@ import (
"go.uber.org/zap"
)
const gatewayCompatibilityMetricsLogInterval = 1024
var gatewayCompatibilityMetricsLogCounter atomic.Uint64
// GatewayHandler handles API gateway requests
type GatewayHandler struct {
gatewayService *service.GatewayService
@@ -39,9 +46,11 @@ type GatewayHandler struct {
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
userMsgQueueHelper *UserMsgQueueHelper
maxAccountSwitches int
maxAccountSwitchesGemini int
cfg *config.Config
settingService *service.SettingService
}
// NewGatewayHandler creates a new GatewayHandler
@@ -56,7 +65,9 @@ func NewGatewayHandler(
apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
userMsgQueueService *service.UserMessageQueueService,
cfg *config.Config,
settingService *service.SettingService,
) *GatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 10
@@ -70,6 +81,13 @@ func NewGatewayHandler(
maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini
}
}
// 初始化用户消息串行队列 helper
var umqHelper *UserMsgQueueHelper
if userMsgQueueService != nil && cfg != nil {
umqHelper = NewUserMsgQueueHelper(userMsgQueueService, SSEPingFormatClaude, pingInterval)
}
return &GatewayHandler{
gatewayService: gatewayService,
geminiCompatService: geminiCompatService,
@@ -81,9 +99,11 @@ func NewGatewayHandler(
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
userMsgQueueHelper: umqHelper,
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
cfg: cfg,
settingService: settingService,
}
}
@@ -109,9 +129,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
defer h.maybeLogCompatibilityFallbackMetrics(reqLog)
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
@@ -140,16 +161,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
ctx := service.WithIsMaxTokensOneHaikuRequest(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
// 检查是否为 Claude Code 客户端,设置到 context 中(复用已解析请求,避免二次反序列化)。
SetClaudeCodeClientContext(c, body, parsedReq)
isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context())
// 版本检查:仅对 Claude Code 客户端,拒绝低于最低版本的请求
if !h.checkClaudeCodeVersion(c) {
return
}
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
setOpsRequestContext(c, reqModel, reqStream, body)
@@ -247,8 +273,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if apiKey.GroupID != nil {
prefetchedGroupID = *apiKey.GroupID
}
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
}
@@ -261,7 +286,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
@@ -275,7 +300,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
@@ -364,7 +389,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var result *service.ForwardResult
requestCtx := c.Request.Context()
if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
}
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
@@ -397,6 +422,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// RPM 计数递增Forward 成功后)
// 注意TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil {
reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
@@ -440,7 +474,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) {
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
@@ -459,7 +493,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
@@ -544,18 +578,78 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// ===== 用户消息串行队列 START =====
var queueRelease func()
umqMode := h.getUserMsgQueueMode(account, parsedReq)
switch umqMode {
case config.UMQModeSerialize:
// 串行模式:获取锁 + RPM 延迟 + 释放(当前行为不变)
baseRPM := account.GetBaseRPM()
release, qErr := h.userMsgQueueHelper.AcquireWithWait(
c, account.ID, baseRPM, reqStream, &streamStarted,
h.cfg.Gateway.UserMessageQueue.WaitTimeout(),
reqLog,
)
if qErr != nil {
// fail-open: 记录 warn不阻止请求
reqLog.Warn("gateway.umq_acquire_failed",
zap.Int64("account_id", account.ID),
zap.Error(qErr),
)
} else {
queueRelease = release
}
case config.UMQModeThrottle:
// 软性限速:仅施加 RPM 自适应延迟,不阻塞并发
baseRPM := account.GetBaseRPM()
if tErr := h.userMsgQueueHelper.ThrottleWithPing(
c, account.ID, baseRPM, reqStream, &streamStarted,
h.cfg.Gateway.UserMessageQueue.WaitTimeout(),
reqLog,
); tErr != nil {
reqLog.Warn("gateway.umq_throttle_failed",
zap.Int64("account_id", account.ID),
zap.Error(tErr),
)
}
default:
if umqMode != "" {
reqLog.Warn("gateway.umq_unknown_mode",
zap.String("mode", umqMode),
zap.Int64("account_id", account.ID),
)
}
}
// 用 wrapReleaseOnDone 确保 context 取消时自动释放(仅 serialize 模式有 queueRelease
queueRelease = wrapReleaseOnDone(c.Request.Context(), queueRelease)
// 注入回调到 ParsedRequest使用外层 wrapper 以便提前清理 AfterFunc
parsedReq.OnUpstreamAccepted = queueRelease
// ===== 用户消息串行队列 END =====
// 转发请求 - 根据账号平台分流
c.Set("parsed_request", parsedReq)
var result *service.ForwardResult
requestCtx := c.Request.Context()
if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
}
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
} else {
result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
}
// 兜底释放串行锁(正常情况已通过回调提前释放)
if queueRelease != nil {
queueRelease()
}
// 清理回调引用,防止 failover 重试时旧回调被错误调用
parsedReq.OnUpstreamAccepted = nil
if accountReleaseFunc != nil {
accountReleaseFunc()
}
@@ -591,7 +685,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
// 兜底重试按直接请求兜底分组处理:清除强制平台,允许按分组平台调度
// 兜底重试按"直接请求兜底分组"处理:清除强制平台,允许按分组平台调度
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "")
c.Request = c.Request.WithContext(ctx)
currentAPIKey = fallbackAPIKey
@@ -625,6 +719,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// RPM 计数递增Forward 成功后)
// 注意TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil {
reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
@@ -745,6 +848,10 @@ func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service
// Usage handles getting account balance and usage statistics for CC Switch integration
// GET /v1/usage
//
// Two modes:
// - quota_limited: API Key has quota or rate limits configured. Returns key-level limits/usage.
// - unrestricted: No key-level limits. Returns subscription or wallet balance info.
func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
@@ -758,54 +865,183 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
return
}
ctx := c.Request.Context()
// 解析可选的日期范围参数(用于 model_stats 查询)
startTime, endTime := h.parseUsageDateRange(c)
// Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应
var usageData gin.H
usageData := h.buildUsageData(ctx, apiKey.ID)
// Best-effort: 获取模型统计
var modelStats any
if h.usageService != nil {
dashStats, err := h.usageService.GetAPIKeyDashboardStats(c.Request.Context(), apiKey.ID)
if err == nil && dashStats != nil {
usageData = gin.H{
"today": gin.H{
"requests": dashStats.TodayRequests,
"input_tokens": dashStats.TodayInputTokens,
"output_tokens": dashStats.TodayOutputTokens,
"cache_creation_tokens": dashStats.TodayCacheCreationTokens,
"cache_read_tokens": dashStats.TodayCacheReadTokens,
"total_tokens": dashStats.TodayTokens,
"cost": dashStats.TodayCost,
"actual_cost": dashStats.TodayActualCost,
},
"total": gin.H{
"requests": dashStats.TotalRequests,
"input_tokens": dashStats.TotalInputTokens,
"output_tokens": dashStats.TotalOutputTokens,
"cache_creation_tokens": dashStats.TotalCacheCreationTokens,
"cache_read_tokens": dashStats.TotalCacheReadTokens,
"total_tokens": dashStats.TotalTokens,
"cost": dashStats.TotalCost,
"actual_cost": dashStats.TotalActualCost,
},
"average_duration_ms": dashStats.AverageDurationMs,
"rpm": dashStats.Rpm,
"tpm": dashStats.Tpm,
if stats, err := h.usageService.GetAPIKeyModelStats(ctx, apiKey.ID, startTime, endTime); err == nil && len(stats) > 0 {
modelStats = stats
}
}
// 判断模式: key 有总额度或速率限制 → quota_limited否则 → unrestricted
isQuotaLimited := apiKey.Quota > 0 || apiKey.HasRateLimits()
if isQuotaLimited {
h.usageQuotaLimited(c, ctx, apiKey, usageData, modelStats)
return
}
h.usageUnrestricted(c, ctx, apiKey, subject, usageData, modelStats)
}
// parseUsageDateRange 解析 start_date / end_date query params默认返回近 30 天范围
func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Time) {
now := timezone.Now()
endTime := now
startTime := now.AddDate(0, 0, -30)
if s := c.Query("start_date"); s != "" {
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
startTime = t
}
}
if s := c.Query("end_date"); s != "" {
if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil {
endTime = t.Add(24*time.Hour - time.Second) // end of day
}
}
return startTime, endTime
}
// buildUsageData 构建 today/total 用量摘要
func (h *GatewayHandler) buildUsageData(ctx context.Context, apiKeyID int64) gin.H {
if h.usageService == nil {
return nil
}
dashStats, err := h.usageService.GetAPIKeyDashboardStats(ctx, apiKeyID)
if err != nil || dashStats == nil {
return nil
}
return gin.H{
"today": gin.H{
"requests": dashStats.TodayRequests,
"input_tokens": dashStats.TodayInputTokens,
"output_tokens": dashStats.TodayOutputTokens,
"cache_creation_tokens": dashStats.TodayCacheCreationTokens,
"cache_read_tokens": dashStats.TodayCacheReadTokens,
"total_tokens": dashStats.TodayTokens,
"cost": dashStats.TodayCost,
"actual_cost": dashStats.TodayActualCost,
},
"total": gin.H{
"requests": dashStats.TotalRequests,
"input_tokens": dashStats.TotalInputTokens,
"output_tokens": dashStats.TotalOutputTokens,
"cache_creation_tokens": dashStats.TotalCacheCreationTokens,
"cache_read_tokens": dashStats.TotalCacheReadTokens,
"total_tokens": dashStats.TotalTokens,
"cost": dashStats.TotalCost,
"actual_cost": dashStats.TotalActualCost,
},
"average_duration_ms": dashStats.AverageDurationMs,
"rpm": dashStats.Rpm,
"tpm": dashStats.Tpm,
}
}
// usageQuotaLimited 处理 quota_limited 模式的响应
func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, modelStats any) {
resp := gin.H{
"mode": "quota_limited",
"isValid": apiKey.Status == service.StatusAPIKeyActive || apiKey.Status == service.StatusAPIKeyQuotaExhausted || apiKey.Status == service.StatusAPIKeyExpired,
"status": apiKey.Status,
}
// 总额度信息
if apiKey.Quota > 0 {
remaining := apiKey.GetQuotaRemaining()
resp["quota"] = gin.H{
"limit": apiKey.Quota,
"used": apiKey.QuotaUsed,
"remaining": remaining,
"unit": "USD",
}
resp["remaining"] = remaining
resp["unit"] = "USD"
}
// 速率限制信息(从 DB 获取实时用量)
if apiKey.HasRateLimits() && h.apiKeyService != nil {
rateLimitData, err := h.apiKeyService.GetRateLimitData(ctx, apiKey.ID)
if err == nil && rateLimitData != nil {
var rateLimits []gin.H
if apiKey.RateLimit5h > 0 {
used := rateLimitData.Usage5h
rateLimits = append(rateLimits, gin.H{
"window": "5h",
"limit": apiKey.RateLimit5h,
"used": used,
"remaining": max(0, apiKey.RateLimit5h-used),
"window_start": rateLimitData.Window5hStart,
})
}
if apiKey.RateLimit1d > 0 {
used := rateLimitData.Usage1d
rateLimits = append(rateLimits, gin.H{
"window": "1d",
"limit": apiKey.RateLimit1d,
"used": used,
"remaining": max(0, apiKey.RateLimit1d-used),
"window_start": rateLimitData.Window1dStart,
})
}
if apiKey.RateLimit7d > 0 {
used := rateLimitData.Usage7d
rateLimits = append(rateLimits, gin.H{
"window": "7d",
"limit": apiKey.RateLimit7d,
"used": used,
"remaining": max(0, apiKey.RateLimit7d-used),
"window_start": rateLimitData.Window7dStart,
})
}
if len(rateLimits) > 0 {
resp["rate_limits"] = rateLimits
}
}
}
// 订阅模式:返回订阅限额信息 + 用量统计
// 过期时间
if apiKey.ExpiresAt != nil {
resp["expires_at"] = apiKey.ExpiresAt
resp["days_until_expiry"] = apiKey.GetDaysUntilExpiry()
}
if usageData != nil {
resp["usage"] = usageData
}
if modelStats != nil {
resp["model_stats"] = modelStats
}
c.JSON(http.StatusOK, resp)
}
// usageUnrestricted 处理 unrestricted 模式的响应(向后兼容)
func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, modelStats any) {
// 订阅模式
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
subscription, ok := middleware2.GetSubscriptionFromContext(c)
if !ok {
h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription")
return
resp := gin.H{
"mode": "unrestricted",
"isValid": true,
"planName": apiKey.Group.Name,
"unit": "USD",
}
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
resp := gin.H{
"isValid": true,
"planName": apiKey.Group.Name,
"remaining": remaining,
"unit": "USD",
"subscription": gin.H{
// 订阅信息可能不在 context 中(/v1/usage 路径跳过了中间件的计费检查)
subscription, ok := middleware2.GetSubscriptionFromContext(c)
if ok {
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
resp["remaining"] = remaining
resp["subscription"] = gin.H{
"daily_usage_usd": subscription.DailyUsageUSD,
"weekly_usage_usd": subscription.WeeklyUsageUSD,
"monthly_usage_usd": subscription.MonthlyUsageUSD,
@@ -813,23 +1049,28 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
"weekly_limit_usd": apiKey.Group.WeeklyLimitUSD,
"monthly_limit_usd": apiKey.Group.MonthlyLimitUSD,
"expires_at": subscription.ExpiresAt,
},
}
}
if usageData != nil {
resp["usage"] = usageData
}
if modelStats != nil {
resp["model_stats"] = modelStats
}
c.JSON(http.StatusOK, resp)
return
}
// 余额模式:返回钱包余额 + 用量统计
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
// 余额模式
latestUser, err := h.userService.GetByID(ctx, subject.UserID)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return
}
resp := gin.H{
"mode": "unrestricted",
"isValid": true,
"planName": "钱包余额",
"remaining": latestUser.Balance,
@@ -839,6 +1080,9 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
if usageData != nil {
resp["usage"] = usageData
}
if modelStats != nil {
resp["model_stats"] = modelStats
}
c.JSON(http.StatusOK, resp)
}
@@ -959,20 +1203,8 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
// Send error event in SSE format with proper JSON marshaling
errorData := map[string]any{
"type": "error",
"error": map[string]string{
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
if err != nil {
_ = c.Error(err)
return
}
errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes))
// SSE 错误事件固定 schema使用 Quote 直拼可避免额外 Marshal 分配。
errorEvent := `data: {"type":"error","error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n"
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
@@ -994,6 +1226,41 @@ func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarte
return true
}
// checkClaudeCodeVersion 检查 Claude Code 客户端版本是否满足最低要求
// 仅对已识别的 Claude Code 客户端执行count_tokens 路径除外
func (h *GatewayHandler) checkClaudeCodeVersion(c *gin.Context) bool {
ctx := c.Request.Context()
if !service.IsClaudeCodeClient(ctx) {
return true
}
// 排除 count_tokens 子路径
if strings.HasSuffix(c.Request.URL.Path, "/count_tokens") {
return true
}
minVersion := h.settingService.GetMinClaudeCodeVersion(ctx)
if minVersion == "" {
return true // 未设置,不检查
}
clientVersion := service.GetClaudeCodeVersion(ctx)
if clientVersion == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error",
"Unable to determine Claude Code version. Please update Claude Code: npm update -g @anthropic-ai/claude-code")
return false
}
if service.CompareVersions(clientVersion, minVersion) < 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error",
fmt.Sprintf("Your Claude Code version (%s) is below the minimum required version (%s). Please update: npm update -g @anthropic-ai/claude-code",
clientVersion, minVersion))
return false
}
return true
}
// errorResponse 返回Claude API格式的错误响应
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
@@ -1027,9 +1294,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
defer h.maybeLogCompatibilityFallbackMetrics(reqLog)
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
@@ -1044,9 +1312,6 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
@@ -1054,9 +1319,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// count_tokens 走 messages 严格校验时,复用已解析请求,避免二次反序列化。
SetClaudeCodeClientContext(c, body, parsedReq)
reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream))
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
// 验证 model 必填
if parsedReq.Model == "" {
@@ -1220,24 +1487,8 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
textDeltas = []string{"New", " Conversation"}
}
// Build message_start event with proper JSON marshaling
messageStart := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
"content": []any{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]int{
"input_tokens": 10,
"output_tokens": 0,
},
},
}
messageStartJSON, _ := json.Marshal(messageStart)
// Build message_start event with fixed schema.
messageStartJSON := `{"type":"message_start","message":{"id":` + strconv.Quote(msgID) + `,"type":"message","role":"assistant","model":` + strconv.Quote(model) + `,"content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0}}}`
// Build events
events := []string{
@@ -1247,31 +1498,12 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
// Add text deltas
for _, text := range textDeltas {
delta := map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]string{
"type": "text_delta",
"text": text,
},
}
deltaJSON, _ := json.Marshal(delta)
deltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + strconv.Quote(text) + `}}`
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
}
// Add final events
messageDelta := map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": "end_turn",
"stop_sequence": nil,
},
"usage": map[string]int{
"input_tokens": 10,
"output_tokens": outputTokens,
},
}
messageDeltaJSON, _ := json.Marshal(messageDelta)
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":10,"output_tokens":` + strconv.Itoa(outputTokens) + `}}`
events = append(events,
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
@@ -1358,6 +1590,18 @@ func billingErrorDetails(err error) (status int, code, message string) {
}
return http.StatusServiceUnavailable, "billing_service_error", msg
}
if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) {
msg := pkgerrors.Message(err)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
}
if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) {
msg := pkgerrors.Message(err)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
}
if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) {
msg := pkgerrors.Message(err)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
}
msg := pkgerrors.Message(err)
if msg == "" {
logger.L().With(
@@ -1369,6 +1613,30 @@ func billingErrorDetails(err error) (status int, code, message string) {
return http.StatusForbidden, "billing_error", msg
}
func (h *GatewayHandler) metadataBridgeEnabled() bool {
if h == nil || h.cfg == nil {
return true
}
return h.cfg.Gateway.OpenAIWS.MetadataBridgeEnabled
}
func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger) {
if reqLog == nil {
return
}
if gatewayCompatibilityMetricsLogCounter.Add(1)%gatewayCompatibilityMetricsLogInterval != 0 {
return
}
metrics := service.SnapshotOpenAICompatibilityFallbackMetrics()
reqLog.Info("gateway.compatibility_fallback_metrics",
zap.Int64("session_hash_legacy_read_fallback_total", metrics.SessionHashLegacyReadFallbackTotal),
zap.Int64("session_hash_legacy_read_fallback_hit", metrics.SessionHashLegacyReadFallbackHit),
zap.Int64("session_hash_legacy_dual_write_total", metrics.SessionHashLegacyDualWriteTotal),
zap.Float64("session_hash_legacy_read_hit_rate", metrics.SessionHashLegacyReadHitRate),
zap.Int64("metadata_legacy_fallback_total", metrics.MetadataLegacyFallbackTotal),
)
}
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
@@ -1380,5 +1648,34 @@ func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
// 回退路径worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
defer func() {
if recovered := recover(); recovered != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),
zap.Any("panic", recovered),
).Error("gateway.usage_record_task_panic_recovered")
}
}()
task(ctx)
}
// getUserMsgQueueMode 获取当前请求的 UMQ 模式
// 返回 "serialize" | "throttle" | ""
func (h *GatewayHandler) getUserMsgQueueMode(account *service.Account, parsed *service.ParsedRequest) string {
if h.userMsgQueueHelper == nil {
return ""
}
// 仅适用于 Anthropic OAuth/SetupToken 账号
if !account.IsAnthropicOAuthOrSetupToken() {
return ""
}
if !service.IsRealUserMessage(parsed) {
return ""
}
// 账号级模式优先fallback 到全局配置
mode := account.GetUserMsgQueueMode()
if mode == "" {
mode = h.cfg.Gateway.UserMessageQueue.GetEffectiveMode()
}
return mode
}

View File

@@ -119,6 +119,13 @@ func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.A
func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
return map[int64]*service.UserLoadInfo{}, nil
}
func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int, len(accountIDs))
for _, id := range accountIDs {
result[id] = 0
}
return result, nil
}
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
@@ -146,12 +153,13 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // deferredService
nil, // claudeTokenProvider
nil, // sessionLimitCache
nil, // rpmCache
nil, // digestStore
)
// RunModeSimple跳过计费检查避免引入 repo/cache 依赖。
cfg := &config.Config{RunMode: config.RunModeSimple}
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg)
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)

View File

@@ -18,14 +18,21 @@ import (
// claudeCodeValidator is a singleton validator for Claude Code client detection
var claudeCodeValidator = service.NewClaudeCodeValidator()
const claudeCodeParsedRequestContextKey = "claude_code_parsed_request"
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
// 返回更新后的 context
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service.ParsedRequest) {
if c == nil || c.Request == nil {
return
}
if parsedReq != nil {
c.Set(claudeCodeParsedRequestContextKey, parsedReq)
}
ua := c.GetHeader("User-Agent")
// Fast path非 Claude CLI UA 直接判定 false避免热路径二次 JSON 反序列化。
if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) {
if !claudeCodeValidator.ValidateUserAgent(ua) {
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
c.Request = c.Request.WithContext(ctx)
return
@@ -37,8 +44,11 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
isClaudeCode = true
} else {
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
var bodyMap map[string]any
if len(body) > 0 {
bodyMap := claudeCodeBodyMapFromParsedRequest(parsedReq)
if bodyMap == nil {
bodyMap = claudeCodeBodyMapFromContextCache(c)
}
if bodyMap == nil && len(body) > 0 {
_ = json.Unmarshal(body, &bodyMap)
}
isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap)
@@ -46,9 +56,53 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
// 更新 request context
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
// 仅在确认为 Claude Code 客户端时提取版本号写入 context
if isClaudeCode {
if version := claudeCodeValidator.ExtractVersion(ua); version != "" {
ctx = service.SetClaudeCodeVersion(ctx, version)
}
}
c.Request = c.Request.WithContext(ctx)
}
func claudeCodeBodyMapFromParsedRequest(parsedReq *service.ParsedRequest) map[string]any {
if parsedReq == nil {
return nil
}
bodyMap := map[string]any{
"model": parsedReq.Model,
}
if parsedReq.System != nil || parsedReq.HasSystem {
bodyMap["system"] = parsedReq.System
}
if parsedReq.MetadataUserID != "" {
bodyMap["metadata"] = map[string]any{"user_id": parsedReq.MetadataUserID}
}
return bodyMap
}
func claudeCodeBodyMapFromContextCache(c *gin.Context) map[string]any {
if c == nil {
return nil
}
if cached, ok := c.Get(service.OpenAIParsedRequestBodyKey); ok {
if bodyMap, ok := cached.(map[string]any); ok {
return bodyMap
}
}
if cached, ok := c.Get(claudeCodeParsedRequestContextKey); ok {
switch v := cached.(type) {
case *service.ParsedRequest:
return claudeCodeBodyMapFromParsedRequest(v)
case service.ParsedRequest:
return claudeCodeBodyMapFromParsedRequest(&v)
}
}
return nil
}
// 并发槽位等待相关常量
//
// 性能优化说明:

View File

@@ -33,6 +33,14 @@ func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accoun
return 0, nil
}
func (m *concurrencyCacheMock) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
result[accountID] = 0
}
return result, nil
}
func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}

View File

@@ -49,6 +49,14 @@ func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context,
return 0, nil
}
func (s *helperConcurrencyCacheStub) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
out := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
out[accountID] = 0
}
return out, nil
}
func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
@@ -133,7 +141,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "curl/8.6.0")
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil)
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
})
@@ -141,7 +149,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
c, _ := newHelperTestContext(http.MethodGet, "/v1/models")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
SetClaudeCodeClientContext(c, nil)
SetClaudeCodeClientContext(c, nil, nil)
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
@@ -152,7 +160,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
c.Request.Header.Set("anthropic-version", "2023-06-01")
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil)
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
@@ -160,11 +168,51 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
// 缺少严格校验所需 header + body 字段
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`))
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`), nil)
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
})
}
func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing.T) {
t.Run("reuse parsed request without body unmarshal", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
c.Request.Header.Set("X-App", "claude-code")
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
c.Request.Header.Set("anthropic-version", "2023-06-01")
parsedReq := &service.ParsedRequest{
Model: "claude-3-5-sonnet-20241022",
System: []any{
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
},
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
}
// body 非法 JSON如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
SetClaudeCodeClientContext(c, []byte(`{invalid`), parsedReq)
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
t.Run("reuse context cache without body unmarshal", func(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
c.Request.Header.Set("X-App", "claude-code")
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
c.Request.Header.Set("anthropic-version", "2023-06-01")
c.Set(service.OpenAIParsedRequestBodyKey, map[string]any{
"model": "claude-3-5-sonnet-20241022",
"system": []any{
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
},
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
})
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
}
func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false, true},

View File

@@ -7,16 +7,15 @@ import (
"encoding/hex"
"encoding/json"
"errors"
"io"
"net/http"
"regexp"
"strings"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -168,7 +167,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
stream := action == "streamGenerateContent"
reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream))
body, err := io.ReadAll(c.Request.Body)
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
@@ -268,8 +267,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if apiKey.GroupID != nil {
prefetchedGroupID = *apiKey.GroupID
}
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
}
@@ -349,7 +347,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
@@ -363,7 +361,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
@@ -456,7 +454,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var result *service.ForwardResult
requestCtx := c.Request.Context()
if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
}
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)

View File

@@ -11,6 +11,7 @@ type AdminHandlers struct {
Group *admin.GroupHandler
Account *admin.AccountHandler
Announcement *admin.AnnouncementHandler
DataManagement *admin.DataManagementHandler
OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler
@@ -25,6 +26,7 @@ type AdminHandlers struct {
Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler
ErrorPassthrough *admin.ErrorPassthroughHandler
APIKey *admin.AdminAPIKeyHandler
}
// Handlers contains all HTTP handlers
@@ -40,6 +42,7 @@ type Handlers struct {
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
SoraGateway *SoraGatewayHandler
SoraClient *SoraClientHandler
Setting *SettingHandler
Totp *TotpHandler
}

View File

@@ -5,17 +5,20 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"runtime/debug"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
@@ -64,6 +67,11 @@ func NewOpenAIGatewayHandler(
// Responses handles OpenAI Responses API endpoint
// POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
streamStarted := false
defer h.recoverResponsesPanic(c, &streamStarted)
setOpenAIClientTransportHTTP(c)
requestStart := time.Now()
// Get apiKey and user from context (set by ApiKeyAuth middleware)
@@ -85,9 +93,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
if !h.ensureResponsesDependencies(c, reqLog) {
return
}
// Read request body
body, err := io.ReadAll(c.Request.Body)
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
@@ -125,43 +136,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
reqStream := streamResult.Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String())
if previousResponseID != "" {
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
reqLog = reqLog.With(
zap.Bool("has_previous_response_id", true),
zap.String("previous_response_id_kind", previousResponseIDKind),
zap.Int("previous_response_id_len", len(previousResponseID)),
)
if previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "previous_response_id_looks_like_message_id"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id")
return
}
}
setOpsRequestContext(c, reqModel, reqStream, body)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
// 要求 previous_response_id或 input 内存在带 call_id 的 tool_call/function_call
// 或带 id 且与 call_id 匹配的 item_reference。
// 此路径需要遍历 input 数组做 call_id 关联检查,保留 Unmarshal
if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err == nil {
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
if service.HasFunctionCallOutput(reqBody) {
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
if service.HasFunctionCallOutputMissingCallID(reqBody) {
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_call_id"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return
}
callIDs := service.FunctionCallOutputCallIDs(reqBody)
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_item_reference"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return
}
}
}
}
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
return
}
// Track if we've started streaming (for error handling)
streamStarted := false
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
@@ -173,51 +171,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
routingStart := time.Now()
// 0. 先尝试直接抢占用户槽位(快速路径)
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency)
if err != nil {
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
if !acquired {
return
}
waitCounted := false
if !userAcquired {
// 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if waitErr != nil {
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
// 按现有降级语义:等待计数异常时放行后续抢槽流程
} else if !canWait {
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if waitErr == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
}
// 用户槽位已获取:退出等待队列计数。
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
@@ -241,7 +199,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for {
// Select account supporting the requested model
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
previousResponseID,
sessionHash,
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err != nil {
reqLog.Warn("openai.account_select_failed",
zap.Error(err),
@@ -258,80 +224,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
return
}
if selection == nil || selection.Account == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
if previousResponseID != "" && selection != nil && selection.Account != nil {
reqLog.Debug("openai.account_selected_with_previous_response_id", zap.Int64("account_id", selection.Account.ID))
}
reqLog.Debug("openai.account_schedule_decision",
zap.String("layer", scheduleDecision.Layer),
zap.Bool("sticky_previous_hit", scheduleDecision.StickyPreviousHit),
zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit),
zap.Int("candidate_count", scheduleDecision.CandidateCount),
zap.Int("top_k", scheduleDecision.TopK),
zap.Int64("latency_ms", scheduleDecision.LatencyMs),
zap.Float64("load_skew", scheduleDecision.LoadSkew),
)
account := selection.Account
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
setOpsSelectedAccount(c, account.ID, account.Platform)
// 3. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
// 先快速尝试一次账号槽位,命中则跳过等待计数写入。
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
c.Request.Context(),
account.ID,
selection.WaitPlan.MaxConcurrency,
)
if err != nil {
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if fastAcquired {
accountReleaseFunc = fastReleaseFunc
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
} else {
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
} else if !canWait {
reqLog.Info("openai.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
if !acquired {
return
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// Forward request
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
@@ -353,6 +269,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
@@ -368,14 +286,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("openai.forward_failed",
fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
}
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
reqLog.Warn("openai.forward_failed", fields...)
return
}
reqLog.Error("openai.forward_failed", fields...)
return
}
if result != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
} else {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
@@ -411,6 +340,525 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
}
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
return true
}
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
// 保持原有容错语义:解析失败时跳过预校验,沿用后续上游校验结果。
return true
}
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
validation := service.ValidateFunctionCallOutputContext(reqBody)
if !validation.HasFunctionCallOutput {
return true
}
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) != "" || validation.HasToolCallContext {
return true
}
if validation.HasFunctionCallOutputMissingCallID {
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_call_id"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return false
}
if validation.HasItemReferenceForAllCallIDs {
return true
}
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_item_reference"),
)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return false
}
func (h *OpenAIGatewayHandler) acquireResponsesUserSlot(
c *gin.Context,
userID int64,
userConcurrency int,
reqStream bool,
streamStarted *bool,
reqLog *zap.Logger,
) (func(), bool) {
ctx := c.Request.Context()
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, userID, userConcurrency)
if err != nil {
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
h.handleConcurrencyError(c, err, "user", *streamStarted)
return nil, false
}
if userAcquired {
return wrapReleaseOnDone(ctx, userReleaseFunc), true
}
maxWait := service.CalculateMaxWait(userConcurrency)
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(ctx, userID, maxWait)
if waitErr != nil {
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
// 按现有降级语义:等待计数异常时放行后续抢槽流程
} else if !canWait {
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return nil, false
}
waitCounted := waitErr == nil && canWait
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
}
}()
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, userID, userConcurrency, reqStream, streamStarted)
if err != nil {
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
h.handleConcurrencyError(c, err, "user", *streamStarted)
return nil, false
}
// 槽位获取成功后,立刻退出等待计数。
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
waitCounted = false
}
return wrapReleaseOnDone(ctx, userReleaseFunc), true
}
func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
c *gin.Context,
groupID *int64,
sessionHash string,
selection *service.AccountSelectionResult,
reqStream bool,
streamStarted *bool,
reqLog *zap.Logger,
) (func(), bool) {
if selection == nil || selection.Account == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
return nil, false
}
ctx := c.Request.Context()
account := selection.Account
if selection.Acquired {
return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true
}
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
return nil, false
}
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
ctx,
account.ID,
selection.WaitPlan.MaxConcurrency,
)
if err != nil {
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", *streamStarted)
return nil, false
}
if fastAcquired {
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
return wrapReleaseOnDone(ctx, fastReleaseFunc), true
}
canWait, waitErr := h.concurrencyHelper.IncrementAccountWaitCount(ctx, account.ID, selection.WaitPlan.MaxWaiting)
if waitErr != nil {
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(waitErr))
} else if !canWait {
reqLog.Info("openai.account_wait_queue_full",
zap.Int64("account_id", account.ID),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", *streamStarted)
return nil, false
}
accountWaitCounted := waitErr == nil && canWait
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(ctx, account.ID)
accountWaitCounted = false
}
}
defer releaseWait()
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
streamStarted,
)
if err != nil {
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
h.handleConcurrencyError(c, err, "account", *streamStarted)
return nil, false
}
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
return wrapReleaseOnDone(ctx, accountReleaseFunc), true
}
// ResponsesWebSocket handles OpenAI Responses API WebSocket ingress endpoint
// GET /openai/v1/responses (Upgrade: websocket)
func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
if !isOpenAIWSUpgradeRequest(c.Request) {
h.errorResponse(c, http.StatusUpgradeRequired, "invalid_request_error", "WebSocket upgrade required (Upgrade: websocket)")
return
}
setOpenAIClientTransportWS(c)
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.openai_gateway.responses_ws",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.Bool("openai_ws_mode", true),
)
if !h.ensureResponsesDependencies(c, reqLog) {
return
}
reqLog.Info("openai.websocket_ingress_started")
clientIP := ip.GetClientIP(c)
userAgent := strings.TrimSpace(c.GetHeader("User-Agent"))
wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
reqLog.Warn("openai.websocket_accept_failed",
zap.Error(err),
zap.String("client_ip", clientIP),
zap.String("request_user_agent", userAgent),
zap.String("upgrade_header", strings.TrimSpace(c.GetHeader("Upgrade"))),
zap.String("connection_header", strings.TrimSpace(c.GetHeader("Connection"))),
zap.String("sec_websocket_version", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Version"))),
zap.Bool("has_sec_websocket_key", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Key")) != ""),
)
return
}
defer func() {
_ = wsConn.CloseNow()
}()
wsConn.SetReadLimit(16 * 1024 * 1024)
ctx := c.Request.Context()
readCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
msgType, firstMessage, err := wsConn.Read(readCtx)
cancel()
if err != nil {
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_read_first_message_failed",
zap.Error(err),
zap.String("client_ip", clientIP),
zap.String("close_status", closeStatus),
zap.String("close_reason", closeReason),
zap.Duration("read_timeout", 30*time.Second),
)
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "missing first response.create message")
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "unsupported websocket message type")
return
}
if !gjson.ValidBytes(firstMessage) {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "invalid JSON payload")
return
}
reqModel := strings.TrimSpace(gjson.GetBytes(firstMessage, "model").String())
if reqModel == "" {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model is required in first response.create payload")
return
}
previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String())
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id")
return
}
reqLog = reqLog.With(
zap.Bool("ws_ingress", true),
zap.String("model", reqModel),
zap.Bool("has_previous_response_id", previousResponseID != ""),
zap.String("previous_response_id_kind", previousResponseIDKind),
)
setOpsRequestContext(c, reqModel, true, firstMessage)
var currentUserRelease func()
var currentAccountRelease func()
releaseTurnSlots := func() {
if currentAccountRelease != nil {
currentAccountRelease()
currentAccountRelease = nil
}
if currentUserRelease != nil {
currentUserRelease()
currentUserRelease = nil
}
}
// 必须尽早注册,确保任何 early return 都能释放已获取的并发槽位。
defer releaseTurnSlots()
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
reqLog.Warn("openai.websocket_user_slot_acquire_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
return
}
if !userAcquired {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
return
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed")
return
}
sessionHash := h.gatewayService.GenerateSessionHashWithFallback(
c,
firstMessage,
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
)
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
ctx,
apiKey.GroupID,
previousResponseID,
sessionHash,
reqModel,
nil,
service.OpenAIUpstreamTransportResponsesWebsocketV2,
)
if err != nil {
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
return
}
if selection == nil || selection.Account == nil {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
return
}
account := selection.Account
accountMaxConcurrency := account.Concurrency
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
}
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
return
}
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
ctx,
account.ID,
selection.WaitPlan.MaxConcurrency,
)
if err != nil {
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
return
}
if !fastAcquired {
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
return
}
accountReleaseFunc = fastReleaseFunc
}
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
if err != nil {
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
return
}
reqLog.Debug("openai.websocket_account_selected",
zap.Int64("account_id", account.ID),
zap.String("account_name", account.Name),
zap.String("schedule_layer", scheduleDecision.Layer),
zap.Int("candidate_count", scheduleDecision.CandidateCount),
)
hooks := &service.OpenAIWSIngressHooks{
BeforeTurn: func(turn int) error {
if turn == 1 {
return nil
}
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
releaseTurnSlots()
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
if err != nil {
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
}
if !userAcquired {
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
}
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
if err != nil {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
}
if !accountAcquired {
if userReleaseFunc != nil {
userReleaseFunc()
}
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
}
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
return nil
},
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
releaseTurnSlots()
if turnErr != nil || result == nil {
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
h.submitUsageRecordTask(func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),
zap.String("request_id", result.RequestID),
zap.Error(err),
)
}
})
},
}
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_proxy_failed",
zap.Int64("account_id", account.ID),
zap.Error(err),
zap.String("close_status", closeStatus),
zap.String("close_reason", closeReason),
)
var closeErr *service.OpenAIWSClientCloseError
if errors.As(err, &closeErr) {
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
return
}
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
return
}
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
}
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
recovered := recover()
if recovered == nil {
return
}
started := false
if streamStarted != nil {
started = *streamStarted
}
wroteFallback := h.ensureForwardErrorResponse(c, started)
requestLogger(c, "handler.openai_gateway.responses").Error(
"openai.responses_panic_recovered",
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Any("panic", recovered),
zap.ByteString("stack", debug.Stack()),
)
}
func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool {
missing := h.missingResponsesDependencies()
if len(missing) == 0 {
return true
}
if reqLog == nil {
reqLog = requestLogger(c, "handler.openai_gateway.responses")
}
reqLog.Error("openai.handler_dependencies_missing", zap.Strings("missing_dependencies", missing))
if c != nil && c.Writer != nil && !c.Writer.Written() {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Service temporarily unavailable",
},
})
}
return false
}
func (h *OpenAIGatewayHandler) missingResponsesDependencies() []string {
missing := make([]string, 0, 5)
if h == nil {
return append(missing, "handler")
}
if h.gatewayService == nil {
missing = append(missing, "gatewayService")
}
if h.billingCacheService == nil {
missing = append(missing, "billingCacheService")
}
if h.apiKeyService == nil {
missing = append(missing, "apiKeyService")
}
if h.concurrencyHelper == nil || h.concurrencyHelper.concurrencyService == nil {
missing = append(missing, "concurrencyHelper")
}
return missing
}
func getContextInt64(c *gin.Context, key string) (int64, bool) {
if c == nil || key == "" {
return 0, false
@@ -444,6 +892,14 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
// 回退路径worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
defer func() {
if recovered := recover(); recovered != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
zap.Any("panic", recovered),
).Error("openai.usage_record_task_panic_recovered")
}
}()
task(ctx)
}
@@ -515,19 +971,8 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
// Send error event in OpenAI SSE format with proper JSON marshaling
errorData := map[string]any{
"error": map[string]string{
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
if err != nil {
_ = c.Error(err)
return
}
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
// SSE 错误事件固定 schema使用 Quote 直拼可避免额外 Marshal 分配。
errorEvent := "event: error\ndata: " + `{"error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n"
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
@@ -549,6 +994,16 @@ func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, stream
return true
}
func shouldLogOpenAIForwardFailureAsWarn(c *gin.Context, wroteFallback bool) bool {
if wroteFallback {
return false
}
if c == nil || c.Writer == nil {
return false
}
return c.Writer.Written()
}
// errorResponse returns OpenAI API format error response
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
@@ -558,3 +1013,61 @@ func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType
},
})
}
func setOpenAIClientTransportHTTP(c *gin.Context) {
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportHTTP)
}
func setOpenAIClientTransportWS(c *gin.Context) {
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
}
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
gid := int64(0)
if groupID != nil {
gid = *groupID
}
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
}
func isOpenAIWSUpgradeRequest(r *http.Request) bool {
if r == nil {
return false
}
if !strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket") {
return false
}
return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade")
}
func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason string) {
if conn == nil {
return
}
reason = strings.TrimSpace(reason)
if len(reason) > 120 {
reason = reason[:120]
}
_ = conn.Close(status, reason)
_ = conn.CloseNow()
}
func summarizeWSCloseErrorForLog(err error) (string, string) {
if err == nil {
return "-", "-"
}
statusCode := coderws.CloseStatus(err)
if statusCode == -1 {
return "-", "-"
}
closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
closeReason := "-"
var closeErr coderws.CloseError
if errors.As(err, &closeErr) {
reason := strings.TrimSpace(closeErr.Reason)
if reason != "" {
closeReason = reason
}
}
return closeStatus, closeReason
}

View File

@@ -1,12 +1,19 @@
package handler
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -105,6 +112,27 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
assert.Equal(t, "test error", errorObj["message"])
}
func TestReadRequestBodyWithPrealloc(t *testing.T) {
payload := `{"model":"gpt-5","input":"hello"}`
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(payload))
req.ContentLength = int64(len(payload))
body, err := pkghttputil.ReadRequestBodyWithPrealloc(req)
require.NoError(t, err)
require.Equal(t, payload, string(body))
}
func TestReadRequestBodyWithPrealloc_MaxBytesError(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(strings.Repeat("x", 8)))
req.Body = http.MaxBytesReader(rec, req.Body, 4)
_, err := pkghttputil.ReadRequestBodyWithPrealloc(req)
require.Error(t, err)
var maxErr *http.MaxBytesError
require.ErrorAs(t, err, &maxErr)
}
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
@@ -141,6 +169,387 @@ func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *test
assert.Equal(t, "already written", w.Body.String())
}
func TestShouldLogOpenAIForwardFailureAsWarn(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("fallback_written_should_not_downgrade", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, true))
})
t.Run("context_nil_should_not_downgrade", func(t *testing.T) {
require.False(t, shouldLogOpenAIForwardFailureAsWarn(nil, false))
})
t.Run("response_not_written_should_not_downgrade", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, false))
})
t.Run("response_already_written_should_downgrade", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.String(http.StatusForbidden, "already written")
require.True(t, shouldLogOpenAIForwardFailureAsWarn(c, false))
})
}
func TestOpenAIRecoverResponsesPanic_WritesFallbackResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
h := &OpenAIGatewayHandler{}
streamStarted := false
require.NotPanics(t, func() {
func() {
defer h.recoverResponsesPanic(c, &streamStarted)
panic("test panic")
}()
})
require.Equal(t, http.StatusBadGateway, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errorObj["type"])
assert.Equal(t, "Upstream request failed", errorObj["message"])
}
func TestOpenAIRecoverResponsesPanic_NoPanicNoWrite(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
h := &OpenAIGatewayHandler{}
streamStarted := false
require.NotPanics(t, func() {
func() {
defer h.recoverResponsesPanic(c, &streamStarted)
}()
})
require.False(t, c.Writer.Written())
assert.Equal(t, "", w.Body.String())
}
func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
c.String(http.StatusTeapot, "already written")
h := &OpenAIGatewayHandler{}
streamStarted := false
require.NotPanics(t, func() {
func() {
defer h.recoverResponsesPanic(c, &streamStarted)
panic("test panic")
}()
})
require.Equal(t, http.StatusTeapot, w.Code)
assert.Equal(t, "already written", w.Body.String())
}
func TestOpenAIMissingResponsesDependencies(t *testing.T) {
t.Run("nil_handler", func(t *testing.T) {
var h *OpenAIGatewayHandler
require.Equal(t, []string{"handler"}, h.missingResponsesDependencies())
})
t.Run("all_dependencies_missing", func(t *testing.T) {
h := &OpenAIGatewayHandler{}
require.Equal(t,
[]string{"gatewayService", "billingCacheService", "apiKeyService", "concurrencyHelper"},
h.missingResponsesDependencies(),
)
})
t.Run("all_dependencies_present", func(t *testing.T) {
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: &service.BillingCacheService{},
apiKeyService: &service.APIKeyService{},
concurrencyHelper: &ConcurrencyHelper{
concurrencyService: &service.ConcurrencyService{},
},
}
require.Empty(t, h.missingResponsesDependencies())
})
}
func TestOpenAIEnsureResponsesDependencies(t *testing.T) {
t.Run("missing_dependencies_returns_503", func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
h := &OpenAIGatewayHandler{}
ok := h.ensureResponsesDependencies(c, nil)
require.False(t, ok)
require.Equal(t, http.StatusServiceUnavailable, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
errorObj, exists := parsed["error"].(map[string]any)
require.True(t, exists)
assert.Equal(t, "api_error", errorObj["type"])
assert.Equal(t, "Service temporarily unavailable", errorObj["message"])
})
t.Run("already_written_response_not_overridden", func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
c.String(http.StatusTeapot, "already written")
h := &OpenAIGatewayHandler{}
ok := h.ensureResponsesDependencies(c, nil)
require.False(t, ok)
require.Equal(t, http.StatusTeapot, w.Code)
assert.Equal(t, "already written", w.Body.String())
})
t.Run("dependencies_ready_returns_true_and_no_write", func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: &service.BillingCacheService{},
apiKeyService: &service.APIKeyService{},
concurrencyHelper: &ConcurrencyHelper{
concurrencyService: &service.ConcurrencyService{},
},
}
ok := h.ensureResponsesDependencies(c, nil)
require.True(t, ok)
require.False(t, c.Writer.Written())
assert.Equal(t, "", w.Body.String())
})
}
func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"gpt-5","stream":false}`))
c.Request.Header.Set("Content-Type", "application/json")
groupID := int64(2)
c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
ID: 10,
GroupID: &groupID,
})
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
UserID: 1,
Concurrency: 1,
})
// 故意使用未初始化依赖,验证快速失败而不是崩溃。
h := &OpenAIGatewayHandler{}
require.NotPanics(t, func() {
h.Responses(c)
})
require.Equal(t, http.StatusServiceUnavailable, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "api_error", errorObj["type"])
assert.Equal(t, "Service temporarily unavailable", errorObj["message"])
}
func TestOpenAIResponses_SetsClientTransportHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(`{"model":"gpt-5"}`))
c.Request.Header.Set("Content-Type", "application/json")
h := &OpenAIGatewayHandler{}
h.Responses(c)
require.Equal(t, http.StatusUnauthorized, w.Code)
require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c))
}
func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
`{"model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456","input":[{"type":"input_text","text":"hello"}]}`,
))
c.Request.Header.Set("Content-Type", "application/json")
groupID := int64(2)
c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
ID: 101,
GroupID: &groupID,
User: &service.User{ID: 1},
})
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
UserID: 1,
Concurrency: 1,
})
h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
h.Responses(c)
require.Equal(t, http.StatusBadRequest, w.Code)
require.Contains(t, w.Body.String(), "previous_response_id must be a response.id")
}
func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil)
c.Request.Header.Set("Upgrade", "websocket")
c.Request.Header.Set("Connection", "Upgrade")
h := &OpenAIGatewayHandler{}
h.ResponsesWebSocket(c)
require.Equal(t, http.StatusUnauthorized, w.Code)
require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c))
}
func TestOpenAIResponsesWebSocket_InvalidUpgradeDoesNotSetTransport(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil)
h := &OpenAIGatewayHandler{}
h.ResponsesWebSocket(c)
require.Equal(t, http.StatusUpgradeRequired, w.Code)
require.Equal(t, service.OpenAIClientTransportUnknown, service.GetOpenAIClientTransport(c))
}
func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
gin.SetMode(gin.TestMode)
h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`,
))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, _, err = clientConn.Read(readCtx)
cancelRead()
require.Error(t, err)
var closeErr coderws.CloseError
require.ErrorAs(t, err, &closeErr)
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id")
}
func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailure(t *testing.T) {
gin.SetMode(gin.TestMode)
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return false, errors.New("user slot unavailable")
},
}
h := newOpenAIHandlerForPreviousResponseIDValidation(t, cache)
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_123"}`,
))
cancelWrite()
require.NoError(t, err)
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
_, _, err = clientConn.Read(readCtx)
cancelRead()
require.Error(t, err)
var closeErr coderws.CloseError
require.ErrorAs(t, err, &closeErr)
require.Equal(t, coderws.StatusInternalError, closeErr.Code)
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
}
func TestSetOpenAIClientTransportHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
setOpenAIClientTransportHTTP(c)
require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c))
}
func TestSetOpenAIClientTransportWS(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
setOpenAIClientTransportWS(c)
require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c))
}
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
tests := []struct {
@@ -228,3 +637,41 @@ func TestOpenAIHandler_InstructionsInjection(t *testing.T) {
require.NoError(t, setErr)
require.True(t, gjson.ValidBytes(result))
}
func newOpenAIHandlerForPreviousResponseIDValidation(t *testing.T, cache *concurrencyCacheMock) *OpenAIGatewayHandler {
t.Helper()
if cache == nil {
cache = &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
}
}
return &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: &service.BillingCacheService{},
apiKeyService: &service.APIKeyService{},
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
}
}
func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject middleware.AuthSubject) *httptest.Server {
t.Helper()
groupID := int64(2)
apiKey := &service.APIKey{
ID: 101,
GroupID: &groupID,
User: &service.User{ID: subject.UserID},
}
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
c.Set(string(middleware.ContextKeyUser), subject)
c.Next()
})
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
return httptest.NewServer(router)
}

View File

@@ -311,6 +311,35 @@ type opsCaptureWriter struct {
buf bytes.Buffer
}
const opsCaptureWriterLimit = 64 * 1024
var opsCaptureWriterPool = sync.Pool{
New: func() any {
return &opsCaptureWriter{limit: opsCaptureWriterLimit}
},
}
func acquireOpsCaptureWriter(rw gin.ResponseWriter) *opsCaptureWriter {
w, ok := opsCaptureWriterPool.Get().(*opsCaptureWriter)
if !ok || w == nil {
w = &opsCaptureWriter{}
}
w.ResponseWriter = rw
w.limit = opsCaptureWriterLimit
w.buf.Reset()
return w
}
func releaseOpsCaptureWriter(w *opsCaptureWriter) {
if w == nil {
return
}
w.ResponseWriter = nil
w.limit = opsCaptureWriterLimit
w.buf.Reset()
opsCaptureWriterPool.Put(w)
}
func (w *opsCaptureWriter) Write(b []byte) (int, error) {
if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit {
remaining := w.limit - w.buf.Len()
@@ -342,7 +371,16 @@ func (w *opsCaptureWriter) WriteString(s string) (int, error) {
// - Streaming errors after the response has started (SSE) may still need explicit logging.
func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
return func(c *gin.Context) {
w := &opsCaptureWriter{ResponseWriter: c.Writer, limit: 64 * 1024}
originalWriter := c.Writer
w := acquireOpsCaptureWriter(originalWriter)
defer func() {
// Restore the original writer before returning so outer middlewares
// don't observe a pooled wrapper that has been released.
if c.Writer == w {
c.Writer = originalWriter
}
releaseOpsCaptureWriter(w)
}()
c.Writer = w
c.Next()
@@ -624,8 +662,10 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
requestID = c.Writer.Header().Get("x-request-id")
}
phase := classifyOpsPhase(parsed.ErrorType, parsed.Message, parsed.Code)
isBusinessLimited := classifyOpsIsBusinessLimited(parsed.ErrorType, phase, parsed.Code, status, parsed.Message)
normalizedType := normalizeOpsErrorType(parsed.ErrorType, parsed.Code)
phase := classifyOpsPhase(normalizedType, parsed.Message, parsed.Code)
isBusinessLimited := classifyOpsIsBusinessLimited(normalizedType, phase, parsed.Code, status, parsed.Message)
errorOwner := classifyOpsErrorOwner(phase, parsed.Message)
errorSource := classifyOpsErrorSource(phase, parsed.Message)
@@ -647,8 +687,8 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
UserAgent: c.GetHeader("User-Agent"),
ErrorPhase: phase,
ErrorType: normalizeOpsErrorType(parsed.ErrorType, parsed.Code),
Severity: classifyOpsSeverity(parsed.ErrorType, status),
ErrorType: normalizedType,
Severity: classifyOpsSeverity(normalizedType, status),
StatusCode: status,
IsBusinessLimited: isBusinessLimited,
IsCountTokens: isCountTokensRequest(c),
@@ -660,7 +700,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
ErrorSource: errorSource,
ErrorOwner: errorOwner,
IsRetryable: classifyOpsIsRetryable(parsed.ErrorType, status),
IsRetryable: classifyOpsIsRetryable(normalizedType, status),
RetryCount: 0,
CreatedAt: time.Now(),
}
@@ -901,8 +941,29 @@ func guessPlatformFromPath(path string) string {
}
}
// isKnownOpsErrorType returns true if t is a recognized error type used by the
// ops classification pipeline. Upstream proxies sometimes return garbage values
// (e.g. the Go-serialized literal "<nil>") which would pollute phase/severity
// classification if accepted blindly.
func isKnownOpsErrorType(t string) bool {
switch t {
case "invalid_request_error",
"authentication_error",
"rate_limit_error",
"billing_error",
"subscription_error",
"upstream_error",
"overloaded_error",
"api_error",
"not_found_error",
"forbidden_error":
return true
}
return false
}
func normalizeOpsErrorType(errType string, code string) string {
if errType != "" {
if errType != "" && isKnownOpsErrorType(errType) {
return errType
}
switch strings.TrimSpace(code) {

View File

@@ -6,6 +6,7 @@ import (
"sync"
"testing"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
@@ -173,3 +174,103 @@ func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) {
enqueueOpsErrorLog(ops, entry)
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
}
func TestOpsCaptureWriterPool_ResetOnRelease(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/test", nil)
writer := acquireOpsCaptureWriter(c.Writer)
require.NotNil(t, writer)
_, err := writer.buf.WriteString("temp-error-body")
require.NoError(t, err)
releaseOpsCaptureWriter(writer)
reused := acquireOpsCaptureWriter(c.Writer)
defer releaseOpsCaptureWriter(reused)
require.Zero(t, reused.buf.Len(), "writer should be reset before reuse")
}
func TestOpsErrorLoggerMiddleware_DoesNotBreakOuterMiddlewares(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(middleware2.Recovery())
r.Use(middleware2.RequestLogger())
r.Use(middleware2.Logger())
r.GET("/v1/messages", OpsErrorLoggerMiddleware(nil), func(c *gin.Context) {
c.Status(http.StatusNoContent)
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/v1/messages", nil)
require.NotPanics(t, func() {
r.ServeHTTP(rec, req)
})
require.Equal(t, http.StatusNoContent, rec.Code)
}
func TestIsKnownOpsErrorType(t *testing.T) {
known := []string{
"invalid_request_error",
"authentication_error",
"rate_limit_error",
"billing_error",
"subscription_error",
"upstream_error",
"overloaded_error",
"api_error",
"not_found_error",
"forbidden_error",
}
for _, k := range known {
require.True(t, isKnownOpsErrorType(k), "expected known: %s", k)
}
unknown := []string{"<nil>", "null", "", "random_error", "some_new_type", "<nil>\u003e"}
for _, u := range unknown {
require.False(t, isKnownOpsErrorType(u), "expected unknown: %q", u)
}
}
func TestNormalizeOpsErrorType(t *testing.T) {
tests := []struct {
name string
errType string
code string
want string
}{
// Known types pass through.
{"known invalid_request_error", "invalid_request_error", "", "invalid_request_error"},
{"known rate_limit_error", "rate_limit_error", "", "rate_limit_error"},
{"known upstream_error", "upstream_error", "", "upstream_error"},
// Unknown/garbage types are rejected and fall through to code-based or default.
{"nil literal from upstream", "<nil>", "", "api_error"},
{"null string", "null", "", "api_error"},
{"random string", "something_weird", "", "api_error"},
// Unknown type but known code still maps correctly.
{"nil with INSUFFICIENT_BALANCE code", "<nil>", "INSUFFICIENT_BALANCE", "billing_error"},
{"nil with USAGE_LIMIT_EXCEEDED code", "<nil>", "USAGE_LIMIT_EXCEEDED", "subscription_error"},
// Empty type falls through to code-based mapping.
{"empty type with balance code", "", "INSUFFICIENT_BALANCE", "billing_error"},
{"empty type with subscription code", "", "SUBSCRIPTION_NOT_FOUND", "subscription_error"},
{"empty type no code", "", "", "api_error"},
// Known type overrides conflicting code-based mapping.
{"known type overrides conflicting code", "rate_limit_error", "INSUFFICIENT_BALANCE", "rate_limit_error"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeOpsErrorType(tt.errType, tt.code)
require.Equal(t, tt.want, got)
})
}
}

View File

@@ -32,25 +32,28 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
}
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version,
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
SoraClientEnabled: settings.SoraClientEnabled,
Version: h.version,
})
}

View File

@@ -0,0 +1,979 @@
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
// 上游模型缓存 TTL
modelCacheTTL = 1 * time.Hour // 上游获取成功
modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地)
)
// SoraClientHandler 处理 Sora 客户端 API 请求。
type SoraClientHandler struct {
genService *service.SoraGenerationService
quotaService *service.SoraQuotaService
s3Storage *service.SoraS3Storage
soraGatewayService *service.SoraGatewayService
gatewayService *service.GatewayService
mediaStorage *service.SoraMediaStorage
apiKeyService *service.APIKeyService
// 上游模型缓存
modelCacheMu sync.RWMutex
cachedFamilies []service.SoraModelFamily
modelCacheTime time.Time
modelCacheUpstream bool // 是否来自上游(决定 TTL
}
// NewSoraClientHandler 创建 Sora 客户端 Handler。
func NewSoraClientHandler(
genService *service.SoraGenerationService,
quotaService *service.SoraQuotaService,
s3Storage *service.SoraS3Storage,
soraGatewayService *service.SoraGatewayService,
gatewayService *service.GatewayService,
mediaStorage *service.SoraMediaStorage,
apiKeyService *service.APIKeyService,
) *SoraClientHandler {
return &SoraClientHandler{
genService: genService,
quotaService: quotaService,
s3Storage: s3Storage,
soraGatewayService: soraGatewayService,
gatewayService: gatewayService,
mediaStorage: mediaStorage,
apiKeyService: apiKeyService,
}
}
// GenerateRequest 生成请求。
type GenerateRequest struct {
Model string `json:"model" binding:"required"`
Prompt string `json:"prompt" binding:"required"`
MediaType string `json:"media_type"` // video / image默认 video
VideoCount int `json:"video_count,omitempty"` // 视频数量1-3
ImageInput string `json:"image_input,omitempty"` // 参考图base64 或 URL
APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID
}
// Generate 异步生成 — 创建 pending 记录后立即返回。
// POST /api/v1/sora/generate
func (h *SoraClientHandler) Generate(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
var req GenerateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error())
return
}
if req.MediaType == "" {
req.MediaType = "video"
}
req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount)
// 并发数检查(最多 3 个)
activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if activeCount >= 3 {
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
return
}
// 配额检查(粗略检查,实际文件大小在上传后才知道)
if h.quotaService != nil {
if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil {
var quotaErr *service.QuotaExceededError
if errors.As(err, &quotaErr) {
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
return
}
response.Error(c, http.StatusForbidden, err.Error())
return
}
}
// 获取 API Key ID 和 Group ID
var apiKeyID *int64
var groupID *int64
if req.APIKeyID != nil && h.apiKeyService != nil {
// 前端传递了 api_key_id需要校验
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID)
if err != nil {
response.Error(c, http.StatusBadRequest, "API Key 不存在")
return
}
if apiKey.UserID != userID {
response.Error(c, http.StatusForbidden, "API Key 不属于当前用户")
return
}
if apiKey.Status != service.StatusAPIKeyActive {
response.Error(c, http.StatusForbidden, "API Key 不可用")
return
}
apiKeyID = &apiKey.ID
groupID = apiKey.GroupID
} else if id, ok := c.Get("api_key_id"); ok {
// 兼容 API Key 认证路径(/sora/v1/ 网关路由)
if v, ok := id.(int64); ok {
apiKeyID = &v
}
}
gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType)
if err != nil {
if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) {
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
return
}
response.ErrorFrom(c, err)
return
}
// 启动后台异步生成 goroutine
go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount)
response.Success(c, gin.H{
"generation_id": gen.ID,
"status": gen.Status,
})
}
// processGeneration 后台异步执行 Sora 生成任务。
// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储S3 → 本地 → 上游)→ 更新记录。
func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// 标记为生成中
if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil {
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID)
return
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err)
return
}
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d",
genID,
userID,
groupIDForLog(groupID),
model,
mediaType,
videoCount,
strings.TrimSpace(imageInput) != "",
len(strings.TrimSpace(prompt)),
)
// 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
if groupID == nil {
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
}
if h.gatewayService == nil {
_ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化")
return
}
// 选择 Sora 账号
account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model)
if err != nil {
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v",
genID,
userID,
groupIDForLog(groupID),
model,
err,
)
_ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error())
return
}
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s",
genID,
userID,
groupIDForLog(groupID),
model,
account.ID,
account.Name,
account.Platform,
account.Type,
)
// 构建 chat completions 请求体(非流式)
body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount))
if h.soraGatewayService == nil {
_ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化")
return
}
// 创建 mock gin 上下文用于 Forward捕获响应以提取媒体 URL
recorder := httptest.NewRecorder()
mockGinCtx, _ := gin.CreateTestContext(recorder)
mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil)
// 调用 Forward非流式
result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false)
if err != nil {
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v",
genID,
account.ID,
model,
recorder.Code,
trimForLog(recorder.Body.String(), 400),
err,
)
// 检查是否已取消
gen, _ := h.genService.GetByID(ctx, genID, userID)
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
return
}
_ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error())
return
}
// 提取媒体 URL优先从 ForwardResult其次从响应体解析
mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder)
if mediaURL == "" {
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s",
genID,
account.ID,
model,
recorder.Code,
trimForLog(recorder.Body.String(), 400),
)
_ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL")
return
}
// 检查任务是否已被取消
gen, _ := h.genService.GetByID(ctx, genID, userID)
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID)
return
}
// 三层降级存储S3 → 本地 → 上游临时 URL
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs)
usageAdded := false
if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil {
if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil {
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
var quotaErr *service.QuotaExceededError
if errors.As(err, &quotaErr) {
_ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间")
return
}
_ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error())
return
}
usageAdded = true
}
// 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
gen, _ = h.genService.GetByID(ctx, genID, userID)
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID)
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
if usageAdded && h.quotaService != nil {
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
}
return
}
// 标记完成
if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil {
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
if usageAdded && h.quotaService != nil {
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
}
return
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err)
return
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize)
}
// storeMediaWithDegradation 实现三层降级存储链S3 → 本地 → 上游。
func (h *SoraClientHandler) storeMediaWithDegradation(
ctx context.Context, userID int64, mediaType string,
mediaURL string, mediaURLs []string,
) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) {
urls := mediaURLs
if len(urls) == 0 {
urls = []string{mediaURL}
}
// 第一层:尝试 S3
if h.s3Storage != nil && h.s3Storage.Enabled(ctx) {
keys := make([]string, 0, len(urls))
var totalSize int64
allOK := true
for _, u := range urls {
key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u)
if err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err)
allOK = false
// 清理已上传的文件
if len(keys) > 0 {
_ = h.s3Storage.DeleteObjects(ctx, keys)
}
break
}
keys = append(keys, key)
totalSize += size
}
if allOK && len(keys) > 0 {
accessURLs := make([]string, 0, len(keys))
for _, key := range keys {
accessURL, err := h.s3Storage.GetAccessURL(ctx, key)
if err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err)
_ = h.s3Storage.DeleteObjects(ctx, keys)
allOK = false
break
}
accessURLs = append(accessURLs, accessURL)
}
if allOK && len(accessURLs) > 0 {
return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize
}
}
}
// 第二层:尝试本地存储
if h.mediaStorage != nil && h.mediaStorage.Enabled() {
storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls)
if err == nil && len(storedPaths) > 0 {
firstPath := storedPaths[0]
totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths)
if sizeErr != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr)
}
return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err)
}
// 第三层:保留上游临时 URL
return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0
}
// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte {
body := map[string]any{
"model": model,
"messages": []map[string]string{
{"role": "user", "content": prompt},
},
"stream": false,
}
if imageInput != "" {
body["image_input"] = imageInput
}
if videoCount > 1 {
body["video_count"] = videoCount
}
b, _ := json.Marshal(body)
return b
}
func normalizeVideoCount(mediaType string, videoCount int) int {
if mediaType != "video" {
return 1
}
if videoCount <= 0 {
return 1
}
if videoCount > 3 {
return 3
}
return videoCount
}
// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
// OAuth 路径ForwardResult.MediaURL 已填充。
// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) {
// 优先从 ForwardResult 获取OAuth 路径)
if result != nil && result.MediaURL != "" {
// 尝试从响应体获取完整 URL 列表
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
return urls[0], urls
}
return result.MediaURL, []string{result.MediaURL}
}
// 从响应体解析APIKey 路径)
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
return urls[0], urls
}
return "", nil
}
// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
func parseMediaURLsFromBody(body []byte) []string {
if len(body) == 0 {
return nil
}
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return nil
}
// 优先 media_urls多图数组
if rawURLs, ok := resp["media_urls"]; ok {
if arr, ok := rawURLs.([]any); ok && len(arr) > 0 {
urls := make([]string, 0, len(arr))
for _, item := range arr {
if s, ok := item.(string); ok && s != "" {
urls = append(urls, s)
}
}
if len(urls) > 0 {
return urls
}
}
}
// 回退到 media_url单个 URL
if url, ok := resp["media_url"].(string); ok && url != "" {
return []string{url}
}
return nil
}
// ListGenerations 查询生成记录列表。
// GET /api/v1/sora/generations
func (h *SoraClientHandler) ListGenerations(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
params := service.SoraGenerationListParams{
UserID: userID,
Status: c.Query("status"),
StorageType: c.Query("storage_type"),
MediaType: c.Query("media_type"),
Page: page,
PageSize: pageSize,
}
gens, total, err := h.genService.List(c.Request.Context(), params)
if err != nil {
response.ErrorFrom(c, err)
return
}
// 为 S3 记录动态生成预签名 URL
for _, gen := range gens {
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
}
response.Success(c, gin.H{
"data": gens,
"total": total,
"page": page,
})
}
// GetGeneration 查询生成记录详情。
// GET /api/v1/sora/generations/:id
func (h *SoraClientHandler) GetGeneration(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
response.Success(c, gen)
}
// DeleteGeneration 删除生成记录。
// DELETE /api/v1/sora/generations/:id
func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
// 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil {
paths := gen.MediaURLs
if len(paths) == 0 && gen.MediaURL != "" {
paths = []string{gen.MediaURL}
}
if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err)
}
}
if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
response.Success(c, gin.H{"message": "已删除"})
}
// GetQuota 查询用户存储配额。
// GET /api/v1/sora/quota
func (h *SoraClientHandler) GetQuota(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
if h.quotaService == nil {
response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"})
return
}
quota, err := h.quotaService.GetQuota(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, quota)
}
// CancelGeneration 取消生成任务。
// POST /api/v1/sora/generations/:id/cancel
func (h *SoraClientHandler) CancelGeneration(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
// 权限校验
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
_ = gen
if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil {
if errors.Is(err, service.ErrSoraGenerationNotActive) {
response.Error(c, http.StatusConflict, "任务已结束,无法取消")
return
}
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, gin.H{"message": "已取消"})
}
// SaveToStorage 手动保存 upstream 记录到 S3。
// POST /api/v1/sora/generations/:id/save
func (h *SoraClientHandler) SaveToStorage(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
if gen.StorageType != service.SoraStorageTypeUpstream {
response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存")
return
}
if gen.MediaURL == "" {
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
return
}
if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) {
response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员")
return
}
sourceURLs := gen.MediaURLs
if len(sourceURLs) == 0 && gen.MediaURL != "" {
sourceURLs = []string{gen.MediaURL}
}
if len(sourceURLs) == 0 {
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
return
}
uploadedKeys := make([]string, 0, len(sourceURLs))
accessURLs := make([]string, 0, len(sourceURLs))
var totalSize int64
for _, sourceURL := range sourceURLs {
objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL)
if uploadErr != nil {
if len(uploadedKeys) > 0 {
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
}
var upstreamErr *service.UpstreamDownloadError
if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) {
response.Error(c, http.StatusGone, "媒体链接已过期,无法保存")
return
}
response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error())
return
}
accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey)
if err != nil {
uploadedKeys = append(uploadedKeys, objectKey)
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error())
return
}
uploadedKeys = append(uploadedKeys, objectKey)
accessURLs = append(accessURLs, accessURL)
totalSize += fileSize
}
usageAdded := false
if totalSize > 0 && h.quotaService != nil {
if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil {
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
var quotaErr *service.QuotaExceededError
if errors.As(err, &quotaErr) {
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
return
}
response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error())
return
}
usageAdded = true
}
if err := h.genService.UpdateStorageForCompleted(
c.Request.Context(),
id,
accessURLs[0],
accessURLs,
service.SoraStorageTypeS3,
uploadedKeys,
totalSize,
); err != nil {
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
if usageAdded && h.quotaService != nil {
_ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize)
}
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"message": "已保存到 S3",
"object_key": uploadedKeys[0],
"object_keys": uploadedKeys,
})
}
// GetStorageStatus 返回存储状态。
// GET /api/v1/sora/storage-status
func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) {
s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context())
s3Healthy := false
if s3Enabled {
s3Healthy = h.s3Storage.IsHealthy(c.Request.Context())
}
localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled()
response.Success(c, gin.H{
"s3_enabled": s3Enabled,
"s3_healthy": s3Healthy,
"local_enabled": localEnabled,
})
}
func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) {
switch storageType {
case service.SoraStorageTypeS3:
if h.s3Storage != nil && len(s3Keys) > 0 {
if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err)
}
}
case service.SoraStorageTypeLocal:
if h.mediaStorage != nil && len(localPaths) > 0 {
if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err)
}
}
}
}
// getUserIDFromContext 从 gin 上下文中提取用户 ID。
func getUserIDFromContext(c *gin.Context) int64 {
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
return subject.UserID
}
if id, ok := c.Get("user_id"); ok {
switch v := id.(type) {
case int64:
return v
case float64:
return int64(v)
case string:
n, _ := strconv.ParseInt(v, 10, 64)
return n
}
}
// 尝试从 JWT claims 获取
if id, ok := c.Get("userID"); ok {
if v, ok := id.(int64); ok {
return v
}
}
return 0
}
func groupIDForLog(groupID *int64) int64 {
if groupID == nil {
return 0
}
return *groupID
}
func trimForLog(raw string, maxLen int) string {
trimmed := strings.TrimSpace(raw)
if maxLen <= 0 || len(trimmed) <= maxLen {
return trimmed
}
return trimmed[:maxLen] + "...(truncated)"
}
// GetModels 获取可用 Sora 模型家族列表。
// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
// GET /api/v1/sora/models
func (h *SoraClientHandler) GetModels(c *gin.Context) {
families := h.getModelFamilies(c.Request.Context())
response.Success(c, families)
}
// getModelFamilies 获取模型家族列表(带缓存)。
func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily {
// 读锁检查缓存
h.modelCacheMu.RLock()
ttl := modelCacheTTL
if !h.modelCacheUpstream {
ttl = modelCacheFailedTTL
}
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
families := h.cachedFamilies
h.modelCacheMu.RUnlock()
return families
}
h.modelCacheMu.RUnlock()
// 写锁更新缓存
h.modelCacheMu.Lock()
defer h.modelCacheMu.Unlock()
// double-check
ttl = modelCacheTTL
if !h.modelCacheUpstream {
ttl = modelCacheFailedTTL
}
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
return h.cachedFamilies
}
// 尝试从上游获取
families, err := h.fetchUpstreamModels(ctx)
if err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err)
families = service.BuildSoraModelFamilies()
h.cachedFamilies = families
h.modelCacheTime = time.Now()
h.modelCacheUpstream = false
return families
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families))
h.cachedFamilies = families
h.modelCacheTime = time.Now()
h.modelCacheUpstream = true
return families
}
// fetchUpstreamModels 从上游 Sora API 获取模型列表。
func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) {
if h.gatewayService == nil {
return nil, fmt.Errorf("gatewayService 未初始化")
}
// 设置 ForcePlatform 用于 Sora 账号选择
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
// 选择一个 Sora 账号
account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s")
if err != nil {
return nil, fmt.Errorf("选择 Sora 账号失败: %w", err)
}
// 仅支持 API Key 类型账号
if account.Type != service.AccountTypeAPIKey {
return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type)
}
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return nil, fmt.Errorf("账号缺少 api_key")
}
baseURL := account.GetBaseURL()
if baseURL == "" {
return nil, fmt.Errorf("账号缺少 base_url")
}
// 构建上游模型列表请求
modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models"
reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+apiKey)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求上游失败: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
// 解析 OpenAI 格式的模型列表
var modelsResp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.Unmarshal(body, &modelsResp); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if len(modelsResp.Data) == 0 {
return nil, fmt.Errorf("上游返回空模型列表")
}
// 提取模型 ID
modelIDs := make([]string, 0, len(modelsResp.Data))
for _, m := range modelsResp.Data {
modelIDs = append(modelIDs, m.ID)
}
// 转换为模型家族
families := service.BuildSoraModelFamiliesFromIDs(modelIDs)
if len(families) == 0 {
return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族")
}
return families, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"path"
@@ -17,6 +16,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -107,7 +107,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
zap.Any("group_id", apiKey.GroupID),
)
body, err := io.ReadAll(c.Request.Body)
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
@@ -461,6 +461,14 @@ func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask)
// 回退路径worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
defer func() {
if recovered := recover(); recovered != nil {
logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"),
zap.Any("panic", recovered),
).Error("sora.usage_record_task_panic_recovered")
}
}()
task(ctx)
}

View File

@@ -182,6 +182,12 @@ func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platfo
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
return r.ListSchedulableByPlatforms(ctx, platforms)
}
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return r.ListSchedulableByPlatform(ctx, platform)
}
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
return r.ListSchedulableByPlatforms(ctx, platforms)
}
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
@@ -314,10 +320,13 @@ func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID i
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
@@ -405,7 +414,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
deferredService := service.NewDeferredService(accountRepo, nil, 0)
billingService := service.NewBillingService(cfg, nil)
concurrencyService := service.NewConcurrencyService(testutil.StubConcurrencyCache{})
billingCacheService := service.NewBillingCacheService(nil, nil, nil, cfg)
billingCacheService := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
t.Cleanup(func() {
billingCacheService.Stop()
})
@@ -429,7 +438,8 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
deferredService,
nil,
testutil.StubSessionLimitCache{},
nil,
nil, // rpmCache
nil, // digestStore
)
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}

View File

@@ -2,6 +2,7 @@ package handler
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -65,8 +66,17 @@ func (h *UsageHandler) List(c *gin.Context) {
// Parse additional filters
model := c.Query("model")
var requestType *int16
var stream *bool
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
val, err := strconv.ParseBool(streamStr)
if err != nil {
response.BadRequest(c, "Invalid stream value, use true or false")
@@ -114,6 +124,7 @@ func (h *UsageHandler) List(c *gin.Context) {
UserID: subject.UserID, // Always filter by current user for security
APIKeyID: apiKeyID,
Model: model,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
StartTime: startTime,

View File

@@ -0,0 +1,80 @@
package handler
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type userUsageRepoCapture struct {
service.UsageLogRepository
listFilters usagestats.UsageLogFilters
}
func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
s.listFilters = filters
return []service.UsageLog{}, &pagination.PaginationResult{
Total: 0,
Page: params.Page,
PageSize: params.PageSize,
Pages: 0,
}, nil
}
func newUserUsageRequestTypeTestRouter(repo *userUsageRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
usageSvc := service.NewUsageService(repo, nil, nil, nil)
handler := NewUsageHandler(usageSvc, nil)
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42})
c.Next()
})
router.GET("/usage", handler.List)
return router
}
func TestUserUsageListRequestTypePriority(t *testing.T) {
repo := &userUsageRepoCapture{}
router := newUserUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/usage?request_type=ws_v2&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, int64(42), repo.listFilters.UserID)
require.NotNil(t, repo.listFilters.RequestType)
require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType)
require.Nil(t, repo.listFilters.Stream)
}
func TestUserUsageListInvalidRequestType(t *testing.T) {
repo := &userUsageRepoCapture{}
router := newUserUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/usage?request_type=invalid", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestUserUsageListInvalidStream(t *testing.T) {
repo := &userUsageRepoCapture{}
router := newUserUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/usage?stream=invalid", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}

View File

@@ -61,6 +61,22 @@ func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
})
}
func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
h := &GatewayHandler{}
var called atomic.Bool
require.NotPanics(t, func() {
h.submitUsageRecordTask(func(ctx context.Context) {
panic("usage task panic")
})
})
h.submitUsageRecordTask(func(ctx context.Context) {
called.Store(true)
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
}
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
@@ -98,6 +114,22 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
})
}
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
h := &OpenAIGatewayHandler{}
var called atomic.Bool
require.NotPanics(t, func() {
h.submitUsageRecordTask(func(ctx context.Context) {
panic("usage task panic")
})
})
h.submitUsageRecordTask(func(ctx context.Context) {
called.Store(true)
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
@@ -134,3 +166,19 @@ func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h.submitUsageRecordTask(nil)
})
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
h := &SoraGatewayHandler{}
var called atomic.Bool
require.NotPanics(t, func() {
h.submitUsageRecordTask(func(ctx context.Context) {
panic("usage task panic")
})
})
h.submitUsageRecordTask(func(ctx context.Context) {
called.Store(true)
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
}

View File

@@ -0,0 +1,237 @@
package handler
import (
"context"
"fmt"
"net/http"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// UserMsgQueueHelper 用户消息串行队列 Handler 层辅助
// 复用 ConcurrencyHelper 的退避 + SSE ping 模式
type UserMsgQueueHelper struct {
queueService *service.UserMessageQueueService
pingFormat SSEPingFormat
pingInterval time.Duration
}
// NewUserMsgQueueHelper 创建用户消息串行队列辅助
func NewUserMsgQueueHelper(
queueService *service.UserMessageQueueService,
pingFormat SSEPingFormat,
pingInterval time.Duration,
) *UserMsgQueueHelper {
if pingInterval <= 0 {
pingInterval = defaultPingInterval
}
return &UserMsgQueueHelper{
queueService: queueService,
pingFormat: pingFormat,
pingInterval: pingInterval,
}
}
// AcquireWithWait 等待获取串行锁,流式请求期间发送 SSE ping
// 返回的 releaseFunc 内部使用 sync.Once确保只执行一次释放
func (h *UserMsgQueueHelper) AcquireWithWait(
c *gin.Context,
accountID int64,
baseRPM int,
isStream bool,
streamStarted *bool,
timeout time.Duration,
reqLog *zap.Logger,
) (releaseFunc func(), err error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// 先尝试立即获取
result, err := h.queueService.TryAcquire(ctx, accountID)
if err != nil {
return nil, err // fail-open 已在 service 层处理
}
if result.Acquired {
// 获取成功,执行 RPM 自适应延迟
if err := h.queueService.EnforceDelay(ctx, accountID, baseRPM); err != nil {
if ctx.Err() != nil {
// 延迟期间 context 取消,释放锁
bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second)
_ = h.queueService.Release(bgCtx, accountID, result.RequestID)
bgCancel()
return nil, ctx.Err()
}
}
reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID))
return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil
}
// 需要等待:指数退避轮询
return h.waitForLockWithPing(c, ctx, accountID, baseRPM, isStream, streamStarted, reqLog)
}
// waitForLockWithPing 等待获取锁,流式请求期间发送 SSE ping
func (h *UserMsgQueueHelper) waitForLockWithPing(
c *gin.Context,
ctx context.Context,
accountID int64,
baseRPM int,
isStream bool,
streamStarted *bool,
reqLog *zap.Logger,
) (func(), error) {
needPing := isStream && h.pingFormat != ""
var flusher http.Flusher
if needPing {
var ok bool
flusher, ok = c.Writer.(http.Flusher)
if !ok {
needPing = false
}
}
var pingCh <-chan time.Time
if needPing {
pingTicker := time.NewTicker(h.pingInterval)
defer pingTicker.Stop()
pingCh = pingTicker.C
}
backoff := initialBackoff
timer := time.NewTimer(backoff)
defer timer.Stop()
for {
select {
case <-ctx.Done():
return nil, fmt.Errorf("umq wait timeout for account %d", accountID)
case <-pingCh:
if !*streamStarted {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
*streamStarted = true
}
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
return nil, err
}
flusher.Flush()
case <-timer.C:
result, err := h.queueService.TryAcquire(ctx, accountID)
if err != nil {
return nil, err
}
if result.Acquired {
// 获取成功,执行 RPM 自适应延迟
if delayErr := h.queueService.EnforceDelay(ctx, accountID, baseRPM); delayErr != nil {
if ctx.Err() != nil {
bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second)
_ = h.queueService.Release(bgCtx, accountID, result.RequestID)
bgCancel()
return nil, ctx.Err()
}
}
reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID))
return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil
}
backoff = nextBackoff(backoff)
timer.Reset(backoff)
}
}
}
// makeReleaseFunc 创建锁释放函数(使用 sync.Once 确保只执行一次)
func (h *UserMsgQueueHelper) makeReleaseFunc(accountID int64, requestID string, reqLog *zap.Logger) func() {
var once sync.Once
return func() {
once.Do(func() {
bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer bgCancel()
if err := h.queueService.Release(bgCtx, accountID, requestID); err != nil {
reqLog.Warn("gateway.umq_release_failed",
zap.Int64("account_id", accountID),
zap.Error(err),
)
} else {
reqLog.Debug("gateway.umq_lock_released", zap.Int64("account_id", accountID))
}
})
}
}
// ThrottleWithPing 软性限速模式:施加 RPM 自适应延迟,流式期间发送 SSE ping
// 不获取串行锁,不阻塞并发。返回后即可转发请求。
func (h *UserMsgQueueHelper) ThrottleWithPing(
c *gin.Context,
accountID int64,
baseRPM int,
isStream bool,
streamStarted *bool,
timeout time.Duration,
reqLog *zap.Logger,
) error {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
delay := h.queueService.CalculateRPMAwareDelay(ctx, accountID, baseRPM)
if delay <= 0 {
return nil
}
reqLog.Debug("gateway.umq_throttle_delay",
zap.Int64("account_id", accountID),
zap.Duration("delay", delay),
)
// 延迟期间发送 SSE ping复用 waitForLockWithPing 的 ping 逻辑)
needPing := isStream && h.pingFormat != ""
var flusher http.Flusher
if needPing {
flusher, _ = c.Writer.(http.Flusher)
if flusher == nil {
needPing = false
}
}
var pingCh <-chan time.Time
if needPing {
pingTicker := time.NewTicker(h.pingInterval)
defer pingTicker.Stop()
pingCh = pingTicker.C
}
timer := time.NewTimer(delay)
defer timer.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-pingCh:
// SSE ping 逻辑(与 waitForLockWithPing 一致)
if !*streamStarted {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
*streamStarted = true
}
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
return err
}
flusher.Flush()
case <-timer.C:
return nil
}
}
}

View File

@@ -14,6 +14,7 @@ func ProvideAdminHandlers(
groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler,
announcementHandler *admin.AnnouncementHandler,
dataManagementHandler *admin.DataManagementHandler,
oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler,
@@ -28,6 +29,7 @@ func ProvideAdminHandlers(
usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler,
errorPassthroughHandler *admin.ErrorPassthroughHandler,
apiKeyHandler *admin.AdminAPIKeyHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
@@ -35,6 +37,7 @@ func ProvideAdminHandlers(
Group: groupHandler,
Account: accountHandler,
Announcement: announcementHandler,
DataManagement: dataManagementHandler,
OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler,
@@ -49,6 +52,7 @@ func ProvideAdminHandlers(
Usage: usageHandler,
UserAttribute: userAttributeHandler,
ErrorPassthrough: errorPassthroughHandler,
APIKey: apiKeyHandler,
}
}
@@ -75,6 +79,7 @@ func ProvideHandlers(
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
soraGatewayHandler *SoraGatewayHandler,
soraClientHandler *SoraClientHandler,
settingHandler *SettingHandler,
totpHandler *TotpHandler,
_ *service.IdempotencyCoordinator,
@@ -92,6 +97,7 @@ func ProvideHandlers(
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
SoraGateway: soraGatewayHandler,
SoraClient: soraClientHandler,
Setting: settingHandler,
Totp: totpHandler,
}
@@ -119,6 +125,7 @@ var ProviderSet = wire.NewSet(
admin.NewGroupHandler,
admin.NewAccountHandler,
admin.NewAnnouncementHandler,
admin.NewDataManagementHandler,
admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler,
@@ -133,6 +140,7 @@ var ProviderSet = wire.NewSet(
admin.NewUsageHandler,
admin.NewUserAttributeHandler,
admin.NewErrorPassthroughHandler,
admin.NewAdminAPIKeyHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,

View File

@@ -152,6 +152,7 @@ var claudeModels = []modelDef{
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
{ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
{ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"},
{ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
}
@@ -165,6 +166,8 @@ var geminiModels = []modelDef{
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3.1-pro-low", DisplayName: "Gemini 3.1 Pro Low", CreatedAt: "2026-02-19T00:00:00Z"},
{ID: "gemini-3.1-pro-high", DisplayName: "Gemini 3.1 Pro High", CreatedAt: "2026-02-19T00:00:00Z"},
{ID: "gemini-3.1-flash-image", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: "2026-02-19T00:00:00Z"},
{ID: "gemini-3.1-flash-image-preview", DisplayName: "Gemini 3.1 Flash Image Preview", CreatedAt: "2026-02-19T00:00:00Z"},
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
}

View File

@@ -0,0 +1,26 @@
package antigravity
import "testing"
func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
t.Parallel()
models := DefaultModels()
byID := make(map[string]ClaudeModel, len(models))
for _, m := range models {
byID[m.ID] = m
}
requiredIDs := []string{
"claude-opus-4-6-thinking",
"gemini-3.1-flash-image",
"gemini-3.1-flash-image-preview",
"gemini-3-pro-image", // legacy compatibility
}
for _, id := range requiredIDs {
if _, ok := byID[id]; !ok {
t.Fatalf("expected model %q to be exposed in DefaultModels", id)
}
}
}

View File

@@ -14,6 +14,9 @@ import (
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
)
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求v1internal 端点)
@@ -149,22 +152,26 @@ type Client struct {
httpClient *http.Client
}
func NewClient(proxyURL string) *Client {
func NewClient(proxyURL string) (*Client, error) {
client := &http.Client{
Timeout: 30 * time.Second,
}
if strings.TrimSpace(proxyURL) != "" {
if proxyURLParsed, err := url.Parse(proxyURL); err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURLParsed),
}
_, parsed, err := proxyurl.Parse(proxyURL)
if err != nil {
return nil, err
}
if parsed != nil {
transport := &http.Transport{}
if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
return nil, fmt.Errorf("configure proxy: %w", err)
}
client.Transport = transport
}
return &Client{
httpClient: client,
}
}, nil
}
// isConnectionError 判断是否为连接错误网络超时、DNS 失败、连接拒绝)

View File

@@ -228,8 +228,20 @@ func TestGetTier_两者都为nil(t *testing.T) {
// NewClient
// ---------------------------------------------------------------------------
func mustNewClient(t *testing.T, proxyURL string) *Client {
t.Helper()
client, err := NewClient(proxyURL)
if err != nil {
t.Fatalf("NewClient(%q) failed: %v", proxyURL, err)
}
return client
}
func TestNewClient_无代理(t *testing.T) {
client := NewClient("")
client, err := NewClient("")
if err != nil {
t.Fatalf("NewClient 返回错误: %v", err)
}
if client == nil {
t.Fatal("NewClient 返回 nil")
}
@@ -246,7 +258,10 @@ func TestNewClient_无代理(t *testing.T) {
}
func TestNewClient_有代理(t *testing.T) {
client := NewClient("http://proxy.example.com:8080")
client, err := NewClient("http://proxy.example.com:8080")
if err != nil {
t.Fatalf("NewClient 返回错误: %v", err)
}
if client == nil {
t.Fatal("NewClient 返回 nil")
}
@@ -256,7 +271,10 @@ func TestNewClient_有代理(t *testing.T) {
}
func TestNewClient_空格代理(t *testing.T) {
client := NewClient(" ")
client, err := NewClient(" ")
if err != nil {
t.Fatalf("NewClient 返回错误: %v", err)
}
if client == nil {
t.Fatal("NewClient 返回 nil")
}
@@ -267,15 +285,13 @@ func TestNewClient_空格代理(t *testing.T) {
}
func TestNewClient_无效代理URL(t *testing.T) {
// 无效 URL 时 url.Parse 不一定返回错误Go 的 url.Parse 很宽容),
// 但 ://invalid 会导致解析错误
client := NewClient("://invalid")
if client == nil {
t.Fatal("NewClient 返回 nil")
// 无效 URL 应返回 error
_, err := NewClient("://invalid")
if err == nil {
t.Fatal("无效代理 URL 应返回错误")
}
// 无效 URL 解析失败时Transport 应保持 nil
if client.httpClient.Transport != nil {
t.Error("无效代理 URL 时 Transport 应为 nil")
if !strings.Contains(err.Error(), "invalid proxy URL") {
t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error())
}
}
@@ -499,7 +515,7 @@ func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
client := NewClient("")
client := mustNewClient(t, "")
_, err := client.ExchangeCode(context.Background(), "code", "verifier")
if err == nil {
t.Fatal("缺少 client_secret 时应返回错误")
@@ -602,7 +618,7 @@ func TestClient_RefreshToken_无ClientSecret(t *testing.T) {
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
client := NewClient("")
client := mustNewClient(t, "")
_, err := client.RefreshToken(context.Background(), "refresh-tok")
if err == nil {
t.Fatal("缺少 client_secret 时应返回错误")
@@ -1242,7 +1258,7 @@ func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
client := mustNewClient(t, "")
resp, rawResp, err := client.LoadCodeAssist(context.Background(), "test-token")
if err != nil {
t.Fatalf("LoadCodeAssist 失败: %v", err)
@@ -1277,7 +1293,7 @@ func TestClient_LoadCodeAssist_HTTPError_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
client := mustNewClient(t, "")
_, _, err := client.LoadCodeAssist(context.Background(), "bad-token")
if err == nil {
t.Fatal("服务器返回 403 时应返回错误")
@@ -1300,7 +1316,7 @@ func TestClient_LoadCodeAssist_InvalidJSON_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
client := mustNewClient(t, "")
_, _, err := client.LoadCodeAssist(context.Background(), "token")
if err == nil {
t.Fatal("无效 JSON 响应应返回错误")
@@ -1333,7 +1349,7 @@ func TestClient_LoadCodeAssist_URLFallback_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := NewClient("")
client := mustNewClient(t, "")
resp, _, err := client.LoadCodeAssist(context.Background(), "token")
if err != nil {
t.Fatalf("LoadCodeAssist 应在 fallback 后成功: %v", err)
@@ -1361,7 +1377,7 @@ func TestClient_LoadCodeAssist_AllURLsFail_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := NewClient("")
client := mustNewClient(t, "")
_, _, err := client.LoadCodeAssist(context.Background(), "token")
if err == nil {
t.Fatal("所有 URL 都失败时应返回错误")
@@ -1377,7 +1393,7 @@ func TestClient_LoadCodeAssist_ContextCanceled_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
client := mustNewClient(t, "")
ctx, cancel := context.WithCancel(context.Background())
cancel()
@@ -1441,7 +1457,7 @@ func TestClient_FetchAvailableModels_Success_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
client := mustNewClient(t, "")
resp, rawResp, err := client.FetchAvailableModels(context.Background(), "test-token", "project-abc")
if err != nil {
t.Fatalf("FetchAvailableModels 失败: %v", err)
@@ -1496,7 +1512,7 @@ func TestClient_FetchAvailableModels_HTTPError_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
client := mustNewClient(t, "")
_, _, err := client.FetchAvailableModels(context.Background(), "bad-token", "proj")
if err == nil {
t.Fatal("服务器返回 403 时应返回错误")
@@ -1516,7 +1532,7 @@ func TestClient_FetchAvailableModels_InvalidJSON_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
client := mustNewClient(t, "")
_, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err == nil {
t.Fatal("无效 JSON 响应应返回错误")
@@ -1546,7 +1562,7 @@ func TestClient_FetchAvailableModels_URLFallback_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := NewClient("")
client := mustNewClient(t, "")
resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err != nil {
t.Fatalf("FetchAvailableModels 应在 fallback 后成功: %v", err)
@@ -1574,7 +1590,7 @@ func TestClient_FetchAvailableModels_AllURLsFail_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := NewClient("")
client := mustNewClient(t, "")
_, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err == nil {
t.Fatal("所有 URL 都失败时应返回错误")
@@ -1590,7 +1606,7 @@ func TestClient_FetchAvailableModels_ContextCanceled_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
client := mustNewClient(t, "")
ctx, cancel := context.WithCancel(context.Background())
cancel()
@@ -1610,7 +1626,7 @@ func TestClient_FetchAvailableModels_EmptyModels_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
client := NewClient("")
client := mustNewClient(t, "")
resp, rawResp, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err != nil {
t.Fatalf("FetchAvailableModels 失败: %v", err)
@@ -1646,7 +1662,7 @@ func TestClient_LoadCodeAssist_408Fallback_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := NewClient("")
client := mustNewClient(t, "")
resp, _, err := client.LoadCodeAssist(context.Background(), "token")
if err != nil {
t.Fatalf("LoadCodeAssist 应在 408 fallback 后成功: %v", err)
@@ -1672,7 +1688,7 @@ func TestClient_FetchAvailableModels_404Fallback_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server1.URL, server2.URL})
client := NewClient("")
client := mustNewClient(t, "")
resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err != nil {
t.Fatalf("FetchAvailableModels 应在 404 fallback 后成功: %v", err)

View File

@@ -70,7 +70,7 @@ type GeminiGenerationConfig struct {
ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"`
}
// GeminiImageConfig Gemini 图片生成配置(gemini-3-pro-image 支持)
// GeminiImageConfig Gemini 图片生成配置gemini-3-pro-image / gemini-3.1-flash-image 等图片模型支持)
type GeminiImageConfig struct {
AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4"
ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K"

View File

@@ -612,14 +612,14 @@ func TestBuildAuthorizationURL_参数验证(t *testing.T) {
expectedParams := map[string]string{
"client_id": ClientID,
"redirect_uri": RedirectURI,
"response_type": "code",
"scope": Scopes,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
"access_type": "offline",
"prompt": "consent",
"redirect_uri": RedirectURI,
"response_type": "code",
"scope": Scopes,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
"access_type": "offline",
"prompt": "consent",
"include_granted_scopes": "true",
}

View File

@@ -52,4 +52,7 @@ const (
// PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。
// Service 层仅在分组匹配时复用 PrefetchedStickyAccountID避免分组切换重试误用旧 sticky。
PrefetchedStickyGroupID Key = "ctx_prefetched_sticky_group_id"
// ClaudeCodeVersion stores the extracted Claude Code version from User-Agent (e.g. "2.1.22")
ClaudeCodeVersion Key = "ctx_claude_code_version"
)

View File

@@ -166,3 +166,18 @@ func TestToHTTP(t *testing.T) {
})
}
}
func TestToHTTP_MetadataDeepCopy(t *testing.T) {
md := map[string]string{"k": "v"}
appErr := BadRequest("BAD_REQUEST", "invalid").WithMetadata(md)
code, body := ToHTTP(appErr)
require.Equal(t, http.StatusBadRequest, code)
require.Equal(t, "v", body.Metadata["k"])
md["k"] = "changed"
require.Equal(t, "v", body.Metadata["k"])
appErr.Metadata["k"] = "changed-again"
require.Equal(t, "v", body.Metadata["k"])
}

View File

@@ -16,6 +16,16 @@ func ToHTTP(err error) (statusCode int, body Status) {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
cloned := Clone(appErr)
return int(cloned.Code), cloned.Status
body = Status{
Code: appErr.Code,
Reason: appErr.Reason,
Message: appErr.Message,
}
if appErr.Metadata != nil {
body.Metadata = make(map[string]string, len(appErr.Metadata))
for k, v := range appErr.Metadata {
body.Metadata[k] = v
}
}
return int(appErr.Code), body
}

View File

@@ -18,11 +18,11 @@ package httpclient
import (
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
@@ -32,6 +32,7 @@ const (
defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时)
validatedHostTTL = 30 * time.Second // DNS Rebinding 校验缓存 TTL
)
// Options 定义共享 HTTP 客户端的构建参数
@@ -40,7 +41,6 @@ type Options struct {
Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证(已禁用,不允许设置为 true
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
ValidateResolvedIP bool // 是否校验解析后的 IP防止 DNS Rebinding
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
@@ -53,6 +53,9 @@ type Options struct {
// sharedClients 存储按配置参数缓存的 http.Client 实例
var sharedClients sync.Map
// 允许测试替换校验函数,生产默认指向真实实现。
var validateResolvedIP = urlvalidator.ValidateResolvedIP
// GetClient 返回共享的 HTTP 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险
@@ -84,7 +87,7 @@ func buildClient(opts Options) (*http.Client, error) {
var rt http.RoundTripper = transport
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
rt = &validatedTransport{base: transport}
rt = newValidatedTransport(transport)
}
return &http.Client{
Transport: rt,
@@ -116,15 +119,13 @@ func buildTransport(opts Options) (*http.Transport, error) {
return nil, fmt.Errorf("insecure_skip_verify is not allowed; install a trusted certificate instead")
}
proxyURL := strings.TrimSpace(opts.ProxyURL)
if proxyURL == "" {
return transport, nil
}
parsed, err := url.Parse(proxyURL)
_, parsed, err := proxyurl.Parse(opts.ProxyURL)
if err != nil {
return nil, err
}
if parsed == nil {
return transport, nil
}
if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
return nil, err
@@ -134,12 +135,11 @@ func buildTransport(opts Options) (*http.Transport, error) {
}
func buildClientKey(opts Options) string {
return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d",
return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
opts.ProxyStrict,
opts.ValidateResolvedIP,
opts.AllowPrivateHosts,
opts.MaxIdleConns,
@@ -149,17 +149,56 @@ func buildClientKey(opts Options) string {
}
type validatedTransport struct {
base http.RoundTripper
base http.RoundTripper
validatedHosts sync.Map // map[string]time.Time, value 为过期时间
now func() time.Time
}
func newValidatedTransport(base http.RoundTripper) *validatedTransport {
return &validatedTransport{
base: base,
now: time.Now,
}
}
func (t *validatedTransport) isValidatedHost(host string, now time.Time) bool {
if t == nil {
return false
}
raw, ok := t.validatedHosts.Load(host)
if !ok {
return false
}
expireAt, ok := raw.(time.Time)
if !ok {
t.validatedHosts.Delete(host)
return false
}
if now.Before(expireAt) {
return true
}
t.validatedHosts.Delete(host)
return false
}
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req != nil && req.URL != nil {
host := strings.TrimSpace(req.URL.Hostname())
host := strings.ToLower(strings.TrimSpace(req.URL.Hostname()))
if host != "" {
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
return nil, err
now := time.Now()
if t != nil && t.now != nil {
now = t.now()
}
if !t.isValidatedHost(host, now) {
if err := validateResolvedIP(host); err != nil {
return nil, err
}
t.validatedHosts.Store(host, now.Add(validatedHostTTL))
}
}
}
if t == nil || t.base == nil {
return nil, fmt.Errorf("validated transport base is nil")
}
return t.base.RoundTrip(req)
}

View File

@@ -0,0 +1,115 @@
package httpclient
import (
"errors"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestValidatedTransport_CacheHostValidation(t *testing.T) {
originalValidate := validateResolvedIP
defer func() { validateResolvedIP = originalValidate }()
var validateCalls int32
validateResolvedIP = func(host string) error {
atomic.AddInt32(&validateCalls, 1)
require.Equal(t, "api.openai.com", host)
return nil
}
var baseCalls int32
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
atomic.AddInt32(&baseCalls, 1)
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{}`)),
Header: make(http.Header),
}, nil
})
now := time.Unix(1730000000, 0)
transport := newValidatedTransport(base)
transport.now = func() time.Time { return now }
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&validateCalls))
require.Equal(t, int32(2), atomic.LoadInt32(&baseCalls))
}
func TestValidatedTransport_ExpiredCacheTriggersRevalidation(t *testing.T) {
originalValidate := validateResolvedIP
defer func() { validateResolvedIP = originalValidate }()
var validateCalls int32
validateResolvedIP = func(_ string) error {
atomic.AddInt32(&validateCalls, 1)
return nil
}
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{}`)),
Header: make(http.Header),
}, nil
})
now := time.Unix(1730001000, 0)
transport := newValidatedTransport(base)
transport.now = func() time.Time { return now }
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
now = now.Add(validatedHostTTL + time.Second)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, int32(2), atomic.LoadInt32(&validateCalls))
}
func TestValidatedTransport_ValidationErrorStopsRoundTrip(t *testing.T) {
originalValidate := validateResolvedIP
defer func() { validateResolvedIP = originalValidate }()
expectedErr := errors.New("dns rebinding rejected")
validateResolvedIP = func(_ string) error {
return expectedErr
}
var baseCalls int32
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
atomic.AddInt32(&baseCalls, 1)
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{}`))}, nil
})
transport := newValidatedTransport(base)
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.ErrorIs(t, err, expectedErr)
require.Equal(t, int32(0), atomic.LoadInt32(&baseCalls))
}

View File

@@ -0,0 +1,37 @@
package httputil
import (
"bytes"
"io"
"net/http"
)
const (
requestBodyReadInitCap = 512
requestBodyReadMaxInitCap = 1 << 20
)
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
if req == nil || req.Body == nil {
return nil, nil
}
capHint := requestBodyReadInitCap
if req.ContentLength > 0 {
switch {
case req.ContentLength < int64(requestBodyReadInitCap):
capHint = requestBodyReadInitCap
case req.ContentLength > int64(requestBodyReadMaxInitCap):
capHint = requestBodyReadMaxInitCap
default:
capHint = int(req.ContentLength)
}
}
buf := bytes.NewBuffer(make([]byte, 0, capHint))
if _, err := io.Copy(buf, req.Body); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View File

@@ -67,6 +67,14 @@ func normalizeIP(ip string) string {
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
var privateNets []*net.IPNet
// CompiledIPRules 表示预编译的 IP 匹配规则。
// PatternCount 记录原始规则数量,用于保留“规则存在但全无效”时的行为语义。
type CompiledIPRules struct {
CIDRs []*net.IPNet
IPs []net.IP
PatternCount int
}
func init() {
for _, cidr := range []string{
"10.0.0.0/8",
@@ -84,6 +92,53 @@ func init() {
}
}
// CompileIPRules 将 IP/CIDR 字符串规则预编译为可复用结构。
// 非法规则会被忽略,但 PatternCount 会保留原始规则条数。
func CompileIPRules(patterns []string) *CompiledIPRules {
compiled := &CompiledIPRules{
CIDRs: make([]*net.IPNet, 0, len(patterns)),
IPs: make([]net.IP, 0, len(patterns)),
PatternCount: len(patterns),
}
for _, pattern := range patterns {
normalized := strings.TrimSpace(pattern)
if normalized == "" {
continue
}
if strings.Contains(normalized, "/") {
_, cidr, err := net.ParseCIDR(normalized)
if err != nil || cidr == nil {
continue
}
compiled.CIDRs = append(compiled.CIDRs, cidr)
continue
}
parsedIP := net.ParseIP(normalized)
if parsedIP == nil {
continue
}
compiled.IPs = append(compiled.IPs, parsedIP)
}
return compiled
}
func matchesCompiledRules(parsedIP net.IP, rules *CompiledIPRules) bool {
if parsedIP == nil || rules == nil {
return false
}
for _, cidr := range rules.CIDRs {
if cidr.Contains(parsedIP) {
return true
}
}
for _, ruleIP := range rules.IPs {
if parsedIP.Equal(ruleIP) {
return true
}
}
return false
}
// isPrivateIP 检查 IP 是否为私有地址。
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
@@ -142,19 +197,32 @@ func MatchesAnyPattern(clientIP string, patterns []string) bool {
// 2. 如果白名单不为空IP 必须在白名单中
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) {
return CheckIPRestrictionWithCompiledRules(
clientIP,
CompileIPRules(whitelist),
CompileIPRules(blacklist),
)
}
// CheckIPRestrictionWithCompiledRules 使用预编译规则检查 IP 是否允许访问。
func CheckIPRestrictionWithCompiledRules(clientIP string, whitelist, blacklist *CompiledIPRules) (bool, string) {
// 规范化 IP
clientIP = normalizeIP(clientIP)
if clientIP == "" {
return false, "access denied"
}
parsedIP := net.ParseIP(clientIP)
if parsedIP == nil {
return false, "access denied"
}
// 1. 检查黑名单
if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) {
if blacklist != nil && blacklist.PatternCount > 0 && matchesCompiledRules(parsedIP, blacklist) {
return false, "access denied"
}
// 2. 检查白名单如果设置了白名单IP 必须在其中)
if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) {
if whitelist != nil && whitelist.PatternCount > 0 && !matchesCompiledRules(parsedIP, whitelist) {
return false, "access denied"
}

View File

@@ -73,3 +73,24 @@ func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
require.Equal(t, 200, w.Code)
require.Equal(t, "9.9.9.9", w.Body.String())
}
func TestCheckIPRestrictionWithCompiledRules(t *testing.T) {
whitelist := CompileIPRules([]string{"10.0.0.0/8", "192.168.1.2"})
blacklist := CompileIPRules([]string{"10.1.1.1"})
allowed, reason := CheckIPRestrictionWithCompiledRules("10.2.3.4", whitelist, blacklist)
require.True(t, allowed)
require.Equal(t, "", reason)
allowed, reason = CheckIPRestrictionWithCompiledRules("10.1.1.1", whitelist, blacklist)
require.False(t, allowed)
require.Equal(t, "access denied", reason)
}
func TestCheckIPRestrictionWithCompiledRules_InvalidWhitelistStillDenies(t *testing.T) {
// 与旧实现保持一致:白名单有配置但全无效时,最终应拒绝访问。
invalidWhitelist := CompileIPRules([]string{"not-a-valid-pattern"})
allowed, reason := CheckIPRestrictionWithCompiledRules("8.8.8.8", invalidWhitelist, nil)
require.False(t, allowed)
require.Equal(t, "access denied", reason)
}

View File

@@ -10,6 +10,7 @@ import (
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"go.uber.org/zap"
@@ -42,15 +43,19 @@ type LogEvent struct {
var (
mu sync.RWMutex
global *zap.Logger
sugar *zap.SugaredLogger
global atomic.Pointer[zap.Logger]
sugar atomic.Pointer[zap.SugaredLogger]
atomicLevel zap.AtomicLevel
initOptions InitOptions
currentSink Sink
currentSink atomic.Value // sinkState
stdLogUndo func()
bootstrapOnce sync.Once
)
type sinkState struct {
sink Sink
}
func InitBootstrap() {
bootstrapOnce.Do(func() {
if err := Init(bootstrapOptions()); err != nil {
@@ -72,9 +77,9 @@ func initLocked(options InitOptions) error {
return err
}
prev := global
global = zl
sugar = zl.Sugar()
prev := global.Load()
global.Store(zl)
sugar.Store(zl.Sugar())
atomicLevel = al
initOptions = normalized
@@ -115,24 +120,32 @@ func SetLevel(level string) error {
func CurrentLevel() string {
mu.RLock()
defer mu.RUnlock()
if global == nil {
if global.Load() == nil {
return "info"
}
return atomicLevel.Level().String()
}
func SetSink(sink Sink) {
mu.Lock()
defer mu.Unlock()
currentSink = sink
currentSink.Store(sinkState{sink: sink})
}
func loadSink() Sink {
v := currentSink.Load()
if v == nil {
return nil
}
state, ok := v.(sinkState)
if !ok {
return nil
}
return state.sink
}
// WriteSinkEvent 直接写入日志 sink不经过全局日志级别门控。
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
func WriteSinkEvent(level, component, message string, fields map[string]any) {
mu.RLock()
sink := currentSink
mu.RUnlock()
sink := loadSink()
if sink == nil {
return
}
@@ -168,19 +181,15 @@ func WriteSinkEvent(level, component, message string, fields map[string]any) {
}
func L() *zap.Logger {
mu.RLock()
defer mu.RUnlock()
if global != nil {
return global
if l := global.Load(); l != nil {
return l
}
return zap.NewNop()
}
func S() *zap.SugaredLogger {
mu.RLock()
defer mu.RUnlock()
if sugar != nil {
return sugar
if s := sugar.Load(); s != nil {
return s
}
return zap.NewNop().Sugar()
}
@@ -190,9 +199,7 @@ func With(fields ...zap.Field) *zap.Logger {
}
func Sync() {
mu.RLock()
l := global
mu.RUnlock()
l := global.Load()
if l != nil {
_ = l.Sync()
}
@@ -210,7 +217,11 @@ func bridgeStdLogLocked() {
log.SetFlags(0)
log.SetPrefix("")
log.SetOutput(newStdLogBridge(global.Named("stdlog")))
base := global.Load()
if base == nil {
base = zap.NewNop()
}
log.SetOutput(newStdLogBridge(base.Named("stdlog")))
stdLogUndo = func() {
log.SetOutput(prevWriter)
@@ -220,7 +231,11 @@ func bridgeStdLogLocked() {
}
func bridgeSlogLocked() {
slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog"))))
base := global.Load()
if base == nil {
base = zap.NewNop()
}
slog.SetDefault(slog.New(newSlogZapHandler(base.Named("slog"))))
}
func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) {
@@ -363,9 +378,7 @@ func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore
func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
// Only handle sink forwarding — the inner cores write via their own
// Write methods (added to CheckedEntry by s.core.Check above).
mu.RLock()
sink := currentSink
mu.RUnlock()
sink := loadSink()
if sink == nil {
return nil
}
@@ -454,7 +467,7 @@ func inferStdLogLevel(msg string) Level {
if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") {
return LevelError
}
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
return LevelWarn
}
return LevelInfo
@@ -467,9 +480,7 @@ func LegacyPrintf(component, format string, args ...any) {
return
}
mu.RLock()
initialized := global != nil
mu.RUnlock()
initialized := global.Load() != nil
if !initialized {
// 在日志系统未初始化前,回退到标准库 log避免测试/工具链丢日志。
log.Print(msg)

View File

@@ -48,16 +48,15 @@ func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
return true
})
entry := h.logger.With(fields...)
switch {
case record.Level >= slog.LevelError:
entry.Error(record.Message)
h.logger.Error(record.Message, fields...)
case record.Level >= slog.LevelWarn:
entry.Warn(record.Message)
h.logger.Warn(record.Message, fields...)
case record.Level <= slog.LevelDebug:
entry.Debug(record.Message)
h.logger.Debug(record.Message, fields...)
default:
entry.Info(record.Message)
h.logger.Info(record.Message, fields...)
}
return nil
}

View File

@@ -16,6 +16,7 @@ func TestInferStdLogLevel(t *testing.T) {
{msg: "Warning: queue full", want: LevelWarn},
{msg: "Forward request failed: timeout", want: LevelError},
{msg: "[ERROR] upstream unavailable", want: LevelError},
{msg: "[OpenAI WS Mode] reconnect_retry account_id=22 retry=1 max_retries=5", want: LevelInfo},
{msg: "service started", want: LevelInfo},
{msg: "debug: cache miss", want: LevelDebug},
}

View File

@@ -36,10 +36,18 @@ const (
SessionTTL = 30 * time.Minute
)
const (
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
OAuthPlatformOpenAI = "openai"
// OAuthPlatformSora uses Sora OAuth client.
OAuthPlatformSora = "sora"
)
// OAuthSession stores OAuth flow state for OpenAI
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ClientID string `json:"client_id,omitempty"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
CreatedAt time.Time `json:"created_at"`
@@ -174,13 +182,20 @@ func base64URLEncode(data []byte) string {
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
return BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, OAuthPlatformOpenAI)
}
// BuildAuthorizationURLForPlatform builds authorization URL by platform.
func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platform string) string {
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
clientID, codexFlow := OAuthClientConfigByPlatform(platform)
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("client_id", clientID)
params.Set("redirect_uri", redirectURI)
params.Set("scope", DefaultScopes)
params.Set("state", state)
@@ -188,11 +203,25 @@ func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
params.Set("code_challenge_method", "S256")
// OpenAI specific parameters
params.Set("id_token_add_organizations", "true")
params.Set("codex_cli_simplified_flow", "true")
if codexFlow {
params.Set("codex_cli_simplified_flow", "true")
}
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
// Sora 授权流程复用 Codex CLI 的 client_id支持 localhost redirect_uri
// 但不启用 codex_cli_simplified_flow拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) {
switch strings.ToLower(strings.TrimSpace(platform)) {
case OAuthPlatformSora:
return ClientID, false
default:
return ClientID, true
}
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
@@ -296,9 +325,11 @@ func (r *RefreshTokenRequest) ToFormData() string {
return params.Encode()
}
// ParseIDToken parses the ID Token JWT and extracts claims
// Note: This does NOT verify the signature - it only decodes the payload
// For production, you should verify the token signature using OpenAI's public keys
// ParseIDToken parses the ID Token JWT and extracts claims.
// 注意:当前仅解码 payload 并校验 exp未验证 JWT 签名。
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
//
// https://auth.openai.com/.well-known/jwks.json
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
@@ -329,6 +360,13 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
}
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
const clockSkewTolerance = 120 // 秒
now := time.Now().Unix()
if claims.Exp > 0 && now > claims.Exp+clockSkewTolerance {
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
}
return &claims, nil
}

View File

@@ -1,6 +1,7 @@
package openai
import (
"net/url"
"sync"
"testing"
"time"
@@ -41,3 +42,41 @@ func TestSessionStore_Stop_Concurrent(t *testing.T) {
t.Fatal("stopCh 未关闭")
}
}
func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
authURL := BuildAuthorizationURLForPlatform("state-1", "challenge-1", DefaultRedirectURI, OAuthPlatformOpenAI)
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Parse URL failed: %v", err)
}
q := parsed.Query()
if got := q.Get("client_id"); got != ClientID {
t.Fatalf("client_id mismatch: got=%q want=%q", got, ClientID)
}
if got := q.Get("codex_cli_simplified_flow"); got != "true" {
t.Fatalf("codex flow mismatch: got=%q want=true", got)
}
if got := q.Get("id_token_add_organizations"); got != "true" {
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
}
}
// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id
// 但不启用 codex_cli_simplified_flow。
func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) {
authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora)
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Parse URL failed: %v", err)
}
q := parsed.Query()
if got := q.Get("client_id"); got != ClientID {
t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID)
}
if got := q.Get("codex_cli_simplified_flow"); got != "" {
t.Fatalf("codex flow should be empty for sora, got=%q", got)
}
if got := q.Get("id_token_add_organizations"); got != "true" {
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
}
}

View File

@@ -0,0 +1,66 @@
// Package proxyurl 提供代理 URL 的统一验证fail-fast无效代理不回退直连
//
// 所有需要解析代理 URL 的地方必须通过此包的 Parse 函数。
// 直接使用 url.Parse 处理代理 URL 是被禁止的。
// 这确保了 fail-fast 行为:无效代理配置在创建时立即失败,
// 而不是在运行时静默回退到直连(产生 IP 关联风险)。
package proxyurl
import (
"fmt"
"net/url"
"strings"
)
// allowedSchemes 代理协议白名单
var allowedSchemes = map[string]bool{
"http": true,
"https": true,
"socks5": true,
"socks5h": true,
}
// Parse 解析并验证代理 URL。
//
// 语义:
// - 空字符串 → ("", nil, nil),表示直连
// - 非空且有效 → (trimmed, *url.URL, nil)
// - 非空但无效 → ("", nil, error)fail-fast 不回退
//
// 验证规则:
// - TrimSpace 后为空视为直连
// - url.Parse 失败返回 error不含原始 URL防凭据泄露
// - Host 为空返回 error用 Redacted() 脱敏)
// - Scheme 必须为 http/https/socks5/socks5h
// - socks5:// 自动升级为 socks5h://(确保 DNS 由代理端解析,防止 DNS 泄漏)
func Parse(raw string) (trimmed string, parsed *url.URL, err error) {
trimmed = strings.TrimSpace(raw)
if trimmed == "" {
return "", nil, nil
}
parsed, err = url.Parse(trimmed)
if err != nil {
// 不使用 %w 包装,避免 url.Parse 的底层错误消息泄漏原始 URL可能含凭据
return "", nil, fmt.Errorf("invalid proxy URL: %v", err)
}
if parsed.Host == "" || parsed.Hostname() == "" {
return "", nil, fmt.Errorf("proxy URL missing host: %s", parsed.Redacted())
}
scheme := strings.ToLower(parsed.Scheme)
if !allowedSchemes[scheme] {
return "", nil, fmt.Errorf("unsupported proxy scheme %q (allowed: http, https, socks5, socks5h)", scheme)
}
// 自动升级 socks5 → socks5h确保 DNS 由代理端解析,防止 DNS 泄漏。
// Go 的 golang.org/x/net/proxy 对 socks5:// 默认在客户端本地解析 DNS
// 仅 socks5h:// 才将域名发送给代理端做远程 DNS 解析。
if scheme == "socks5" {
parsed.Scheme = "socks5h"
trimmed = parsed.String()
}
return trimmed, parsed, nil
}

View File

@@ -0,0 +1,215 @@
package proxyurl
import (
"strings"
"testing"
)
func TestParse_空字符串直连(t *testing.T) {
trimmed, parsed, err := Parse("")
if err != nil {
t.Fatalf("空字符串应直连: %v", err)
}
if trimmed != "" {
t.Errorf("trimmed 应为空: got %q", trimmed)
}
if parsed != nil {
t.Errorf("parsed 应为 nil: got %v", parsed)
}
}
func TestParse_空白字符串直连(t *testing.T) {
trimmed, parsed, err := Parse(" ")
if err != nil {
t.Fatalf("空白字符串应直连: %v", err)
}
if trimmed != "" {
t.Errorf("trimmed 应为空: got %q", trimmed)
}
if parsed != nil {
t.Errorf("parsed 应为 nil: got %v", parsed)
}
}
func TestParse_有效HTTP代理(t *testing.T) {
trimmed, parsed, err := Parse("http://proxy.example.com:8080")
if err != nil {
t.Fatalf("有效 HTTP 代理应成功: %v", err)
}
if trimmed != "http://proxy.example.com:8080" {
t.Errorf("trimmed 不匹配: got %q", trimmed)
}
if parsed == nil {
t.Fatal("parsed 不应为 nil")
}
if parsed.Host != "proxy.example.com:8080" {
t.Errorf("Host 不匹配: got %q", parsed.Host)
}
}
func TestParse_有效HTTPS代理(t *testing.T) {
_, parsed, err := Parse("https://proxy.example.com:443")
if err != nil {
t.Fatalf("有效 HTTPS 代理应成功: %v", err)
}
if parsed.Scheme != "https" {
t.Errorf("Scheme 不匹配: got %q", parsed.Scheme)
}
}
func TestParse_有效SOCKS5代理_自动升级为SOCKS5H(t *testing.T) {
trimmed, parsed, err := Parse("socks5://127.0.0.1:1080")
if err != nil {
t.Fatalf("有效 SOCKS5 代理应成功: %v", err)
}
// socks5 自动升级为 socks5h确保 DNS 由代理端解析
if trimmed != "socks5h://127.0.0.1:1080" {
t.Errorf("trimmed 应升级为 socks5h: got %q", trimmed)
}
if parsed.Scheme != "socks5h" {
t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme)
}
}
func TestParse_无效URL(t *testing.T) {
_, _, err := Parse("://invalid")
if err == nil {
t.Fatal("无效 URL 应返回错误")
}
if !strings.Contains(err.Error(), "invalid proxy URL") {
t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error())
}
}
func TestParse_缺少Host(t *testing.T) {
_, _, err := Parse("http://")
if err == nil {
t.Fatal("缺少 host 应返回错误")
}
if !strings.Contains(err.Error(), "missing host") {
t.Errorf("错误信息应包含 'missing host': got %s", err.Error())
}
}
func TestParse_不支持的Scheme(t *testing.T) {
_, _, err := Parse("ftp://proxy.example.com:21")
if err == nil {
t.Fatal("不支持的 scheme 应返回错误")
}
if !strings.Contains(err.Error(), "unsupported proxy scheme") {
t.Errorf("错误信息应包含 'unsupported proxy scheme': got %s", err.Error())
}
}
func TestParse_含密码URL脱敏(t *testing.T) {
// 场景 1: 带密码的 socks5 URL 应成功解析并升级为 socks5h
trimmed, parsed, err := Parse("socks5://user:secret_password@proxy.local:1080")
if err != nil {
t.Fatalf("含密码的有效 URL 应成功: %v", err)
}
if trimmed == "" || parsed == nil {
t.Fatal("应返回非空结果")
}
if parsed.Scheme != "socks5h" {
t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme)
}
if !strings.HasPrefix(trimmed, "socks5h://") {
t.Errorf("trimmed 应以 socks5h:// 开头: got %q", trimmed)
}
if parsed.User == nil {
t.Error("升级后应保留 UserInfo")
}
// 场景 2: 带密码但缺少 host触发 Redacted 脱敏路径)
_, _, err = Parse("http://user:secret_password@:0/")
if err == nil {
t.Fatal("缺少 host 应返回错误")
}
if strings.Contains(err.Error(), "secret_password") {
t.Error("错误信息不应包含明文密码")
}
if !strings.Contains(err.Error(), "missing host") {
t.Errorf("错误信息应包含 'missing host': got %s", err.Error())
}
}
func TestParse_带空白的有效URL(t *testing.T) {
trimmed, parsed, err := Parse(" http://proxy.example.com:8080 ")
if err != nil {
t.Fatalf("带空白的有效 URL 应成功: %v", err)
}
if trimmed != "http://proxy.example.com:8080" {
t.Errorf("trimmed 应去除空白: got %q", trimmed)
}
if parsed == nil {
t.Fatal("parsed 不应为 nil")
}
}
func TestParse_Scheme大小写不敏感(t *testing.T) {
// 大写 SOCKS5 应被接受并升级为 socks5h
trimmed, parsed, err := Parse("SOCKS5://proxy.example.com:1080")
if err != nil {
t.Fatalf("大写 SOCKS5 应被接受: %v", err)
}
if parsed.Scheme != "socks5h" {
t.Errorf("大写 SOCKS5 Scheme 应升级为 socks5h: got %q", parsed.Scheme)
}
if !strings.HasPrefix(trimmed, "socks5h://") {
t.Errorf("大写 SOCKS5 trimmed 应升级为 socks5h://: got %q", trimmed)
}
// 大写 HTTP 应被接受(不变)
_, _, err = Parse("HTTP://proxy.example.com:8080")
if err != nil {
t.Fatalf("大写 HTTP 应被接受: %v", err)
}
}
func TestParse_带认证的有效代理(t *testing.T) {
trimmed, parsed, err := Parse("http://user:pass@proxy.example.com:8080")
if err != nil {
t.Fatalf("带认证的代理 URL 应成功: %v", err)
}
if parsed.User == nil {
t.Error("应保留 UserInfo")
}
if trimmed != "http://user:pass@proxy.example.com:8080" {
t.Errorf("trimmed 不匹配: got %q", trimmed)
}
}
func TestParse_IPv6地址(t *testing.T) {
trimmed, parsed, err := Parse("http://[::1]:8080")
if err != nil {
t.Fatalf("IPv6 代理 URL 应成功: %v", err)
}
if parsed.Hostname() != "::1" {
t.Errorf("Hostname 不匹配: got %q", parsed.Hostname())
}
if trimmed != "http://[::1]:8080" {
t.Errorf("trimmed 不匹配: got %q", trimmed)
}
}
func TestParse_SOCKS5H保持不变(t *testing.T) {
trimmed, parsed, err := Parse("socks5h://proxy.local:1080")
if err != nil {
t.Fatalf("有效 SOCKS5H 代理应成功: %v", err)
}
// socks5h 不需要升级,应保持原样
if trimmed != "socks5h://proxy.local:1080" {
t.Errorf("trimmed 不应变化: got %q", trimmed)
}
if parsed.Scheme != "socks5h" {
t.Errorf("Scheme 应保持 socks5h: got %q", parsed.Scheme)
}
}
func TestParse_无Scheme裸地址(t *testing.T) {
// 无 scheme 的裸地址Go url.Parse 将其视为 pathHost 为空
_, _, err := Parse("proxy.example.com:8080")
if err == nil {
t.Fatal("无 scheme 的裸地址应返回错误")
}
}

View File

@@ -2,7 +2,11 @@
//
// 支持的代理协议:
// - HTTP/HTTPS: 通过 Transport.Proxy 设置
// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS
// - SOCKS5: 通过 Transport.DialContext 设置(客户端本地解析 DNS
// - SOCKS5H: 通过 Transport.DialContext 设置(代理端远程解析 DNS推荐
//
// 注意proxyurl.Parse() 会自动将 socks5:// 升级为 socks5h://
// 确保 DNS 也由代理端解析,防止 DNS 泄漏。
package proxyutil
import (
@@ -20,7 +24,8 @@ import (
//
// 支持的协议:
// - http/https: 设置 transport.Proxy
// - socks5/socks5h: 设置 transport.DialContext由代理服务端解析 DNS
// - socks5: 设置 transport.DialContext客户端本地解析 DNS
// - socks5h: 设置 transport.DialContext代理端远程解析 DNS推荐
//
// 参数:
// - transport: 需要配置的 http.Transport

View File

@@ -29,10 +29,10 @@ func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, P
t.Helper()
// 先用 raw json 解析,因为 Data 是 any 类型
var raw struct {
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))

View File

@@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
"cipher_suites", len(spec.CipherSuites),
"extensions", len(spec.Extensions),
"compression_methods", spec.CompressionMethods,
"tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax),
"tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin))
"tls_vers_max", spec.TLSVersMax,
"tls_vers_min", spec.TLSVersMin)
if d.profile != nil {
slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
@@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_socks5_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version),
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
"version", state.Version,
"cipher_suite", state.CipherSuite,
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
@@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_http_proxy_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version),
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
"version", state.Version,
"cipher_suite", state.CipherSuite,
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
@@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.
// Log successful handshake details
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version),
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
"version", state.Version,
"cipher_suite", state.CipherSuite,
"alpn", state.NegotiatedProtocol)
return tlsConn, nil

View File

@@ -80,12 +80,12 @@ type ModelStat struct {
// GroupStat represents usage statistics for a single group
type GroupStat struct {
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Requests int64 `json:"requests"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Requests int64 `json:"requests"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserUsageTrendPoint represents user usage trend data point
@@ -149,10 +149,13 @@ type UsageLogFilters struct {
AccountID int64
GroupID int64
Model string
RequestType *int16
Stream *bool
BillingType *int8
StartTime *time.Time
EndTime *time.Time
// ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging.
ExactTotal bool
}
// UsageStats represents usage statistics

View File

@@ -50,11 +50,6 @@ type accountRepository struct {
schedulerCache service.SchedulerCache
}
type tempUnschedSnapshot struct {
until *time.Time
reason string
}
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
@@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
accountIDs = append(accountIDs, acc.ID)
}
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
@@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
out.AccountGroups = ags
}
if snap, ok := tempUnschedMap[entAcc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outByID[entAcc.ID] = out
}
@@ -611,6 +597,43 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
}
}
func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) {
if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 {
return
}
uniqueIDs := make([]int64, 0, len(accountIDs))
seen := make(map[int64]struct{}, len(accountIDs))
for _, id := range accountIDs {
if id <= 0 {
continue
}
if _, exists := seen[id]; exists {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
}
if len(uniqueIDs) == 0 {
return
}
accounts, err := r.GetByIDs(ctx, uniqueIDs)
if err != nil {
logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot read failed: count=%d err=%v", len(uniqueIDs), err)
return
}
for _, account := range accounts {
if account == nil {
continue
}
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot write failed: id=%d err=%v", account.ID, err)
}
}
}
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
@@ -806,6 +829,51 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
now := time.Now()
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformEQ(platform),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
dbaccount.Not(dbaccount.HasAccountGroups()),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
}
now := time.Now()
accounts, err := r.client.Account.Query().
Where(
dbaccount.PlatformIn(platforms...),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
dbaccount.Not(dbaccount.HasAccountGroups()),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
Order(dbent.Asc(dbaccount.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
return r.accountsToService(ctx, accounts)
}
func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
if len(platforms) == 0 {
return nil, nil
@@ -1197,9 +1265,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
shouldSync = true
}
if shouldSync {
for _, id := range ids {
r.syncSchedulerAccountSnapshot(ctx, id)
}
r.syncSchedulerAccountSnapshots(ctx, ids)
}
}
return rows, nil
@@ -1291,10 +1357,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if err != nil {
return nil, err
}
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
@@ -1320,10 +1382,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
out.AccountGroups = ags
}
if snap, ok := tempUnschedMap[acc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outAccounts = append(outAccounts, *out)
}
@@ -1348,48 +1406,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account {
)
}
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
out := make(map[int64]tempUnschedSnapshot)
if len(accountIDs) == 0 {
return out, nil
}
rows, err := r.sql.QueryContext(ctx, `
SELECT id, temp_unschedulable_until, temp_unschedulable_reason
FROM accounts
WHERE id = ANY($1)
`, pq.Array(accountIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var id int64
var until sql.NullTime
var reason sql.NullString
if err := rows.Scan(&id, &until, &reason); err != nil {
return nil, err
}
var untilPtr *time.Time
if until.Valid {
tmp := until.Time
untilPtr = &tmp
}
if reason.Valid {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String}
} else {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""}
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
proxyMap := make(map[int64]*service.Proxy)
if len(proxyIDs) == 0 {
@@ -1500,31 +1516,33 @@ func accountEntityToService(m *dbent.Account) *service.Account {
rateMultiplier := m.RateMultiplier
return &service.Account{
ID: m.ID,
Name: m.Name,
Notes: m.Notes,
Platform: m.Platform,
Type: m.Type,
Credentials: copyJSONMap(m.Credentials),
Extra: copyJSONMap(m.Extra),
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
RateMultiplier: &rateMultiplier,
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,
ExpiresAt: m.ExpiresAt,
AutoPauseOnExpired: m.AutoPauseOnExpired,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,
RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil,
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: derefString(m.SessionWindowStatus),
ID: m.ID,
Name: m.Name,
Notes: m.Notes,
Platform: m.Platform,
Type: m.Type,
Credentials: copyJSONMap(m.Credentials),
Extra: copyJSONMap(m.Extra),
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
RateMultiplier: &rateMultiplier,
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,
ExpiresAt: m.ExpiresAt,
AutoPauseOnExpired: m.AutoPauseOnExpired,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,
RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil,
TempUnschedulableUntil: m.TempUnschedulableUntil,
TempUnschedulableReason: derefString(m.TempUnschedulableReason),
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: derefString(m.SessionWindowStatus),
}
}

View File

@@ -500,6 +500,38 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
s.Require().Nil(got.OverloadUntil)
}
func (s *AccountRepoSuite) TestTempUnschedulableFieldsLoadedByGetByIDAndGetByIDs() {
acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-1"})
acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-2"})
until := time.Now().Add(15 * time.Minute).UTC().Truncate(time.Second)
reason := `{"rule":"429","matched_keyword":"too many requests"}`
s.Require().NoError(s.repo.SetTempUnschedulable(s.ctx, acc1.ID, until, reason))
gotByID, err := s.repo.GetByID(s.ctx, acc1.ID)
s.Require().NoError(err)
s.Require().NotNil(gotByID.TempUnschedulableUntil)
s.Require().WithinDuration(until, *gotByID.TempUnschedulableUntil, time.Second)
s.Require().Equal(reason, gotByID.TempUnschedulableReason)
gotByIDs, err := s.repo.GetByIDs(s.ctx, []int64{acc2.ID, acc1.ID})
s.Require().NoError(err)
s.Require().Len(gotByIDs, 2)
s.Require().Equal(acc2.ID, gotByIDs[0].ID)
s.Require().Nil(gotByIDs[0].TempUnschedulableUntil)
s.Require().Equal("", gotByIDs[0].TempUnschedulableReason)
s.Require().Equal(acc1.ID, gotByIDs[1].ID)
s.Require().NotNil(gotByIDs[1].TempUnschedulableUntil)
s.Require().WithinDuration(until, *gotByIDs[1].TempUnschedulableUntil, time.Second)
s.Require().Equal(reason, gotByIDs[1].TempUnschedulableReason)
s.Require().NoError(s.repo.ClearTempUnschedulable(s.ctx, acc1.ID))
cleared, err := s.repo.GetByID(s.ctx, acc1.ID)
s.Require().NoError(err)
s.Require().Nil(cleared.TempUnschedulableUntil)
s.Require().Equal("", cleared.TempUnschedulableReason)
}
// --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() {

View File

@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t
userRepo := newUserRepositoryWithSQL(entClient, tx)
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
apiKeyRepo := NewAPIKeyRepository(entClient)
apiKeyRepo := newAPIKeyRepositoryWithSQL(entClient, tx)
u := &service.User{
Email: uniqueTestValue(t, "cascade-user") + "@example.com",

View File

@@ -2,6 +2,7 @@ package repository
import (
"context"
"database/sql"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -16,10 +17,15 @@ import (
type apiKeyRepository struct {
client *dbent.Client
sql sqlExecutor
}
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
return &apiKeyRepository{client: client}
func NewAPIKeyRepository(client *dbent.Client, sqlDB *sql.DB) service.APIKeyRepository {
return newAPIKeyRepositoryWithSQL(client, sqlDB)
}
func newAPIKeyRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *apiKeyRepository {
return &apiKeyRepository{client: client, sql: sqlq}
}
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
@@ -37,7 +43,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
SetNillableLastUsedAt(key.LastUsedAt).
SetQuota(key.Quota).
SetQuotaUsed(key.QuotaUsed).
SetNillableExpiresAt(key.ExpiresAt)
SetNillableExpiresAt(key.ExpiresAt).
SetRateLimit5h(key.RateLimit5h).
SetRateLimit1d(key.RateLimit1d).
SetRateLimit7d(key.RateLimit7d)
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
@@ -118,6 +127,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
apikey.FieldQuota,
apikey.FieldQuotaUsed,
apikey.FieldExpiresAt,
apikey.FieldRateLimit5h,
apikey.FieldRateLimit1d,
apikey.FieldRateLimit7d,
).
WithUser(func(q *dbent.UserQuery) {
q.Select(
@@ -172,13 +184,20 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
// 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at避免二次查询带来的并发可见性问题。
client := clientFromContext(ctx, r.client)
now := time.Now()
builder := r.client.APIKey.Update().
builder := client.APIKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name).
SetStatus(key.Status).
SetQuota(key.Quota).
SetQuotaUsed(key.QuotaUsed).
SetRateLimit5h(key.RateLimit5h).
SetRateLimit1d(key.RateLimit1d).
SetRateLimit7d(key.RateLimit7d).
SetUsage5h(key.Usage5h).
SetUsage1d(key.Usage1d).
SetUsage7d(key.Usage7d).
SetUpdatedAt(now)
if key.GroupID != nil {
builder.SetGroupID(*key.GroupID)
@@ -193,6 +212,23 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder.ClearExpiresAt()
}
// Rate limit window start times
if key.Window5hStart != nil {
builder.SetWindow5hStart(*key.Window5hStart)
} else {
builder.ClearWindow5hStart()
}
if key.Window1dStart != nil {
builder.SetWindow1dStart(*key.Window1dStart)
} else {
builder.ClearWindow1dStart()
}
if key.Window7dStart != nil {
builder.SetWindow7dStart(*key.Window7dStart)
} else {
builder.ClearWindow7dStart()
}
// IP 限制字段
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
@@ -246,9 +282,27 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return nil
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
// Apply filters
if filters.Search != "" {
q = q.Where(apikey.Or(
apikey.NameContainsFold(filters.Search),
apikey.KeyContainsFold(filters.Search),
))
}
if filters.Status != "" {
q = q.Where(apikey.StatusEQ(filters.Status))
}
if filters.GroupID != nil {
if *filters.GroupID == 0 {
q = q.Where(apikey.GroupIDIsNil())
} else {
q = q.Where(apikey.GroupIDEQ(*filters.GroupID))
}
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
@@ -412,25 +466,92 @@ func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt
return nil
}
// IncrementRateLimitUsage atomically increments all rate limit usage counters and initializes
// window start times via COALESCE if not already set.
func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE api_keys SET
usage_5h = usage_5h + $1,
usage_1d = usage_1d + $1,
usage_7d = usage_7d + $1,
window_5h_start = COALESCE(window_5h_start, NOW()),
window_1d_start = COALESCE(window_1d_start, NOW()),
window_7d_start = COALESCE(window_7d_start, NOW()),
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL`,
cost, id)
return err
}
// ResetRateLimitWindows resets expired rate limit windows atomically.
func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE api_keys SET
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END,
window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END,
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END,
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END,
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END,
updated_at = NOW()
WHERE id = $1 AND deleted_at IS NULL`,
id)
return err
}
// GetRateLimitData returns the current rate limit usage and window start times for an API key.
func (r *apiKeyRepository) GetRateLimitData(ctx context.Context, id int64) (result *service.APIKeyRateLimitData, err error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT usage_5h, usage_1d, usage_7d, window_5h_start, window_1d_start, window_7d_start
FROM api_keys
WHERE id = $1 AND deleted_at IS NULL`,
id)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
}
}()
if !rows.Next() {
return nil, service.ErrAPIKeyNotFound
}
data := &service.APIKeyRateLimitData{}
if err := rows.Scan(&data.Usage5h, &data.Usage1d, &data.Usage7d, &data.Window5hStart, &data.Window1dStart, &data.Window7dStart); err != nil {
return nil, err
}
return data, rows.Err()
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
if m == nil {
return nil
}
out := &service.APIKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
Status: m.Status,
IPWhitelist: m.IPWhitelist,
IPBlacklist: m.IPBlacklist,
LastUsedAt: m.LastUsedAt,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
Quota: m.Quota,
QuotaUsed: m.QuotaUsed,
ExpiresAt: m.ExpiresAt,
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
Status: m.Status,
IPWhitelist: m.IPWhitelist,
IPBlacklist: m.IPBlacklist,
LastUsedAt: m.LastUsedAt,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
Quota: m.Quota,
QuotaUsed: m.QuotaUsed,
ExpiresAt: m.ExpiresAt,
RateLimit5h: m.RateLimit5h,
RateLimit1d: m.RateLimit1d,
RateLimit7d: m.RateLimit7d,
Usage5h: m.Usage5h,
Usage1d: m.Usage1d,
Usage7d: m.Usage7d,
Window5hStart: m.Window5hStart,
Window1dStart: m.Window1dStart,
Window7dStart: m.Window7dStart,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
@@ -446,20 +567,22 @@ func userEntityToService(u *dbent.User) *service.User {
return nil
}
return &service.User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
@@ -487,6 +610,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,

View File

@@ -26,7 +26,7 @@ func (s *APIKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
s.repo = newAPIKeyRepositoryWithSQL(s.client, tx)
}
func TestAPIKeyRepoSuite(t *testing.T) {
@@ -158,7 +158,7 @@ func (s *APIKeyRepoSuite) TestListByUserID() {
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
s.Require().NoError(err, "ListByUserID")
s.Require().Len(keys, 2)
s.Require().Equal(int64(2), page.Total)
@@ -170,7 +170,7 @@ func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
}
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2}, service.APIKeyListFilters{})
s.Require().NoError(err)
s.Require().Len(keys, 2)
s.Require().Equal(int64(5), page.Total)
@@ -314,7 +314,7 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(service.StatusDisabled, got2.Status)
s.Require().Nil(got2.GroupID)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
s.Require().NoError(err, "ListByUserID")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(keys, 1)
@@ -421,7 +421,7 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
// 注意:此测试使用 testEntClient非事务隔离数据会真正写入数据库。
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
client := testEntClient(t)
repo := NewAPIKeyRepository(client).(*apiKeyRepository)
repo := NewAPIKeyRepository(client, integrationDB).(*apiKeyRepository)
ctx := context.Background()
// 创建测试用户和 API Key

View File

@@ -14,10 +14,12 @@ import (
)
const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute
billingCacheJitter = 30 * time.Second
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingRateLimitKeyPrefix = "apikey:rate:"
billingCacheTTL = 5 * time.Minute
billingCacheJitter = 30 * time.Second
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
)
// jitteredTTL 返回带随机抖动的 TTL防止缓存雪崩
@@ -49,6 +51,20 @@ const (
subFieldVersion = "version"
)
// billingRateLimitKey generates the Redis key for API key rate limit cache.
func billingRateLimitKey(keyID int64) string {
return fmt.Sprintf("%s%d", billingRateLimitKeyPrefix, keyID)
}
const (
rateLimitFieldUsage5h = "usage_5h"
rateLimitFieldUsage1d = "usage_1d"
rateLimitFieldUsage7d = "usage_7d"
rateLimitFieldWindow5h = "window_5h"
rateLimitFieldWindow1d = "window_1d"
rateLimitFieldWindow7d = "window_7d"
)
var (
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
@@ -73,6 +89,21 @@ var (
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
// updateRateLimitUsageScript atomically increments all three rate limit usage counters.
// Returns 0 if the key doesn't exist (cache miss), 1 on success.
updateRateLimitUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
)
type billingCache struct {
@@ -195,3 +226,69 @@ func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID,
key := billingSubKey(userID, groupID)
return c.rdb.Del(ctx, key).Err()
}
func (c *billingCache) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*service.APIKeyRateLimitCacheData, error) {
key := billingRateLimitKey(keyID)
result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, redis.Nil
}
data := &service.APIKeyRateLimitCacheData{}
if v, ok := result[rateLimitFieldUsage5h]; ok {
data.Usage5h, _ = strconv.ParseFloat(v, 64)
}
if v, ok := result[rateLimitFieldUsage1d]; ok {
data.Usage1d, _ = strconv.ParseFloat(v, 64)
}
if v, ok := result[rateLimitFieldUsage7d]; ok {
data.Usage7d, _ = strconv.ParseFloat(v, 64)
}
if v, ok := result[rateLimitFieldWindow5h]; ok {
data.Window5h, _ = strconv.ParseInt(v, 10, 64)
}
if v, ok := result[rateLimitFieldWindow1d]; ok {
data.Window1d, _ = strconv.ParseInt(v, 10, 64)
}
if v, ok := result[rateLimitFieldWindow7d]; ok {
data.Window7d, _ = strconv.ParseInt(v, 10, 64)
}
return data, nil
}
func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *service.APIKeyRateLimitCacheData) error {
if data == nil {
return nil
}
key := billingRateLimitKey(keyID)
fields := map[string]any{
rateLimitFieldUsage5h: data.Usage5h,
rateLimitFieldUsage1d: data.Usage1d,
rateLimitFieldUsage7d: data.Usage7d,
rateLimitFieldWindow5h: data.Window5h,
rateLimitFieldWindow1d: data.Window1d,
rateLimitFieldWindow7d: data.Window7d,
}
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, rateLimitCacheTTL)
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
key := billingRateLimitKey(keyID)
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
return err
}
return nil
}
func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
key := billingRateLimitKey(keyID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
@@ -28,11 +29,14 @@ func NewClaudeOAuthClient() service.ClaudeOAuthClient {
type claudeOAuthService struct {
baseURL string
tokenURL string
clientFactory func(proxyURL string) *req.Client
clientFactory func(proxyURL string) (*req.Client, error)
}
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := s.clientFactory(proxyURL)
client, err := s.clientFactory(proxyURL)
if err != nil {
return "", fmt.Errorf("create HTTP client: %w", err)
}
var orgs []struct {
UUID string `json:"uuid"`
@@ -88,7 +92,10 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
}
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := s.clientFactory(proxyURL)
client, err := s.clientFactory(proxyURL)
if err != nil {
return "", fmt.Errorf("create HTTP client: %w", err)
}
authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID)
@@ -165,7 +172,10 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
}
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL)
client, err := s.clientFactory(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
// Parse code which may contain state in format "authCode#state"
authCode := code
@@ -223,7 +233,10 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
}
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL)
client, err := s.clientFactory(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
reqBody := map[string]any{
"grant_type": "refresh_token",
@@ -253,16 +266,20 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
return &tokenResp, nil
}
func createReqClient(proxyURL string) *req.Client {
func createReqClient(proxyURL string) (*req.Client, error) {
// 禁用 CookieJar确保每次授权都是干净的会话
client := req.C().
SetTimeout(60 * time.Second).
ImpersonateChrome().
SetCookieJar(nil) // 禁用 CookieJar
if strings.TrimSpace(proxyURL) != "" {
client.SetProxyURL(strings.TrimSpace(proxyURL))
trimmed, _, err := proxyurl.Parse(proxyURL)
if err != nil {
return nil, err
}
if trimmed != "" {
client.SetProxyURL(trimmed)
}
return client
return client, nil
}

View File

@@ -91,7 +91,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
@@ -169,7 +169,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "")
@@ -276,7 +276,7 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
@@ -372,7 +372,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
resp, err := s.client.RefreshToken(context.Background(), "rt", "")

View File

@@ -83,7 +83,7 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
return nil, fmt.Errorf("create http client failed: %w", err)
}
resp, err = client.Do(req)

View File

@@ -50,7 +50,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
allowPrivateHosts: true,
}
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.NoError(s.T(), err, "FetchUsage")
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
@@ -112,6 +112,17 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
require.Error(s.T(), err, "expected error for cancelled context")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_InvalidProxyReturnsError() {
s.fetcher = &claudeUsageService{
usageURL: "http://example.com",
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "create http client failed")
}
func TestClaudeUsageServiceSuite(t *testing.T) {
suite.Run(t, new(ClaudeUsageServiceSuite))
}

View File

@@ -227,6 +227,43 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID
return result, nil
}
func (c *concurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return map[int64]int{}, nil
}
now, err := c.rdb.Time(ctx).Result()
if err != nil {
return nil, fmt.Errorf("redis TIME: %w", err)
}
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
pipe := c.rdb.Pipeline()
type accountCmd struct {
accountID int64
zcardCmd *redis.IntCmd
}
cmds := make([]accountCmd, 0, len(accountIDs))
for _, accountID := range accountIDs {
slotKey := accountSlotKeyPrefix + strconv.FormatInt(accountID, 10)
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
cmds = append(cmds, accountCmd{
accountID: accountID,
zcardCmd: pipe.ZCard(ctx, slotKey),
})
}
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
return nil, fmt.Errorf("pipeline exec: %w", err)
}
result := make(map[int64]int, len(accountIDs))
for _, cmd := range cmds {
result[cmd.accountID] = int(cmd.zcardCmd.Val())
}
return result, nil
}
// User slot operations
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {

View File

@@ -26,7 +26,10 @@ func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient {
}
func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) {
client := createGeminiReqClient(proxyURL)
client, err := createGeminiReqClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
// Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public)
@@ -72,7 +75,10 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
}
func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) {
client := createGeminiReqClient(proxyURL)
client, err := createGeminiReqClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID,
@@ -111,7 +117,7 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
return &tokenResp, nil
}
func createGeminiReqClient(proxyURL string) *req.Client {
func createGeminiReqClient(proxyURL string) (*req.Client, error) {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,

View File

@@ -26,7 +26,11 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo
}
var out geminicli.LoadCodeAssistResponse
resp, err := createGeminiCliReqClient(proxyURL).R().
client, err := createGeminiCliReqClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
resp, err := client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Content-Type", "application/json").
@@ -66,7 +70,11 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
fmt.Printf("[CodeAssist] OnboardUser request body: %+v\n", reqBody)
var out geminicli.OnboardUserResponse
resp, err := createGeminiCliReqClient(proxyURL).R().
client, err := createGeminiCliReqClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create HTTP client: %w", err)
}
resp, err := client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Content-Type", "application/json").
@@ -98,7 +106,7 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
return &out, nil
}
func createGeminiCliReqClient(proxyURL string) *req.Client {
func createGeminiCliReqClient(proxyURL string) (*req.Client, error) {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,

View File

@@ -5,8 +5,10 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
@@ -24,13 +26,19 @@ type githubReleaseClientError struct {
// NewGitHubReleaseClient 创建 GitHub Release 客户端
// proxyURL 为空时直连 GitHub支持 http/https/socks5/socks5h 协议
// 代理配置失败时行为由 allowDirectOnProxyError 控制:
// - false默认返回错误占位客户端禁止回退到直连
// - true回退到直连仅限管理员显式开启
func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) service.GitHubReleaseClient {
// 安全说明httpclient.GetClient 的错误链url.Parse / proxyutil不含明文代理凭据
// 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
})
if err != nil {
if proxyURL != "" && !allowDirectOnProxyError {
if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError {
slog.Warn("proxy client init failed, all requests will fail", "service", "github_release", "error", err)
return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
}
sharedClient = &http.Client{Timeout: 30 * time.Second}
@@ -42,7 +50,8 @@ func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) servi
ProxyURL: proxyURL,
})
if err != nil {
if proxyURL != "" && !allowDirectOnProxyError {
if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError {
slog.Warn("proxy download client init failed, all requests will fail", "service", "github_release", "error", err)
return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
}
downloadClient = &http.Client{Timeout: 10 * time.Minute}

View File

@@ -4,6 +4,8 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
@@ -57,6 +59,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
// 设置模型路由配置
@@ -123,8 +126,41 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetSimulateClaudeMaxEnabled(groupIn.SimulateClaudeMaxEnabled)
// 显式处理可空字段nil 需要 clear非 nil 需要 set。
if groupIn.DailyLimitUSD != nil {
builder = builder.SetDailyLimitUsd(*groupIn.DailyLimitUSD)
} else {
builder = builder.ClearDailyLimitUsd()
}
if groupIn.WeeklyLimitUSD != nil {
builder = builder.SetWeeklyLimitUsd(*groupIn.WeeklyLimitUSD)
} else {
builder = builder.ClearWeeklyLimitUsd()
}
if groupIn.MonthlyLimitUSD != nil {
builder = builder.SetMonthlyLimitUsd(*groupIn.MonthlyLimitUSD)
} else {
builder = builder.ClearMonthlyLimitUsd()
}
if groupIn.ImagePrice1K != nil {
builder = builder.SetImagePrice1k(*groupIn.ImagePrice1K)
} else {
builder = builder.ClearImagePrice1k()
}
if groupIn.ImagePrice2K != nil {
builder = builder.SetImagePrice2k(*groupIn.ImagePrice2K)
} else {
builder = builder.ClearImagePrice2k()
}
if groupIn.ImagePrice4K != nil {
builder = builder.SetImagePrice4k(*groupIn.ImagePrice4K)
} else {
builder = builder.ClearImagePrice4k()
}
// 处理 FallbackGroupIDnil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID)
@@ -283,6 +319,54 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool,
return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx)
}
// ExistsByIDs 批量检查分组是否存在(仅检查未软删除记录)。
// 返回结构map[groupID]exists。
func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) {
result := make(map[int64]bool, len(ids))
if len(ids) == 0 {
return result, nil
}
uniqueIDs := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
result[id] = false
}
if len(uniqueIDs) == 0 {
return result, nil
}
rows, err := r.sql.QueryContext(ctx, `
SELECT id
FROM groups
WHERE id = ANY($1) AND deleted_at IS NULL
`, pq.Array(uniqueIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, err
}
result[id] = true
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
@@ -514,22 +598,72 @@ func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []servic
return nil
}
// 使用事务批量更新
tx, err := r.client.Tx(ctx)
// 去重后保留最后一次排序值,避免重复 ID 造成 CASE 分支冲突。
sortOrderByID := make(map[int64]int, len(updates))
groupIDs := make([]int64, 0, len(updates))
for _, u := range updates {
if u.ID <= 0 {
continue
}
if _, exists := sortOrderByID[u.ID]; !exists {
groupIDs = append(groupIDs, u.ID)
}
sortOrderByID[u.ID] = u.SortOrder
}
if len(groupIDs) == 0 {
return nil
}
// 与旧实现保持一致:任何不存在/已删除的分组都返回 not found且不执行更新。
var existingCount int
if err := scanSingleRow(
ctx,
r.sql,
`SELECT COUNT(*) FROM groups WHERE deleted_at IS NULL AND id = ANY($1)`,
[]any{pq.Array(groupIDs)},
&existingCount,
); err != nil {
return err
}
if existingCount != len(groupIDs) {
return service.ErrGroupNotFound
}
args := make([]any, 0, len(groupIDs)*2+1)
caseClauses := make([]string, 0, len(groupIDs))
placeholder := 1
for _, id := range groupIDs {
caseClauses = append(caseClauses, fmt.Sprintf("WHEN $%d THEN $%d", placeholder, placeholder+1))
args = append(args, id, sortOrderByID[id])
placeholder += 2
}
args = append(args, pq.Array(groupIDs))
query := fmt.Sprintf(`
UPDATE groups
SET sort_order = CASE id
%s
ELSE sort_order
END
WHERE deleted_at IS NULL AND id = ANY($%d)
`, strings.Join(caseClauses, "\n\t\t\t"), placeholder)
result, err := r.sql.ExecContext(ctx, query, args...)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
for _, u := range updates {
if _, err := tx.Group.UpdateOneID(u.ID).SetSortOrder(u.SortOrder).Save(ctx); err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
}
if err := tx.Commit(); err != nil {
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected != int64(len(groupIDs)) {
return service.ErrGroupNotFound
}
for _, id := range groupIDs {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group sort update failed: group=%d err=%v", id, err)
}
}
return nil
}

Some files were not shown because too many files have changed in this diff Show More