mirror of
https://gitee.com/wanwujie/sub2api
synced 2026-04-14 03:44:45 +08:00
Merge branch 'main' into feature/antigravity-user-agent-configurable
This commit is contained in:
@@ -5,7 +5,7 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -19,6 +19,13 @@ const (
|
||||
RunModeSimple = "simple"
|
||||
)
|
||||
|
||||
// 使用量记录队列溢出策略
|
||||
const (
|
||||
UsageRecordOverflowPolicyDrop = "drop"
|
||||
UsageRecordOverflowPolicySample = "sample"
|
||||
UsageRecordOverflowPolicySync = "sync"
|
||||
)
|
||||
|
||||
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
|
||||
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
|
||||
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||
@@ -38,31 +45,68 @@ const (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Billing BillingConfig `mapstructure:"billing"`
|
||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Ops OpsConfig `mapstructure:"ops"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Totp TotpConfig `mapstructure:"totp"`
|
||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
|
||||
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
|
||||
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
|
||||
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
Update UpdateConfig `mapstructure:"update"`
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Log LogConfig `mapstructure:"log"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Billing BillingConfig `mapstructure:"billing"`
|
||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Ops OpsConfig `mapstructure:"ops"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Totp TotpConfig `mapstructure:"totp"`
|
||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
|
||||
SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"`
|
||||
SubscriptionMaintenance SubscriptionMaintenanceConfig `mapstructure:"subscription_maintenance"`
|
||||
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
|
||||
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
|
||||
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
Sora SoraConfig `mapstructure:"sora"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
Update UpdateConfig `mapstructure:"update"`
|
||||
Idempotency IdempotencyConfig `mapstructure:"idempotency"`
|
||||
}
|
||||
|
||||
type LogConfig struct {
|
||||
Level string `mapstructure:"level"`
|
||||
Format string `mapstructure:"format"`
|
||||
ServiceName string `mapstructure:"service_name"`
|
||||
Environment string `mapstructure:"env"`
|
||||
Caller bool `mapstructure:"caller"`
|
||||
StacktraceLevel string `mapstructure:"stacktrace_level"`
|
||||
Output LogOutputConfig `mapstructure:"output"`
|
||||
Rotation LogRotationConfig `mapstructure:"rotation"`
|
||||
Sampling LogSamplingConfig `mapstructure:"sampling"`
|
||||
}
|
||||
|
||||
type LogOutputConfig struct {
|
||||
ToStdout bool `mapstructure:"to_stdout"`
|
||||
ToFile bool `mapstructure:"to_file"`
|
||||
FilePath string `mapstructure:"file_path"`
|
||||
}
|
||||
|
||||
type LogRotationConfig struct {
|
||||
MaxSizeMB int `mapstructure:"max_size_mb"`
|
||||
MaxBackups int `mapstructure:"max_backups"`
|
||||
MaxAgeDays int `mapstructure:"max_age_days"`
|
||||
Compress bool `mapstructure:"compress"`
|
||||
LocalTime bool `mapstructure:"local_time"`
|
||||
}
|
||||
|
||||
type LogSamplingConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
Initial int `mapstructure:"initial"`
|
||||
Thereafter int `mapstructure:"thereafter"`
|
||||
}
|
||||
|
||||
type GeminiConfig struct {
|
||||
@@ -94,6 +138,25 @@ type UpdateConfig struct {
|
||||
ProxyURL string `mapstructure:"proxy_url"`
|
||||
}
|
||||
|
||||
type IdempotencyConfig struct {
|
||||
// ObserveOnly 为 true 时处于观察期:未携带 Idempotency-Key 的请求继续放行。
|
||||
ObserveOnly bool `mapstructure:"observe_only"`
|
||||
// DefaultTTLSeconds 关键写接口的幂等记录默认 TTL(秒)。
|
||||
DefaultTTLSeconds int `mapstructure:"default_ttl_seconds"`
|
||||
// SystemOperationTTLSeconds 系统操作接口的幂等记录 TTL(秒)。
|
||||
SystemOperationTTLSeconds int `mapstructure:"system_operation_ttl_seconds"`
|
||||
// ProcessingTimeoutSeconds processing 状态锁超时(秒)。
|
||||
ProcessingTimeoutSeconds int `mapstructure:"processing_timeout_seconds"`
|
||||
// FailedRetryBackoffSeconds 失败退避窗口(秒)。
|
||||
FailedRetryBackoffSeconds int `mapstructure:"failed_retry_backoff_seconds"`
|
||||
// MaxStoredResponseLen 持久化响应体最大长度(字节)。
|
||||
MaxStoredResponseLen int `mapstructure:"max_stored_response_len"`
|
||||
// CleanupIntervalSeconds 过期记录清理周期(秒)。
|
||||
CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"`
|
||||
// CleanupBatchSize 每次清理的最大记录数。
|
||||
CleanupBatchSize int `mapstructure:"cleanup_batch_size"`
|
||||
}
|
||||
|
||||
type LinuxDoConnectConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
ClientID string `mapstructure:"client_id"`
|
||||
@@ -126,6 +189,8 @@ type TokenRefreshConfig struct {
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
// 重试退避基础时间(秒)
|
||||
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
|
||||
// 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭)
|
||||
SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
|
||||
}
|
||||
|
||||
type PricingConfig struct {
|
||||
@@ -147,6 +212,7 @@ type ServerConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
FrontendURL string `mapstructure:"frontend_url"` // 前端基础 URL,用于生成邮件中的外部链接
|
||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP)
|
||||
@@ -173,6 +239,7 @@ type SecurityConfig struct {
|
||||
URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
|
||||
ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
|
||||
CSP CSPConfig `mapstructure:"csp"`
|
||||
ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"`
|
||||
ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
|
||||
}
|
||||
|
||||
@@ -197,6 +264,12 @@ type CSPConfig struct {
|
||||
Policy string `mapstructure:"policy"`
|
||||
}
|
||||
|
||||
type ProxyFallbackConfig struct {
|
||||
// AllowDirectOnError 当代理初始化失败时是否允许回退直连。
|
||||
// 默认 false:避免因代理配置错误导致 IP 泄露/关联。
|
||||
AllowDirectOnError bool `mapstructure:"allow_direct_on_error"`
|
||||
}
|
||||
|
||||
type ProxyProbeConfig struct {
|
||||
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` // 已禁用:禁止跳过 TLS 证书验证
|
||||
}
|
||||
@@ -217,6 +290,59 @@ type ConcurrencyConfig struct {
|
||||
PingInterval int `mapstructure:"ping_interval"`
|
||||
}
|
||||
|
||||
// SoraConfig 直连 Sora 配置
|
||||
type SoraConfig struct {
|
||||
Client SoraClientConfig `mapstructure:"client"`
|
||||
Storage SoraStorageConfig `mapstructure:"storage"`
|
||||
}
|
||||
|
||||
// SoraClientConfig 直连 Sora 客户端配置
|
||||
type SoraClientConfig struct {
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
TimeoutSeconds int `mapstructure:"timeout_seconds"`
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"`
|
||||
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
|
||||
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
|
||||
RecentTaskLimit int `mapstructure:"recent_task_limit"`
|
||||
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
|
||||
Debug bool `mapstructure:"debug"`
|
||||
UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
|
||||
Headers map[string]string `mapstructure:"headers"`
|
||||
UserAgent string `mapstructure:"user_agent"`
|
||||
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
|
||||
CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"`
|
||||
}
|
||||
|
||||
// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置
|
||||
type SoraCurlCFFISidecarConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
Impersonate string `mapstructure:"impersonate"`
|
||||
TimeoutSeconds int `mapstructure:"timeout_seconds"`
|
||||
SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"`
|
||||
SessionTTLSeconds int `mapstructure:"session_ttl_seconds"`
|
||||
}
|
||||
|
||||
// SoraStorageConfig 媒体存储配置
|
||||
type SoraStorageConfig struct {
|
||||
Type string `mapstructure:"type"`
|
||||
LocalPath string `mapstructure:"local_path"`
|
||||
FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
|
||||
MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
|
||||
DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"`
|
||||
MaxDownloadBytes int64 `mapstructure:"max_download_bytes"`
|
||||
Debug bool `mapstructure:"debug"`
|
||||
Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
|
||||
}
|
||||
|
||||
// SoraStorageCleanupConfig 媒体清理配置
|
||||
type SoraStorageCleanupConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
Schedule string `mapstructure:"schedule"`
|
||||
RetentionDays int `mapstructure:"retention_days"`
|
||||
}
|
||||
|
||||
// GatewayConfig API网关相关配置
|
||||
type GatewayConfig struct {
|
||||
// 等待上游响应头的超时时间(秒),0表示无超时
|
||||
@@ -224,8 +350,20 @@ type GatewayConfig struct {
|
||||
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
|
||||
// 请求体最大字节数,用于网关请求体大小限制
|
||||
MaxBodySize int64 `mapstructure:"max_body_size"`
|
||||
// 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大
|
||||
UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"`
|
||||
// 代理探测响应体读取上限(字节)
|
||||
ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"`
|
||||
// Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销)
|
||||
GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"`
|
||||
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
|
||||
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
|
||||
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
|
||||
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
|
||||
ForceCodexCLI bool `mapstructure:"force_codex_cli"`
|
||||
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
|
||||
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
|
||||
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
|
||||
|
||||
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
|
||||
// MaxIdleConns: 所有主机的最大空闲连接总数
|
||||
@@ -271,6 +409,24 @@ type GatewayConfig struct {
|
||||
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
|
||||
FailoverOn400 bool `mapstructure:"failover_on_400"`
|
||||
|
||||
// Sora 专用配置
|
||||
// SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size)
|
||||
SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"`
|
||||
// SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制)
|
||||
SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"`
|
||||
// SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制)
|
||||
SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"`
|
||||
// SoraStreamMode: stream 强制策略(force/error)
|
||||
SoraStreamMode string `mapstructure:"sora_stream_mode"`
|
||||
// SoraModelFilters: 模型列表过滤配置
|
||||
SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"`
|
||||
// SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key
|
||||
SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"`
|
||||
// SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名)
|
||||
SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"`
|
||||
// SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用)
|
||||
SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"`
|
||||
|
||||
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
|
||||
MaxAccountSwitches int `mapstructure:"max_account_switches"`
|
||||
// Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格)
|
||||
@@ -284,6 +440,53 @@ type GatewayConfig struct {
|
||||
|
||||
// TLSFingerprint: TLS指纹伪装配置
|
||||
TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"`
|
||||
|
||||
// UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker)
|
||||
UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"`
|
||||
|
||||
// UserGroupRateCacheTTLSeconds: 用户分组倍率热路径缓存 TTL(秒)
|
||||
UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"`
|
||||
// ModelsListCacheTTLSeconds: /v1/models 模型列表短缓存 TTL(秒)
|
||||
ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"`
|
||||
}
|
||||
|
||||
// GatewayUsageRecordConfig 使用量记录异步队列配置
|
||||
type GatewayUsageRecordConfig struct {
|
||||
// WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限)
|
||||
WorkerCount int `mapstructure:"worker_count"`
|
||||
// QueueSize: 队列容量(有界)
|
||||
QueueSize int `mapstructure:"queue_size"`
|
||||
// TaskTimeoutSeconds: 单个使用量记录任务超时(秒)
|
||||
TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"`
|
||||
// OverflowPolicy: 队列满时策略(drop/sample/sync)
|
||||
OverflowPolicy string `mapstructure:"overflow_policy"`
|
||||
// OverflowSamplePercent: sample 策略下,同步回写采样百分比(1-100)
|
||||
OverflowSamplePercent int `mapstructure:"overflow_sample_percent"`
|
||||
|
||||
// AutoScaleEnabled: 是否启用 worker 自动扩缩容
|
||||
AutoScaleEnabled bool `mapstructure:"auto_scale_enabled"`
|
||||
// AutoScaleMinWorkers: 自动扩缩容最小 worker 数
|
||||
AutoScaleMinWorkers int `mapstructure:"auto_scale_min_workers"`
|
||||
// AutoScaleMaxWorkers: 自动扩缩容最大 worker 数
|
||||
AutoScaleMaxWorkers int `mapstructure:"auto_scale_max_workers"`
|
||||
// AutoScaleUpQueuePercent: 队列占用率达到该阈值时触发扩容
|
||||
AutoScaleUpQueuePercent int `mapstructure:"auto_scale_up_queue_percent"`
|
||||
// AutoScaleDownQueuePercent: 队列占用率低于该阈值时触发缩容
|
||||
AutoScaleDownQueuePercent int `mapstructure:"auto_scale_down_queue_percent"`
|
||||
// AutoScaleUpStep: 每次扩容步长
|
||||
AutoScaleUpStep int `mapstructure:"auto_scale_up_step"`
|
||||
// AutoScaleDownStep: 每次缩容步长
|
||||
AutoScaleDownStep int `mapstructure:"auto_scale_down_step"`
|
||||
// AutoScaleCheckIntervalSeconds: 自动扩缩容检测间隔(秒)
|
||||
AutoScaleCheckIntervalSeconds int `mapstructure:"auto_scale_check_interval_seconds"`
|
||||
// AutoScaleCooldownSeconds: 自动扩缩容冷却时间(秒)
|
||||
AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"`
|
||||
}
|
||||
|
||||
// SoraModelFiltersConfig Sora 模型过滤配置
|
||||
type SoraModelFiltersConfig struct {
|
||||
// HidePromptEnhance 是否隐藏 prompt-enhance 模型
|
||||
HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"`
|
||||
}
|
||||
|
||||
// TLSFingerprintConfig TLS指纹伪装配置
|
||||
@@ -479,8 +682,9 @@ type OpsMetricsCollectorCacheConfig struct {
|
||||
type JWTConfig struct {
|
||||
Secret string `mapstructure:"secret"`
|
||||
ExpireHour int `mapstructure:"expire_hour"`
|
||||
// AccessTokenExpireMinutes: Access Token有效期(分钟),默认15分钟
|
||||
// 短有效期减少被盗用风险,配合Refresh Token实现无感续期
|
||||
// AccessTokenExpireMinutes: Access Token有效期(分钟)
|
||||
// - >0: 使用分钟配置(优先级高于 ExpireHour)
|
||||
// - =0: 回退使用 ExpireHour(向后兼容旧配置)
|
||||
AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"`
|
||||
// RefreshTokenExpireDays: Refresh Token有效期(天),默认30天
|
||||
RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"`
|
||||
@@ -525,6 +729,20 @@ type APIKeyAuthCacheConfig struct {
|
||||
Singleflight bool `mapstructure:"singleflight"`
|
||||
}
|
||||
|
||||
// SubscriptionCacheConfig 订阅认证 L1 缓存配置
|
||||
type SubscriptionCacheConfig struct {
|
||||
L1Size int `mapstructure:"l1_size"`
|
||||
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
|
||||
JitterPercent int `mapstructure:"jitter_percent"`
|
||||
}
|
||||
|
||||
// SubscriptionMaintenanceConfig 订阅窗口维护后台任务配置。
|
||||
// 用于将“请求路径触发的维护动作”有界化,避免高并发下 goroutine 膨胀。
|
||||
type SubscriptionMaintenanceConfig struct {
|
||||
WorkerCount int `mapstructure:"worker_count"`
|
||||
QueueSize int `mapstructure:"queue_size"`
|
||||
}
|
||||
|
||||
// DashboardCacheConfig 仪表盘统计缓存配置
|
||||
type DashboardCacheConfig struct {
|
||||
// Enabled: 是否启用仪表盘缓存
|
||||
@@ -588,7 +806,19 @@ func NormalizeRunMode(value string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// Load 读取并校验完整配置(要求 jwt.secret 已显式提供)。
|
||||
func Load() (*Config, error) {
|
||||
return load(false)
|
||||
}
|
||||
|
||||
// LoadForBootstrap 读取启动阶段配置。
|
||||
//
|
||||
// 启动阶段允许 jwt.secret 先留空,后续由数据库初始化流程补齐并再次完整校验。
|
||||
func LoadForBootstrap() (*Config, error) {
|
||||
return load(true)
|
||||
}
|
||||
|
||||
func load(allowMissingJWTSecret bool) (*Config, error) {
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
|
||||
@@ -630,6 +860,7 @@ func Load() (*Config, error) {
|
||||
if cfg.Server.Mode == "" {
|
||||
cfg.Server.Mode = "debug"
|
||||
}
|
||||
cfg.Server.FrontendURL = strings.TrimSpace(cfg.Server.FrontendURL)
|
||||
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
||||
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
|
||||
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
|
||||
@@ -648,15 +879,12 @@ func Load() (*Config, error) {
|
||||
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
||||
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
||||
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
|
||||
|
||||
if cfg.JWT.Secret == "" {
|
||||
secret, err := generateJWTSecret(64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate jwt secret error: %w", err)
|
||||
}
|
||||
cfg.JWT.Secret = secret
|
||||
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
|
||||
}
|
||||
cfg.Log.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level))
|
||||
cfg.Log.Format = strings.ToLower(strings.TrimSpace(cfg.Log.Format))
|
||||
cfg.Log.ServiceName = strings.TrimSpace(cfg.Log.ServiceName)
|
||||
cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment)
|
||||
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
|
||||
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
|
||||
|
||||
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
|
||||
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
|
||||
@@ -667,29 +895,39 @@ func Load() (*Config, error) {
|
||||
}
|
||||
cfg.Totp.EncryptionKey = key
|
||||
cfg.Totp.EncryptionKeyConfigured = false
|
||||
log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
|
||||
slog.Warn("TOTP encryption key auto-generated. Consider setting a fixed key for production.")
|
||||
} else {
|
||||
cfg.Totp.EncryptionKeyConfigured = true
|
||||
}
|
||||
|
||||
originalJWTSecret := cfg.JWT.Secret
|
||||
if allowMissingJWTSecret && originalJWTSecret == "" {
|
||||
// 启动阶段允许先无 JWT 密钥,后续在数据库初始化后补齐。
|
||||
cfg.JWT.Secret = strings.Repeat("0", 32)
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("validate config error: %w", err)
|
||||
}
|
||||
|
||||
if allowMissingJWTSecret && originalJWTSecret == "" {
|
||||
cfg.JWT.Secret = ""
|
||||
}
|
||||
|
||||
if !cfg.Security.URLAllowlist.Enabled {
|
||||
log.Println("Warning: security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).")
|
||||
slog.Warn("security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).")
|
||||
}
|
||||
if !cfg.Security.ResponseHeaders.Enabled {
|
||||
log.Println("Warning: security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).")
|
||||
slog.Warn("security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).")
|
||||
}
|
||||
|
||||
if cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) {
|
||||
log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.")
|
||||
slog.Warn("JWT secret appears weak; use a 32+ character random secret in production.")
|
||||
}
|
||||
if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 {
|
||||
log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v",
|
||||
cfg.Security.ResponseHeaders.AdditionalAllowed,
|
||||
cfg.Security.ResponseHeaders.ForceRemove,
|
||||
slog.Info("response header policy configured",
|
||||
"additional_allowed", cfg.Security.ResponseHeaders.AdditionalAllowed,
|
||||
"force_remove", cfg.Security.ResponseHeaders.ForceRemove,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -702,7 +940,8 @@ func setDefaults() {
|
||||
// Server
|
||||
viper.SetDefault("server.host", "0.0.0.0")
|
||||
viper.SetDefault("server.port", 8080)
|
||||
viper.SetDefault("server.mode", "debug")
|
||||
viper.SetDefault("server.mode", "release")
|
||||
viper.SetDefault("server.frontend_url", "")
|
||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||
viper.SetDefault("server.trusted_proxies", []string{})
|
||||
@@ -715,6 +954,25 @@ func setDefaults() {
|
||||
viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB
|
||||
viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB
|
||||
|
||||
// Log
|
||||
viper.SetDefault("log.level", "info")
|
||||
viper.SetDefault("log.format", "console")
|
||||
viper.SetDefault("log.service_name", "sub2api")
|
||||
viper.SetDefault("log.env", "production")
|
||||
viper.SetDefault("log.caller", true)
|
||||
viper.SetDefault("log.stacktrace_level", "error")
|
||||
viper.SetDefault("log.output.to_stdout", true)
|
||||
viper.SetDefault("log.output.to_file", true)
|
||||
viper.SetDefault("log.output.file_path", "")
|
||||
viper.SetDefault("log.rotation.max_size_mb", 100)
|
||||
viper.SetDefault("log.rotation.max_backups", 10)
|
||||
viper.SetDefault("log.rotation.max_age_days", 7)
|
||||
viper.SetDefault("log.rotation.compress", true)
|
||||
viper.SetDefault("log.rotation.local_time", true)
|
||||
viper.SetDefault("log.sampling.enabled", false)
|
||||
viper.SetDefault("log.sampling.initial", 100)
|
||||
viper.SetDefault("log.sampling.thereafter", 100)
|
||||
|
||||
// CORS
|
||||
viper.SetDefault("cors.allowed_origins", []string{})
|
||||
viper.SetDefault("cors.allow_credentials", true)
|
||||
@@ -737,7 +995,7 @@ func setDefaults() {
|
||||
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
|
||||
viper.SetDefault("security.url_allowlist.allow_private_hosts", true)
|
||||
viper.SetDefault("security.url_allowlist.allow_insecure_http", true)
|
||||
viper.SetDefault("security.response_headers.enabled", false)
|
||||
viper.SetDefault("security.response_headers.enabled", true)
|
||||
viper.SetDefault("security.response_headers.additional_allowed", []string{})
|
||||
viper.SetDefault("security.response_headers.force_remove", []string{})
|
||||
viper.SetDefault("security.csp.enabled", true)
|
||||
@@ -775,9 +1033,9 @@ func setDefaults() {
|
||||
viper.SetDefault("database.user", "postgres")
|
||||
viper.SetDefault("database.password", "postgres")
|
||||
viper.SetDefault("database.dbname", "sub2api")
|
||||
viper.SetDefault("database.sslmode", "disable")
|
||||
viper.SetDefault("database.max_open_conns", 50)
|
||||
viper.SetDefault("database.max_idle_conns", 10)
|
||||
viper.SetDefault("database.sslmode", "prefer")
|
||||
viper.SetDefault("database.max_open_conns", 256)
|
||||
viper.SetDefault("database.max_idle_conns", 128)
|
||||
viper.SetDefault("database.conn_max_lifetime_minutes", 30)
|
||||
viper.SetDefault("database.conn_max_idle_time_minutes", 5)
|
||||
|
||||
@@ -789,8 +1047,8 @@ func setDefaults() {
|
||||
viper.SetDefault("redis.dial_timeout_seconds", 5)
|
||||
viper.SetDefault("redis.read_timeout_seconds", 3)
|
||||
viper.SetDefault("redis.write_timeout_seconds", 3)
|
||||
viper.SetDefault("redis.pool_size", 128)
|
||||
viper.SetDefault("redis.min_idle_conns", 10)
|
||||
viper.SetDefault("redis.pool_size", 1024)
|
||||
viper.SetDefault("redis.min_idle_conns", 128)
|
||||
viper.SetDefault("redis.enable_tls", false)
|
||||
|
||||
// Ops (vNext)
|
||||
@@ -810,9 +1068,9 @@ func setDefaults() {
|
||||
// JWT
|
||||
viper.SetDefault("jwt.secret", "")
|
||||
viper.SetDefault("jwt.expire_hour", 24)
|
||||
viper.SetDefault("jwt.access_token_expire_minutes", 360) // 6小时Access Token有效期
|
||||
viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期
|
||||
viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新
|
||||
viper.SetDefault("jwt.access_token_expire_minutes", 0) // 0 表示回退到 expire_hour
|
||||
viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期
|
||||
viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新
|
||||
|
||||
// TOTP
|
||||
viper.SetDefault("totp.encryption_key", "")
|
||||
@@ -849,6 +1107,11 @@ func setDefaults() {
|
||||
viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
|
||||
viper.SetDefault("api_key_auth_cache.singleflight", true)
|
||||
|
||||
// Subscription auth L1 cache
|
||||
viper.SetDefault("subscription_cache.l1_size", 16384)
|
||||
viper.SetDefault("subscription_cache.l1_ttl_seconds", 10)
|
||||
viper.SetDefault("subscription_cache.jitter_percent", 10)
|
||||
|
||||
// Dashboard cache
|
||||
viper.SetDefault("dashboard_cache.enabled", true)
|
||||
viper.SetDefault("dashboard_cache.key_prefix", "sub2api:")
|
||||
@@ -874,6 +1137,16 @@ func setDefaults() {
|
||||
viper.SetDefault("usage_cleanup.worker_interval_seconds", 10)
|
||||
viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800)
|
||||
|
||||
// Idempotency
|
||||
viper.SetDefault("idempotency.observe_only", true)
|
||||
viper.SetDefault("idempotency.default_ttl_seconds", 86400)
|
||||
viper.SetDefault("idempotency.system_operation_ttl_seconds", 3600)
|
||||
viper.SetDefault("idempotency.processing_timeout_seconds", 30)
|
||||
viper.SetDefault("idempotency.failed_retry_backoff_seconds", 5)
|
||||
viper.SetDefault("idempotency.max_stored_response_len", 64*1024)
|
||||
viper.SetDefault("idempotency.cleanup_interval_seconds", 60)
|
||||
viper.SetDefault("idempotency.cleanup_batch_size", 500)
|
||||
|
||||
// Gateway
|
||||
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||
viper.SetDefault("gateway.log_upstream_error_body", true)
|
||||
@@ -882,13 +1155,25 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.failover_on_400", false)
|
||||
viper.SetDefault("gateway.max_account_switches", 10)
|
||||
viper.SetDefault("gateway.max_account_switches_gemini", 3)
|
||||
viper.SetDefault("gateway.force_codex_cli", false)
|
||||
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
||||
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
||||
viper.SetDefault("gateway.gemini_debug_response_headers", false)
|
||||
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
|
||||
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
|
||||
viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
|
||||
viper.SetDefault("gateway.sora_stream_mode", "force")
|
||||
viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true)
|
||||
viper.SetDefault("gateway.sora_media_require_api_key", true)
|
||||
viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900)
|
||||
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
||||
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
||||
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认)
|
||||
viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大)
|
||||
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认)
|
||||
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认)
|
||||
viper.SetDefault("gateway.max_conns_per_host", 1024) // 每主机最大连接数(含活跃;流式/HTTP1.1 场景可调大,如 2400+)
|
||||
viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒)
|
||||
viper.SetDefault("gateway.max_upstream_clients", 5000)
|
||||
viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
|
||||
@@ -912,16 +1197,65 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
|
||||
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
|
||||
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
|
||||
viper.SetDefault("gateway.usage_record.worker_count", 128)
|
||||
viper.SetDefault("gateway.usage_record.queue_size", 16384)
|
||||
viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5)
|
||||
viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample)
|
||||
viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_enabled", true)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_max_workers", 512)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_up_queue_percent", 70)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_down_queue_percent", 15)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_up_step", 32)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10)
|
||||
viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30)
|
||||
viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15)
|
||||
// TLS指纹伪装配置(默认关闭,需要账号级别单独启用)
|
||||
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
|
||||
viper.SetDefault("concurrency.ping_interval", 10)
|
||||
|
||||
// Sora 直连配置
|
||||
viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
|
||||
viper.SetDefault("sora.client.timeout_seconds", 120)
|
||||
viper.SetDefault("sora.client.max_retries", 3)
|
||||
viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900)
|
||||
viper.SetDefault("sora.client.poll_interval_seconds", 2)
|
||||
viper.SetDefault("sora.client.max_poll_attempts", 600)
|
||||
viper.SetDefault("sora.client.recent_task_limit", 50)
|
||||
viper.SetDefault("sora.client.recent_task_limit_max", 200)
|
||||
viper.SetDefault("sora.client.debug", false)
|
||||
viper.SetDefault("sora.client.use_openai_token_provider", false)
|
||||
viper.SetDefault("sora.client.headers", map[string]string{})
|
||||
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080")
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131")
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true)
|
||||
viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600)
|
||||
|
||||
viper.SetDefault("sora.storage.type", "local")
|
||||
viper.SetDefault("sora.storage.local_path", "")
|
||||
viper.SetDefault("sora.storage.fallback_to_upstream", true)
|
||||
viper.SetDefault("sora.storage.max_concurrent_downloads", 4)
|
||||
viper.SetDefault("sora.storage.download_timeout_seconds", 120)
|
||||
viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20))
|
||||
viper.SetDefault("sora.storage.debug", false)
|
||||
viper.SetDefault("sora.storage.cleanup.enabled", true)
|
||||
viper.SetDefault("sora.storage.cleanup.retention_days", 7)
|
||||
viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *")
|
||||
|
||||
// TokenRefresh
|
||||
viper.SetDefault("token_refresh.enabled", true)
|
||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
|
||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||
viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
|
||||
|
||||
// Gemini OAuth - configure via environment variables or config file
|
||||
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
|
||||
@@ -930,9 +1264,106 @@ func setDefaults() {
|
||||
viper.SetDefault("gemini.oauth.client_secret", "")
|
||||
viper.SetDefault("gemini.oauth.scopes", "")
|
||||
viper.SetDefault("gemini.quota.policy", "")
|
||||
|
||||
// Security - proxy fallback
|
||||
viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false)
|
||||
|
||||
// Subscription Maintenance (bounded queue + worker pool)
|
||||
viper.SetDefault("subscription_maintenance.worker_count", 2)
|
||||
viper.SetDefault("subscription_maintenance.queue_size", 1024)
|
||||
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
jwtSecret := strings.TrimSpace(c.JWT.Secret)
|
||||
if jwtSecret == "" {
|
||||
return fmt.Errorf("jwt.secret is required")
|
||||
}
|
||||
// NOTE: 按 UTF-8 编码后的字节长度计算。
|
||||
// 选择 bytes 而不是 rune 计数,确保二进制/随机串的长度语义更接近“熵”而非“字符数”。
|
||||
if len([]byte(jwtSecret)) < 32 {
|
||||
return fmt.Errorf("jwt.secret must be at least 32 bytes")
|
||||
}
|
||||
switch c.Log.Level {
|
||||
case "debug", "info", "warn", "error":
|
||||
case "":
|
||||
return fmt.Errorf("log.level is required")
|
||||
default:
|
||||
return fmt.Errorf("log.level must be one of: debug/info/warn/error")
|
||||
}
|
||||
switch c.Log.Format {
|
||||
case "json", "console":
|
||||
case "":
|
||||
return fmt.Errorf("log.format is required")
|
||||
default:
|
||||
return fmt.Errorf("log.format must be one of: json/console")
|
||||
}
|
||||
switch c.Log.StacktraceLevel {
|
||||
case "none", "error", "fatal":
|
||||
case "":
|
||||
return fmt.Errorf("log.stacktrace_level is required")
|
||||
default:
|
||||
return fmt.Errorf("log.stacktrace_level must be one of: none/error/fatal")
|
||||
}
|
||||
if !c.Log.Output.ToStdout && !c.Log.Output.ToFile {
|
||||
return fmt.Errorf("log.output.to_stdout and log.output.to_file cannot both be false")
|
||||
}
|
||||
if c.Log.Rotation.MaxSizeMB <= 0 {
|
||||
return fmt.Errorf("log.rotation.max_size_mb must be positive")
|
||||
}
|
||||
if c.Log.Rotation.MaxBackups < 0 {
|
||||
return fmt.Errorf("log.rotation.max_backups must be non-negative")
|
||||
}
|
||||
if c.Log.Rotation.MaxAgeDays < 0 {
|
||||
return fmt.Errorf("log.rotation.max_age_days must be non-negative")
|
||||
}
|
||||
if c.Log.Sampling.Enabled {
|
||||
if c.Log.Sampling.Initial <= 0 {
|
||||
return fmt.Errorf("log.sampling.initial must be positive when sampling is enabled")
|
||||
}
|
||||
if c.Log.Sampling.Thereafter <= 0 {
|
||||
return fmt.Errorf("log.sampling.thereafter must be positive when sampling is enabled")
|
||||
}
|
||||
} else {
|
||||
if c.Log.Sampling.Initial < 0 {
|
||||
return fmt.Errorf("log.sampling.initial must be non-negative")
|
||||
}
|
||||
if c.Log.Sampling.Thereafter < 0 {
|
||||
return fmt.Errorf("log.sampling.thereafter must be non-negative")
|
||||
}
|
||||
}
|
||||
|
||||
if c.SubscriptionMaintenance.WorkerCount < 0 {
|
||||
return fmt.Errorf("subscription_maintenance.worker_count must be non-negative")
|
||||
}
|
||||
if c.SubscriptionMaintenance.QueueSize < 0 {
|
||||
return fmt.Errorf("subscription_maintenance.queue_size must be non-negative")
|
||||
}
|
||||
|
||||
// Gemini OAuth 配置校验:client_id 与 client_secret 必须同时设置或同时留空。
|
||||
// 留空时表示使用内置的 Gemini CLI OAuth 客户端(其 client_secret 通过环境变量注入)。
|
||||
geminiClientID := strings.TrimSpace(c.Gemini.OAuth.ClientID)
|
||||
geminiClientSecret := strings.TrimSpace(c.Gemini.OAuth.ClientSecret)
|
||||
if (geminiClientID == "") != (geminiClientSecret == "") {
|
||||
return fmt.Errorf("gemini.oauth.client_id and gemini.oauth.client_secret must be both set or both empty")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(c.Server.FrontendURL) != "" {
|
||||
if err := ValidateAbsoluteHTTPURL(c.Server.FrontendURL); err != nil {
|
||||
return fmt.Errorf("server.frontend_url invalid: %w", err)
|
||||
}
|
||||
u, err := url.Parse(strings.TrimSpace(c.Server.FrontendURL))
|
||||
if err != nil {
|
||||
return fmt.Errorf("server.frontend_url invalid: %w", err)
|
||||
}
|
||||
if u.RawQuery != "" || u.ForceQuery {
|
||||
return fmt.Errorf("server.frontend_url invalid: must not include query")
|
||||
}
|
||||
if u.User != nil {
|
||||
return fmt.Errorf("server.frontend_url invalid: must not include userinfo")
|
||||
}
|
||||
warnIfInsecureURL("server.frontend_url", c.Server.FrontendURL)
|
||||
}
|
||||
if c.JWT.ExpireHour <= 0 {
|
||||
return fmt.Errorf("jwt.expire_hour must be positive")
|
||||
}
|
||||
@@ -940,20 +1371,20 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)")
|
||||
}
|
||||
if c.JWT.ExpireHour > 24 {
|
||||
log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour)
|
||||
slog.Warn("jwt.expire_hour is high; consider shorter expiration for security", "expire_hour", c.JWT.ExpireHour)
|
||||
}
|
||||
// JWT Refresh Token配置验证
|
||||
if c.JWT.AccessTokenExpireMinutes <= 0 {
|
||||
return fmt.Errorf("jwt.access_token_expire_minutes must be positive")
|
||||
if c.JWT.AccessTokenExpireMinutes < 0 {
|
||||
return fmt.Errorf("jwt.access_token_expire_minutes must be non-negative")
|
||||
}
|
||||
if c.JWT.AccessTokenExpireMinutes > 720 {
|
||||
log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes)
|
||||
slog.Warn("jwt.access_token_expire_minutes is high; consider shorter expiration for security", "access_token_expire_minutes", c.JWT.AccessTokenExpireMinutes)
|
||||
}
|
||||
if c.JWT.RefreshTokenExpireDays <= 0 {
|
||||
return fmt.Errorf("jwt.refresh_token_expire_days must be positive")
|
||||
}
|
||||
if c.JWT.RefreshTokenExpireDays > 90 {
|
||||
log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays)
|
||||
slog.Warn("jwt.refresh_token_expire_days is high; consider shorter expiration for security", "refresh_token_expire_days", c.JWT.RefreshTokenExpireDays)
|
||||
}
|
||||
if c.JWT.RefreshWindowMinutes < 0 {
|
||||
return fmt.Errorf("jwt.refresh_window_minutes must be non-negative")
|
||||
@@ -1159,9 +1590,116 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative")
|
||||
}
|
||||
}
|
||||
if c.Idempotency.DefaultTTLSeconds <= 0 {
|
||||
return fmt.Errorf("idempotency.default_ttl_seconds must be positive")
|
||||
}
|
||||
if c.Idempotency.SystemOperationTTLSeconds <= 0 {
|
||||
return fmt.Errorf("idempotency.system_operation_ttl_seconds must be positive")
|
||||
}
|
||||
if c.Idempotency.ProcessingTimeoutSeconds <= 0 {
|
||||
return fmt.Errorf("idempotency.processing_timeout_seconds must be positive")
|
||||
}
|
||||
if c.Idempotency.FailedRetryBackoffSeconds <= 0 {
|
||||
return fmt.Errorf("idempotency.failed_retry_backoff_seconds must be positive")
|
||||
}
|
||||
if c.Idempotency.MaxStoredResponseLen <= 0 {
|
||||
return fmt.Errorf("idempotency.max_stored_response_len must be positive")
|
||||
}
|
||||
if c.Idempotency.CleanupIntervalSeconds <= 0 {
|
||||
return fmt.Errorf("idempotency.cleanup_interval_seconds must be positive")
|
||||
}
|
||||
if c.Idempotency.CleanupBatchSize <= 0 {
|
||||
return fmt.Errorf("idempotency.cleanup_batch_size must be positive")
|
||||
}
|
||||
if c.Gateway.MaxBodySize <= 0 {
|
||||
return fmt.Errorf("gateway.max_body_size must be positive")
|
||||
}
|
||||
if c.Gateway.UpstreamResponseReadMaxBytes <= 0 {
|
||||
return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive")
|
||||
}
|
||||
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
|
||||
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
|
||||
}
|
||||
if c.Gateway.SoraMaxBodySize < 0 {
|
||||
return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
|
||||
}
|
||||
if c.Gateway.SoraStreamTimeoutSeconds < 0 {
|
||||
return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.SoraRequestTimeoutSeconds < 0 {
|
||||
return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 {
|
||||
return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative")
|
||||
}
|
||||
if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" {
|
||||
switch mode {
|
||||
case "force", "error":
|
||||
default:
|
||||
return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error")
|
||||
}
|
||||
}
|
||||
if c.Sora.Client.TimeoutSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.MaxRetries < 0 {
|
||||
return fmt.Errorf("sora.client.max_retries must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.PollIntervalSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.MaxPollAttempts < 0 {
|
||||
return fmt.Errorf("sora.client.max_poll_attempts must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.RecentTaskLimit < 0 {
|
||||
return fmt.Errorf("sora.client.recent_task_limit must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.RecentTaskLimitMax < 0 {
|
||||
return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 &&
|
||||
c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
|
||||
c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
|
||||
}
|
||||
if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative")
|
||||
}
|
||||
if !c.Sora.Client.CurlCFFISidecar.Enabled {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true")
|
||||
}
|
||||
if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" {
|
||||
return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required")
|
||||
}
|
||||
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
|
||||
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
|
||||
}
|
||||
if c.Sora.Storage.DownloadTimeoutSeconds < 0 {
|
||||
return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Storage.MaxDownloadBytes < 0 {
|
||||
return fmt.Errorf("sora.storage.max_download_bytes must be non-negative")
|
||||
}
|
||||
if c.Sora.Storage.Cleanup.Enabled {
|
||||
if c.Sora.Storage.Cleanup.RetentionDays <= 0 {
|
||||
return fmt.Errorf("sora.storage.cleanup.retention_days must be positive")
|
||||
}
|
||||
if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" {
|
||||
return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled")
|
||||
}
|
||||
} else {
|
||||
if c.Sora.Storage.Cleanup.RetentionDays < 0 {
|
||||
return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative")
|
||||
}
|
||||
}
|
||||
if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" {
|
||||
return fmt.Errorf("sora.storage.type must be 'local'")
|
||||
}
|
||||
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
|
||||
switch c.Gateway.ConnectionPoolIsolation {
|
||||
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
|
||||
@@ -1183,7 +1721,7 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
|
||||
}
|
||||
if c.Gateway.IdleConnTimeoutSeconds > 180 {
|
||||
log.Printf("Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse.", c.Gateway.IdleConnTimeoutSeconds)
|
||||
slog.Warn("gateway.idle_conn_timeout_seconds is high; consider 60-120 seconds for better connection reuse", "idle_conn_timeout_seconds", c.Gateway.IdleConnTimeoutSeconds)
|
||||
}
|
||||
if c.Gateway.MaxUpstreamClients <= 0 {
|
||||
return fmt.Errorf("gateway.max_upstream_clients must be positive")
|
||||
@@ -1214,6 +1752,70 @@ func (c *Config) Validate() error {
|
||||
if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 {
|
||||
return fmt.Errorf("gateway.max_line_size must be at least 1MB")
|
||||
}
|
||||
if c.Gateway.UsageRecord.WorkerCount <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.worker_count must be positive")
|
||||
}
|
||||
if c.Gateway.UsageRecord.QueueSize <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.queue_size must be positive")
|
||||
}
|
||||
if c.Gateway.UsageRecord.TaskTimeoutSeconds <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.task_timeout_seconds must be positive")
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy)) {
|
||||
case UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync:
|
||||
default:
|
||||
return fmt.Errorf("gateway.usage_record.overflow_policy must be one of: %s/%s/%s",
|
||||
UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync)
|
||||
}
|
||||
if c.Gateway.UsageRecord.OverflowSamplePercent < 0 || c.Gateway.UsageRecord.OverflowSamplePercent > 100 {
|
||||
return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be between 0-100")
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy), UsageRecordOverflowPolicySample) &&
|
||||
c.Gateway.UsageRecord.OverflowSamplePercent <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be positive when overflow_policy=sample")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleEnabled {
|
||||
if c.Gateway.UsageRecord.AutoScaleMinWorkers <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_min_workers must be positive")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleMaxWorkers <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be positive")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleMaxWorkers < c.Gateway.UsageRecord.AutoScaleMinWorkers {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be >= auto_scale_min_workers")
|
||||
}
|
||||
if c.Gateway.UsageRecord.WorkerCount < c.Gateway.UsageRecord.AutoScaleMinWorkers ||
|
||||
c.Gateway.UsageRecord.WorkerCount > c.Gateway.UsageRecord.AutoScaleMaxWorkers {
|
||||
return fmt.Errorf("gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleUpQueuePercent <= 0 || c.Gateway.UsageRecord.AutoScaleUpQueuePercent > 100 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_up_queue_percent must be between 1-100")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleDownQueuePercent < 0 || c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= 100 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be between 0-99")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= c.Gateway.UsageRecord.AutoScaleUpQueuePercent {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be less than auto_scale_up_queue_percent")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleUpStep <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_up_step must be positive")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleDownStep <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_down_step must be positive")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_check_interval_seconds must be positive")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleCooldownSeconds < 0 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_cooldown_seconds must be non-negative")
|
||||
}
|
||||
}
|
||||
if c.Gateway.UserGroupRateCacheTTLSeconds <= 0 {
|
||||
return fmt.Errorf("gateway.user_group_rate_cache_ttl_seconds must be positive")
|
||||
}
|
||||
if c.Gateway.ModelsListCacheTTLSeconds < 10 || c.Gateway.ModelsListCacheTTLSeconds > 30 {
|
||||
return fmt.Errorf("gateway.models_list_cache_ttl_seconds must be between 10-30")
|
||||
}
|
||||
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
|
||||
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
|
||||
}
|
||||
@@ -1420,6 +2022,6 @@ func warnIfInsecureURL(field, raw string) {
|
||||
return
|
||||
}
|
||||
if strings.EqualFold(u.Scheme, "http") {
|
||||
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
|
||||
slog.Warn("url uses http scheme; use https in production to avoid token leakage", "field", field)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,25 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func resetViperWithJWTSecret(t *testing.T) {
|
||||
t.Helper()
|
||||
viper.Reset()
|
||||
t.Setenv("JWT_SECRET", strings.Repeat("x", 32))
|
||||
}
|
||||
|
||||
func TestLoadForBootstrapAllowsMissingJWTSecret(t *testing.T) {
|
||||
viper.Reset()
|
||||
t.Setenv("JWT_SECRET", "")
|
||||
|
||||
cfg, err := LoadForBootstrap()
|
||||
if err != nil {
|
||||
t.Fatalf("LoadForBootstrap() error: %v", err)
|
||||
}
|
||||
if cfg.JWT.Secret != "" {
|
||||
t.Fatalf("LoadForBootstrap() should keep empty jwt.secret during bootstrap")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeRunMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
@@ -29,7 +48,7 @@ func TestNormalizeRunMode(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -56,8 +75,44 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultIdempotencyConfig(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if !cfg.Idempotency.ObserveOnly {
|
||||
t.Fatalf("Idempotency.ObserveOnly = false, want true")
|
||||
}
|
||||
if cfg.Idempotency.DefaultTTLSeconds != 86400 {
|
||||
t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 86400", cfg.Idempotency.DefaultTTLSeconds)
|
||||
}
|
||||
if cfg.Idempotency.SystemOperationTTLSeconds != 3600 {
|
||||
t.Fatalf("Idempotency.SystemOperationTTLSeconds = %d, want 3600", cfg.Idempotency.SystemOperationTTLSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadIdempotencyConfigFromEnv(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
t.Setenv("IDEMPOTENCY_OBSERVE_ONLY", "false")
|
||||
t.Setenv("IDEMPOTENCY_DEFAULT_TTL_SECONDS", "600")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
if cfg.Idempotency.ObserveOnly {
|
||||
t.Fatalf("Idempotency.ObserveOnly = true, want false")
|
||||
}
|
||||
if cfg.Idempotency.DefaultTTLSeconds != 600 {
|
||||
t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 600", cfg.Idempotency.DefaultTTLSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSchedulingConfigFromEnv(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
|
||||
|
||||
cfg, err := Load()
|
||||
@@ -71,7 +126,7 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLoadDefaultSecurityToggles(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -87,13 +142,69 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
|
||||
if !cfg.Security.URLAllowlist.AllowPrivateHosts {
|
||||
t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true")
|
||||
}
|
||||
if cfg.Security.ResponseHeaders.Enabled {
|
||||
t.Fatalf("ResponseHeaders.Enabled = true, want false")
|
||||
if !cfg.Security.ResponseHeaders.Enabled {
|
||||
t.Fatalf("ResponseHeaders.Enabled = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultServerMode(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Server.Mode != "release" {
|
||||
t.Fatalf("Server.Mode = %q, want %q", cfg.Server.Mode, "release")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultJWTAccessTokenExpireMinutes(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.JWT.ExpireHour != 24 {
|
||||
t.Fatalf("JWT.ExpireHour = %d, want 24", cfg.JWT.ExpireHour)
|
||||
}
|
||||
if cfg.JWT.AccessTokenExpireMinutes != 0 {
|
||||
t.Fatalf("JWT.AccessTokenExpireMinutes = %d, want 0", cfg.JWT.AccessTokenExpireMinutes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadJWTAccessTokenExpireMinutesFromEnv(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
t.Setenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "90")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.JWT.AccessTokenExpireMinutes != 90 {
|
||||
t.Fatalf("JWT.AccessTokenExpireMinutes = %d, want 90", cfg.JWT.AccessTokenExpireMinutes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultDatabaseSSLMode(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Database.SSLMode != "prefer" {
|
||||
t.Fatalf("Database.SSLMode = %q, want %q", cfg.Database.SSLMode, "prefer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -118,7 +229,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -143,7 +254,7 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -168,7 +279,7 @@ func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -188,7 +299,7 @@ func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -207,7 +318,7 @@ func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -244,7 +355,7 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -263,7 +374,7 @@ func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -282,7 +393,7 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -307,7 +418,7 @@ func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -326,7 +437,7 @@ func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateUsageCleanupConfigDisabled(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -424,6 +535,40 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateServerFrontendURL(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Server.FrontendURL = "https://example.com"
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("Validate() frontend_url valid error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Server.FrontendURL = "https://example.com/path"
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("Validate() frontend_url with path valid error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Server.FrontendURL = "https://example.com?utm=1"
|
||||
if err := cfg.Validate(); err == nil {
|
||||
t.Fatalf("Validate() should reject server.frontend_url with query")
|
||||
}
|
||||
|
||||
cfg.Server.FrontendURL = "https://user:pass@example.com"
|
||||
if err := cfg.Validate(); err == nil {
|
||||
t.Fatalf("Validate() should reject server.frontend_url with userinfo")
|
||||
}
|
||||
|
||||
cfg.Server.FrontendURL = "/relative"
|
||||
if err := cfg.Validate(); err == nil {
|
||||
t.Fatalf("Validate() should reject relative server.frontend_url")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFrontendRedirectURL(t *testing.T) {
|
||||
if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err)
|
||||
@@ -445,6 +590,7 @@ func TestValidateFrontendRedirectURL(t *testing.T) {
|
||||
func TestWarnIfInsecureURL(t *testing.T) {
|
||||
warnIfInsecureURL("test", "http://example.com")
|
||||
warnIfInsecureURL("test", "bad://url")
|
||||
warnIfInsecureURL("test", "://invalid")
|
||||
}
|
||||
|
||||
func TestGenerateJWTSecretDefaultLength(t *testing.T) {
|
||||
@@ -458,7 +604,7 @@ func TestGenerateJWTSecretDefaultLength(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -476,7 +622,7 @@ func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateConcurrencyPingInterval(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -493,14 +639,14 @@ func TestValidateConcurrencyPingInterval(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProvideConfig(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
if _, err := ProvideConfig(); err != nil {
|
||||
t.Fatalf("ProvideConfig() error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
@@ -544,6 +690,24 @@ func TestGenerateJWTSecretWithLength(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseDSNWithTimezone_WithPassword(t *testing.T) {
|
||||
d := &DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "u",
|
||||
Password: "p",
|
||||
DBName: "db",
|
||||
SSLMode: "prefer",
|
||||
}
|
||||
got := d.DSNWithTimezone("UTC")
|
||||
if !strings.Contains(got, "password=p") {
|
||||
t.Fatalf("DSNWithTimezone should include password: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "TimeZone=UTC") {
|
||||
t.Fatalf("DSNWithTimezone should include TimeZone=UTC: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) {
|
||||
if err := ValidateAbsoluteHTTPURL("https://"); err == nil {
|
||||
t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host")
|
||||
@@ -566,10 +730,35 @@ func TestWarnIfInsecureURLHTTPS(t *testing.T) {
|
||||
warnIfInsecureURL("secure", "https://example.com")
|
||||
}
|
||||
|
||||
func TestValidateJWTSecret_UTF8Bytes(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
// 31 bytes (< 32) even though it's 31 characters.
|
||||
cfg.JWT.Secret = strings.Repeat("a", 31)
|
||||
err = cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("Validate() should reject 31-byte secret")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "at least 32 bytes") {
|
||||
t.Fatalf("Validate() error = %v", err)
|
||||
}
|
||||
|
||||
// 32 bytes OK.
|
||||
cfg.JWT.Secret = strings.Repeat("a", 32)
|
||||
err = cfg.Validate()
|
||||
if err != nil {
|
||||
t.Fatalf("Validate() should accept 32-byte secret: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfigErrors(t *testing.T) {
|
||||
buildValid := func(t *testing.T) *Config {
|
||||
t.Helper()
|
||||
viper.Reset()
|
||||
resetViperWithJWTSecret(t)
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
@@ -582,6 +771,26 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate func(*Config)
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "jwt secret required",
|
||||
mutate: func(c *Config) { c.JWT.Secret = "" },
|
||||
wantErr: "jwt.secret is required",
|
||||
},
|
||||
{
|
||||
name: "jwt secret min bytes",
|
||||
mutate: func(c *Config) { c.JWT.Secret = strings.Repeat("a", 31) },
|
||||
wantErr: "jwt.secret must be at least 32 bytes",
|
||||
},
|
||||
{
|
||||
name: "subscription maintenance worker_count non-negative",
|
||||
mutate: func(c *Config) { c.SubscriptionMaintenance.WorkerCount = -1 },
|
||||
wantErr: "subscription_maintenance.worker_count",
|
||||
},
|
||||
{
|
||||
name: "subscription maintenance queue_size non-negative",
|
||||
mutate: func(c *Config) { c.SubscriptionMaintenance.QueueSize = -1 },
|
||||
wantErr: "subscription_maintenance.queue_size",
|
||||
},
|
||||
{
|
||||
name: "jwt expire hour positive",
|
||||
mutate: func(c *Config) { c.JWT.ExpireHour = 0 },
|
||||
@@ -592,6 +801,11 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.JWT.ExpireHour = 200 },
|
||||
wantErr: "jwt.expire_hour must be <= 168",
|
||||
},
|
||||
{
|
||||
name: "jwt access token expire minutes non-negative",
|
||||
mutate: func(c *Config) { c.JWT.AccessTokenExpireMinutes = -1 },
|
||||
wantErr: "jwt.access_token_expire_minutes must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "csp policy required",
|
||||
mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" },
|
||||
@@ -799,6 +1013,84 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 },
|
||||
wantErr: "gateway.max_line_size must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record worker count",
|
||||
mutate: func(c *Config) { c.Gateway.UsageRecord.WorkerCount = 0 },
|
||||
wantErr: "gateway.usage_record.worker_count",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record queue size",
|
||||
mutate: func(c *Config) { c.Gateway.UsageRecord.QueueSize = 0 },
|
||||
wantErr: "gateway.usage_record.queue_size",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record timeout",
|
||||
mutate: func(c *Config) { c.Gateway.UsageRecord.TaskTimeoutSeconds = 0 },
|
||||
wantErr: "gateway.usage_record.task_timeout_seconds",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record overflow policy",
|
||||
mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowPolicy = "invalid" },
|
||||
wantErr: "gateway.usage_record.overflow_policy",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record sample percent range",
|
||||
mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowSamplePercent = 101 },
|
||||
wantErr: "gateway.usage_record.overflow_sample_percent",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record sample percent required for sample policy",
|
||||
mutate: func(c *Config) {
|
||||
c.Gateway.UsageRecord.OverflowPolicy = UsageRecordOverflowPolicySample
|
||||
c.Gateway.UsageRecord.OverflowSamplePercent = 0
|
||||
},
|
||||
wantErr: "gateway.usage_record.overflow_sample_percent must be positive",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record auto scale max gte min",
|
||||
mutate: func(c *Config) {
|
||||
c.Gateway.UsageRecord.AutoScaleMinWorkers = 256
|
||||
c.Gateway.UsageRecord.AutoScaleMaxWorkers = 128
|
||||
},
|
||||
wantErr: "gateway.usage_record.auto_scale_max_workers",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record worker in auto scale range",
|
||||
mutate: func(c *Config) {
|
||||
c.Gateway.UsageRecord.AutoScaleMinWorkers = 200
|
||||
c.Gateway.UsageRecord.AutoScaleMaxWorkers = 300
|
||||
c.Gateway.UsageRecord.WorkerCount = 128
|
||||
},
|
||||
wantErr: "gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record auto scale queue thresholds order",
|
||||
mutate: func(c *Config) {
|
||||
c.Gateway.UsageRecord.AutoScaleUpQueuePercent = 50
|
||||
c.Gateway.UsageRecord.AutoScaleDownQueuePercent = 50
|
||||
},
|
||||
wantErr: "gateway.usage_record.auto_scale_down_queue_percent must be less",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record auto scale up step",
|
||||
mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleUpStep = 0 },
|
||||
wantErr: "gateway.usage_record.auto_scale_up_step",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record auto scale interval",
|
||||
mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 },
|
||||
wantErr: "gateway.usage_record.auto_scale_check_interval_seconds",
|
||||
},
|
||||
{
|
||||
name: "gateway user group rate cache ttl",
|
||||
mutate: func(c *Config) { c.Gateway.UserGroupRateCacheTTLSeconds = 0 },
|
||||
wantErr: "gateway.user_group_rate_cache_ttl_seconds",
|
||||
},
|
||||
{
|
||||
name: "gateway models list cache ttl range",
|
||||
mutate: func(c *Config) { c.Gateway.ModelsListCacheTTLSeconds = 31 },
|
||||
wantErr: "gateway.models_list_cache_ttl_seconds",
|
||||
},
|
||||
{
|
||||
name: "gateway scheduling sticky waiting",
|
||||
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
|
||||
@@ -822,6 +1114,37 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
},
|
||||
wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds",
|
||||
},
|
||||
{
|
||||
name: "log level invalid",
|
||||
mutate: func(c *Config) { c.Log.Level = "trace" },
|
||||
wantErr: "log.level",
|
||||
},
|
||||
{
|
||||
name: "log format invalid",
|
||||
mutate: func(c *Config) { c.Log.Format = "plain" },
|
||||
wantErr: "log.format",
|
||||
},
|
||||
{
|
||||
name: "log output disabled",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.Output.ToStdout = false
|
||||
c.Log.Output.ToFile = false
|
||||
},
|
||||
wantErr: "log.output.to_stdout and log.output.to_file cannot both be false",
|
||||
},
|
||||
{
|
||||
name: "log rotation size",
|
||||
mutate: func(c *Config) { c.Log.Rotation.MaxSizeMB = 0 },
|
||||
wantErr: "log.rotation.max_size_mb",
|
||||
},
|
||||
{
|
||||
name: "log sampling enabled invalid",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.Sampling.Enabled = true
|
||||
c.Log.Sampling.Initial = 0
|
||||
},
|
||||
wantErr: "log.sampling.initial",
|
||||
},
|
||||
{
|
||||
name: "ops metrics collector ttl",
|
||||
mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 },
|
||||
@@ -850,3 +1173,234 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Gateway.UsageRecord.AutoScaleEnabled = false
|
||||
cfg.Gateway.UsageRecord.WorkerCount = 64
|
||||
|
||||
// 自动扩缩容关闭时,这些字段应被忽略,不应导致校验失败。
|
||||
cfg.Gateway.UsageRecord.AutoScaleMinWorkers = 0
|
||||
cfg.Gateway.UsageRecord.AutoScaleMaxWorkers = 0
|
||||
cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent = 0
|
||||
cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent = 100
|
||||
cfg.Gateway.UsageRecord.AutoScaleUpStep = 0
|
||||
cfg.Gateway.UsageRecord.AutoScaleDownStep = 0
|
||||
cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0
|
||||
cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds = -1
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("Validate() should ignore auto scale fields when disabled: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
mutate func(*Config)
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "log level required",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.Level = ""
|
||||
},
|
||||
wantErr: "log.level is required",
|
||||
},
|
||||
{
|
||||
name: "log format required",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.Format = ""
|
||||
},
|
||||
wantErr: "log.format is required",
|
||||
},
|
||||
{
|
||||
name: "log stacktrace required",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.StacktraceLevel = ""
|
||||
},
|
||||
wantErr: "log.stacktrace_level is required",
|
||||
},
|
||||
{
|
||||
name: "log max backups non-negative",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.Rotation.MaxBackups = -1
|
||||
},
|
||||
wantErr: "log.rotation.max_backups must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "log max age non-negative",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.Rotation.MaxAgeDays = -1
|
||||
},
|
||||
wantErr: "log.rotation.max_age_days must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "sampling thereafter non-negative when disabled",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.Sampling.Enabled = false
|
||||
c.Log.Sampling.Thereafter = -1
|
||||
},
|
||||
wantErr: "log.sampling.thereafter must be non-negative",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
tt.mutate(cfg)
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraCurlCFFISidecarDefaults(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if !cfg.Sora.Client.CurlCFFISidecar.Enabled {
|
||||
t.Fatalf("Sora curl_cffi sidecar should be enabled by default")
|
||||
}
|
||||
if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 {
|
||||
t.Fatalf("Sora cloudflare challenge cooldown should be positive by default")
|
||||
}
|
||||
if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" {
|
||||
t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default")
|
||||
}
|
||||
if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" {
|
||||
t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default")
|
||||
}
|
||||
if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled {
|
||||
t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default")
|
||||
}
|
||||
if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 {
|
||||
t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CurlCFFISidecar.Enabled = false
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") {
|
||||
t.Fatalf("Validate() error = %v, want sidecar enabled error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CurlCFFISidecar.BaseURL = " "
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") {
|
||||
t.Fatalf("Validate() error = %v, want sidecar base_url required error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") {
|
||||
t.Fatalf("Validate() error = %v, want sidecar session ttl error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") {
|
||||
t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.WorkerCount != 128 {
|
||||
t.Fatalf("worker_count = %d, want 128", cfg.Gateway.UsageRecord.WorkerCount)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.QueueSize != 16384 {
|
||||
t.Fatalf("queue_size = %d, want 16384", cfg.Gateway.UsageRecord.QueueSize)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.TaskTimeoutSeconds != 5 {
|
||||
t.Fatalf("task_timeout_seconds = %d, want 5", cfg.Gateway.UsageRecord.TaskTimeoutSeconds)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample {
|
||||
t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySample)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.OverflowSamplePercent != 10 {
|
||||
t.Fatalf("overflow_sample_percent = %d, want 10", cfg.Gateway.UsageRecord.OverflowSamplePercent)
|
||||
}
|
||||
if !cfg.Gateway.UsageRecord.AutoScaleEnabled {
|
||||
t.Fatalf("auto_scale_enabled = false, want true")
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleMinWorkers != 128 {
|
||||
t.Fatalf("auto_scale_min_workers = %d, want 128", cfg.Gateway.UsageRecord.AutoScaleMinWorkers)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleMaxWorkers != 512 {
|
||||
t.Fatalf("auto_scale_max_workers = %d, want 512", cfg.Gateway.UsageRecord.AutoScaleMaxWorkers)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent != 70 {
|
||||
t.Fatalf("auto_scale_up_queue_percent = %d, want 70", cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent != 15 {
|
||||
t.Fatalf("auto_scale_down_queue_percent = %d, want 15", cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleUpStep != 32 {
|
||||
t.Fatalf("auto_scale_up_step = %d, want 32", cfg.Gateway.UsageRecord.AutoScaleUpStep)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleDownStep != 16 {
|
||||
t.Fatalf("auto_scale_down_step = %d, want 16", cfg.Gateway.UsageRecord.AutoScaleDownStep)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds != 3 {
|
||||
t.Fatalf("auto_scale_check_interval_seconds = %d, want 3", cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds != 10 {
|
||||
t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,5 +9,5 @@ var ProviderSet = wire.NewSet(
|
||||
|
||||
// ProvideConfig 提供应用配置
|
||||
func ProvideConfig() (*Config, error) {
|
||||
return Load()
|
||||
return LoadForBootstrap()
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ const (
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
PlatformSora = "sora"
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
|
||||
@@ -175,22 +175,28 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
dataPayload := req.Data
|
||||
if err := validateDataHeader(dataPayload); err != nil {
|
||||
if err := validateDataHeader(req.Data); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
executeAdminIdempotentJSON(c, "admin.accounts.import_data", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
return h.importData(ctx, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) (DataImportResult, error) {
|
||||
skipDefaultGroupBind := true
|
||||
if req.SkipDefaultGroupBind != nil {
|
||||
skipDefaultGroupBind = *req.SkipDefaultGroupBind
|
||||
}
|
||||
|
||||
dataPayload := req.Data
|
||||
result := DataImportResult{}
|
||||
existingProxies, err := h.listAllProxies(c.Request.Context())
|
||||
|
||||
existingProxies, err := h.listAllProxies(ctx)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return result, err
|
||||
}
|
||||
|
||||
proxyKeyToID := make(map[string]int64, len(existingProxies))
|
||||
@@ -221,8 +227,8 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
proxyKeyToID[key] = existingID
|
||||
result.ProxyReused++
|
||||
if normalizedStatus != "" {
|
||||
if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus {
|
||||
_, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{
|
||||
if proxy, getErr := h.adminService.GetProxy(ctx, existingID); getErr == nil && proxy != nil && proxy.Status != normalizedStatus {
|
||||
_, _ = h.adminService.UpdateProxy(ctx, existingID, &service.UpdateProxyInput{
|
||||
Status: normalizedStatus,
|
||||
})
|
||||
}
|
||||
@@ -230,7 +236,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
|
||||
created, createErr := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
|
||||
Name: defaultProxyName(item.Name),
|
||||
Protocol: item.Protocol,
|
||||
Host: item.Host,
|
||||
@@ -238,13 +244,13 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
Username: item.Username,
|
||||
Password: item.Password,
|
||||
})
|
||||
if err != nil {
|
||||
if createErr != nil {
|
||||
result.ProxyFailed++
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "proxy",
|
||||
Name: item.Name,
|
||||
ProxyKey: key,
|
||||
Message: err.Error(),
|
||||
Message: createErr.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
@@ -252,7 +258,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
result.ProxyCreated++
|
||||
|
||||
if normalizedStatus != "" && normalizedStatus != created.Status {
|
||||
_, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{
|
||||
_, _ = h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{
|
||||
Status: normalizedStatus,
|
||||
})
|
||||
}
|
||||
@@ -303,7 +309,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
SkipDefaultGroupBind: skipDefaultGroupBind,
|
||||
}
|
||||
|
||||
if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil {
|
||||
if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil {
|
||||
result.AccountFailed++
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "account",
|
||||
@@ -315,7 +321,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
result.AccountCreated++
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) {
|
||||
|
||||
@@ -2,7 +2,13 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -142,6 +148,44 @@ type AccountWithConcurrency struct {
|
||||
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
|
||||
}
|
||||
|
||||
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
||||
item := AccountWithConcurrency{
|
||||
Account: dto.AccountFromService(account),
|
||||
CurrentConcurrency: 0,
|
||||
}
|
||||
if account == nil {
|
||||
return item
|
||||
}
|
||||
|
||||
if h.concurrencyService != nil {
|
||||
if counts, err := h.concurrencyService.GetAccountConcurrencyBatch(ctx, []int64{account.ID}); err == nil {
|
||||
item.CurrentConcurrency = counts[account.ID]
|
||||
}
|
||||
}
|
||||
|
||||
if account.IsAnthropicOAuthOrSetupToken() {
|
||||
if h.accountUsageService != nil && account.GetWindowCostLimit() > 0 {
|
||||
startTime := account.GetCurrentWindowStartTime()
|
||||
if stats, err := h.accountUsageService.GetAccountWindowStats(ctx, account.ID, startTime); err == nil && stats != nil {
|
||||
cost := stats.StandardCost
|
||||
item.CurrentWindowCost = &cost
|
||||
}
|
||||
}
|
||||
|
||||
if h.sessionLimitCache != nil && account.GetMaxSessions() > 0 {
|
||||
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
idleTimeouts := map[int64]time.Duration{account.ID: idleTimeout}
|
||||
if sessions, err := h.sessionLimitCache.GetActiveSessionCountBatch(ctx, []int64{account.ID}, idleTimeouts); err == nil {
|
||||
if count, ok := sessions[account.ID]; ok {
|
||||
item.ActiveSessions = &count
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return item
|
||||
}
|
||||
|
||||
// List handles listing all accounts with pagination
|
||||
// GET /api/v1/admin/accounts
|
||||
func (h *AccountHandler) List(c *gin.Context) {
|
||||
@@ -262,9 +306,71 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
result[i] = item
|
||||
}
|
||||
|
||||
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search)
|
||||
if etag != "" {
|
||||
c.Header("ETag", etag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), etag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
response.Paginated(c, result, total, page, pageSize)
|
||||
}
|
||||
|
||||
func buildAccountsListETag(
|
||||
items []AccountWithConcurrency,
|
||||
total int64,
|
||||
page, pageSize int,
|
||||
platform, accountType, status, search string,
|
||||
) string {
|
||||
payload := struct {
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Platform string `json:"platform"`
|
||||
AccountType string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Search string `json:"search"`
|
||||
Items []AccountWithConcurrency `json:"items"`
|
||||
}{
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Platform: platform,
|
||||
AccountType: accountType,
|
||||
Status: status,
|
||||
Search: search,
|
||||
Items: items,
|
||||
}
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(raw)
|
||||
return "\"" + hex.EncodeToString(sum[:]) + "\""
|
||||
}
|
||||
|
||||
func ifNoneMatchMatched(ifNoneMatch, etag string) bool {
|
||||
if etag == "" || ifNoneMatch == "" {
|
||||
return false
|
||||
}
|
||||
for _, token := range strings.Split(ifNoneMatch, ",") {
|
||||
candidate := strings.TrimSpace(token)
|
||||
if candidate == "*" {
|
||||
return true
|
||||
}
|
||||
if candidate == etag {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(candidate, "W/") && strings.TrimPrefix(candidate, "W/") == etag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetByID handles getting an account by ID
|
||||
// GET /api/v1/admin/accounts/:id
|
||||
func (h *AccountHandler) GetByID(c *gin.Context) {
|
||||
@@ -280,7 +386,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// Create handles creating a new account
|
||||
@@ -299,21 +405,27 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
// 确定是否跳过混合渠道检查
|
||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||
|
||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||
Name: req.Name,
|
||||
Notes: req.Notes,
|
||||
Platform: req.Platform,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
|
||||
Name: req.Name,
|
||||
Notes: req.Notes,
|
||||
Platform: req.Platform,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
return h.buildAccountResponseWithRuntime(ctx, account), nil
|
||||
})
|
||||
if err != nil {
|
||||
// 检查是否为混合渠道错误
|
||||
@@ -334,11 +446,17 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
if result != nil && result.Replayed {
|
||||
c.Header("X-Idempotency-Replayed", "true")
|
||||
}
|
||||
response.Success(c, result.Data)
|
||||
}
|
||||
|
||||
// Update handles updating an account
|
||||
@@ -402,7 +520,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// Delete handles deleting an account
|
||||
@@ -660,7 +778,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
|
||||
}
|
||||
|
||||
// GetStats handles getting account statistics
|
||||
@@ -718,7 +836,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// BatchCreate handles batch creating accounts
|
||||
@@ -732,61 +850,62 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
success := 0
|
||||
failed := 0
|
||||
results := make([]gin.H, 0, len(req.Accounts))
|
||||
executeAdminIdempotentJSON(c, "admin.accounts.batch_create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
success := 0
|
||||
failed := 0
|
||||
results := make([]gin.H, 0, len(req.Accounts))
|
||||
|
||||
for _, item := range req.Accounts {
|
||||
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
|
||||
failed++
|
||||
for _, item := range req.Accounts {
|
||||
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"success": false,
|
||||
"error": "rate_multiplier must be >= 0",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
|
||||
|
||||
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
|
||||
Name: item.Name,
|
||||
Notes: item.Notes,
|
||||
Platform: item.Platform,
|
||||
Type: item.Type,
|
||||
Credentials: item.Credentials,
|
||||
Extra: item.Extra,
|
||||
ProxyID: item.ProxyID,
|
||||
Concurrency: item.Concurrency,
|
||||
Priority: item.Priority,
|
||||
RateMultiplier: item.RateMultiplier,
|
||||
GroupIDs: item.GroupIDs,
|
||||
ExpiresAt: item.ExpiresAt,
|
||||
AutoPauseOnExpired: item.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"success": false,
|
||||
"error": "rate_multiplier must be >= 0",
|
||||
"id": account.ID,
|
||||
"success": true,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
|
||||
|
||||
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
|
||||
Name: item.Name,
|
||||
Notes: item.Notes,
|
||||
Platform: item.Platform,
|
||||
Type: item.Type,
|
||||
Credentials: item.Credentials,
|
||||
Extra: item.Extra,
|
||||
ProxyID: item.ProxyID,
|
||||
Concurrency: item.Concurrency,
|
||||
Priority: item.Priority,
|
||||
RateMultiplier: item.RateMultiplier,
|
||||
GroupIDs: item.GroupIDs,
|
||||
ExpiresAt: item.ExpiresAt,
|
||||
AutoPauseOnExpired: item.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"id": account.ID,
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
return gin.H{
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -824,57 +943,58 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
success := 0
|
||||
failed := 0
|
||||
results := []gin.H{}
|
||||
|
||||
// 阶段一:预验证所有账号存在,收集 credentials
|
||||
type accountUpdate struct {
|
||||
ID int64
|
||||
Credentials map[string]any
|
||||
}
|
||||
updates := make([]accountUpdate, 0, len(req.AccountIDs))
|
||||
for _, accountID := range req.AccountIDs {
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": false,
|
||||
"error": "Account not found",
|
||||
})
|
||||
continue
|
||||
response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID))
|
||||
return
|
||||
}
|
||||
|
||||
// Update credentials field
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
|
||||
account.Credentials[req.Field] = req.Value
|
||||
updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials})
|
||||
}
|
||||
|
||||
// Update account
|
||||
updateInput := &service.UpdateAccountInput{
|
||||
Credentials: account.Credentials,
|
||||
}
|
||||
|
||||
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
|
||||
if err != nil {
|
||||
// 阶段二:依次更新,返回每个账号的成功/失败明细,便于调用方重试
|
||||
success := 0
|
||||
failed := 0
|
||||
successIDs := make([]int64, 0, len(updates))
|
||||
failedIDs := make([]int64, 0, len(updates))
|
||||
results := make([]gin.H, 0, len(updates))
|
||||
for _, u := range updates {
|
||||
updateInput := &service.UpdateAccountInput{Credentials: u.Credentials}
|
||||
if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil {
|
||||
failed++
|
||||
failedIDs = append(failedIDs, u.ID)
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"account_id": u.ID,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
success++
|
||||
successIDs = append(successIDs, u.ID)
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"account_id": u.ID,
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"success_ids": successIDs,
|
||||
"failed_ids": failedIDs,
|
||||
"results": results,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1109,7 +1229,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// GetTempUnschedulable handles getting temporary unschedulable status
|
||||
@@ -1199,7 +1325,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// GetAvailableModels handles getting available models for an account
|
||||
@@ -1325,6 +1451,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Sora accounts
|
||||
if account.Platform == service.PlatformSora {
|
||||
response.Success(c, service.DefaultSoraModels(nil))
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Claude/Anthropic accounts
|
||||
// For OAuth and Setup-Token accounts: return default models
|
||||
if account.IsOAuth() {
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
adminSvc := newStubAdminService()
|
||||
handler := NewAccountHandler(
|
||||
adminSvc,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
router := gin.New()
|
||||
router.POST("/api/v1/admin/accounts", handler.Create)
|
||||
|
||||
body := map[string]any{
|
||||
"name": "anthropic-key-1",
|
||||
"platform": "anthropic",
|
||||
"type": "apikey",
|
||||
"credentials": map[string]any{
|
||||
"api_key": "sk-ant-xxx",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
},
|
||||
"extra": map[string]any{
|
||||
"anthropic_passthrough": true,
|
||||
},
|
||||
"concurrency": 1,
|
||||
"priority": 1,
|
||||
}
|
||||
raw, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(raw))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Len(t, adminSvc.createdAccounts, 1)
|
||||
|
||||
created := adminSvc.createdAccounts[0]
|
||||
require.Equal(t, "anthropic", created.Platform)
|
||||
require.Equal(t, "apikey", created.Type)
|
||||
require.NotNil(t, created.Extra)
|
||||
require.Equal(t, true, created.Extra["anthropic_passthrough"])
|
||||
}
|
||||
@@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
||||
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
|
||||
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
|
||||
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
|
||||
router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality)
|
||||
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
|
||||
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
|
||||
|
||||
@@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) {
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
@@ -58,6 +58,96 @@ func TestParseOpsDuration(t *testing.T) {
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParseOpsOpenAITokenStatsDuration(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want time.Duration
|
||||
ok bool
|
||||
}{
|
||||
{input: "30m", want: 30 * time.Minute, ok: true},
|
||||
{input: "1h", want: time.Hour, ok: true},
|
||||
{input: "1d", want: 24 * time.Hour, ok: true},
|
||||
{input: "15d", want: 15 * 24 * time.Hour, ok: true},
|
||||
{input: "30d", want: 30 * 24 * time.Hour, ok: true},
|
||||
{input: "7d", want: 0, ok: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got, ok := parseOpsOpenAITokenStatsDuration(tt.input)
|
||||
require.Equal(t, tt.ok, ok, "input=%s", tt.input)
|
||||
require.Equal(t, tt.want, got, "input=%s", tt.input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOpsOpenAITokenStatsFilter_Defaults(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
before := time.Now().UTC()
|
||||
filter, err := parseOpsOpenAITokenStatsFilter(c)
|
||||
after := time.Now().UTC()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, filter)
|
||||
require.Equal(t, "30d", filter.TimeRange)
|
||||
require.Equal(t, 1, filter.Page)
|
||||
require.Equal(t, 20, filter.PageSize)
|
||||
require.Equal(t, 0, filter.TopN)
|
||||
require.Nil(t, filter.GroupID)
|
||||
require.Equal(t, "", filter.Platform)
|
||||
require.True(t, filter.StartTime.Before(filter.EndTime))
|
||||
require.WithinDuration(t, before.Add(-30*24*time.Hour), filter.StartTime, 2*time.Second)
|
||||
require.WithinDuration(t, after, filter.EndTime, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestParseOpsOpenAITokenStatsFilter_WithTopN(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/?time_range=1h&platform=openai&group_id=12&top_n=50",
|
||||
nil,
|
||||
)
|
||||
|
||||
filter, err := parseOpsOpenAITokenStatsFilter(c)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "1h", filter.TimeRange)
|
||||
require.Equal(t, "openai", filter.Platform)
|
||||
require.NotNil(t, filter.GroupID)
|
||||
require.Equal(t, int64(12), *filter.GroupID)
|
||||
require.Equal(t, 50, filter.TopN)
|
||||
require.Equal(t, 0, filter.Page)
|
||||
require.Equal(t, 0, filter.PageSize)
|
||||
}
|
||||
|
||||
func TestParseOpsOpenAITokenStatsFilter_InvalidParams(t *testing.T) {
|
||||
tests := []string{
|
||||
"/?time_range=7d",
|
||||
"/?group_id=0",
|
||||
"/?group_id=abc",
|
||||
"/?top_n=0",
|
||||
"/?top_n=101",
|
||||
"/?top_n=10&page=1",
|
||||
"/?top_n=10&page_size=20",
|
||||
"/?page=0",
|
||||
"/?page_size=0",
|
||||
"/?page_size=101",
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
for _, rawURL := range tests {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, rawURL, nil)
|
||||
|
||||
_, err := parseOpsOpenAITokenStatsFilter(c)
|
||||
require.Error(t, err, "url=%s", rawURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOpsTimeRange(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -327,6 +327,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr
|
||||
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) {
|
||||
return &service.ProxyQualityCheckResult{
|
||||
ProxyID: id,
|
||||
Score: 95,
|
||||
Grade: "A",
|
||||
Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项",
|
||||
PassedCount: 5,
|
||||
WarnCount: 0,
|
||||
FailedCount: 0,
|
||||
ChallengeCount: 0,
|
||||
CheckedAt: time.Now().Unix(),
|
||||
Items: []service.ProxyQualityCheckItem{
|
||||
{Target: "base_connectivity", Status: "pass", Message: "ok"},
|
||||
{Target: "openai", Status: "pass", HTTPStatus: 401},
|
||||
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
|
||||
{Target: "gemini", Status: "pass", HTTPStatus: 200},
|
||||
{Target: "sora", Status: "pass", HTTPStatus: 401},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
|
||||
return s.redeems, int64(len(s.redeems)), nil
|
||||
}
|
||||
|
||||
208
backend/internal/handler/admin/batch_update_credentials_test.go
Normal file
208
backend/internal/handler/admin/batch_update_credentials_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。
|
||||
type failingAdminService struct {
|
||||
*stubAdminService
|
||||
failOnAccountID int64
|
||||
updateCallCount atomic.Int64
|
||||
}
|
||||
|
||||
func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
|
||||
f.updateCallCount.Add(1)
|
||||
if id == f.failOnAccountID {
|
||||
return nil, errors.New("database error")
|
||||
}
|
||||
return f.stubAdminService.UpdateAccount(ctx, id, input)
|
||||
}
|
||||
|
||||
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
|
||||
return router, handler
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_AllSuccess(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Field: "account_uuid",
|
||||
Value: "test-uuid",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200")
|
||||
require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_PartialFailure(t *testing.T) {
|
||||
// 让第 2 个账号(ID=2)更新时失败
|
||||
svc := &failingAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
failOnAccountID: 2,
|
||||
}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Field: "org_uuid",
|
||||
Value: "test-org",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 实现采用"部分成功"模式:总是返回 200 + 成功/失败明细
|
||||
require.Equal(t, http.StatusOK, w.Code, "批量更新返回 200 + 成功/失败明细")
|
||||
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||
data := resp["data"].(map[string]any)
|
||||
require.Equal(t, float64(2), data["success"], "应有 2 个成功")
|
||||
require.Equal(t, float64(1), data["failed"], "应有 1 个失败")
|
||||
|
||||
// 所有 3 个账号都会被尝试更新(非 fail-fast)
|
||||
require.Equal(t, int64(3), svc.updateCallCount.Load(),
|
||||
"应调用 3 次 UpdateAccount(逐个尝试,失败后继续)")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) {
|
||||
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
|
||||
svc := &getAccountFailingService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
failOnAccountID: 1,
|
||||
}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Field: "account_uuid",
|
||||
Value: "test",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404")
|
||||
}
|
||||
|
||||
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
|
||||
type getAccountFailingService struct {
|
||||
*stubAdminService
|
||||
failOnAccountID int64
|
||||
}
|
||||
|
||||
func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
|
||||
if id == f.failOnAccountID {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return f.stubAdminService.GetAccount(ctx, id)
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
// intercept_warmup_requests 传入非 bool 类型(string),应返回 400
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "intercept_warmup_requests",
|
||||
"value": "not-a-bool",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code,
|
||||
"intercept_warmup_requests 传入非 bool 值应返回 400")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "intercept_warmup_requests",
|
||||
"value": true,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code,
|
||||
"intercept_warmup_requests 传入合法 bool 值应返回 200")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
// account_uuid 传入非 string 类型(number),应返回 400
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "account_uuid",
|
||||
"value": 12345,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code,
|
||||
"account_uuid 传入非 string 值应返回 400")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
// account_uuid 传入 null(设置为空),应正常通过
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "account_uuid",
|
||||
"value": nil,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code,
|
||||
"account_uuid 传入 null 应返回 200")
|
||||
}
|
||||
@@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage stats")
|
||||
return
|
||||
@@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
|
||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage stats")
|
||||
return
|
||||
|
||||
@@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
@@ -38,6 +38,10 @@ type CreateGroupRequest struct {
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
@@ -55,7 +59,7 @@ type CreateGroupRequest struct {
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
@@ -67,6 +71,10 @@ type UpdateGroupRequest struct {
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
||||
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
@@ -179,6 +187,10 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
SoraImagePrice360: req.SoraImagePrice360,
|
||||
SoraImagePrice540: req.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||
@@ -225,6 +237,10 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
SoraImagePrice360: req.SoraImagePrice360,
|
||||
SoraImagePrice540: req.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||
|
||||
115
backend/internal/handler/admin/idempotency_helper.go
Normal file
115
backend/internal/handler/admin/idempotency_helper.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type idempotencyStoreUnavailableMode int
|
||||
|
||||
const (
|
||||
idempotencyStoreUnavailableFailClose idempotencyStoreUnavailableMode = iota
|
||||
idempotencyStoreUnavailableFailOpen
|
||||
)
|
||||
|
||||
func executeAdminIdempotent(
|
||||
c *gin.Context,
|
||||
scope string,
|
||||
payload any,
|
||||
ttl time.Duration,
|
||||
execute func(context.Context) (any, error),
|
||||
) (*service.IdempotencyExecuteResult, error) {
|
||||
coordinator := service.DefaultIdempotencyCoordinator()
|
||||
if coordinator == nil {
|
||||
data, err := execute(c.Request.Context())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &service.IdempotencyExecuteResult{Data: data}, nil
|
||||
}
|
||||
|
||||
actorScope := "admin:0"
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
|
||||
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
|
||||
}
|
||||
|
||||
return coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
|
||||
Scope: scope,
|
||||
ActorScope: actorScope,
|
||||
Method: c.Request.Method,
|
||||
Route: c.FullPath(),
|
||||
IdempotencyKey: c.GetHeader("Idempotency-Key"),
|
||||
Payload: payload,
|
||||
RequireKey: true,
|
||||
TTL: ttl,
|
||||
}, execute)
|
||||
}
|
||||
|
||||
func executeAdminIdempotentJSON(
|
||||
c *gin.Context,
|
||||
scope string,
|
||||
payload any,
|
||||
ttl time.Duration,
|
||||
execute func(context.Context) (any, error),
|
||||
) {
|
||||
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailClose, execute)
|
||||
}
|
||||
|
||||
func executeAdminIdempotentJSONFailOpenOnStoreUnavailable(
|
||||
c *gin.Context,
|
||||
scope string,
|
||||
payload any,
|
||||
ttl time.Duration,
|
||||
execute func(context.Context) (any, error),
|
||||
) {
|
||||
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailOpen, execute)
|
||||
}
|
||||
|
||||
func executeAdminIdempotentJSONWithMode(
|
||||
c *gin.Context,
|
||||
scope string,
|
||||
payload any,
|
||||
ttl time.Duration,
|
||||
mode idempotencyStoreUnavailableMode,
|
||||
execute func(context.Context) (any, error),
|
||||
) {
|
||||
result, err := executeAdminIdempotent(c, scope, payload, ttl, execute)
|
||||
if err != nil {
|
||||
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
|
||||
strategy := "fail_close"
|
||||
if mode == idempotencyStoreUnavailableFailOpen {
|
||||
strategy = "fail_open"
|
||||
}
|
||||
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_"+strategy)
|
||||
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=%s", c.Request.Method, c.FullPath(), scope, strategy)
|
||||
if mode == idempotencyStoreUnavailableFailOpen {
|
||||
data, fallbackErr := execute(c.Request.Context())
|
||||
if fallbackErr != nil {
|
||||
response.ErrorFrom(c, fallbackErr)
|
||||
return
|
||||
}
|
||||
c.Header("X-Idempotency-Degraded", "store-unavailable")
|
||||
response.Success(c, data)
|
||||
return
|
||||
}
|
||||
}
|
||||
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if result != nil && result.Replayed {
|
||||
c.Header("X-Idempotency-Replayed", "true")
|
||||
}
|
||||
response.Success(c, result.Data)
|
||||
}
|
||||
285
backend/internal/handler/admin/idempotency_helper_test.go
Normal file
285
backend/internal/handler/admin/idempotency_helper_test.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type storeUnavailableRepoStub struct{}
|
||||
|
||||
func (storeUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
|
||||
return nil, errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
|
||||
return 0, errors.New("store unavailable")
|
||||
}
|
||||
|
||||
func TestExecuteAdminIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
|
||||
t.Cleanup(func() {
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
var executed int
|
||||
router := gin.New()
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeAdminIdempotentJSON(c, "admin.test.high", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed++
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Idempotency-Key", "test-key-1")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
require.Equal(t, 0, executed, "fail-close should block business execution when idempotency store is unavailable")
|
||||
}
|
||||
|
||||
func TestExecuteAdminIdempotentJSONFailOpenOnStoreUnavailable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
|
||||
t.Cleanup(func() {
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
var executed int
|
||||
router := gin.New()
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeAdminIdempotentJSONFailOpenOnStoreUnavailable(c, "admin.test.medium", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed++
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Idempotency-Key", "test-key-2")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "store-unavailable", rec.Header().Get("X-Idempotency-Degraded"))
|
||||
require.Equal(t, 1, executed, "fail-open strategy should allow semantic idempotent path to continue")
|
||||
}
|
||||
|
||||
type memoryIdempotencyRepoStub struct {
|
||||
mu sync.Mutex
|
||||
nextID int64
|
||||
data map[string]*service.IdempotencyRecord
|
||||
}
|
||||
|
||||
func newMemoryIdempotencyRepoStub() *memoryIdempotencyRepoStub {
|
||||
return &memoryIdempotencyRepoStub{
|
||||
nextID: 1,
|
||||
data: make(map[string]*service.IdempotencyRecord),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) key(scope, keyHash string) string {
|
||||
return scope + "|" + keyHash
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := *in
|
||||
if in.LockedUntil != nil {
|
||||
v := *in.LockedUntil
|
||||
out.LockedUntil = &v
|
||||
}
|
||||
if in.ResponseBody != nil {
|
||||
v := *in.ResponseBody
|
||||
out.ResponseBody = &v
|
||||
}
|
||||
if in.ResponseStatus != nil {
|
||||
v := *in.ResponseStatus
|
||||
out.ResponseStatus = &v
|
||||
}
|
||||
if in.ErrorReason != nil {
|
||||
v := *in.ErrorReason
|
||||
out.ErrorReason = &v
|
||||
}
|
||||
return &out
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
k := r.key(record.Scope, record.IdempotencyKeyHash)
|
||||
if _, ok := r.data[k]; ok {
|
||||
return false, nil
|
||||
}
|
||||
cp := r.clone(record)
|
||||
cp.ID = r.nextID
|
||||
r.nextID++
|
||||
r.data[k] = cp
|
||||
record.ID = cp.ID
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.clone(r.data[r.key(scope, keyHash)]), nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != fromStatus {
|
||||
return false, nil
|
||||
}
|
||||
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
|
||||
return false, nil
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusProcessing
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
rec.ErrorReason = nil
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
|
||||
return false, nil
|
||||
}
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusSucceeded
|
||||
rec.LockedUntil = nil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.ResponseStatus = &responseStatus
|
||||
rec.ResponseBody = &responseBody
|
||||
rec.ErrorReason = nil
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusFailedRetryable
|
||||
rec.LockedUntil = &lockedUntil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.ErrorReason = &errorReason
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func TestExecuteAdminIdempotentJSONConcurrentRetryOnlyOneSideEffect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := newMemoryIdempotencyRepoStub()
|
||||
cfg := service.DefaultIdempotencyConfig()
|
||||
cfg.ProcessingTimeout = 2 * time.Second
|
||||
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
|
||||
t.Cleanup(func() {
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
var executed atomic.Int32
|
||||
router := gin.New()
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeAdminIdempotentJSON(c, "admin.test.concurrent", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed.Add(1)
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
call := func() (int, http.Header) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Idempotency-Key", "same-key")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
return rec.Code, rec.Header()
|
||||
}
|
||||
|
||||
var status1, status2 int
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
status1, _ = call()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
status2, _ = call()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
|
||||
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
|
||||
require.Equal(t, int32(1), executed.Load(), "same idempotency key should execute side-effect only once")
|
||||
|
||||
status3, headers3 := call()
|
||||
require.Equal(t, http.StatusOK, status3)
|
||||
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
|
||||
require.Equal(t, int32(1), executed.Load())
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
func oauthPlatformFromPath(c *gin.Context) string {
|
||||
if strings.Contains(c.FullPath(), "/admin/sora/") {
|
||||
return service.PlatformSora
|
||||
}
|
||||
return service.PlatformOpenAI
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
|
||||
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
|
||||
return &OpenAIOAuthHandler{
|
||||
@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
type OpenAIExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
State string `json:"state" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
State: req.State,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
|
||||
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
|
||||
type OpenAIRefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
RT string `json:"rt"`
|
||||
ClientID string `json:"client_id"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
// POST /api/v1/admin/openai/refresh-token
|
||||
// POST /api/v1/admin/sora/rt2at
|
||||
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req OpenAIRefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
refreshToken := strings.TrimSpace(req.RefreshToken)
|
||||
if refreshToken == "" {
|
||||
refreshToken = strings.TrimSpace(req.RT)
|
||||
}
|
||||
if refreshToken == "" {
|
||||
response.BadRequest(c, "refresh_token is required")
|
||||
return
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if req.ProxyID != nil {
|
||||
@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI account
|
||||
// ExchangeSoraSessionToken exchanges Sora session token to access token
|
||||
// POST /api/v1/admin/sora/st2at
|
||||
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionToken string `json:"session_token"`
|
||||
ST string `json:"st"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken := strings.TrimSpace(req.SessionToken)
|
||||
if sessionToken == "" {
|
||||
sessionToken = strings.TrimSpace(req.ST)
|
||||
}
|
||||
if sessionToken == "" {
|
||||
response.BadRequest(c, "session_token is required")
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
|
||||
// POST /api/v1/admin/openai/accounts/:id/refresh
|
||||
// POST /api/v1/admin/sora/accounts/:id/refresh
|
||||
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure account is OpenAI platform
|
||||
if !account.IsOpenAI() {
|
||||
response.BadRequest(c, "Account is not an OpenAI account")
|
||||
platform := oauthPlatformFromPath(c)
|
||||
if account.Platform != platform {
|
||||
response.BadRequest(c, "Account platform does not match OAuth endpoint")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
}
|
||||
|
||||
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
||||
// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
|
||||
// POST /api/v1/admin/openai/create-from-oauth
|
||||
// POST /api/v1/admin/sora/create-from-oauth
|
||||
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
State string `json:"state" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Name string `json:"name"`
|
||||
@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
State: req.State,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
// Build credentials from token info
|
||||
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
platform := oauthPlatformFromPath(c)
|
||||
|
||||
// Use email as default name if not provided
|
||||
name := req.Name
|
||||
if name == "" && tokenInfo.Email != "" {
|
||||
name = tokenInfo.Email
|
||||
}
|
||||
if name == "" {
|
||||
name = "OpenAI OAuth Account"
|
||||
if platform == service.PlatformSora {
|
||||
name = "Sora OAuth Account"
|
||||
} else {
|
||||
name = "OpenAI OAuth Account"
|
||||
}
|
||||
}
|
||||
|
||||
// Create account
|
||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||
Name: name,
|
||||
Platform: "openai",
|
||||
Platform: platform,
|
||||
Type: "oauth",
|
||||
Credentials: credentials,
|
||||
ProxyID: req.ProxyID,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -218,6 +219,115 @@ func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) {
|
||||
response.Success(c, data)
|
||||
}
|
||||
|
||||
// GetDashboardOpenAITokenStats returns OpenAI token efficiency stats grouped by model.
|
||||
// GET /api/v1/admin/ops/dashboard/openai-token-stats
|
||||
func (h *OpsHandler) GetDashboardOpenAITokenStats(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
filter, err := parseOpsOpenAITokenStatsFilter(c)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.opsService.GetOpenAITokenStats(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, data)
|
||||
}
|
||||
|
||||
func parseOpsOpenAITokenStatsFilter(c *gin.Context) (*service.OpsOpenAITokenStatsFilter, error) {
|
||||
if c == nil {
|
||||
return nil, fmt.Errorf("invalid request")
|
||||
}
|
||||
|
||||
timeRange := strings.TrimSpace(c.Query("time_range"))
|
||||
if timeRange == "" {
|
||||
timeRange = "30d"
|
||||
}
|
||||
dur, ok := parseOpsOpenAITokenStatsDuration(timeRange)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid time_range")
|
||||
}
|
||||
end := time.Now().UTC()
|
||||
start := end.Add(-dur)
|
||||
|
||||
filter := &service.OpsOpenAITokenStatsFilter{
|
||||
TimeRange: timeRange,
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Platform: strings.TrimSpace(c.Query("platform")),
|
||||
}
|
||||
|
||||
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||
id, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
return nil, fmt.Errorf("invalid group_id")
|
||||
}
|
||||
filter.GroupID = &id
|
||||
}
|
||||
|
||||
topNRaw := strings.TrimSpace(c.Query("top_n"))
|
||||
pageRaw := strings.TrimSpace(c.Query("page"))
|
||||
pageSizeRaw := strings.TrimSpace(c.Query("page_size"))
|
||||
if topNRaw != "" && (pageRaw != "" || pageSizeRaw != "") {
|
||||
return nil, fmt.Errorf("invalid query: top_n cannot be used with page/page_size")
|
||||
}
|
||||
|
||||
if topNRaw != "" {
|
||||
topN, err := strconv.Atoi(topNRaw)
|
||||
if err != nil || topN < 1 || topN > 100 {
|
||||
return nil, fmt.Errorf("invalid top_n")
|
||||
}
|
||||
filter.TopN = topN
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
filter.Page = 1
|
||||
filter.PageSize = 20
|
||||
if pageRaw != "" {
|
||||
page, err := strconv.Atoi(pageRaw)
|
||||
if err != nil || page < 1 {
|
||||
return nil, fmt.Errorf("invalid page")
|
||||
}
|
||||
filter.Page = page
|
||||
}
|
||||
if pageSizeRaw != "" {
|
||||
pageSize, err := strconv.Atoi(pageSizeRaw)
|
||||
if err != nil || pageSize < 1 || pageSize > 100 {
|
||||
return nil, fmt.Errorf("invalid page_size")
|
||||
}
|
||||
filter.PageSize = pageSize
|
||||
}
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func parseOpsOpenAITokenStatsDuration(v string) (time.Duration, bool) {
|
||||
switch strings.TrimSpace(v) {
|
||||
case "30m":
|
||||
return 30 * time.Minute, true
|
||||
case "1h":
|
||||
return time.Hour, true
|
||||
case "1d":
|
||||
return 24 * time.Hour, true
|
||||
case "15d":
|
||||
return 15 * 24 * time.Hour, true
|
||||
case "30d":
|
||||
return 30 * 24 * time.Hour, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func pickThroughputBucketSeconds(window time.Duration) int {
|
||||
// Keep buckets predictable and avoid huge responses.
|
||||
switch {
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type testSettingRepo struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func newTestSettingRepo() *testSettingRepo {
|
||||
return &testSettingRepo{values: map[string]string{}}
|
||||
}
|
||||
|
||||
func (s *testSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
|
||||
v, err := s.GetValue(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &service.Setting{Key: key, Value: v}, nil
|
||||
}
|
||||
func (s *testSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
|
||||
v, ok := s.values[key]
|
||||
if !ok {
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
func (s *testSettingRepo) Set(ctx context.Context, key, value string) error {
|
||||
s.values[key] = value
|
||||
return nil
|
||||
}
|
||||
func (s *testSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, k := range keys {
|
||||
if v, ok := s.values[k]; ok {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
func (s *testSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
for k, v := range settings {
|
||||
s.values[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *testSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
out := make(map[string]string, len(s.values))
|
||||
for k, v := range s.values {
|
||||
out[k] = v
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
func (s *testSettingRepo) Delete(ctx context.Context, key string) error {
|
||||
delete(s.values, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newOpsRuntimeRouter(handler *OpsHandler, withUser bool) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
if withUser {
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 7})
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
r.GET("/runtime/logging", handler.GetRuntimeLogConfig)
|
||||
r.PUT("/runtime/logging", handler.UpdateRuntimeLogConfig)
|
||||
r.POST("/runtime/logging/reset", handler.ResetRuntimeLogConfig)
|
||||
return r
|
||||
}
|
||||
|
||||
func newRuntimeOpsService(t *testing.T) *service.OpsService {
|
||||
t.Helper()
|
||||
if err := logger.Init(logger.InitOptions{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: false,
|
||||
ToFile: false,
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("init logger: %v", err)
|
||||
}
|
||||
|
||||
settingRepo := newTestSettingRepo()
|
||||
cfg := &config.Config{
|
||||
Ops: config.OpsConfig{Enabled: true},
|
||||
Log: config.LogConfig{
|
||||
Level: "info",
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
Sampling: config.LogSamplingConfig{
|
||||
Enabled: false,
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
},
|
||||
}
|
||||
return service.NewOpsService(nil, settingRepo, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
|
||||
h := NewOpsHandler(newRuntimeOpsService(t))
|
||||
r := newOpsRuntimeRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/runtime/logging", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
|
||||
h := NewOpsHandler(newRuntimeOpsService(t))
|
||||
r := newOpsRuntimeRouter(h, false)
|
||||
|
||||
body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status=%d, want 401", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) {
|
||||
h := NewOpsHandler(newRuntimeOpsService(t))
|
||||
r := newOpsRuntimeRouter(h, true)
|
||||
|
||||
payload := map[string]any{
|
||||
"level": "debug",
|
||||
"enable_sampling": false,
|
||||
"sampling_initial": 100,
|
||||
"sampling_thereafter": 100,
|
||||
"caller": true,
|
||||
"stacktrace_level": "error",
|
||||
"retention_days": 30,
|
||||
}
|
||||
raw, _ := json.Marshal(payload)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewReader(raw))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("update status=%d, want 200, body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/runtime/logging/reset", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("reset status=%d, want 200, body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -101,6 +102,84 @@ func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) {
|
||||
response.Success(c, updated)
|
||||
}
|
||||
|
||||
// GetRuntimeLogConfig returns runtime log config (DB-backed).
|
||||
// GET /api/v1/admin/ops/runtime/logging
|
||||
func (h *OpsHandler) GetRuntimeLogConfig(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := h.opsService.GetRuntimeLogConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get runtime log config")
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
// UpdateRuntimeLogConfig updates runtime log config and applies changes immediately.
|
||||
// PUT /api/v1/admin/ops/runtime/logging
|
||||
func (h *OpsHandler) UpdateRuntimeLogConfig(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var req service.OpsRuntimeLogConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
updated, err := h.opsService.UpdateRuntimeLogConfig(c.Request.Context(), &req, subject.UserID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, updated)
|
||||
}
|
||||
|
||||
// ResetRuntimeLogConfig removes runtime override and falls back to env/yaml baseline.
|
||||
// POST /api/v1/admin/ops/runtime/logging/reset
|
||||
func (h *OpsHandler) ResetRuntimeLogConfig(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
updated, err := h.opsService.ResetRuntimeLogConfig(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, updated)
|
||||
}
|
||||
|
||||
// GetAdvancedSettings returns Ops advanced settings (DB-backed).
|
||||
// GET /api/v1/admin/ops/advanced-settings
|
||||
func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) {
|
||||
|
||||
174
backend/internal/handler/admin/ops_system_log_handler.go
Normal file
174
backend/internal/handler/admin/ops_system_log_handler.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type opsSystemLogCleanupRequest struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
|
||||
Level string `json:"level"`
|
||||
Component string `json:"component"`
|
||||
RequestID string `json:"request_id"`
|
||||
ClientRequestID string `json:"client_request_id"`
|
||||
UserID *int64 `json:"user_id"`
|
||||
AccountID *int64 `json:"account_id"`
|
||||
Platform string `json:"platform"`
|
||||
Model string `json:"model"`
|
||||
Query string `json:"q"`
|
||||
}
|
||||
|
||||
// ListSystemLogs returns indexed system logs.
|
||||
// GET /api/v1/admin/ops/system-logs
|
||||
func (h *OpsHandler) ListSystemLogs(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
if pageSize > 200 {
|
||||
pageSize = 200
|
||||
}
|
||||
|
||||
start, end, err := parseOpsTimeRange(c, "1h")
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
filter := &service.OpsSystemLogFilter{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
StartTime: &start,
|
||||
EndTime: &end,
|
||||
Level: strings.TrimSpace(c.Query("level")),
|
||||
Component: strings.TrimSpace(c.Query("component")),
|
||||
RequestID: strings.TrimSpace(c.Query("request_id")),
|
||||
ClientRequestID: strings.TrimSpace(c.Query("client_request_id")),
|
||||
Platform: strings.TrimSpace(c.Query("platform")),
|
||||
Model: strings.TrimSpace(c.Query("model")),
|
||||
Query: strings.TrimSpace(c.Query("q")),
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("user_id")); v != "" {
|
||||
id, parseErr := strconv.ParseInt(v, 10, 64)
|
||||
if parseErr != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
filter.UserID = &id
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
|
||||
id, parseErr := strconv.ParseInt(v, 10, 64)
|
||||
if parseErr != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid account_id")
|
||||
return
|
||||
}
|
||||
filter.AccountID = &id
|
||||
}
|
||||
|
||||
result, err := h.opsService.ListSystemLogs(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, result.Logs, int64(result.Total), result.Page, result.PageSize)
|
||||
}
|
||||
|
||||
// CleanupSystemLogs deletes indexed system logs by filter.
|
||||
// POST /api/v1/admin/ops/system-logs/cleanup
|
||||
func (h *OpsHandler) CleanupSystemLogs(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
var req opsSystemLogCleanupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
parseTS := func(raw string) (*time.Time, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339Nano, raw); err == nil {
|
||||
return &t, nil
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
start, err := parseTS(req.StartTime)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_time")
|
||||
return
|
||||
}
|
||||
end, err := parseTS(req.EndTime)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_time")
|
||||
return
|
||||
}
|
||||
|
||||
filter := &service.OpsSystemLogCleanupFilter{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Level: strings.TrimSpace(req.Level),
|
||||
Component: strings.TrimSpace(req.Component),
|
||||
RequestID: strings.TrimSpace(req.RequestID),
|
||||
ClientRequestID: strings.TrimSpace(req.ClientRequestID),
|
||||
UserID: req.UserID,
|
||||
AccountID: req.AccountID,
|
||||
Platform: strings.TrimSpace(req.Platform),
|
||||
Model: strings.TrimSpace(req.Model),
|
||||
Query: strings.TrimSpace(req.Query),
|
||||
}
|
||||
|
||||
deleted, err := h.opsService.CleanupSystemLogs(c.Request.Context(), filter, subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": deleted})
|
||||
}
|
||||
|
||||
// GetSystemLogIngestionHealth returns sink health metrics.
|
||||
// GET /api/v1/admin/ops/system-logs/health
|
||||
func (h *OpsHandler) GetSystemLogIngestionHealth(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, h.opsService.GetSystemLogSinkHealth())
|
||||
}
|
||||
233
backend/internal/handler/admin/ops_system_log_handler_test.go
Normal file
233
backend/internal/handler/admin/ops_system_log_handler_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type responseEnvelope struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
if withUser {
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 99})
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
r.GET("/logs", handler.ListSystemLogs)
|
||||
r.POST("/logs/cleanup", handler.CleanupSystemLogs)
|
||||
r.GET("/logs/health", handler.GetSystemLogIngestionHealth)
|
||||
return r
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
|
||||
h := NewOpsHandler(nil)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("status=%d, want 503", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs?user_id=abc", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs?account_id=-1", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||
Ops: config.OpsConfig{Enabled: false},
|
||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("status=%d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs?time_range=30m&page=1&page_size=20", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d, want 200", w.Code)
|
||||
}
|
||||
|
||||
var resp responseEnvelope
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
if resp.Code != 0 {
|
||||
t.Fatalf("unexpected response code: %+v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status=%d, want 401", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{bad-json`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"start_time":"bad","request_id":"r1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"end_time":"bad","request_id":"r1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("status=%d, want 503", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||
Ops: config.OpsConfig{Enabled: false},
|
||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("status=%d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_Health(t *testing.T) {
|
||||
sink := service.NewOpsSystemLogSink(nil)
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
|
||||
h := NewOpsHandler(nil)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("status=%d, want 503", w.Code)
|
||||
}
|
||||
|
||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||
Ops: config.OpsConfig{Enabled: false},
|
||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h = NewOpsHandler(svc)
|
||||
r = newOpsSystemLogTestRouter(h, false)
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("status=%d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package admin
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -16,6 +15,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -252,7 +252,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
|
||||
stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now)
|
||||
if err != nil || stats == nil {
|
||||
if err != nil {
|
||||
log.Printf("[OpsWS] refresh: get window stats failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: get window stats failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -278,7 +278,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
|
||||
|
||||
msg, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
log.Printf("[OpsWS] refresh: marshal payload failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: marshal payload failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -338,7 +338,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
|
||||
|
||||
// Reserve a global slot before upgrading the connection to keep the limit strict.
|
||||
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
|
||||
log.Printf("[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
|
||||
return
|
||||
}
|
||||
@@ -350,7 +350,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
|
||||
|
||||
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
|
||||
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
|
||||
log.Printf("[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
|
||||
return
|
||||
}
|
||||
@@ -359,7 +359,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
|
||||
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("[OpsWS] upgrade failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] upgrade failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -452,7 +452,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||
|
||||
conn.SetReadLimit(qpsWSMaxReadBytes)
|
||||
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
|
||||
log.Printf("[OpsWS] set read deadline failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] set read deadline failed: %v", err)
|
||||
return
|
||||
}
|
||||
conn.SetPongHandler(func(string) error {
|
||||
@@ -471,7 +471,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||
log.Printf("[OpsWS] read failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] read failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -508,7 +508,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||
continue
|
||||
}
|
||||
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
|
||||
log.Printf("[OpsWS] write failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] write failed: %v", err)
|
||||
cancel()
|
||||
closeConn()
|
||||
wg.Wait()
|
||||
@@ -517,7 +517,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||
|
||||
case <-pingTicker.C:
|
||||
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
|
||||
log.Printf("[OpsWS] ping failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] ping failed: %v", err)
|
||||
cancel()
|
||||
closeConn()
|
||||
wg.Wait()
|
||||
@@ -666,14 +666,14 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
|
||||
if parsed, err := strconv.ParseBool(v); err == nil {
|
||||
cfg.TrustProxy = parsed
|
||||
} else {
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
|
||||
}
|
||||
}
|
||||
|
||||
if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" {
|
||||
prefixes, invalid := parseTrustedProxyList(raw)
|
||||
if len(invalid) > 0 {
|
||||
log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
|
||||
}
|
||||
cfg.TrustedProxies = prefixes
|
||||
}
|
||||
@@ -684,7 +684,7 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
|
||||
case OriginPolicyStrict, OriginPolicyPermissive:
|
||||
cfg.OriginPolicy = normalized
|
||||
default:
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -701,14 +701,14 @@ func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits {
|
||||
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
|
||||
cfg.MaxConns = int32(parsed)
|
||||
} else {
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
|
||||
}
|
||||
}
|
||||
if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" {
|
||||
if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 {
|
||||
cfg.MaxConnsPerIP = int32(parsed)
|
||||
} else {
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
|
||||
}
|
||||
}
|
||||
return cfg
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -130,20 +131,20 @@ func (h *ProxyHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Protocol: strings.TrimSpace(req.Protocol),
|
||||
Host: strings.TrimSpace(req.Host),
|
||||
Port: req.Port,
|
||||
Username: strings.TrimSpace(req.Username),
|
||||
Password: strings.TrimSpace(req.Password),
|
||||
executeAdminIdempotentJSON(c, "admin.proxies.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
proxy, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Protocol: strings.TrimSpace(req.Protocol),
|
||||
Host: strings.TrimSpace(req.Host),
|
||||
Port: req.Port,
|
||||
Username: strings.TrimSpace(req.Username),
|
||||
Password: strings.TrimSpace(req.Password),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dto.ProxyFromService(proxy), nil
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ProxyFromService(proxy))
|
||||
}
|
||||
|
||||
// Update handles updating a proxy
|
||||
@@ -236,6 +237,24 @@ func (h *ProxyHandler) Test(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// CheckQuality handles checking proxy quality across common AI targets.
|
||||
// POST /api/v1/admin/proxies/:id/quality-check
|
||||
func (h *ProxyHandler) CheckQuality(c *gin.Context) {
|
||||
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid proxy ID")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// GetStats handles getting proxy statistics
|
||||
// GET /api/v1/admin/proxies/:id/stats
|
||||
func (h *ProxyHandler) GetStats(c *gin.Context) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"strconv"
|
||||
@@ -88,23 +89,24 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{
|
||||
Count: req.Count,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{
|
||||
Count: req.Count,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
|
||||
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||
}
|
||||
return out, nil
|
||||
})
|
||||
}
|
||||
|
||||
// Delete handles deleting a redeem code
|
||||
|
||||
97
backend/internal/handler/admin/search_truncate_test.go
Normal file
97
backend/internal/handler/admin/search_truncate_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
|
||||
func truncateSearchByRune(search string, maxRunes int) string {
|
||||
if runes := []rune(search); len(runes) > maxRunes {
|
||||
return string(runes[:maxRunes])
|
||||
}
|
||||
return search
|
||||
}
|
||||
|
||||
func TestTruncateSearchByRune(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxRunes int
|
||||
wantLen int // 期望的 rune 长度
|
||||
}{
|
||||
{
|
||||
name: "纯中文超长",
|
||||
input: string(make([]rune, 150)),
|
||||
maxRunes: 100,
|
||||
wantLen: 100,
|
||||
},
|
||||
{
|
||||
name: "纯 ASCII 超长",
|
||||
input: string(make([]byte, 150)),
|
||||
maxRunes: 100,
|
||||
wantLen: 100,
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
input: "",
|
||||
maxRunes: 100,
|
||||
wantLen: 0,
|
||||
},
|
||||
{
|
||||
name: "恰好 100 个字符",
|
||||
input: string(make([]rune, 100)),
|
||||
maxRunes: 100,
|
||||
wantLen: 100,
|
||||
},
|
||||
{
|
||||
name: "不足 100 字符不截断",
|
||||
input: "hello世界",
|
||||
maxRunes: 100,
|
||||
wantLen: 7,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := truncateSearchByRune(tc.input, tc.maxRunes)
|
||||
require.Equal(t, tc.wantLen, len([]rune(result)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) {
|
||||
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
|
||||
input := ""
|
||||
for i := 0; i < 101; i++ {
|
||||
input += "中"
|
||||
}
|
||||
result := truncateSearchByRune(input, 100)
|
||||
|
||||
require.Equal(t, 100, len([]rune(result)))
|
||||
// 验证截断结果是有效的 UTF-8(每个中文字符 3 字节)
|
||||
require.Equal(t, 300, len(result))
|
||||
}
|
||||
|
||||
func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) {
|
||||
// 50 个 ASCII + 51 个中文 = 101 个 rune
|
||||
input := ""
|
||||
for i := 0; i < 50; i++ {
|
||||
input += "a"
|
||||
}
|
||||
for i := 0; i < 51; i++ {
|
||||
input += "中"
|
||||
}
|
||||
result := truncateSearchByRune(input, 100)
|
||||
|
||||
runes := []rune(result)
|
||||
require.Equal(t, 100, len(runes))
|
||||
// 前 50 个应该是 'a',后 50 个应该是 '中'
|
||||
require.Equal(t, 'a', runes[0])
|
||||
require.Equal(t, 'a', runes[49])
|
||||
require.Equal(t, '中', runes[50])
|
||||
require.Equal(t, '中', runes[99])
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@@ -199,13 +200,20 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
idempotencyPayload := struct {
|
||||
SubscriptionID int64 `json:"subscription_id"`
|
||||
Body AdjustSubscriptionRequest `json:"body"`
|
||||
}{
|
||||
SubscriptionID: subscriptionID,
|
||||
Body: req,
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||
executeAdminIdempotentJSON(c, "admin.subscriptions.extend", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
subscription, execErr := h.subscriptionService.ExtendSubscription(ctx, subscriptionID, req.Days)
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
return dto.UserSubscriptionFromServiceAdmin(subscription), nil
|
||||
})
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -14,12 +18,14 @@ import (
|
||||
// SystemHandler handles system-related operations
|
||||
type SystemHandler struct {
|
||||
updateSvc *service.UpdateService
|
||||
lockSvc *service.SystemOperationLockService
|
||||
}
|
||||
|
||||
// NewSystemHandler creates a new SystemHandler
|
||||
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
|
||||
func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler {
|
||||
return &SystemHandler{
|
||||
updateSvc: updateSvc,
|
||||
lockSvc: lockSvc,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,41 +53,125 @@ func (h *SystemHandler) CheckUpdates(c *gin.Context) {
|
||||
// PerformUpdate downloads and applies the update
|
||||
// POST /api/v1/admin/system/update
|
||||
func (h *SystemHandler) PerformUpdate(c *gin.Context) {
|
||||
if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Update completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
operationID := buildSystemOperationID(c, "update")
|
||||
payload := gin.H{"operation_id": operationID}
|
||||
executeAdminIdempotentJSON(c, "admin.system.update", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
lock, release, err := h.acquireSystemLock(ctx, operationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var releaseReason string
|
||||
succeeded := false
|
||||
defer func() {
|
||||
release(releaseReason, succeeded)
|
||||
}()
|
||||
|
||||
if err := h.updateSvc.PerformUpdate(ctx); err != nil {
|
||||
releaseReason = "SYSTEM_UPDATE_FAILED"
|
||||
return nil, err
|
||||
}
|
||||
succeeded = true
|
||||
|
||||
return gin.H{
|
||||
"message": "Update completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
"operation_id": lock.OperationID(),
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// Rollback restores the previous version
|
||||
// POST /api/v1/admin/system/rollback
|
||||
func (h *SystemHandler) Rollback(c *gin.Context) {
|
||||
if err := h.updateSvc.Rollback(); err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Rollback completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
operationID := buildSystemOperationID(c, "rollback")
|
||||
payload := gin.H{"operation_id": operationID}
|
||||
executeAdminIdempotentJSON(c, "admin.system.rollback", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
lock, release, err := h.acquireSystemLock(ctx, operationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var releaseReason string
|
||||
succeeded := false
|
||||
defer func() {
|
||||
release(releaseReason, succeeded)
|
||||
}()
|
||||
|
||||
if err := h.updateSvc.Rollback(); err != nil {
|
||||
releaseReason = "SYSTEM_ROLLBACK_FAILED"
|
||||
return nil, err
|
||||
}
|
||||
succeeded = true
|
||||
|
||||
return gin.H{
|
||||
"message": "Rollback completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
"operation_id": lock.OperationID(),
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// RestartService restarts the systemd service
|
||||
// POST /api/v1/admin/system/restart
|
||||
func (h *SystemHandler) RestartService(c *gin.Context) {
|
||||
// Schedule service restart in background after sending response
|
||||
// This ensures the client receives the success response before the service restarts
|
||||
go func() {
|
||||
// Wait a moment to ensure the response is sent
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
sysutil.RestartServiceAsync()
|
||||
}()
|
||||
operationID := buildSystemOperationID(c, "restart")
|
||||
payload := gin.H{"operation_id": operationID}
|
||||
executeAdminIdempotentJSON(c, "admin.system.restart", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
lock, release, err := h.acquireSystemLock(ctx, operationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
succeeded := false
|
||||
defer func() {
|
||||
release("", succeeded)
|
||||
}()
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"message": "Service restart initiated",
|
||||
// Schedule service restart in background after sending response
|
||||
// This ensures the client receives the success response before the service restarts
|
||||
go func() {
|
||||
// Wait a moment to ensure the response is sent
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
sysutil.RestartServiceAsync()
|
||||
}()
|
||||
succeeded = true
|
||||
return gin.H{
|
||||
"message": "Service restart initiated",
|
||||
"operation_id": lock.OperationID(),
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (h *SystemHandler) acquireSystemLock(
|
||||
ctx context.Context,
|
||||
operationID string,
|
||||
) (*service.SystemOperationLock, func(string, bool), error) {
|
||||
if h.lockSvc == nil {
|
||||
return nil, nil, service.ErrIdempotencyStoreUnavail
|
||||
}
|
||||
lock, err := h.lockSvc.Acquire(ctx, operationID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
release := func(reason string, succeeded bool) {
|
||||
releaseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = h.lockSvc.Release(releaseCtx, lock, succeeded, reason)
|
||||
}
|
||||
return lock, release, nil
|
||||
}
|
||||
|
||||
func buildSystemOperationID(c *gin.Context, operation string) string {
|
||||
key := strings.TrimSpace(c.GetHeader("Idempotency-Key"))
|
||||
if key == "" {
|
||||
return "sysop-" + operation + "-" + strconv.FormatInt(time.Now().UnixNano(), 36)
|
||||
}
|
||||
actorScope := "admin:0"
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
|
||||
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
|
||||
}
|
||||
seed := operation + "|" + actorScope + "|" + c.FullPath() + "|" + key
|
||||
hash := service.HashIdempotencyKey(seed)
|
||||
if len(hash) > 24 {
|
||||
hash = hash[:24]
|
||||
}
|
||||
return "sysop-" + hash
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"log"
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
@@ -378,11 +379,11 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
|
||||
operator = subject.UserID
|
||||
}
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
|
||||
if err != nil {
|
||||
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
@@ -390,7 +391,7 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
|
||||
for i := range tasks {
|
||||
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
|
||||
}
|
||||
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
@@ -472,29 +473,36 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
billingType = *filters.BillingType
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
|
||||
subject.UserID,
|
||||
filters.StartTime.Format(time.RFC3339),
|
||||
filters.EndTime.Format(time.RFC3339),
|
||||
userID,
|
||||
apiKeyID,
|
||||
accountID,
|
||||
groupID,
|
||||
model,
|
||||
stream,
|
||||
billingType,
|
||||
req.Timezone,
|
||||
)
|
||||
|
||||
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
|
||||
if err != nil {
|
||||
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
idempotencyPayload := struct {
|
||||
OperatorID int64 `json:"operator_id"`
|
||||
Body CreateUsageCleanupTaskRequest `json:"body"`
|
||||
}{
|
||||
OperatorID: subject.UserID,
|
||||
Body: req,
|
||||
}
|
||||
executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
|
||||
subject.UserID,
|
||||
filters.StartTime.Format(time.RFC3339),
|
||||
filters.EndTime.Format(time.RFC3339),
|
||||
userID,
|
||||
apiKeyID,
|
||||
accountID,
|
||||
groupID,
|
||||
model,
|
||||
stream,
|
||||
billingType,
|
||||
req.Timezone,
|
||||
)
|
||||
|
||||
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
|
||||
response.Success(c, dto.UsageCleanupTaskFromService(task))
|
||||
task, err := h.cleanupService.CreateTask(ctx, filters, subject.UserID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
|
||||
return nil, err
|
||||
}
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
|
||||
return dto.UsageCleanupTaskFromService(task), nil
|
||||
})
|
||||
}
|
||||
|
||||
// CancelCleanupTask handles canceling a usage cleanup task
|
||||
@@ -515,12 +523,12 @@ func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid task id")
|
||||
return
|
||||
}
|
||||
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
|
||||
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
|
||||
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
|
||||
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -78,8 +79,8 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
if runes := []rune(search); len(runes) > 100 {
|
||||
search = string(runes[:100])
|
||||
}
|
||||
|
||||
filters := service.UserListFilters{
|
||||
@@ -257,13 +258,20 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
idempotencyPayload := struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Body UpdateBalanceRequest `json:"body"`
|
||||
}{
|
||||
UserID: userID,
|
||||
Body: req,
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
executeAdminIdempotentJSON(c, "admin.users.balance.update", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
user, execErr := h.adminService.UpdateUserBalance(ctx, userID, req.Balance, req.Operation, req.Notes)
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
return dto.UserFromServiceAdmin(user), nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserAPIKeys handles getting user's API keys
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -130,13 +131,14 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
if req.Quota != nil {
|
||||
svcReq.Quota = *req.Quota
|
||||
}
|
||||
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.APIKeyFromService(key))
|
||||
executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dto.APIKeyFromService(key), nil
|
||||
})
|
||||
}
|
||||
|
||||
// Update handles updating an API key
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@@ -112,12 +113,11 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
|
||||
if req.VerifyCode == "" {
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
// Turnstile 验证 — 始终执行,防止绕过
|
||||
// TODO: 确认前端在提交邮箱验证码注册时也传递了 turnstile_token
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
_, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
|
||||
@@ -448,17 +448,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Build frontend base URL from request
|
||||
scheme := "https"
|
||||
if c.Request.TLS == nil {
|
||||
// Check X-Forwarded-Proto header (common in reverse proxy setups)
|
||||
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
|
||||
scheme = proto
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
|
||||
if frontendBaseURL == "" {
|
||||
slog.Error("server.frontend_url not configured; cannot build password reset link")
|
||||
response.InternalError(c, "Password reset is not configured")
|
||||
return
|
||||
}
|
||||
frontendBaseURL := scheme + "://" + c.Request.Host
|
||||
|
||||
// Request password reset (async)
|
||||
// Note: This returns success even if email doesn't exist (to prevent enumeration)
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAPIKeyFromService_MapsLastUsedAt(t *testing.T) {
|
||||
lastUsed := time.Now().UTC().Truncate(time.Second)
|
||||
src := &service.APIKey{
|
||||
ID: 1,
|
||||
UserID: 2,
|
||||
Key: "sk-map-last-used",
|
||||
Name: "Mapper",
|
||||
Status: service.StatusActive,
|
||||
LastUsedAt: &lastUsed,
|
||||
}
|
||||
|
||||
out := APIKeyFromService(src)
|
||||
require.NotNil(t, out)
|
||||
require.NotNil(t, out.LastUsedAt)
|
||||
require.WithinDuration(t, lastUsed, *out.LastUsedAt, time.Second)
|
||||
}
|
||||
|
||||
func TestAPIKeyFromService_MapsNilLastUsedAt(t *testing.T) {
|
||||
src := &service.APIKey{
|
||||
ID: 1,
|
||||
UserID: 2,
|
||||
Key: "sk-map-last-used-nil",
|
||||
Name: "MapperNil",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
out := APIKeyFromService(src)
|
||||
require.NotNil(t, out)
|
||||
require.Nil(t, out.LastUsedAt)
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -77,6 +78,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
Status: k.Status,
|
||||
IPWhitelist: k.IPWhitelist,
|
||||
IPBlacklist: k.IPBlacklist,
|
||||
LastUsedAt: k.LastUsedAt,
|
||||
Quota: k.Quota,
|
||||
QuotaUsed: k.QuotaUsed,
|
||||
ExpiresAt: k.ExpiresAt,
|
||||
@@ -129,23 +131,26 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
|
||||
func groupFromServiceBase(g *service.Group) Group {
|
||||
return Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
// 无效请求兜底分组
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
SoraImagePrice360: g.SoraImagePrice360,
|
||||
SoraImagePrice540: g.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
@@ -300,6 +305,11 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
|
||||
CountryCode: p.CountryCode,
|
||||
Region: p.Region,
|
||||
City: p.City,
|
||||
QualityStatus: p.QualityStatus,
|
||||
QualityScore: p.QualityScore,
|
||||
QualityGrade: p.QualityGrade,
|
||||
QualitySummary: p.QualitySummary,
|
||||
QualityChecked: p.QualityChecked,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -404,6 +414,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
FirstTokenMs: l.FirstTokenMs,
|
||||
ImageCount: l.ImageCount,
|
||||
ImageSize: l.ImageSize,
|
||||
MediaType: l.MediaType,
|
||||
UserAgent: l.UserAgent,
|
||||
CacheTTLOverridden: l.CacheTTLOverridden,
|
||||
CreatedAt: l.CreatedAt,
|
||||
@@ -532,11 +543,18 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
|
||||
for i := range r.Subscriptions {
|
||||
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
|
||||
}
|
||||
statuses := make(map[string]string, len(r.Statuses))
|
||||
for userID, status := range r.Statuses {
|
||||
statuses[strconv.FormatInt(userID, 10)] = status
|
||||
}
|
||||
return &BulkAssignResult{
|
||||
SuccessCount: r.SuccessCount,
|
||||
CreatedCount: r.CreatedCount,
|
||||
ReusedCount: r.ReusedCount,
|
||||
FailedCount: r.FailedCount,
|
||||
Subscriptions: subs,
|
||||
Errors: r.Errors,
|
||||
Statuses: statuses,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ type APIKey struct {
|
||||
Status string `json:"status"`
|
||||
IPWhitelist []string `json:"ip_whitelist"`
|
||||
IPBlacklist []string `json:"ip_blacklist"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
|
||||
QuotaUsed float64 `json:"quota_used"` // Used quota amount in USD
|
||||
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never expires)
|
||||
@@ -67,6 +68,12 @@ type Group struct {
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
|
||||
// Sora 按次计费配置
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
@@ -196,6 +203,11 @@ type ProxyWithAccountCount struct {
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
QualityStatus string `json:"quality_status,omitempty"`
|
||||
QualityScore *int `json:"quality_score,omitempty"`
|
||||
QualityGrade string `json:"quality_grade,omitempty"`
|
||||
QualitySummary string `json:"quality_summary,omitempty"`
|
||||
QualityChecked *int64 `json:"quality_checked,omitempty"`
|
||||
}
|
||||
|
||||
type ProxyAccountSummary struct {
|
||||
@@ -274,6 +286,7 @@ type UsageLog struct {
|
||||
// 图片生成字段
|
||||
ImageCount int `json:"image_count"`
|
||||
ImageSize *string `json:"image_size"`
|
||||
MediaType *string `json:"media_type"`
|
||||
|
||||
// User-Agent
|
||||
UserAgent *string `json:"user_agent"`
|
||||
@@ -382,9 +395,12 @@ type AdminUserSubscription struct {
|
||||
|
||||
type BulkAssignResult struct {
|
||||
SuccessCount int `json:"success_count"`
|
||||
CreatedCount int `json:"created_count"`
|
||||
ReusedCount int `json:"reused_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
Subscriptions []AdminUserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
Statuses map[string]string `json:"statuses,omitempty"`
|
||||
}
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
|
||||
@@ -19,11 +19,13 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// GatewayHandler handles API gateway requests
|
||||
@@ -35,10 +37,12 @@ type GatewayHandler struct {
|
||||
billingCacheService *service.BillingCacheService
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.APIKeyService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
maxAccountSwitchesGemini int
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewGatewayHandler creates a new GatewayHandler
|
||||
@@ -51,6 +55,7 @@ func NewGatewayHandler(
|
||||
billingCacheService *service.BillingCacheService,
|
||||
usageService *service.UsageService,
|
||||
apiKeyService *service.APIKeyService,
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||
errorPassthroughService *service.ErrorPassthroughService,
|
||||
cfg *config.Config,
|
||||
) *GatewayHandler {
|
||||
@@ -74,10 +79,12 @@ func NewGatewayHandler(
|
||||
billingCacheService: billingCacheService,
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,6 +103,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.gateway.messages",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
@@ -122,6 +136,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
reqModel := parsedReq.Model
|
||||
reqStream := parsedReq.Stream
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
|
||||
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
|
||||
@@ -161,9 +176,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
reqLog.Warn("gateway.user_wait_counter_increment_failed", zap.Error(err))
|
||||
// On error, allow request to proceed
|
||||
} else if !canWait {
|
||||
reqLog.Info("gateway.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
@@ -180,7 +196,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 1. 首先获取用户并发槽位
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
reqLog.Warn("gateway.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -197,7 +213,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 2. 【新增】Wait后二次检查余额/订阅
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
@@ -227,6 +243,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
var sessionBoundAccountID int64
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
if sessionBoundAccountID > 0 {
|
||||
prefetchedGroupID := int64(0)
|
||||
if apiKey.GroupID != nil {
|
||||
prefetchedGroupID = *apiKey.GroupID
|
||||
}
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
|
||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
|
||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||
@@ -250,7 +275,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||
@@ -258,7 +284,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||
reqLog.Warn("gateway.single_account_retrying",
|
||||
zap.Int("retry_count", switchCount),
|
||||
zap.Int("max_retries", maxAccountSwitches),
|
||||
)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
@@ -274,7 +303,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
@@ -302,21 +331,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
reqLog.Info("gateway.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
// Ensure the wait counter is decremented if we exit before acquiring the slot.
|
||||
defer func() {
|
||||
releaseWait := func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
@@ -327,17 +359,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
releaseWait()
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
@@ -387,7 +417,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
reqLog.Warn("gateway.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
@@ -395,8 +430,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -404,24 +443,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fcb,
|
||||
ForceCacheBilling: forceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("gateway.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -455,7 +499,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||
@@ -463,7 +508,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||
reqLog.Warn("gateway.single_account_retrying",
|
||||
zap.Int("retry_count", switchCount),
|
||||
zap.Int("max_retries", maxAccountSwitches),
|
||||
)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
@@ -479,7 +527,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
@@ -507,20 +555,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
reqLog.Info("gateway.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
releaseWait := func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
@@ -531,16 +583,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
releaseWait()
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
@@ -563,18 +614,26 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err != nil {
|
||||
var promptTooLongErr *service.PromptTooLongError
|
||||
if errors.As(err, &promptTooLongErr) {
|
||||
log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed)
|
||||
reqLog.Warn("gateway.prompt_too_long_from_antigravity",
|
||||
zap.Any("current_group_id", currentAPIKey.GroupID),
|
||||
zap.Any("fallback_group_id", fallbackGroupID),
|
||||
zap.Bool("fallback_used", fallbackUsed),
|
||||
)
|
||||
if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 {
|
||||
fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID)
|
||||
if err != nil {
|
||||
log.Printf("Resolve fallback group failed: %v", err)
|
||||
reqLog.Warn("gateway.resolve_fallback_group_failed", zap.Int64("fallback_group_id", *fallbackGroupID), zap.Error(err))
|
||||
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
|
||||
return
|
||||
}
|
||||
if fallbackGroup.Platform != service.PlatformAnthropic ||
|
||||
fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription ||
|
||||
fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
|
||||
log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType)
|
||||
reqLog.Warn("gateway.fallback_group_invalid",
|
||||
zap.Int64("fallback_group_id", fallbackGroup.ID),
|
||||
zap.String("fallback_platform", fallbackGroup.Platform),
|
||||
zap.String("fallback_subscription_type", fallbackGroup.SubscriptionType),
|
||||
)
|
||||
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
|
||||
return
|
||||
}
|
||||
@@ -625,7 +684,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
reqLog.Warn("gateway.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
@@ -633,8 +697,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -642,24 +710,34 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: usedAccount,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: ua,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fcb,
|
||||
ForceCacheBilling: forceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", currentAPIKey.ID),
|
||||
zap.Any("group_id", currentAPIKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("gateway.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||||
})
|
||||
reqLog.Debug("gateway.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Bool("fallback_used", fallbackUsed),
|
||||
)
|
||||
return
|
||||
}
|
||||
if !retryWithFallback {
|
||||
@@ -682,6 +760,17 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
groupID = &apiKey.Group.ID
|
||||
platform = apiKey.Group.Platform
|
||||
}
|
||||
if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok && strings.TrimSpace(forcedPlatform) != "" {
|
||||
platform = forcedPlatform
|
||||
}
|
||||
|
||||
if platform == service.PlatformSora {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": service.DefaultSoraModels(h.cfg),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get available models from account configurations (without platform filter)
|
||||
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
|
||||
@@ -942,7 +1031,11 @@ func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) b
|
||||
// Handler 层只需短暂间隔后重新进入 Service 层即可。
|
||||
const delay = 2 * time.Second
|
||||
|
||||
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", delay, retryCount)
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.failover"),
|
||||
zap.Duration("delay", delay),
|
||||
zap.Int("retry_count", retryCount),
|
||||
).Info("gateway.single_account_backoff_waiting")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -1040,6 +1133,15 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
|
||||
func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||
return false
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
|
||||
return true
|
||||
}
|
||||
|
||||
// errorResponse 返回Claude API格式的错误响应
|
||||
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
@@ -1067,6 +1169,12 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.gateway.count_tokens",
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
@@ -1094,6 +1202,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream))
|
||||
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
|
||||
|
||||
@@ -1127,14 +1236,15 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
// 选择支持该模型的账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err))
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
|
||||
return
|
||||
}
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 转发请求(不记录使用量)
|
||||
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
|
||||
log.Printf("Forward count_tokens request failed: %v", err)
|
||||
reqLog.Error("gateway.count_tokens_forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
// 错误响应已在 ForwardCountTokens 中处理
|
||||
return
|
||||
}
|
||||
@@ -1398,7 +1508,25 @@ func billingErrorDetails(err error) (status int, code, message string) {
|
||||
}
|
||||
msg := pkgerrors.Message(err)
|
||||
if msg == "" {
|
||||
msg = err.Error()
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.billing"),
|
||||
zap.Error(err),
|
||||
).Warn("gateway.billing_error_missing_message")
|
||||
msg = "Billing error"
|
||||
}
|
||||
return http.StatusForbidden, "billing_error", msg
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
if h.usageRecordWorkerPool != nil {
|
||||
h.usageRecordWorkerPool.Submit(task)
|
||||
return
|
||||
}
|
||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &GatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.True(t, wrote)
|
||||
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "error", parsed["type"])
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &GatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.False(t, wrote)
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
}
|
||||
@@ -4,8 +4,9 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -20,14 +21,28 @@ var claudeCodeValidator = service.NewClaudeCodeValidator()
|
||||
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
|
||||
// 返回更新后的 context
|
||||
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
||||
// 解析请求体为 map
|
||||
var bodyMap map[string]any
|
||||
if len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &bodyMap)
|
||||
if c == nil || c.Request == nil {
|
||||
return
|
||||
}
|
||||
// Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。
|
||||
if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) {
|
||||
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证是否为 Claude Code 客户端
|
||||
isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap)
|
||||
isClaudeCode := false
|
||||
if !strings.Contains(c.Request.URL.Path, "messages") {
|
||||
// 与 Validate 行为一致:非 messages 路径 UA 命中即可视为 Claude Code 客户端。
|
||||
isClaudeCode = true
|
||||
} else {
|
||||
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
|
||||
var bodyMap map[string]any
|
||||
if len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &bodyMap)
|
||||
}
|
||||
isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap)
|
||||
}
|
||||
|
||||
// 更新 request context
|
||||
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
|
||||
@@ -104,31 +119,24 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo
|
||||
|
||||
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
|
||||
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
|
||||
// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露
|
||||
// 优化:基于 context.AfterFunc 注册回调,避免每请求额外守护 goroutine。
|
||||
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
|
||||
if releaseFunc == nil {
|
||||
return nil
|
||||
}
|
||||
var once sync.Once
|
||||
quit := make(chan struct{})
|
||||
var stop func() bool
|
||||
|
||||
release := func() {
|
||||
once.Do(func() {
|
||||
if stop != nil {
|
||||
_ = stop()
|
||||
}
|
||||
releaseFunc()
|
||||
close(quit) // 通知监听 goroutine 退出
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context 取消时释放资源
|
||||
release()
|
||||
case <-quit:
|
||||
// 正常释放已完成,goroutine 退出
|
||||
return
|
||||
}
|
||||
}()
|
||||
stop = context.AfterFunc(ctx, release)
|
||||
|
||||
return release
|
||||
}
|
||||
@@ -153,6 +161,32 @@ func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accou
|
||||
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
|
||||
}
|
||||
|
||||
// TryAcquireUserSlot 尝试立即获取用户并发槽位。
|
||||
// 返回值: (releaseFunc, acquired, error)
|
||||
func (h *ConcurrencyHelper) TryAcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (func(), bool, error) {
|
||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if !result.Acquired {
|
||||
return nil, false, nil
|
||||
}
|
||||
return result.ReleaseFunc, true, nil
|
||||
}
|
||||
|
||||
// TryAcquireAccountSlot 尝试立即获取账号并发槽位。
|
||||
// 返回值: (releaseFunc, acquired, error)
|
||||
func (h *ConcurrencyHelper) TryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (func(), bool, error) {
|
||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if !result.Acquired {
|
||||
return nil, false, nil
|
||||
}
|
||||
return result.ReleaseFunc, true, nil
|
||||
}
|
||||
|
||||
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
@@ -160,13 +194,13 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
|
||||
releaseFunc, acquired, err := h.TryAcquireUserSlot(ctx, userID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
if acquired {
|
||||
return releaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
@@ -180,13 +214,13 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
releaseFunc, acquired, err := h.TryAcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
if acquired {
|
||||
return releaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
@@ -196,27 +230,29 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
|
||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
|
||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
|
||||
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted, false)
|
||||
}
|
||||
|
||||
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool, tryImmediate bool) (func(), error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// Try immediate acquire first (avoid unnecessary wait)
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
acquireSlot := func() (*service.AcquireResult, error) {
|
||||
if slotType == "user" {
|
||||
return h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
return h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
|
||||
if tryImmediate {
|
||||
result, err := acquireSlot()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Determine if ping is needed (streaming + ping format defined)
|
||||
@@ -242,7 +278,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
||||
backoff := initialBackoff
|
||||
timer := time.NewTimer(backoff)
|
||||
defer timer.Stop()
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -268,15 +303,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
||||
|
||||
case <-timer.C:
|
||||
// Try to acquire slot
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
|
||||
result, err := acquireSlot()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -284,7 +311,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
backoff = nextBackoff(backoff, rng)
|
||||
backoff = nextBackoff(backoff)
|
||||
timer.Reset(backoff)
|
||||
}
|
||||
}
|
||||
@@ -292,26 +319,22 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
||||
|
||||
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
|
||||
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
|
||||
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted, true)
|
||||
}
|
||||
|
||||
// nextBackoff 计算下一次退避时间
|
||||
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
|
||||
// current: 当前退避时间
|
||||
// rng: 随机数生成器(可为 nil,此时不添加抖动)
|
||||
// 返回值:下一次退避时间(100ms ~ 2s 之间)
|
||||
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration {
|
||||
func nextBackoff(current time.Duration) time.Duration {
|
||||
// 指数退避:当前时间 * 1.5
|
||||
next := time.Duration(float64(current) * backoffMultiplier)
|
||||
if next > maxBackoff {
|
||||
next = maxBackoff
|
||||
}
|
||||
if rng == nil {
|
||||
return next
|
||||
}
|
||||
// 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2)
|
||||
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
|
||||
jitter := 0.8 + rng.Float64()*0.4
|
||||
jitter := 0.8 + rand.Float64()*0.4
|
||||
jittered := time.Duration(float64(next) * jitter)
|
||||
if jittered < initialBackoff {
|
||||
return initialBackoff
|
||||
|
||||
106
backend/internal/handler/gateway_helper_backoff_test.go
Normal file
106
backend/internal/handler/gateway_helper_backoff_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- Task 6.2 验证: math/rand/v2 迁移后 nextBackoff 行为正确 ---
|
||||
|
||||
func TestNextBackoff_ExponentialGrowth(t *testing.T) {
|
||||
// 验证退避时间指数增长(乘数 1.5)
|
||||
// 由于有随机抖动(±20%),需要验证范围
|
||||
current := initialBackoff // 100ms
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
next := nextBackoff(current)
|
||||
|
||||
// 退避结果应在 [initialBackoff, maxBackoff] 范围内
|
||||
assert.GreaterOrEqual(t, int64(next), int64(initialBackoff),
|
||||
"第 %d 次退避不应低于初始值 %v", i, initialBackoff)
|
||||
assert.LessOrEqual(t, int64(next), int64(maxBackoff),
|
||||
"第 %d 次退避不应超过最大值 %v", i, maxBackoff)
|
||||
|
||||
// 为下一轮提供当前退避值
|
||||
current = next
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextBackoff_BoundedByMaxBackoff(t *testing.T) {
|
||||
// 即使输入非常大,输出也不超过 maxBackoff
|
||||
for i := 0; i < 100; i++ {
|
||||
result := nextBackoff(10 * time.Second)
|
||||
assert.LessOrEqual(t, int64(result), int64(maxBackoff),
|
||||
"退避值不应超过 maxBackoff")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextBackoff_BoundedByInitialBackoff(t *testing.T) {
|
||||
// 即使输入非常小,输出也不低于 initialBackoff
|
||||
for i := 0; i < 100; i++ {
|
||||
result := nextBackoff(1 * time.Millisecond)
|
||||
assert.GreaterOrEqual(t, int64(result), int64(initialBackoff),
|
||||
"退避值不应低于 initialBackoff")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextBackoff_HasJitter(t *testing.T) {
|
||||
// 验证多次调用会产生不同的值(随机抖动生效)
|
||||
// 使用相同的输入调用 50 次,收集结果
|
||||
results := make(map[time.Duration]bool)
|
||||
current := 500 * time.Millisecond
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
result := nextBackoff(current)
|
||||
results[result] = true
|
||||
}
|
||||
|
||||
// 50 次调用应该至少有 2 个不同的值(抖动存在)
|
||||
require.Greater(t, len(results), 1,
|
||||
"nextBackoff 应产生随机抖动,但所有 50 次调用结果相同")
|
||||
}
|
||||
|
||||
func TestNextBackoff_InitialValueGrows(t *testing.T) {
|
||||
// 验证从初始值开始,退避趋势是增长的
|
||||
current := initialBackoff
|
||||
var sum time.Duration
|
||||
|
||||
runs := 100
|
||||
for i := 0; i < runs; i++ {
|
||||
next := nextBackoff(current)
|
||||
sum += next
|
||||
current = next
|
||||
}
|
||||
|
||||
avg := sum / time.Duration(runs)
|
||||
// 平均退避时间应大于初始值(因为指数增长 + 上限)
|
||||
assert.Greater(t, int64(avg), int64(initialBackoff),
|
||||
"平均退避时间应大于初始退避值")
|
||||
}
|
||||
|
||||
func TestNextBackoff_ConvergesToMaxBackoff(t *testing.T) {
|
||||
// 从初始值开始,经过多次退避后应收敛到 maxBackoff 附近
|
||||
current := initialBackoff
|
||||
for i := 0; i < 20; i++ {
|
||||
current = nextBackoff(current)
|
||||
}
|
||||
|
||||
// 经过 20 次迭代后,应该已经到达 maxBackoff 区间
|
||||
// 由于抖动,允许 ±20% 的范围
|
||||
lowerBound := time.Duration(float64(maxBackoff) * 0.8)
|
||||
assert.GreaterOrEqual(t, int64(current), int64(lowerBound),
|
||||
"经过多次退避后应收敛到 maxBackoff 附近")
|
||||
}
|
||||
|
||||
func BenchmarkNextBackoff(b *testing.B) {
|
||||
current := initialBackoff
|
||||
for i := 0; i < b.N; i++ {
|
||||
current = nextBackoff(current)
|
||||
if current > maxBackoff {
|
||||
current = initialBackoff
|
||||
}
|
||||
}
|
||||
}
|
||||
114
backend/internal/handler/gateway_helper_fastpath_test.go
Normal file
114
backend/internal/handler/gateway_helper_fastpath_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type concurrencyCacheMock struct {
|
||||
acquireUserSlotFn func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
acquireAccountSlotFn func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
releaseUserCalled int32
|
||||
releaseAccountCalled int32
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
if m.acquireAccountSlotFn != nil {
|
||||
return m.acquireAccountSlotFn(ctx, accountID, maxConcurrency, requestID)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
atomic.AddInt32(&m.releaseAccountCalled, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
if m.acquireUserSlotFn != nil {
|
||||
return m.acquireUserSlotFn(ctx, userID, maxConcurrency, requestID)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
atomic.AddInt32(&m.releaseUserCalled, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||
return map[int64]*service.AccountLoadInfo{}, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||
return map[int64]*service.UserLoadInfo{}, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
|
||||
|
||||
release, acquired, err := helper.TryAcquireUserSlot(context.Background(), 101, 2)
|
||||
require.NoError(t, err)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
|
||||
release()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.releaseUserCalled))
|
||||
}
|
||||
|
||||
func TestConcurrencyHelper_TryAcquireAccountSlot_NotAcquired(t *testing.T) {
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
|
||||
|
||||
release, acquired, err := helper.TryAcquireAccountSlot(context.Background(), 201, 1)
|
||||
require.NoError(t, err)
|
||||
require.False(t, acquired)
|
||||
require.Nil(t, release)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.releaseAccountCalled))
|
||||
}
|
||||
269
backend/internal/handler/gateway_helper_hotpath_test.go
Normal file
269
backend/internal/handler/gateway_helper_hotpath_test.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type helperConcurrencyCacheStub struct {
|
||||
mu sync.Mutex
|
||||
|
||||
accountSeq []bool
|
||||
userSeq []bool
|
||||
|
||||
accountAcquireCalls int
|
||||
userAcquireCalls int
|
||||
accountReleaseCalls int
|
||||
userReleaseCalls int
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.accountAcquireCalls++
|
||||
if len(s.accountSeq) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
v := s.accountSeq[0]
|
||||
s.accountSeq = s.accountSeq[1:]
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.accountReleaseCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.userAcquireCalls++
|
||||
if len(s.userSeq) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
v := s.userSeq[0]
|
||||
s.userSeq = s.userSeq[1:]
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.userReleaseCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||
out := make(map[int64]*service.AccountLoadInfo, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
out[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||
out := make(map[int64]*service.UserLoadInfo, len(users))
|
||||
for _, user := range users {
|
||||
out[user.ID] = &service.UserLoadInfo{UserID: user.ID}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(method, path, nil)
|
||||
return c, rec
|
||||
}
|
||||
|
||||
func validClaudeCodeBodyJSON() []byte {
|
||||
return []byte(`{
|
||||
"model":"claude-3-5-sonnet-20241022",
|
||||
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
||||
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
|
||||
}`)
|
||||
}
|
||||
|
||||
func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
||||
t.Run("non_cli_user_agent_sets_false", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "curl/8.6.0")
|
||||
|
||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
|
||||
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("cli_non_messages_path_sets_true", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodGet, "/v1/models")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
|
||||
SetClaudeCodeClientContext(c, nil)
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("cli_messages_path_valid_body_sets_true", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
c.Request.Header.Set("X-App", "claude-code")
|
||||
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
||||
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("cli_messages_path_invalid_body_sets_false", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
// 缺少严格校验所需 header + body 字段
|
||||
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`))
|
||||
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
|
||||
cache := &helperConcurrencyCacheStub{
|
||||
accountSeq: []bool{false, true},
|
||||
userSeq: []bool{false, true},
|
||||
}
|
||||
concurrency := service.NewConcurrencyService(cache)
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||
|
||||
t.Run("account_slot_acquired_after_retry", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, release)
|
||||
require.False(t, streamStarted)
|
||||
release()
|
||||
require.GreaterOrEqual(t, cache.accountAcquireCalls, 2)
|
||||
require.GreaterOrEqual(t, cache.accountReleaseCalls, 1)
|
||||
})
|
||||
|
||||
t.Run("user_slot_acquired_after_retry", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, release)
|
||||
release()
|
||||
require.GreaterOrEqual(t, cache.userAcquireCalls, 2)
|
||||
require.GreaterOrEqual(t, cache.userReleaseCalls, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
|
||||
cache := &helperConcurrencyCacheStub{
|
||||
accountSeq: []bool{false, false, false},
|
||||
}
|
||||
concurrency := service.NewConcurrencyService(cache)
|
||||
|
||||
t.Run("timeout_returns_concurrency_error", func(t *testing.T) {
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted, true)
|
||||
require.Nil(t, release)
|
||||
var cErr *ConcurrencyError
|
||||
require.ErrorAs(t, err, &cErr)
|
||||
require.True(t, cErr.IsTimeout)
|
||||
})
|
||||
|
||||
t.Run("stream_mode_sends_ping_before_timeout", func(t *testing.T) {
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond)
|
||||
c, rec := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted, true)
|
||||
require.Nil(t, release)
|
||||
var cErr *ConcurrencyError
|
||||
require.ErrorAs(t, err, &cErr)
|
||||
require.True(t, cErr.IsTimeout)
|
||||
require.True(t, streamStarted)
|
||||
require.Contains(t, rec.Body.String(), ":\n\n")
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
|
||||
errCache := &helperConcurrencyCacheStubWithError{
|
||||
err: errors.New("redis unavailable"),
|
||||
}
|
||||
concurrency := service.NewConcurrencyService(errCache)
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted, true)
|
||||
require.Nil(t, release)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "redis unavailable")
|
||||
}
|
||||
|
||||
func TestAcquireAccountSlotWithWaitTimeout_ImmediateAttemptBeforeBackoff(t *testing.T) {
|
||||
cache := &helperConcurrencyCacheStub{
|
||||
accountSeq: []bool{false},
|
||||
}
|
||||
concurrency := service.NewConcurrencyService(cache)
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
|
||||
release, err := helper.AcquireAccountSlotWithWaitTimeout(c, 301, 1, 30*time.Millisecond, false, &streamStarted)
|
||||
require.Nil(t, release)
|
||||
var cErr *ConcurrencyError
|
||||
require.ErrorAs(t, err, &cErr)
|
||||
require.True(t, cErr.IsTimeout)
|
||||
require.GreaterOrEqual(t, cache.accountAcquireCalls, 1)
|
||||
}
|
||||
|
||||
type helperConcurrencyCacheStubWithError struct {
|
||||
helperConcurrencyCacheStub
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStubWithError) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return false, s.err
|
||||
}
|
||||
@@ -8,11 +8,9 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
@@ -20,11 +18,13 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
|
||||
@@ -143,6 +143,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
googleError(c, http.StatusInternalServerError, "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.gemini_v1beta.models",
|
||||
zap.Int64("user_id", authSubject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组
|
||||
if !middleware.HasForcePlatform(c) {
|
||||
@@ -159,6 +166,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
stream := action == "streamGenerateContent"
|
||||
reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream))
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
@@ -187,8 +195,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
reqLog.Warn("gemini.user_wait_counter_increment_failed", zap.Error(err))
|
||||
} else if !canWait {
|
||||
reqLog.Info("gemini.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
@@ -208,6 +217,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("gemini.user_slot_acquire_failed", zap.Error(err))
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -223,6 +233,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
|
||||
// 2) billing eligibility check (after wait)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, _, message := billingErrorDetails(err)
|
||||
googleError(c, status, message)
|
||||
return
|
||||
@@ -252,6 +263,15 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
var sessionBoundAccountID int64
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
if sessionBoundAccountID > 0 {
|
||||
prefetchedGroupID := int64(0)
|
||||
if apiKey.GroupID != nil {
|
||||
prefetchedGroupID = *apiKey.GroupID
|
||||
}
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
|
||||
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// === Gemini 内容摘要会话 Fallback 逻辑 ===
|
||||
@@ -296,8 +316,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
matchedDigestChain = foundMatchedChain
|
||||
sessionBoundAccountID = foundAccountID
|
||||
geminiSessionUUID = foundUUID
|
||||
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
|
||||
safeShortPrefix(foundUUID, 8), foundAccountID, truncateDigestChain(geminiDigestChain))
|
||||
reqLog.Info("gemini.digest_fallback_matched",
|
||||
zap.String("session_uuid_prefix", safeShortPrefix(foundUUID, 8)),
|
||||
zap.Int64("account_id", foundAccountID),
|
||||
zap.String("digest_chain", truncateDigestChain(geminiDigestChain)),
|
||||
)
|
||||
|
||||
// 关键:如果原 sessionKey 为空,使用 prefixHash + uuid 作为 sessionKey
|
||||
// 这样 SelectAccountWithLoadAwareness 的粘性会话逻辑会优先使用匹配到的账号
|
||||
@@ -346,7 +369,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||
reqLog.Warn("gemini.single_account_retrying",
|
||||
zap.Int("retry_count", switchCount),
|
||||
zap.Int("max_retries", maxAccountSwitches),
|
||||
)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
@@ -358,18 +384,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
|
||||
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
|
||||
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
|
||||
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
|
||||
reqLog.Info("gemini.sticky_session_account_switched",
|
||||
zap.Int64("from_account_id", sessionBoundAccountID),
|
||||
zap.Int64("to_account_id", account.ID),
|
||||
zap.Bool("clean_thought_signature", true),
|
||||
)
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
sessionBoundAccountID = account.ID
|
||||
} else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
||||
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。
|
||||
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
|
||||
log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively")
|
||||
reqLog.Info("gemini.sticky_session_binding_missing",
|
||||
zap.Bool("clean_thought_signature", true),
|
||||
)
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
cleanedForUnknownBinding = true
|
||||
sessionBoundAccountID = account.ID
|
||||
@@ -388,9 +420,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
accountWaitCounted := false
|
||||
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
reqLog.Warn("gemini.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
reqLog.Info("gemini.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
@@ -412,6 +447,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("gemini.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -420,7 +456,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
accountWaitCounted = false
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
reqLog.Warn("gemini.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
@@ -454,7 +490,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
lastFailoverErr = failoverErr
|
||||
switchCount++
|
||||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
reqLog.Warn("gemini.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
@@ -463,7 +504,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
// ForwardNative already wrote the response
|
||||
log.Printf("Gemini native forward failed: %v", err)
|
||||
reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -482,31 +523,39 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
account.ID,
|
||||
matchedDigestChain,
|
||||
); err != nil {
|
||||
log.Printf("[Gemini] Failed to save digest session: %v", err)
|
||||
reqLog.Warn("gemini.digest_session_save_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 6) record usage async (Gemini 使用长上下文双倍计费)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: ip,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||
ForceCacheBilling: fcb,
|
||||
ForceCacheBilling: forceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gemini_v1beta.models"),
|
||||
zap.Int64("user_id", authSubject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", modelName),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("gemini.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||||
})
|
||||
reqLog.Debug("gemini.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ type Handlers struct {
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
SoraGateway *SoraGatewayHandler
|
||||
Setting *SettingHandler
|
||||
Totp *TotpHandler
|
||||
}
|
||||
|
||||
65
backend/internal/handler/idempotency_helper.go
Normal file
65
backend/internal/handler/idempotency_helper.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func executeUserIdempotentJSON(
|
||||
c *gin.Context,
|
||||
scope string,
|
||||
payload any,
|
||||
ttl time.Duration,
|
||||
execute func(context.Context) (any, error),
|
||||
) {
|
||||
coordinator := service.DefaultIdempotencyCoordinator()
|
||||
if coordinator == nil {
|
||||
data, err := execute(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, data)
|
||||
return
|
||||
}
|
||||
|
||||
actorScope := "user:0"
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
|
||||
actorScope = "user:" + strconv.FormatInt(subject.UserID, 10)
|
||||
}
|
||||
|
||||
result, err := coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
|
||||
Scope: scope,
|
||||
ActorScope: actorScope,
|
||||
Method: c.Request.Method,
|
||||
Route: c.FullPath(),
|
||||
IdempotencyKey: c.GetHeader("Idempotency-Key"),
|
||||
Payload: payload,
|
||||
RequireKey: true,
|
||||
TTL: ttl,
|
||||
}, execute)
|
||||
if err != nil {
|
||||
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
|
||||
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_fail_close")
|
||||
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=fail_close", c.Request.Method, c.FullPath(), scope)
|
||||
}
|
||||
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if result != nil && result.Replayed {
|
||||
c.Header("X-Idempotency-Replayed", "true")
|
||||
}
|
||||
response.Success(c, result.Data)
|
||||
}
|
||||
285
backend/internal/handler/idempotency_helper_test.go
Normal file
285
backend/internal/handler/idempotency_helper_test.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type userStoreUnavailableRepoStub struct{}
|
||||
|
||||
func (userStoreUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
|
||||
return nil, errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
|
||||
return 0, errors.New("store unavailable")
|
||||
}
|
||||
|
||||
type userMemoryIdempotencyRepoStub struct {
|
||||
mu sync.Mutex
|
||||
nextID int64
|
||||
data map[string]*service.IdempotencyRecord
|
||||
}
|
||||
|
||||
func newUserMemoryIdempotencyRepoStub() *userMemoryIdempotencyRepoStub {
|
||||
return &userMemoryIdempotencyRepoStub{
|
||||
nextID: 1,
|
||||
data: make(map[string]*service.IdempotencyRecord),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) key(scope, keyHash string) string {
|
||||
return scope + "|" + keyHash
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := *in
|
||||
if in.LockedUntil != nil {
|
||||
v := *in.LockedUntil
|
||||
out.LockedUntil = &v
|
||||
}
|
||||
if in.ResponseBody != nil {
|
||||
v := *in.ResponseBody
|
||||
out.ResponseBody = &v
|
||||
}
|
||||
if in.ResponseStatus != nil {
|
||||
v := *in.ResponseStatus
|
||||
out.ResponseStatus = &v
|
||||
}
|
||||
if in.ErrorReason != nil {
|
||||
v := *in.ErrorReason
|
||||
out.ErrorReason = &v
|
||||
}
|
||||
return &out
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
k := r.key(record.Scope, record.IdempotencyKeyHash)
|
||||
if _, ok := r.data[k]; ok {
|
||||
return false, nil
|
||||
}
|
||||
cp := r.clone(record)
|
||||
cp.ID = r.nextID
|
||||
r.nextID++
|
||||
r.data[k] = cp
|
||||
record.ID = cp.ID
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.clone(r.data[r.key(scope, keyHash)]), nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != fromStatus {
|
||||
return false, nil
|
||||
}
|
||||
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
|
||||
return false, nil
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusProcessing
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
rec.ErrorReason = nil
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
|
||||
return false, nil
|
||||
}
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusSucceeded
|
||||
rec.LockedUntil = nil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.ResponseStatus = &responseStatus
|
||||
rec.ResponseBody = &responseBody
|
||||
rec.ErrorReason = nil
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusFailedRetryable
|
||||
rec.LockedUntil = &lockedUntil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.ErrorReason = &errorReason
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func withUserSubject(userID int64) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteUserIdempotentJSONFallbackWithoutCoordinator(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
|
||||
var executed int
|
||||
router := gin.New()
|
||||
router.Use(withUserSubject(1))
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed++
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 1, executed)
|
||||
}
|
||||
|
||||
func TestExecuteUserIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(userStoreUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
|
||||
t.Cleanup(func() {
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
var executed int
|
||||
router := gin.New()
|
||||
router.Use(withUserSubject(2))
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed++
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Idempotency-Key", "k1")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
require.Equal(t, 0, executed)
|
||||
}
|
||||
|
||||
func TestExecuteUserIdempotentJSONConcurrentRetrySingleSideEffectAndReplay(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := newUserMemoryIdempotencyRepoStub()
|
||||
cfg := service.DefaultIdempotencyConfig()
|
||||
cfg.ProcessingTimeout = 2 * time.Second
|
||||
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
|
||||
t.Cleanup(func() {
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
var executed atomic.Int32
|
||||
router := gin.New()
|
||||
router.Use(withUserSubject(3))
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed.Add(1)
|
||||
time.Sleep(80 * time.Millisecond)
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
call := func() (int, http.Header) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Idempotency-Key", "same-user-key")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
return rec.Code, rec.Header()
|
||||
}
|
||||
|
||||
var status1, status2 int
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() { defer wg.Done(); status1, _ = call() }()
|
||||
go func() { defer wg.Done(); status2, _ = call() }()
|
||||
wg.Wait()
|
||||
|
||||
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
|
||||
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
|
||||
require.Equal(t, int32(1), executed.Load())
|
||||
|
||||
status3, headers3 := call()
|
||||
require.Equal(t, http.StatusOK, status3)
|
||||
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
|
||||
require.Equal(t, int32(1), executed.Load())
|
||||
}
|
||||
19
backend/internal/handler/logging.go
Normal file
19
backend/internal/handler/logging.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func requestLogger(c *gin.Context, component string, fields ...zap.Field) *zap.Logger {
|
||||
base := logger.L()
|
||||
if c != nil && c.Request != nil {
|
||||
base = logger.FromContext(c.Request.Context())
|
||||
}
|
||||
|
||||
if component != "" {
|
||||
fields = append([]zap.Field{zap.String("component", component)}, fields...)
|
||||
}
|
||||
return base.With(fields...)
|
||||
}
|
||||
@@ -6,18 +6,19 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||||
@@ -25,6 +26,7 @@ type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
apiKeyService *service.APIKeyService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
@@ -36,6 +38,7 @@ func NewOpenAIGatewayHandler(
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
apiKeyService *service.APIKeyService,
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||
errorPassthroughService *service.ErrorPassthroughService,
|
||||
cfg *config.Config,
|
||||
) *OpenAIGatewayHandler {
|
||||
@@ -51,6 +54,7 @@ func NewOpenAIGatewayHandler(
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
apiKeyService: apiKeyService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
@@ -60,6 +64,8 @@ func NewOpenAIGatewayHandler(
|
||||
// Responses handles OpenAI Responses API endpoint
|
||||
// POST /openai/v1/responses
|
||||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
requestStart := time.Now()
|
||||
|
||||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
@@ -72,6 +78,13 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.responses",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
@@ -91,57 +104,57 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
// Parse request body to map for potential modification
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
// 校验请求体 JSON 合法性
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract model and stream
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
|
||||
// 验证 model 必填
|
||||
if reqModel == "" {
|
||||
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
if !openai.IsCodexCLIRequest(userAgent) {
|
||||
existingInstructions, _ := reqBody["instructions"].(string)
|
||||
if strings.TrimSpace(existingInstructions) == "" {
|
||||
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
|
||||
reqBody["instructions"] = instructions
|
||||
// Re-serialize body
|
||||
body, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
streamResult := gjson.GetBytes(body, "stream")
|
||||
if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type")
|
||||
return
|
||||
}
|
||||
reqStream := streamResult.Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||
// 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
|
||||
// 或带 id 且与 call_id 匹配的 item_reference。
|
||||
if service.HasFunctionCallOutput(reqBody) {
|
||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
||||
if service.HasFunctionCallOutputMissingCallID(reqBody) {
|
||||
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return
|
||||
}
|
||||
callIDs := service.FunctionCallOutputCallIDs(reqBody)
|
||||
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
|
||||
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return
|
||||
// 此路径需要遍历 input 数组做 call_id 关联检查,保留 Unmarshal
|
||||
if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err == nil {
|
||||
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
|
||||
if service.HasFunctionCallOutput(reqBody) {
|
||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
||||
if service.HasFunctionCallOutputMissingCallID(reqBody) {
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_call_id"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return
|
||||
}
|
||||
callIDs := service.FunctionCallOutputCallIDs(reqBody)
|
||||
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_item_reference"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -157,34 +170,48 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Get subscription info (may be nil)
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
// 0. Check if wait queue is full
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
// On error, allow request to proceed
|
||||
} else if !canWait {
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
// 1. First acquire user concurrency slot
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
// 0. 先尝试直接抢占用户槽位(快速路径)
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
// User slot acquired: no longer waiting.
|
||||
|
||||
waitCounted := false
|
||||
if !userAcquired {
|
||||
// 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
if waitErr != nil {
|
||||
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
|
||||
// 按现有降级语义:等待计数异常时放行后续抢槽流程
|
||||
} else if !canWait {
|
||||
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
if waitErr == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 用户槽位已获取:退出等待队列计数。
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
@@ -197,14 +224,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
|
||||
// 2. Re-check billing eligibility after wait
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate session hash (header first; fallback to prompt_cache_key)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
@@ -213,12 +240,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
|
||||
for {
|
||||
// Select account supporting the requested model
|
||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||
reqLog.Warn("openai.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
@@ -229,8 +259,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
@@ -239,53 +269,87 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}()
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
// 先快速尝试一次账号槽位,命中则跳过等待计数写入。
|
||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||
c.Request.Context(),
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
if fastAcquired {
|
||||
accountReleaseFunc = fastReleaseFunc
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
} else {
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
} else if !canWait {
|
||||
reqLog.Info("openai.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
releaseWait := func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
releaseWait()
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
// Forward request
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
@@ -296,11 +360,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
reqLog.Warn("openai.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
// Error response already handled in Forward, just log
|
||||
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("openai.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -308,27 +381,72 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// Async record usage
|
||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: ip,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
}(result, account, userAgent, clientIP)
|
||||
})
|
||||
reqLog.Debug("openai.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func getContextInt64(c *gin.Context, key string) (int64, bool) {
|
||||
if c == nil || key == "" {
|
||||
return 0, false
|
||||
}
|
||||
v, ok := c.Get(key)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case int64:
|
||||
return t, true
|
||||
case int:
|
||||
return int64(t), true
|
||||
case int32:
|
||||
return int64(t), true
|
||||
case float64:
|
||||
return int64(t), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
if h.usageRecordWorkerPool != nil {
|
||||
h.usageRecordWorkerPool.Submit(task)
|
||||
return
|
||||
}
|
||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
@@ -397,8 +515,19 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
// Send error event in OpenAI SSE format
|
||||
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
// Send error event in OpenAI SSE format with proper JSON marshaling
|
||||
errorData := map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(errorData)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
@@ -411,6 +540,15 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
|
||||
func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||
return false
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
|
||||
return true
|
||||
}
|
||||
|
||||
// errorResponse returns OpenAI API format error response
|
||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
|
||||
230
backend/internal/handler/openai_gateway_handler_test.go
Normal file
230
backend/internal/handler/openai_gateway_handler_test.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "包含双引号的消息",
|
||||
errType: "server_error",
|
||||
message: `upstream returned "invalid" response`,
|
||||
},
|
||||
{
|
||||
name: "包含反斜杠的消息",
|
||||
errType: "server_error",
|
||||
message: `path C:\Users\test\file.txt not found`,
|
||||
},
|
||||
{
|
||||
name: "包含双引号和反斜杠的消息",
|
||||
errType: "upstream_error",
|
||||
message: `error parsing "key\value": unexpected token`,
|
||||
},
|
||||
{
|
||||
name: "包含换行符的消息",
|
||||
errType: "server_error",
|
||||
message: "line1\nline2\ttab",
|
||||
},
|
||||
{
|
||||
name: "普通消息",
|
||||
errType: "upstream_error",
|
||||
message: "Upstream service temporarily unavailable",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
// 验证 SSE 格式:event: error\ndata: {JSON}\n\n
|
||||
assert.True(t, strings.HasPrefix(body, "event: error\n"), "应以 'event: error\\n' 开头")
|
||||
assert.True(t, strings.HasSuffix(body, "\n\n"), "应以 '\\n\\n' 结尾")
|
||||
|
||||
// 提取 data 部分
|
||||
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||
require.Len(t, lines, 2, "应有 event 行和 data 行")
|
||||
dataLine := lines[1]
|
||||
require.True(t, strings.HasPrefix(dataLine, "data: "), "第二行应以 'data: ' 开头")
|
||||
jsonStr := strings.TrimPrefix(dataLine, "data: ")
|
||||
|
||||
// 验证 JSON 合法性
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal([]byte(jsonStr), &parsed)
|
||||
require.NoError(t, err, "JSON 应能被成功解析,原始 JSON: %s", jsonStr)
|
||||
|
||||
// 验证结构
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok, "应包含 error 对象")
|
||||
assert.Equal(t, tt.errType, errorObj["type"])
|
||||
assert.Equal(t, tt.message, errorObj["message"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "test error", false)
|
||||
|
||||
// 非流式应返回 JSON 响应
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||
assert.Equal(t, "test error", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.True(t, wrote)
|
||||
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.False(t, wrote)
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
}
|
||||
|
||||
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
|
||||
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantModel string
|
||||
wantStream bool
|
||||
}{
|
||||
{"正常提取", `{"model":"gpt-4","stream":true,"input":"hello"}`, "gpt-4", true},
|
||||
{"stream false", `{"model":"gpt-4","stream":false}`, "gpt-4", false},
|
||||
{"无 stream 字段", `{"model":"gpt-4"}`, "gpt-4", false},
|
||||
{"model 缺失", `{"stream":true}`, "", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body := []byte(tt.body)
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
model := ""
|
||||
if modelResult.Type == gjson.String {
|
||||
model = modelResult.String()
|
||||
}
|
||||
stream := gjson.GetBytes(body, "stream").Bool()
|
||||
require.Equal(t, tt.wantModel, model)
|
||||
require.Equal(t, tt.wantStream, stream)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpenAIHandler_GjsonValidation 验证修复后的 JSON 合法性和类型校验
|
||||
func TestOpenAIHandler_GjsonValidation(t *testing.T) {
|
||||
// 非法 JSON 被 gjson.ValidBytes 拦截
|
||||
require.False(t, gjson.ValidBytes([]byte(`{invalid json`)))
|
||||
|
||||
// model 为数字 → 类型不是 gjson.String,应被拒绝
|
||||
body := []byte(`{"model":123}`)
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
require.True(t, modelResult.Exists())
|
||||
require.NotEqual(t, gjson.String, modelResult.Type)
|
||||
|
||||
// model 为 null → 类型不是 gjson.String,应被拒绝
|
||||
body2 := []byte(`{"model":null}`)
|
||||
modelResult2 := gjson.GetBytes(body2, "model")
|
||||
require.True(t, modelResult2.Exists())
|
||||
require.NotEqual(t, gjson.String, modelResult2.Type)
|
||||
|
||||
// stream 为 string → 类型既不是 True 也不是 False,应被拒绝
|
||||
body3 := []byte(`{"model":"gpt-4","stream":"true"}`)
|
||||
streamResult := gjson.GetBytes(body3, "stream")
|
||||
require.True(t, streamResult.Exists())
|
||||
require.NotEqual(t, gjson.True, streamResult.Type)
|
||||
require.NotEqual(t, gjson.False, streamResult.Type)
|
||||
|
||||
// stream 为 int → 同上
|
||||
body4 := []byte(`{"model":"gpt-4","stream":1}`)
|
||||
streamResult2 := gjson.GetBytes(body4, "stream")
|
||||
require.True(t, streamResult2.Exists())
|
||||
require.NotEqual(t, gjson.True, streamResult2.Type)
|
||||
require.NotEqual(t, gjson.False, streamResult2.Type)
|
||||
}
|
||||
|
||||
// TestOpenAIHandler_InstructionsInjection 验证 instructions 的 gjson/sjson 注入逻辑
|
||||
func TestOpenAIHandler_InstructionsInjection(t *testing.T) {
|
||||
// 测试 1:无 instructions → 注入
|
||||
body := []byte(`{"model":"gpt-4"}`)
|
||||
existing := gjson.GetBytes(body, "instructions").String()
|
||||
require.Empty(t, existing)
|
||||
newBody, err := sjson.SetBytes(body, "instructions", "test instruction")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test instruction", gjson.GetBytes(newBody, "instructions").String())
|
||||
|
||||
// 测试 2:已有 instructions → 不覆盖
|
||||
body2 := []byte(`{"model":"gpt-4","instructions":"existing"}`)
|
||||
existing2 := gjson.GetBytes(body2, "instructions").String()
|
||||
require.Equal(t, "existing", existing2)
|
||||
|
||||
// 测试 3:空白 instructions → 注入
|
||||
body3 := []byte(`{"model":"gpt-4","instructions":" "}`)
|
||||
existing3 := strings.TrimSpace(gjson.GetBytes(body3, "instructions").String())
|
||||
require.Empty(t, existing3)
|
||||
|
||||
// 测试 4:sjson.SetBytes 返回错误时不应 panic
|
||||
// 正常 JSON 不会产生 sjson 错误,验证返回值被正确处理
|
||||
validBody := []byte(`{"model":"gpt-4"}`)
|
||||
result, setErr := sjson.SetBytes(validBody, "instructions", "hello")
|
||||
require.NoError(t, setErr)
|
||||
require.True(t, gjson.ValidBytes(result))
|
||||
}
|
||||
@@ -41,9 +41,8 @@ const (
|
||||
)
|
||||
|
||||
type opsErrorLogJob struct {
|
||||
ops *service.OpsService
|
||||
entry *service.OpsInsertErrorLogInput
|
||||
requestBody []byte
|
||||
ops *service.OpsService
|
||||
entry *service.OpsInsertErrorLogInput
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -58,6 +57,7 @@ var (
|
||||
opsErrorLogEnqueued atomic.Int64
|
||||
opsErrorLogDropped atomic.Int64
|
||||
opsErrorLogProcessed atomic.Int64
|
||||
opsErrorLogSanitized atomic.Int64
|
||||
|
||||
opsErrorLogLastDropLogAt atomic.Int64
|
||||
|
||||
@@ -94,7 +94,7 @@ func startOpsErrorLogWorkers() {
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = job.ops.RecordError(ctx, job.entry, job.requestBody)
|
||||
_ = job.ops.RecordError(ctx, job.entry, nil)
|
||||
cancel()
|
||||
opsErrorLogProcessed.Add(1)
|
||||
}()
|
||||
@@ -103,7 +103,7 @@ func startOpsErrorLogWorkers() {
|
||||
}
|
||||
}
|
||||
|
||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput, requestBody []byte) {
|
||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
||||
if ops == nil || entry == nil {
|
||||
return
|
||||
}
|
||||
@@ -129,7 +129,7 @@ func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLo
|
||||
}
|
||||
|
||||
select {
|
||||
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry, requestBody: requestBody}:
|
||||
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry}:
|
||||
opsErrorLogQueueLen.Add(1)
|
||||
opsErrorLogEnqueued.Add(1)
|
||||
default:
|
||||
@@ -205,6 +205,10 @@ func OpsErrorLogProcessedTotal() int64 {
|
||||
return opsErrorLogProcessed.Load()
|
||||
}
|
||||
|
||||
func OpsErrorLogSanitizedTotal() int64 {
|
||||
return opsErrorLogSanitized.Load()
|
||||
}
|
||||
|
||||
func maybeLogOpsErrorLogDrop() {
|
||||
now := time.Now().Unix()
|
||||
|
||||
@@ -222,12 +226,13 @@ func maybeLogOpsErrorLogDrop() {
|
||||
queueCap := OpsErrorLogQueueCapacity()
|
||||
|
||||
log.Printf(
|
||||
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d)",
|
||||
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d sanitized_total=%d)",
|
||||
queued,
|
||||
queueCap,
|
||||
opsErrorLogEnqueued.Load(),
|
||||
opsErrorLogDropped.Load(),
|
||||
opsErrorLogProcessed.Load(),
|
||||
opsErrorLogSanitized.Load(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -255,18 +260,49 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
model = strings.TrimSpace(model)
|
||||
c.Set(opsModelKey, model)
|
||||
c.Set(opsStreamKey, stream)
|
||||
if len(requestBody) > 0 {
|
||||
c.Set(opsRequestBodyKey, requestBody)
|
||||
}
|
||||
if c.Request != nil && model != "" {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.Model, model)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func setOpsSelectedAccount(c *gin.Context, accountID int64) {
|
||||
func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
|
||||
if c == nil || entry == nil {
|
||||
return
|
||||
}
|
||||
v, ok := c.Get(opsRequestBodyKey)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
raw, ok := v.([]byte)
|
||||
if !ok || len(raw) == 0 {
|
||||
return
|
||||
}
|
||||
entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = service.PrepareOpsRequestBodyForQueue(raw)
|
||||
opsErrorLogSanitized.Add(1)
|
||||
}
|
||||
|
||||
func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) {
|
||||
if c == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
c.Set(opsAccountIDKey, accountID)
|
||||
if c.Request != nil {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.AccountID, accountID)
|
||||
if len(platform) > 0 {
|
||||
p := strings.TrimSpace(platform[0])
|
||||
if p != "" {
|
||||
ctx = context.WithValue(ctx, ctxkey.Platform, p)
|
||||
}
|
||||
}
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
type opsCaptureWriter struct {
|
||||
@@ -507,6 +543,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
RetryCount: 0,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
applyOpsLatencyFieldsFromContext(c, entry)
|
||||
|
||||
if apiKey != nil {
|
||||
entry.APIKeyID = &apiKey.ID
|
||||
@@ -528,14 +565,9 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
entry.ClientIP = &clientIP
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if v, ok := c.Get(opsRequestBodyKey); ok {
|
||||
if b, ok := v.([]byte); ok && len(b) > 0 {
|
||||
requestBody = b
|
||||
}
|
||||
}
|
||||
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
|
||||
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
// Skip logging if a passthrough rule with skip_monitoring=true matched.
|
||||
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
|
||||
@@ -544,7 +576,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
enqueueOpsErrorLog(ops, entry, requestBody)
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -632,6 +664,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
RetryCount: 0,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
applyOpsLatencyFieldsFromContext(c, entry)
|
||||
|
||||
// Capture upstream error context set by gateway services (if present).
|
||||
// This does NOT affect the client response; it enriches Ops troubleshooting data.
|
||||
@@ -707,17 +740,12 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
entry.ClientIP = &clientIP
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if v, ok := c.Get(opsRequestBodyKey); ok {
|
||||
if b, ok := v.([]byte); ok && len(b) > 0 {
|
||||
requestBody = b
|
||||
}
|
||||
}
|
||||
// Persist only a minimal, whitelisted set of request headers to improve retry fidelity.
|
||||
// Do NOT store Authorization/Cookie/etc.
|
||||
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
enqueueOpsErrorLog(ops, entry, requestBody)
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -760,6 +788,44 @@ func extractOpsRetryRequestHeaders(c *gin.Context) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func applyOpsLatencyFieldsFromContext(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
|
||||
if c == nil || entry == nil {
|
||||
return
|
||||
}
|
||||
entry.AuthLatencyMs = getContextLatencyMs(c, service.OpsAuthLatencyMsKey)
|
||||
entry.RoutingLatencyMs = getContextLatencyMs(c, service.OpsRoutingLatencyMsKey)
|
||||
entry.UpstreamLatencyMs = getContextLatencyMs(c, service.OpsUpstreamLatencyMsKey)
|
||||
entry.ResponseLatencyMs = getContextLatencyMs(c, service.OpsResponseLatencyMsKey)
|
||||
entry.TimeToFirstTokenMs = getContextLatencyMs(c, service.OpsTimeToFirstTokenMsKey)
|
||||
}
|
||||
|
||||
func getContextLatencyMs(c *gin.Context, key string) *int64 {
|
||||
if c == nil || strings.TrimSpace(key) == "" {
|
||||
return nil
|
||||
}
|
||||
v, ok := c.Get(key)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var ms int64
|
||||
switch t := v.(type) {
|
||||
case int:
|
||||
ms = int64(t)
|
||||
case int32:
|
||||
ms = int64(t)
|
||||
case int64:
|
||||
ms = t
|
||||
case float64:
|
||||
ms = int64(t)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
if ms < 0 {
|
||||
return nil
|
||||
}
|
||||
return &ms
|
||||
}
|
||||
|
||||
type parsedOpsError struct {
|
||||
ErrorType string
|
||||
Message string
|
||||
|
||||
175
backend/internal/handler/ops_error_logger_test.go
Normal file
175
backend/internal/handler/ops_error_logger_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func resetOpsErrorLoggerStateForTest(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
opsErrorLogMu.Lock()
|
||||
ch := opsErrorLogQueue
|
||||
opsErrorLogQueue = nil
|
||||
opsErrorLogStopping = true
|
||||
opsErrorLogMu.Unlock()
|
||||
|
||||
if ch != nil {
|
||||
close(ch)
|
||||
}
|
||||
opsErrorLogWorkersWg.Wait()
|
||||
|
||||
opsErrorLogOnce = sync.Once{}
|
||||
opsErrorLogStopOnce = sync.Once{}
|
||||
opsErrorLogWorkersWg = sync.WaitGroup{}
|
||||
opsErrorLogMu = sync.RWMutex{}
|
||||
opsErrorLogStopping = false
|
||||
|
||||
opsErrorLogQueueLen.Store(0)
|
||||
opsErrorLogEnqueued.Store(0)
|
||||
opsErrorLogDropped.Store(0)
|
||||
opsErrorLogProcessed.Store(0)
|
||||
opsErrorLogSanitized.Store(0)
|
||||
opsErrorLogLastDropLogAt.Store(0)
|
||||
|
||||
opsErrorLogShutdownCh = make(chan struct{})
|
||||
opsErrorLogShutdownOnce = sync.Once{}
|
||||
opsErrorLogDrained.Store(false)
|
||||
}
|
||||
|
||||
func TestAttachOpsRequestBodyToEntry_SanitizeAndTrim(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
raw := []byte(`{"access_token":"secret-token","messages":[{"role":"user","content":"hello"}]}`)
|
||||
setOpsRequestContext(c, "claude-3", false, raw)
|
||||
|
||||
entry := &service.OpsInsertErrorLogInput{}
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
require.NotNil(t, entry.RequestBodyBytes)
|
||||
require.Equal(t, len(raw), *entry.RequestBodyBytes)
|
||||
require.NotNil(t, entry.RequestBodyJSON)
|
||||
require.NotContains(t, *entry.RequestBodyJSON, "secret-token")
|
||||
require.Contains(t, *entry.RequestBodyJSON, "[REDACTED]")
|
||||
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
|
||||
}
|
||||
|
||||
func TestAttachOpsRequestBodyToEntry_InvalidJSONKeepsSize(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
raw := []byte("not-json")
|
||||
setOpsRequestContext(c, "claude-3", false, raw)
|
||||
|
||||
entry := &service.OpsInsertErrorLogInput{}
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.NotNil(t, entry.RequestBodyBytes)
|
||||
require.Equal(t, len(raw), *entry.RequestBodyBytes)
|
||||
require.False(t, entry.RequestBodyTruncated)
|
||||
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
|
||||
}
|
||||
|
||||
func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
|
||||
// 禁止 enqueueOpsErrorLog 触发 workers,使用测试队列验证满队列降级。
|
||||
opsErrorLogOnce.Do(func() {})
|
||||
|
||||
opsErrorLogMu.Lock()
|
||||
opsErrorLogQueue = make(chan opsErrorLogJob, 1)
|
||||
opsErrorLogMu.Unlock()
|
||||
|
||||
ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"}
|
||||
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
|
||||
require.Equal(t, int64(1), OpsErrorLogEnqueuedTotal())
|
||||
require.Equal(t, int64(1), OpsErrorLogDroppedTotal())
|
||||
require.Equal(t, int64(1), OpsErrorLogQueueLength())
|
||||
}
|
||||
|
||||
func TestAttachOpsRequestBodyToEntry_EarlyReturnBranches(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
entry := &service.OpsInsertErrorLogInput{}
|
||||
attachOpsRequestBodyToEntry(nil, entry)
|
||||
attachOpsRequestBodyToEntry(&gin.Context{}, nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 无请求体 key
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.Nil(t, entry.RequestBodyBytes)
|
||||
require.False(t, entry.RequestBodyTruncated)
|
||||
|
||||
// 错误类型
|
||||
c.Set(opsRequestBodyKey, "not-bytes")
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.Nil(t, entry.RequestBodyBytes)
|
||||
|
||||
// 空 bytes
|
||||
c.Set(opsRequestBodyKey, []byte{})
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.Nil(t, entry.RequestBodyBytes)
|
||||
|
||||
require.Equal(t, int64(0), OpsErrorLogSanitizedTotal())
|
||||
}
|
||||
|
||||
func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
|
||||
ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"}
|
||||
|
||||
// nil 入参分支
|
||||
enqueueOpsErrorLog(nil, entry)
|
||||
enqueueOpsErrorLog(ops, nil)
|
||||
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||
|
||||
// shutdown 分支
|
||||
close(opsErrorLogShutdownCh)
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||
|
||||
// stopping 分支
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
opsErrorLogMu.Lock()
|
||||
opsErrorLogStopping = true
|
||||
opsErrorLogMu.Unlock()
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||
|
||||
// queue nil 分支(防止启动 worker 干扰)
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
opsErrorLogOnce.Do(func() {})
|
||||
opsErrorLogMu.Lock()
|
||||
opsErrorLogQueue = nil
|
||||
opsErrorLogMu.Unlock()
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||
}
|
||||
677
backend/internal/handler/sora_gateway_handler.go
Normal file
677
backend/internal/handler/sora_gateway_handler.go
Normal file
@@ -0,0 +1,677 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SoraGatewayHandler handles Sora chat completions requests
|
||||
type SoraGatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
soraGatewayService *service.SoraGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
streamMode string
|
||||
soraTLSEnabled bool
|
||||
soraMediaSigningKey string
|
||||
soraMediaRoot string
|
||||
}
|
||||
|
||||
// NewSoraGatewayHandler creates a new SoraGatewayHandler
|
||||
func NewSoraGatewayHandler(
|
||||
gatewayService *service.GatewayService,
|
||||
soraGatewayService *service.SoraGatewayService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||
cfg *config.Config,
|
||||
) *SoraGatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
maxAccountSwitches := 3
|
||||
streamMode := "force"
|
||||
soraTLSEnabled := true
|
||||
signKey := ""
|
||||
mediaRoot := "/app/data/sora"
|
||||
if cfg != nil {
|
||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
||||
}
|
||||
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
|
||||
streamMode = mode
|
||||
}
|
||||
soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
|
||||
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
|
||||
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
|
||||
mediaRoot = root
|
||||
}
|
||||
}
|
||||
return &SoraGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
soraGatewayService: soraGatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
streamMode: strings.ToLower(streamMode),
|
||||
soraTLSEnabled: soraTLSEnabled,
|
||||
soraMediaSigningKey: signKey,
|
||||
soraMediaRoot: mediaRoot,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatCompletions handles Sora /v1/chat/completions endpoint
|
||||
func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.sora_gateway.chat_completions",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
// 校验请求体 JSON 合法性
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
|
||||
msgsResult := gjson.GetBytes(body, "messages")
|
||||
if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
|
||||
return
|
||||
}
|
||||
|
||||
clientStream := gjson.GetBytes(body, "stream").Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream))
|
||||
if !clientStream {
|
||||
if h.streamMode == "error" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
|
||||
return
|
||||
}
|
||||
var err error
|
||||
body, err = sjson.SetBytes(body, "stream", true)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, reqModel, clientStream, body)
|
||||
|
||||
platform := ""
|
||||
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
|
||||
platform = forced
|
||||
} else if apiKey.Group != nil {
|
||||
platform = apiKey.Group.Platform
|
||||
}
|
||||
if platform != service.PlatformSora {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform")
|
||||
return
|
||||
}
|
||||
|
||||
streamStarted := false
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err))
|
||||
} else if !canWait {
|
||||
reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := generateOpenAISessionHash(c, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
var lastFailoverBody []byte
|
||||
var lastFailoverHeaders http.Header
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||
fields := []zap.Field{
|
||||
zap.Int("last_upstream_status", lastFailoverStatus),
|
||||
}
|
||||
if rayID != "" {
|
||||
fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
|
||||
}
|
||||
if mitigated != "" {
|
||||
fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
|
||||
}
|
||||
if contentType != "" {
|
||||
fields = append(fields, zap.String("last_upstream_content_type", contentType))
|
||||
}
|
||||
reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
proxyBound := account.ProxyID != nil
|
||||
proxyID := int64(0)
|
||||
if account.ProxyID != nil {
|
||||
proxyID = *account.ProxyID
|
||||
}
|
||||
tlsFingerprintEnabled := h.soraTLSEnabled
|
||||
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.account_wait_counter_increment_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else if !canWait {
|
||||
reqLog.Info("sora.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}()
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
clientStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.account_slot_acquire_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
|
||||
lastFailoverBody = failoverErr.ResponseBody
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
}
|
||||
if rayID != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_ray", rayID))
|
||||
}
|
||||
if mitigated != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
|
||||
}
|
||||
if contentType != "" {
|
||||
fields = append(fields, zap.String("upstream_content_type", contentType))
|
||||
}
|
||||
reqLog.Warn("sora.upstream_failover_exhausted", fields...)
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
|
||||
lastFailoverBody = failoverErr.ResponseBody
|
||||
switchCount++
|
||||
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.String("upstream_error_code", upstreamErrCode),
|
||||
zap.String("upstream_error_message", upstreamErrMsg),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
}
|
||||
if rayID != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_ray", rayID))
|
||||
}
|
||||
if mitigated != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
|
||||
}
|
||||
if contentType != "" {
|
||||
fields = append(fields, zap.String("upstream_content_type", contentType))
|
||||
}
|
||||
reqLog.Warn("sora.upstream_failover_switching", fields...)
|
||||
continue
|
||||
}
|
||||
reqLog.Error("sora.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("sora.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("sora.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func generateOpenAISessionHash(c *gin.Context, body []byte) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
||||
if sessionID == "" {
|
||||
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
||||
}
|
||||
if sessionID == "" && len(body) > 0 {
|
||||
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||
}
|
||||
if sessionID == "" {
|
||||
return ""
|
||||
}
|
||||
hash := sha256.Sum256([]byte(sessionID))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
if h.usageRecordWorkerPool != nil {
|
||||
h.usageRecordWorkerPool.Submit(task)
|
||||
return
|
||||
}
|
||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
|
||||
if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
|
||||
baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
|
||||
return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
||||
}
|
||||
|
||||
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
|
||||
if strings.EqualFold(upstreamCode, "cf_shield_429") {
|
||||
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
|
||||
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
||||
}
|
||||
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
|
||||
switch statusCode {
|
||||
case 401, 403, 404, 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", upstreamMessage
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
|
||||
}
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||
case 404:
|
||||
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
|
||||
}
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
func cloneHTTPHeaders(headers http.Header) http.Header {
|
||||
if headers == nil {
|
||||
return nil
|
||||
}
|
||||
return headers.Clone()
|
||||
}
|
||||
|
||||
func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
|
||||
if headers != nil {
|
||||
mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
|
||||
contentType = strings.TrimSpace(headers.Get("content-type"))
|
||||
if contentType == "" {
|
||||
contentType = strings.TrimSpace(headers.Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
rayID = soraerror.ExtractCloudflareRayID(headers, body)
|
||||
return rayID, mitigated, contentType
|
||||
}
|
||||
|
||||
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
|
||||
}
|
||||
|
||||
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
|
||||
message = strings.TrimSpace(message)
|
||||
if message == "" {
|
||||
return false
|
||||
}
|
||||
if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
|
||||
lower := strings.ToLower(message)
|
||||
if strings.Contains(lower, "<html") || strings.Contains(lower, "<!doctype html") || strings.Contains(lower, "window._cf_chl_opt") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
|
||||
}
|
||||
|
||||
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
|
||||
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
errorData := map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(errorData)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// MediaProxy serves local Sora media files.
|
||||
func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) {
|
||||
h.proxySoraMedia(c, false)
|
||||
}
|
||||
|
||||
// MediaProxySigned serves local Sora media files with signature verification.
|
||||
func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) {
|
||||
h.proxySoraMedia(c, true)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) {
|
||||
rawPath := c.Param("filepath")
|
||||
if rawPath == "" {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
cleaned := path.Clean(rawPath)
|
||||
if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
query := c.Request.URL.Query()
|
||||
if requireSignature {
|
||||
if h.soraMediaSigningKey == "" {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "api_error",
|
||||
"message": "Sora 媒体签名未配置",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
expiresStr := strings.TrimSpace(query.Get("expires"))
|
||||
signature := strings.TrimSpace(query.Get("sig"))
|
||||
expires, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err != nil || expires <= time.Now().Unix() {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "authentication_error",
|
||||
"message": "Sora 媒体签名已过期",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
query.Del("sig")
|
||||
query.Del("expires")
|
||||
signingQuery := query.Encode()
|
||||
if !service.VerifySoraMediaURL(cleaned, signingQuery, expires, signature, h.soraMediaSigningKey) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "authentication_error",
|
||||
"message": "Sora 媒体签名无效",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(h.soraMediaRoot) == "" {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "api_error",
|
||||
"message": "Sora 媒体目录未配置",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
relative := strings.TrimPrefix(cleaned, "/")
|
||||
localPath := filepath.Join(h.soraMediaRoot, filepath.FromSlash(relative))
|
||||
if _, err := os.Stat(localPath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
c.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
c.File(localPath)
|
||||
}
|
||||
688
backend/internal/handler/sora_gateway_handler_test.go
Normal file
688
backend/internal/handler/sora_gateway_handler_test.go
Normal file
@@ -0,0 +1,688 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/testutil"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// 编译期接口断言
|
||||
var _ service.SoraClient = (*stubSoraClient)(nil)
|
||||
var _ service.AccountRepository = (*stubAccountRepo)(nil)
|
||||
var _ service.GroupRepository = (*stubGroupRepo)(nil)
|
||||
var _ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
|
||||
|
||||
type stubSoraClient struct {
|
||||
imageURLs []string
|
||||
}
|
||||
|
||||
func (s *stubSoraClient) Enabled() bool { return true }
|
||||
func (s *stubSoraClient) UploadImage(ctx context.Context, account *service.Account, data []byte, filename string) (string, error) {
|
||||
return "upload", nil
|
||||
}
|
||||
func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.Account, req service.SoraImageRequest) (string, error) {
|
||||
return "task-image", nil
|
||||
}
|
||||
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
||||
return "cameo-1", nil
|
||||
}
|
||||
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
|
||||
return &service.SoraCameoStatus{
|
||||
Status: "finalized",
|
||||
StatusMessage: "Completed",
|
||||
DisplayNameHint: "Character",
|
||||
UsernameHint: "user.character",
|
||||
ProfileAssetURL: "https://example.com/avatar.webp",
|
||||
}, nil
|
||||
}
|
||||
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
|
||||
return []byte("avatar"), nil
|
||||
}
|
||||
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
||||
return "asset-pointer", nil
|
||||
}
|
||||
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
|
||||
return "character-1", nil
|
||||
}
|
||||
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
|
||||
return "s_post", nil
|
||||
}
|
||||
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
|
||||
return "https://example.com/no-watermark.mp4", nil
|
||||
}
|
||||
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
return "enhanced prompt", nil
|
||||
}
|
||||
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
|
||||
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
||||
}
|
||||
func (s *stubSoraClient) GetVideoTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraVideoTaskStatus, error) {
|
||||
return &service.SoraVideoTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
||||
}
|
||||
|
||||
type stubAccountRepo struct {
|
||||
accounts map[int64]*service.Account
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) Create(ctx context.Context, account *service.Account) error { return nil }
|
||||
func (r *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||
if acc, ok := r.accounts[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
return nil, service.ErrAccountNotFound
|
||||
}
|
||||
func (r *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
|
||||
var result []*service.Account
|
||||
for _, id := range ids {
|
||||
if acc, ok := r.accounts[id]; ok {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
||||
_, ok := r.accounts[id]
|
||||
return ok, nil
|
||||
}
|
||||
func (r *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
return map[string]int64{}, nil
|
||||
}
|
||||
func (r *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return nil }
|
||||
func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) { return nil, nil }
|
||||
func (r *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
return r.listSchedulableByPlatform(platform), nil
|
||||
}
|
||||
func (r *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { return nil }
|
||||
func (r *stubAccountRepo) ClearError(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (r *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) {
|
||||
return r.listSchedulable(), nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||
return r.listSchedulable(), nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
return r.listSchedulableByPlatform(platform), nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||
return r.listSchedulableByPlatform(platform), nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||
var result []service.Account
|
||||
for _, acc := range r.accounts {
|
||||
for _, platform := range platforms {
|
||||
if acc.Platform == platform && acc.IsSchedulable() {
|
||||
result = append(result, *acc)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
|
||||
return r.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) listSchedulable() []service.Account {
|
||||
var result []service.Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.IsSchedulable() {
|
||||
result = append(result, *acc)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) listSchedulableByPlatform(platform string) []service.Account {
|
||||
var result []service.Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.Platform == platform && acc.IsSchedulable() {
|
||||
result = append(result, *acc)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type stubGroupRepo struct {
|
||||
group *service.Group
|
||||
}
|
||||
|
||||
func (r *stubGroupRepo) Create(ctx context.Context, group *service.Group) error { return nil }
|
||||
func (r *stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, error) {
|
||||
return r.group, nil
|
||||
}
|
||||
func (r *stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
||||
return r.group, nil
|
||||
}
|
||||
func (r *stubGroupRepo) Update(ctx context.Context, group *service.Group) error { return nil }
|
||||
func (r *stubGroupRepo) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubGroupRepo) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubGroupRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { return nil, nil }
|
||||
func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (r *stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubGroupRepo) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubUsageLogRepo struct{}
|
||||
|
||||
func (s *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (s *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
RunMode: config.RunModeSimple,
|
||||
Gateway: config.GatewayConfig{
|
||||
SoraStreamMode: "force",
|
||||
MaxAccountSwitches: 1,
|
||||
Scheduling: config.GatewaySchedulingConfig{
|
||||
LoadBatchEnabled: false,
|
||||
},
|
||||
},
|
||||
Concurrency: config.ConcurrencyConfig{PingInterval: 0},
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
BaseURL: "https://sora.test",
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
account := &service.Account{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}
|
||||
accountRepo := &stubAccountRepo{accounts: map[int64]*service.Account{account.ID: account}}
|
||||
group := &service.Group{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Hydrated: true}
|
||||
groupRepo := &stubGroupRepo{group: group}
|
||||
|
||||
usageLogRepo := &stubUsageLogRepo{}
|
||||
deferredService := service.NewDeferredService(accountRepo, nil, 0)
|
||||
billingService := service.NewBillingService(cfg, nil)
|
||||
concurrencyService := service.NewConcurrencyService(testutil.StubConcurrencyCache{})
|
||||
billingCacheService := service.NewBillingCacheService(nil, nil, nil, cfg)
|
||||
t.Cleanup(func() {
|
||||
billingCacheService.Stop()
|
||||
})
|
||||
|
||||
gatewayService := service.NewGatewayService(
|
||||
accountRepo,
|
||||
groupRepo,
|
||||
usageLogRepo,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
testutil.StubGatewayCache{},
|
||||
cfg,
|
||||
nil,
|
||||
concurrencyService,
|
||||
billingService,
|
||||
nil,
|
||||
billingCacheService,
|
||||
nil,
|
||||
nil,
|
||||
deferredService,
|
||||
nil,
|
||||
testutil.StubSessionLimitCache{},
|
||||
nil,
|
||||
)
|
||||
|
||||
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
|
||||
soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg)
|
||||
|
||||
handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, nil, cfg)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := `{"model":"gpt-image","messages":[{"role":"user","content":"hello"}]}`
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/sora/v1/chat/completions", strings.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 1,
|
||||
UserID: 1,
|
||||
Status: service.StatusActive,
|
||||
GroupID: &group.ID,
|
||||
User: &service.User{ID: 1, Concurrency: 1, Status: service.StatusActive},
|
||||
Group: group,
|
||||
}
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: apiKey.User.Concurrency})
|
||||
|
||||
handler.ChatCompletions(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp["media_url"])
|
||||
}
|
||||
|
||||
// TestSoraHandler_StreamForcing 验证 sora handler 的 stream 强制逻辑
|
||||
func TestSoraHandler_StreamForcing(t *testing.T) {
|
||||
// 测试 1:stream=false 时 sjson 强制修改为 true
|
||||
body := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":false}`)
|
||||
clientStream := gjson.GetBytes(body, "stream").Bool()
|
||||
require.False(t, clientStream)
|
||||
newBody, err := sjson.SetBytes(body, "stream", true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, gjson.GetBytes(newBody, "stream").Bool())
|
||||
|
||||
// 测试 2:stream=true 时不修改
|
||||
body2 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":true}`)
|
||||
require.True(t, gjson.GetBytes(body2, "stream").Bool())
|
||||
|
||||
// 测试 3:无 stream 字段时 gjson 返回 false(零值)
|
||||
body3 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}]}`)
|
||||
require.False(t, gjson.GetBytes(body3, "stream").Bool())
|
||||
}
|
||||
|
||||
// TestSoraHandler_ValidationExtraction 验证 sora handler 中 gjson 字段校验逻辑
|
||||
func TestSoraHandler_ValidationExtraction(t *testing.T) {
|
||||
// model 缺失
|
||||
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
require.True(t, !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "")
|
||||
|
||||
// model 为数字 → 类型不是 gjson.String,应被拒绝
|
||||
body1b := []byte(`{"model":123,"messages":[{"role":"user","content":"test"}]}`)
|
||||
modelResult1b := gjson.GetBytes(body1b, "model")
|
||||
require.True(t, modelResult1b.Exists())
|
||||
require.NotEqual(t, gjson.String, modelResult1b.Type)
|
||||
|
||||
// messages 缺失
|
||||
body2 := []byte(`{"model":"sora"}`)
|
||||
require.False(t, gjson.GetBytes(body2, "messages").IsArray())
|
||||
|
||||
// messages 不是 JSON 数组(字符串)
|
||||
body3 := []byte(`{"model":"sora","messages":"not array"}`)
|
||||
require.False(t, gjson.GetBytes(body3, "messages").IsArray())
|
||||
|
||||
// messages 是对象而非数组 → IsArray 返回 false
|
||||
body4 := []byte(`{"model":"sora","messages":{}}`)
|
||||
require.False(t, gjson.GetBytes(body4, "messages").IsArray())
|
||||
|
||||
// messages 是空数组 → IsArray 为 true 但 len==0,应被拒绝
|
||||
body5 := []byte(`{"model":"sora","messages":[]}`)
|
||||
msgsResult := gjson.GetBytes(body5, "messages")
|
||||
require.True(t, msgsResult.IsArray())
|
||||
require.Equal(t, 0, len(msgsResult.Array()))
|
||||
|
||||
// 非法 JSON 被 gjson.ValidBytes 拦截
|
||||
require.False(t, gjson.ValidBytes([]byte(`{invalid`)))
|
||||
}
|
||||
|
||||
// TestGenerateOpenAISessionHash_WithBody 验证 generateOpenAISessionHash 的 body/header 解析逻辑
|
||||
func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// 从 body 提取 prompt_cache_key
|
||||
body := []byte(`{"model":"sora","prompt_cache_key":"session-abc"}`)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
|
||||
hash := generateOpenAISessionHash(c, body)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// 无 prompt_cache_key 且无 header → 空 hash
|
||||
body2 := []byte(`{"model":"sora"}`)
|
||||
hash2 := generateOpenAISessionHash(c, body2)
|
||||
require.Empty(t, hash2)
|
||||
|
||||
// header 优先于 body
|
||||
c.Request.Header.Set("session_id", "from-header")
|
||||
hash3 := generateOpenAISessionHash(c, body)
|
||||
require.NotEmpty(t, hash3)
|
||||
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
|
||||
}
|
||||
|
||||
func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "包含双引号",
|
||||
errType: "upstream_error",
|
||||
message: `upstream returned "invalid" payload`,
|
||||
},
|
||||
{
|
||||
name: "包含换行和制表符",
|
||||
errType: "rate_limit_error",
|
||||
message: "line1\nline2\ttab",
|
||||
},
|
||||
{
|
||||
name: "包含反斜杠",
|
||||
errType: "upstream_error",
|
||||
message: `path C:\Users\test\file.txt not found`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
|
||||
|
||||
body := w.Body.String()
|
||||
require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头")
|
||||
require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾")
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||
require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行")
|
||||
require.Equal(t, "event: error", lines[0])
|
||||
require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀")
|
||||
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON")
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok, "JSON 中应包含 error 对象")
|
||||
require.Equal(t, tt.errType, errorObj["type"])
|
||||
require.Equal(t, tt.message, errorObj["message"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
|
||||
h.handleFailoverExhausted(c, http.StatusBadGateway, nil, resp, true)
|
||||
|
||||
body := w.Body.String()
|
||||
require.True(t, strings.HasPrefix(body, "event: error\n"))
|
||||
require.True(t, strings.HasSuffix(body, "\n\n"))
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "upstream_error", errorObj["type"])
|
||||
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
||||
body := []byte(`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "upstream_error", errorObj["type"])
|
||||
msg, _ := errorObj["message"].(string)
|
||||
require.Contains(t, msg, "Cloudflare challenge")
|
||||
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
|
||||
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "rate_limit_error", errorObj["type"])
|
||||
msg, _ := errorObj["message"].(string)
|
||||
require.Contains(t, msg, "Cloudflare shield")
|
||||
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
|
||||
}
|
||||
|
||||
func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-mitigated", "challenge")
|
||||
headers.Set("content-type", "text/html")
|
||||
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
|
||||
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
|
||||
require.Equal(t, "9cff2d62d83bb98d", rayID)
|
||||
require.Equal(t, "challenge", mitigated)
|
||||
require.Equal(t, "text/html", contentType)
|
||||
}
|
||||
@@ -392,7 +392,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs)
|
||||
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs, time.Time{}, time.Time{})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
136
backend/internal/handler/usage_record_submit_task_test.go
Normal file
136
backend/internal/handler/usage_record_submit_task_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newUsageRecordTestPool(t *testing.T) *service.UsageRecordWorkerPool {
|
||||
t.Helper()
|
||||
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 8,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: "drop",
|
||||
OverflowSamplePercent: 0,
|
||||
AutoScaleEnabled: false,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
return pool
|
||||
}
|
||||
|
||||
func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
pool := newUsageRecordTestPool(t)
|
||||
h := &GatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
done := make(chan struct{})
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
close(done)
|
||||
})
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("task not executed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
|
||||
h := &GatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
t.Fatal("expected deadline in fallback context")
|
||||
}
|
||||
called.Store(true)
|
||||
})
|
||||
|
||||
require.True(t, called.Load())
|
||||
}
|
||||
|
||||
func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
h := &GatewayHandler{}
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
pool := newUsageRecordTestPool(t)
|
||||
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
done := make(chan struct{})
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
close(done)
|
||||
})
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("task not executed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
t.Fatal("expected deadline in fallback context")
|
||||
}
|
||||
called.Store(true)
|
||||
})
|
||||
|
||||
require.True(t, called.Load())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{}
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
pool := newUsageRecordTestPool(t)
|
||||
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
done := make(chan struct{})
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
close(done)
|
||||
})
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("task not executed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
|
||||
h := &SoraGatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
t.Fatal("expected deadline in fallback context")
|
||||
}
|
||||
called.Store(true)
|
||||
})
|
||||
|
||||
require.True(t, called.Load())
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
h := &SoraGatewayHandler{}
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(nil)
|
||||
})
|
||||
}
|
||||
@@ -53,8 +53,8 @@ func ProvideAdminHandlers(
|
||||
}
|
||||
|
||||
// ProvideSystemHandler creates admin.SystemHandler with UpdateService
|
||||
func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler {
|
||||
return admin.NewSystemHandler(updateService)
|
||||
func ProvideSystemHandler(updateService *service.UpdateService, lockService *service.SystemOperationLockService) *admin.SystemHandler {
|
||||
return admin.NewSystemHandler(updateService, lockService)
|
||||
}
|
||||
|
||||
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
|
||||
@@ -74,8 +74,11 @@ func ProvideHandlers(
|
||||
adminHandlers *AdminHandlers,
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
soraGatewayHandler *SoraGatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
totpHandler *TotpHandler,
|
||||
_ *service.IdempotencyCoordinator,
|
||||
_ *service.IdempotencyCleanupService,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: authHandler,
|
||||
@@ -88,6 +91,7 @@ func ProvideHandlers(
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
SoraGateway: soraGatewayHandler,
|
||||
Setting: settingHandler,
|
||||
Totp: totpHandler,
|
||||
}
|
||||
@@ -105,6 +109,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAnnouncementHandler,
|
||||
NewGatewayHandler,
|
||||
NewOpenAIGatewayHandler,
|
||||
NewSoraGatewayHandler,
|
||||
NewTotpHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
|
||||
@@ -21,11 +21,18 @@ var (
|
||||
// - "" (默认): 使用 /v1/messages, /v1beta/models(混合模式,可调度 antigravity 账户)
|
||||
// - "/antigravity": 使用 /antigravity/v1/messages, /antigravity/v1beta/models(非混合模式,仅 antigravity 账户)
|
||||
endpointPrefix = getEnv("ENDPOINT_PREFIX", "")
|
||||
claudeAPIKey = "sk-8e572bc3b3de92ace4f41f4256c28600ca11805732a7b693b5c44741346bbbb3"
|
||||
geminiAPIKey = "sk-5950197a2085b38bbe5a1b229cc02b8ece914963fc44cacc06d497ae8b87410f"
|
||||
testInterval = 1 * time.Second // 测试间隔,防止限流
|
||||
)
|
||||
|
||||
const (
|
||||
// 注意:E2E 测试请使用环境变量注入密钥,避免任何凭证进入仓库历史。
|
||||
// 例如:
|
||||
// export CLAUDE_API_KEY="sk-..."
|
||||
// export GEMINI_API_KEY="sk-..."
|
||||
claudeAPIKeyEnv = "CLAUDE_API_KEY"
|
||||
geminiAPIKeyEnv = "GEMINI_API_KEY"
|
||||
)
|
||||
|
||||
func getEnv(key, defaultVal string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
@@ -65,16 +72,45 @@ func TestMain(m *testing.M) {
|
||||
if endpointPrefix != "" {
|
||||
mode = "Antigravity 模式"
|
||||
}
|
||||
fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s)\n\n", baseURL, endpointPrefix, mode)
|
||||
claudeKeySet := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) != ""
|
||||
geminiKeySet := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) != ""
|
||||
fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s, %s=%v, %s=%v)\n\n",
|
||||
baseURL,
|
||||
endpointPrefix,
|
||||
mode,
|
||||
claudeAPIKeyEnv,
|
||||
claudeKeySet,
|
||||
geminiAPIKeyEnv,
|
||||
geminiKeySet,
|
||||
)
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func requireClaudeAPIKey(t *testing.T) string {
|
||||
t.Helper()
|
||||
key := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv))
|
||||
if key == "" {
|
||||
t.Skipf("未设置 %s,跳过 Claude 相关 E2E 测试", claudeAPIKeyEnv)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func requireGeminiAPIKey(t *testing.T) string {
|
||||
t.Helper()
|
||||
key := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv))
|
||||
if key == "" {
|
||||
t.Skipf("未设置 %s,跳过 Gemini 相关 E2E 测试", geminiAPIKeyEnv)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// TestClaudeModelsList 测试 GET /v1/models
|
||||
func TestClaudeModelsList(t *testing.T) {
|
||||
claudeKey := requireClaudeAPIKey(t)
|
||||
url := baseURL + endpointPrefix + "/v1/models"
|
||||
|
||||
req, _ := http.NewRequest("GET", url, nil)
|
||||
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+claudeKey)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
@@ -106,10 +142,11 @@ func TestClaudeModelsList(t *testing.T) {
|
||||
|
||||
// TestGeminiModelsList 测试 GET /v1beta/models
|
||||
func TestGeminiModelsList(t *testing.T) {
|
||||
geminiKey := requireGeminiAPIKey(t)
|
||||
url := baseURL + endpointPrefix + "/v1beta/models"
|
||||
|
||||
req, _ := http.NewRequest("GET", url, nil)
|
||||
req.Header.Set("Authorization", "Bearer "+geminiAPIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+geminiKey)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
@@ -137,21 +174,22 @@ func TestGeminiModelsList(t *testing.T) {
|
||||
|
||||
// TestClaudeMessages 测试 Claude /v1/messages 接口
|
||||
func TestClaudeMessages(t *testing.T) {
|
||||
claudeKey := requireClaudeAPIKey(t)
|
||||
for i, model := range claudeModels {
|
||||
if i > 0 {
|
||||
time.Sleep(testInterval)
|
||||
}
|
||||
t.Run(model+"_非流式", func(t *testing.T) {
|
||||
testClaudeMessage(t, model, false)
|
||||
testClaudeMessage(t, claudeKey, model, false)
|
||||
})
|
||||
time.Sleep(testInterval)
|
||||
t.Run(model+"_流式", func(t *testing.T) {
|
||||
testClaudeMessage(t, model, true)
|
||||
testClaudeMessage(t, claudeKey, model, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testClaudeMessage(t *testing.T, model string, stream bool) {
|
||||
func testClaudeMessage(t *testing.T, claudeKey string, model string, stream bool) {
|
||||
url := baseURL + endpointPrefix + "/v1/messages"
|
||||
|
||||
payload := map[string]any{
|
||||
@@ -166,7 +204,7 @@ func testClaudeMessage(t *testing.T, model string, stream bool) {
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+claudeKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
@@ -213,21 +251,22 @@ func testClaudeMessage(t *testing.T, model string, stream bool) {
|
||||
|
||||
// TestGeminiGenerateContent 测试 Gemini /v1beta/models/:model 接口
|
||||
func TestGeminiGenerateContent(t *testing.T) {
|
||||
geminiKey := requireGeminiAPIKey(t)
|
||||
for i, model := range geminiModels {
|
||||
if i > 0 {
|
||||
time.Sleep(testInterval)
|
||||
}
|
||||
t.Run(model+"_非流式", func(t *testing.T) {
|
||||
testGeminiGenerate(t, model, false)
|
||||
testGeminiGenerate(t, geminiKey, model, false)
|
||||
})
|
||||
time.Sleep(testInterval)
|
||||
t.Run(model+"_流式", func(t *testing.T) {
|
||||
testGeminiGenerate(t, model, true)
|
||||
testGeminiGenerate(t, geminiKey, model, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testGeminiGenerate(t *testing.T, model string, stream bool) {
|
||||
func testGeminiGenerate(t *testing.T, geminiKey string, model string, stream bool) {
|
||||
action := "generateContent"
|
||||
if stream {
|
||||
action = "streamGenerateContent"
|
||||
@@ -254,7 +293,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) {
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+geminiAPIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+geminiKey)
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
@@ -301,6 +340,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) {
|
||||
// TestClaudeMessagesWithComplexTools 测试带复杂工具 schema 的请求
|
||||
// 模拟 Claude Code 发送的请求,包含需要清理的 JSON Schema 字段
|
||||
func TestClaudeMessagesWithComplexTools(t *testing.T) {
|
||||
claudeKey := requireClaudeAPIKey(t)
|
||||
// 测试模型列表(只测试几个代表性模型)
|
||||
models := []string{
|
||||
"claude-opus-4-5-20251101", // Claude 模型
|
||||
@@ -312,12 +352,12 @@ func TestClaudeMessagesWithComplexTools(t *testing.T) {
|
||||
time.Sleep(testInterval)
|
||||
}
|
||||
t.Run(model+"_复杂工具", func(t *testing.T) {
|
||||
testClaudeMessageWithTools(t, model)
|
||||
testClaudeMessageWithTools(t, claudeKey, model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testClaudeMessageWithTools(t *testing.T, model string) {
|
||||
func testClaudeMessageWithTools(t *testing.T, claudeKey string, model string) {
|
||||
url := baseURL + endpointPrefix + "/v1/messages"
|
||||
|
||||
// 构造包含复杂 schema 的工具定义(模拟 Claude Code 的工具)
|
||||
@@ -473,7 +513,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) {
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+claudeKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
@@ -519,6 +559,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) {
|
||||
// 验证:当历史 assistant 消息包含 tool_use 但没有 signature 时,
|
||||
// 系统应自动添加 dummy thought_signature 避免 Gemini 400 错误
|
||||
func TestClaudeMessagesWithThinkingAndTools(t *testing.T) {
|
||||
claudeKey := requireClaudeAPIKey(t)
|
||||
models := []string{
|
||||
"claude-haiku-4-5-20251001", // gemini-3-flash
|
||||
}
|
||||
@@ -527,12 +568,12 @@ func TestClaudeMessagesWithThinkingAndTools(t *testing.T) {
|
||||
time.Sleep(testInterval)
|
||||
}
|
||||
t.Run(model+"_thinking模式工具调用", func(t *testing.T) {
|
||||
testClaudeThinkingWithToolHistory(t, model)
|
||||
testClaudeThinkingWithToolHistory(t, claudeKey, model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
|
||||
func testClaudeThinkingWithToolHistory(t *testing.T, claudeKey string, model string) {
|
||||
url := baseURL + endpointPrefix + "/v1/messages"
|
||||
|
||||
// 模拟历史对话:用户请求 → assistant 调用工具 → 工具返回 → 继续对话
|
||||
@@ -600,7 +641,7 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+claudeKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
@@ -649,6 +690,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) {
|
||||
if endpointPrefix != "/antigravity" {
|
||||
t.Skip("仅在 Antigravity 模式下运行")
|
||||
}
|
||||
claudeKey := requireClaudeAPIKey(t)
|
||||
|
||||
// 测试通过 Claude 端点调用 Gemini 模型
|
||||
geminiViaClaude := []string{
|
||||
@@ -664,11 +706,11 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) {
|
||||
time.Sleep(testInterval)
|
||||
}
|
||||
t.Run(model+"_通过Claude端点", func(t *testing.T) {
|
||||
testClaudeMessage(t, model, false)
|
||||
testClaudeMessage(t, claudeKey, model, false)
|
||||
})
|
||||
time.Sleep(testInterval)
|
||||
t.Run(model+"_通过Claude端点_流式", func(t *testing.T) {
|
||||
testClaudeMessage(t, model, true)
|
||||
testClaudeMessage(t, claudeKey, model, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -676,6 +718,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) {
|
||||
// TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景
|
||||
// 验证:Gemini 模型接受没有 signature 的 thinking block
|
||||
func TestClaudeMessagesWithNoSignature(t *testing.T) {
|
||||
claudeKey := requireClaudeAPIKey(t)
|
||||
models := []string{
|
||||
"claude-haiku-4-5-20251001", // gemini-3-flash - 支持无 signature
|
||||
}
|
||||
@@ -684,12 +727,12 @@ func TestClaudeMessagesWithNoSignature(t *testing.T) {
|
||||
time.Sleep(testInterval)
|
||||
}
|
||||
t.Run(model+"_无signature", func(t *testing.T) {
|
||||
testClaudeWithNoSignature(t, model)
|
||||
testClaudeWithNoSignature(t, claudeKey, model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testClaudeWithNoSignature(t *testing.T, model string) {
|
||||
func testClaudeWithNoSignature(t *testing.T, claudeKey string, model string) {
|
||||
url := baseURL + endpointPrefix + "/v1/messages"
|
||||
|
||||
// 模拟历史对话包含 thinking block 但没有 signature
|
||||
@@ -732,7 +775,7 @@ func testClaudeWithNoSignature(t *testing.T, model string) {
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+claudeKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
@@ -777,6 +820,7 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) {
|
||||
if endpointPrefix != "/antigravity" {
|
||||
t.Skip("仅在 Antigravity 模式下运行")
|
||||
}
|
||||
geminiKey := requireGeminiAPIKey(t)
|
||||
|
||||
// 测试通过 Gemini 端点调用 Claude 模型
|
||||
claudeViaGemini := []string{
|
||||
@@ -789,11 +833,11 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) {
|
||||
time.Sleep(testInterval)
|
||||
}
|
||||
t.Run(model+"_通过Gemini端点", func(t *testing.T) {
|
||||
testGeminiGenerate(t, model, false)
|
||||
testGeminiGenerate(t, geminiKey, model, false)
|
||||
})
|
||||
time.Sleep(testInterval)
|
||||
t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) {
|
||||
testGeminiGenerate(t, model, true)
|
||||
testGeminiGenerate(t, geminiKey, model, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
48
backend/internal/integration/e2e_helpers_test.go
Normal file
48
backend/internal/integration/e2e_helpers_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
//go:build e2e
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// E2E Mock 模式支持
|
||||
// =============================================================================
|
||||
// 当 E2E_MOCK=true 时,使用本地 Mock 响应替代真实 API 调用。
|
||||
// 这允许在没有真实 API Key 的环境(如 CI)中验证基本的请求/响应流程。
|
||||
|
||||
// isMockMode 检查是否启用 Mock 模式
|
||||
func isMockMode() bool {
|
||||
return strings.EqualFold(os.Getenv("E2E_MOCK"), "true")
|
||||
}
|
||||
|
||||
// skipIfNoRealAPI 如果未配置真实 API Key 且不在 Mock 模式,则跳过测试
|
||||
func skipIfNoRealAPI(t *testing.T) {
|
||||
t.Helper()
|
||||
if isMockMode() {
|
||||
return // Mock 模式下不跳过
|
||||
}
|
||||
claudeKey := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv))
|
||||
geminiKey := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv))
|
||||
if claudeKey == "" && geminiKey == "" {
|
||||
t.Skip("未设置 API Key 且未启用 Mock 模式,跳过测试")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// API Key 脱敏(Task 6.10)
|
||||
// =============================================================================
|
||||
|
||||
// safeLogKey 安全地记录 API Key(仅显示前 8 位)
|
||||
func safeLogKey(t *testing.T, prefix string, key string) {
|
||||
t.Helper()
|
||||
key = strings.TrimSpace(key)
|
||||
if len(key) <= 8 {
|
||||
t.Logf("%s: ***(长度: %d)", prefix, len(key))
|
||||
return
|
||||
}
|
||||
t.Logf("%s: %s...(长度: %d)", prefix, key[:8], len(key))
|
||||
}
|
||||
317
backend/internal/integration/e2e_user_flow_test.go
Normal file
317
backend/internal/integration/e2e_user_flow_test.go
Normal file
@@ -0,0 +1,317 @@
|
||||
//go:build e2e
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// E2E 用户流程测试
|
||||
// 测试完整的用户操作链路:注册 → 登录 → 创建 API Key → 调用网关 → 查询用量
|
||||
|
||||
var (
|
||||
testUserEmail = "e2e-test-" + fmt.Sprintf("%d", time.Now().UnixMilli()) + "@test.local"
|
||||
testUserPassword = "E2eTest@12345"
|
||||
testUserName = "e2e-test-user"
|
||||
)
|
||||
|
||||
// TestUserRegistrationAndLogin 测试用户注册和登录流程
|
||||
func TestUserRegistrationAndLogin(t *testing.T) {
|
||||
// 步骤 1: 注册新用户
|
||||
t.Run("注册新用户", func(t *testing.T) {
|
||||
payload := map[string]string{
|
||||
"email": testUserEmail,
|
||||
"password": testUserPassword,
|
||||
"username": testUserName,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
resp, err := doRequest(t, "POST", "/api/auth/register", body, "")
|
||||
if err != nil {
|
||||
t.Skipf("注册接口不可用,跳过用户流程测试: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 注册可能返回 200(成功)或 400(邮箱已存在)或 403(注册已关闭)
|
||||
switch resp.StatusCode {
|
||||
case 200:
|
||||
t.Logf("✅ 用户注册成功: %s", testUserEmail)
|
||||
case 400:
|
||||
t.Logf("⚠️ 用户可能已存在: %s", string(respBody))
|
||||
case 403:
|
||||
t.Skipf("注册功能已关闭: %s", string(respBody))
|
||||
default:
|
||||
t.Logf("⚠️ 注册返回 HTTP %d: %s(继续尝试登录)", resp.StatusCode, string(respBody))
|
||||
}
|
||||
})
|
||||
|
||||
// 步骤 2: 登录获取 JWT
|
||||
var accessToken string
|
||||
t.Run("用户登录获取JWT", func(t *testing.T) {
|
||||
payload := map[string]string{
|
||||
"email": testUserEmail,
|
||||
"password": testUserPassword,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
resp, err := doRequest(t, "POST", "/api/auth/login", body, "")
|
||||
if err != nil {
|
||||
t.Fatalf("登录请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
t.Skipf("登录失败 HTTP %d: %s(可能需要先注册用户)", resp.StatusCode, string(respBody))
|
||||
return
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
t.Fatalf("解析登录响应失败: %v", err)
|
||||
}
|
||||
|
||||
// 尝试从标准响应格式获取 token
|
||||
if token, ok := result["access_token"].(string); ok && token != "" {
|
||||
accessToken = token
|
||||
} else if data, ok := result["data"].(map[string]any); ok {
|
||||
if token, ok := data["access_token"].(string); ok {
|
||||
accessToken = token
|
||||
}
|
||||
}
|
||||
|
||||
if accessToken == "" {
|
||||
t.Skipf("未获取到 access_token,响应: %s", string(respBody))
|
||||
return
|
||||
}
|
||||
|
||||
// 验证 token 不为空且格式基本正确
|
||||
if len(accessToken) < 10 {
|
||||
t.Fatalf("access_token 格式异常: %s", accessToken)
|
||||
}
|
||||
|
||||
t.Logf("✅ 登录成功,获取 JWT(长度: %d)", len(accessToken))
|
||||
})
|
||||
|
||||
if accessToken == "" {
|
||||
t.Skip("未获取到 JWT,跳过后续测试")
|
||||
return
|
||||
}
|
||||
|
||||
// 步骤 3: 使用 JWT 获取当前用户信息
|
||||
t.Run("获取当前用户信息", func(t *testing.T) {
|
||||
resp, err := doRequest(t, "GET", "/api/user/me", nil, accessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("HTTP %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
t.Logf("✅ 成功获取用户信息")
|
||||
})
|
||||
}
|
||||
|
||||
// TestAPIKeyLifecycle 测试 API Key 的创建和使用
|
||||
func TestAPIKeyLifecycle(t *testing.T) {
|
||||
// 先登录获取 JWT
|
||||
accessToken := loginTestUser(t)
|
||||
if accessToken == "" {
|
||||
t.Skip("无法登录,跳过 API Key 生命周期测试")
|
||||
return
|
||||
}
|
||||
|
||||
var apiKey string
|
||||
|
||||
// 步骤 1: 创建 API Key
|
||||
t.Run("创建API_Key", func(t *testing.T) {
|
||||
payload := map[string]string{
|
||||
"name": "e2e-test-key-" + fmt.Sprintf("%d", time.Now().UnixMilli()),
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
resp, err := doRequest(t, "POST", "/api/keys", body, accessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("创建 API Key 请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
t.Skipf("创建 API Key 失败 HTTP %d: %s", resp.StatusCode, string(respBody))
|
||||
return
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
// 从响应中提取 key
|
||||
if key, ok := result["key"].(string); ok {
|
||||
apiKey = key
|
||||
} else if data, ok := result["data"].(map[string]any); ok {
|
||||
if key, ok := data["key"].(string); ok {
|
||||
apiKey = key
|
||||
}
|
||||
}
|
||||
|
||||
if apiKey == "" {
|
||||
t.Skipf("未获取到 API Key,响应: %s", string(respBody))
|
||||
return
|
||||
}
|
||||
|
||||
// 验证 API Key 脱敏日志(只显示前 8 位)
|
||||
masked := apiKey
|
||||
if len(masked) > 8 {
|
||||
masked = masked[:8] + "..."
|
||||
}
|
||||
t.Logf("✅ API Key 创建成功: %s", masked)
|
||||
})
|
||||
|
||||
if apiKey == "" {
|
||||
t.Skip("未创建 API Key,跳过后续测试")
|
||||
return
|
||||
}
|
||||
|
||||
// 步骤 2: 使用 API Key 调用网关(需要 Claude 或 Gemini 可用)
|
||||
t.Run("使用API_Key调用网关", func(t *testing.T) {
|
||||
// 尝试调用 models 列表(最轻量的 API 调用)
|
||||
resp, err := doRequest(t, "GET", "/v1/models", nil, apiKey)
|
||||
if err != nil {
|
||||
t.Fatalf("网关请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 可能返回 200(成功)或 402(余额不足)或 403(无可用账户)
|
||||
switch {
|
||||
case resp.StatusCode == 200:
|
||||
t.Logf("✅ API Key 网关调用成功")
|
||||
case resp.StatusCode == 402:
|
||||
t.Logf("⚠️ 余额不足,但 API Key 认证通过")
|
||||
case resp.StatusCode == 403:
|
||||
t.Logf("⚠️ 无可用账户,但 API Key 认证通过")
|
||||
default:
|
||||
t.Logf("⚠️ 网关返回 HTTP %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
})
|
||||
|
||||
// 步骤 3: 查询用量记录
|
||||
t.Run("查询用量记录", func(t *testing.T) {
|
||||
resp, err := doRequest(t, "GET", "/api/usage/dashboard", nil, accessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("用量查询请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Logf("⚠️ 用量查询返回 HTTP %d: %s", resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("✅ 用量查询成功")
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 辅助函数
|
||||
// =============================================================================
|
||||
|
||||
func doRequest(t *testing.T, method, path string, body []byte, token string) (*http.Response, error) {
|
||||
t.Helper()
|
||||
|
||||
url := baseURL + path
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
bodyReader = bytes.NewReader(body)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(method, url, bodyReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func loginTestUser(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
// 先尝试用管理员账户登录
|
||||
adminEmail := getEnv("ADMIN_EMAIL", "admin@sub2api.local")
|
||||
adminPassword := getEnv("ADMIN_PASSWORD", "")
|
||||
|
||||
if adminPassword == "" {
|
||||
// 尝试用测试用户
|
||||
adminEmail = testUserEmail
|
||||
adminPassword = testUserPassword
|
||||
}
|
||||
|
||||
payload := map[string]string{
|
||||
"email": adminEmail,
|
||||
"password": adminPassword,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
resp, err := doRequest(t, "POST", "/api/auth/login", body, "")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return ""
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if token, ok := result["access_token"].(string); ok {
|
||||
return token
|
||||
}
|
||||
if data, ok := result["data"].(map[string]any); ok {
|
||||
if token, ok := data["access_token"].(string); ok {
|
||||
return token
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// redactAPIKey API Key 脱敏,只显示前 8 位
|
||||
func redactAPIKey(key string) string {
|
||||
key = strings.TrimSpace(key)
|
||||
if len(key) <= 8 {
|
||||
return "***"
|
||||
}
|
||||
return key[:8] + "..."
|
||||
}
|
||||
@@ -60,6 +60,49 @@ func TestRateLimiterFailureModes(t *testing.T) {
|
||||
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
|
||||
}
|
||||
|
||||
func TestRateLimiterDifferentIPsIndependent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
callCounts := make(map[string]int64)
|
||||
originalRun := rateLimitRun
|
||||
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
|
||||
callCounts[key]++
|
||||
return callCounts[key], false, nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
rateLimitRun = originalRun
|
||||
})
|
||||
|
||||
limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}))
|
||||
|
||||
router := gin.New()
|
||||
router.Use(limiter.Limit("api", 1, time.Second))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
// 第一个 IP 的请求应通过
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req1.RemoteAddr = "10.0.0.1:1234"
|
||||
rec1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code, "第一个 IP 的第一次请求应通过")
|
||||
|
||||
// 第二个 IP 的请求应独立通过(不受第一个 IP 的计数影响)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req2.RemoteAddr = "10.0.0.2:5678"
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code, "第二个 IP 的第一次请求应独立通过")
|
||||
|
||||
// 第一个 IP 的第二次请求应被限流
|
||||
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req3.RemoteAddr = "10.0.0.1:1234"
|
||||
rec3 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec3, req3)
|
||||
require.Equal(t, http.StatusTooManyRequests, rec3.Code, "第一个 IP 的第二次请求应被限流")
|
||||
}
|
||||
|
||||
func TestRateLimiterSuccessAndLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -204,9 +204,14 @@ func shouldFallbackToNextURL(err error, statusCode int) bool {
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
|
||||
clientSecret, err := getClientSecret()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_secret", ClientSecret)
|
||||
params.Set("client_secret", clientSecret)
|
||||
params.Set("code", code)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("grant_type", "authorization_code")
|
||||
@@ -243,9 +248,14 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*
|
||||
|
||||
// RefreshToken 刷新 access_token
|
||||
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
|
||||
clientSecret, err := getClientSecret()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_secret", ClientSecret)
|
||||
params.Set("client_secret", clientSecret)
|
||||
params.Set("refresh_token", refreshToken)
|
||||
params.Set("grant_type", "refresh_token")
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,11 +6,14 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -21,7 +24,11 @@ const (
|
||||
|
||||
// Antigravity OAuth 客户端凭证
|
||||
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
ClientSecret = "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"
|
||||
ClientSecret = ""
|
||||
|
||||
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
|
||||
// 出于安全原因,该值不得硬编码入库。
|
||||
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
|
||||
|
||||
// 固定的 redirect_uri(用户需手动复制 code)
|
||||
RedirectURI = "http://localhost:8085/callback"
|
||||
@@ -57,6 +64,17 @@ func init() {
|
||||
// GetUserAgent 返回当前配置的 User-Agent
|
||||
func GetUserAgent() string {
|
||||
return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion)
|
||||
|
||||
func getClientSecret() (string, error) {
|
||||
if v := strings.TrimSpace(ClientSecret); v != "" {
|
||||
return v, nil
|
||||
}
|
||||
if v, ok := os.LookupEnv(AntigravityOAuthClientSecretEnv); ok {
|
||||
if vv := strings.TrimSpace(v); vv != "" {
|
||||
return vv, nil
|
||||
}
|
||||
}
|
||||
return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv)
|
||||
}
|
||||
|
||||
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
|
||||
|
||||
704
backend/internal/pkg/antigravity/oauth_test.go
Normal file
704
backend/internal/pkg/antigravity/oauth_test.go
Normal file
@@ -0,0 +1,704 @@
|
||||
//go:build unit
|
||||
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// getClientSecret
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGetClientSecret_环境变量设置(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value")
|
||||
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("获取 client_secret 失败: %v", err)
|
||||
}
|
||||
if secret != "my-secret-value" {
|
||||
t.Errorf("client_secret 不匹配: got %s, want my-secret-value", secret)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量为空(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "")
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("环境变量为空时应返回错误")
|
||||
}
|
||||
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
|
||||
t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量未设置(t *testing.T) {
|
||||
// t.Setenv 会在测试结束时恢复,但我们需要确保它不存在
|
||||
// 注意:如果 ClientSecret 常量非空,这个测试会直接返回常量值
|
||||
// 当前代码中 ClientSecret = "",所以会走环境变量逻辑
|
||||
|
||||
// 明确设置再取消,确保环境变量不存在
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "")
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("环境变量未设置时应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量含空格(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, " ")
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("环境变量仅含空格时应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量有前后空格(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, " valid-secret ")
|
||||
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("获取 client_secret 失败: %v", err)
|
||||
}
|
||||
if secret != "valid-secret" {
|
||||
t.Errorf("应去除前后空格: got %q, want %q", secret, "valid-secret")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ForwardBaseURLs
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestForwardBaseURLs_Daily优先(t *testing.T) {
|
||||
urls := ForwardBaseURLs()
|
||||
if len(urls) == 0 {
|
||||
t.Fatal("ForwardBaseURLs 返回空列表")
|
||||
}
|
||||
|
||||
// daily URL 应排在第一位
|
||||
if urls[0] != antigravityDailyBaseURL {
|
||||
t.Errorf("第一个 URL 应为 daily: got %s, want %s", urls[0], antigravityDailyBaseURL)
|
||||
}
|
||||
|
||||
// 应包含所有 URL
|
||||
if len(urls) != len(BaseURLs) {
|
||||
t.Errorf("URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
|
||||
}
|
||||
|
||||
// 验证 prod URL 也在列表中
|
||||
found := false
|
||||
for _, u := range urls {
|
||||
if u == antigravityProdBaseURL {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("ForwardBaseURLs 中缺少 prod URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardBaseURLs_不修改原切片(t *testing.T) {
|
||||
originalFirst := BaseURLs[0]
|
||||
_ = ForwardBaseURLs()
|
||||
// 确保原始 BaseURLs 未被修改
|
||||
if BaseURLs[0] != originalFirst {
|
||||
t.Errorf("ForwardBaseURLs 不应修改原始 BaseURLs: got %s, want %s", BaseURLs[0], originalFirst)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// URLAvailability
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNewURLAvailability(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
if ua == nil {
|
||||
t.Fatal("NewURLAvailability 返回 nil")
|
||||
}
|
||||
if ua.ttl != 5*time.Minute {
|
||||
t.Errorf("TTL 不匹配: got %v, want 5m", ua.ttl)
|
||||
}
|
||||
if ua.unavailable == nil {
|
||||
t.Error("unavailable map 不应为 nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_MarkUnavailable(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
testURL := "https://example.com"
|
||||
|
||||
ua.MarkUnavailable(testURL)
|
||||
|
||||
if ua.IsAvailable(testURL) {
|
||||
t.Error("标记为不可用后 IsAvailable 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_MarkSuccess(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
testURL := "https://example.com"
|
||||
|
||||
// 先标记为不可用
|
||||
ua.MarkUnavailable(testURL)
|
||||
if ua.IsAvailable(testURL) {
|
||||
t.Error("标记为不可用后应不可用")
|
||||
}
|
||||
|
||||
// 标记成功后应恢复可用
|
||||
ua.MarkSuccess(testURL)
|
||||
if !ua.IsAvailable(testURL) {
|
||||
t.Error("MarkSuccess 后应恢复可用")
|
||||
}
|
||||
|
||||
// 验证 lastSuccess 被设置
|
||||
ua.mu.RLock()
|
||||
if ua.lastSuccess != testURL {
|
||||
t.Errorf("lastSuccess 不匹配: got %s, want %s", ua.lastSuccess, testURL)
|
||||
}
|
||||
ua.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestURLAvailability_IsAvailable_TTL过期(t *testing.T) {
|
||||
// 使用极短的 TTL
|
||||
ua := NewURLAvailability(1 * time.Millisecond)
|
||||
testURL := "https://example.com"
|
||||
|
||||
ua.MarkUnavailable(testURL)
|
||||
// 等待 TTL 过期
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
if !ua.IsAvailable(testURL) {
|
||||
t.Error("TTL 过期后 URL 应恢复可用")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_IsAvailable_未标记的URL(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
if !ua.IsAvailable("https://never-marked.com") {
|
||||
t.Error("未标记的 URL 应默认可用")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLs(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
|
||||
// 默认所有 URL 都可用
|
||||
urls := ua.GetAvailableURLs()
|
||||
if len(urls) != len(BaseURLs) {
|
||||
t.Errorf("可用 URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLs_标记一个不可用(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
|
||||
if len(BaseURLs) < 2 {
|
||||
t.Skip("BaseURLs 少于 2 个,跳过此测试")
|
||||
}
|
||||
|
||||
ua.MarkUnavailable(BaseURLs[0])
|
||||
urls := ua.GetAvailableURLs()
|
||||
|
||||
// 标记的 URL 不应出现在可用列表中
|
||||
for _, u := range urls {
|
||||
if u == BaseURLs[0] {
|
||||
t.Errorf("被标记不可用的 URL 不应出现在可用列表中: %s", BaseURLs[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
if len(urls) != 3 {
|
||||
t.Errorf("可用 URL 数量不匹配: got %d, want 3", len(urls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess优先(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
|
||||
|
||||
ua.MarkSuccess("https://c.com")
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
if len(urls) != 3 {
|
||||
t.Fatalf("可用 URL 数量不匹配: got %d, want 3", len(urls))
|
||||
}
|
||||
// c.com 应排在第一位
|
||||
if urls[0] != "https://c.com" {
|
||||
t.Errorf("lastSuccess 应排在第一位: got %s, want https://c.com", urls[0])
|
||||
}
|
||||
// 其余按原始顺序
|
||||
if urls[1] != "https://a.com" {
|
||||
t.Errorf("第二个应为 a.com: got %s", urls[1])
|
||||
}
|
||||
if urls[2] != "https://b.com" {
|
||||
t.Errorf("第三个应为 b.com: got %s", urls[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不可用(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com"}
|
||||
|
||||
ua.MarkSuccess("https://b.com")
|
||||
ua.MarkUnavailable("https://b.com")
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
// b.com 被标记不可用,不应出现
|
||||
if len(urls) != 1 {
|
||||
t.Fatalf("可用 URL 数量不匹配: got %d, want 1", len(urls))
|
||||
}
|
||||
if urls[0] != "https://a.com" {
|
||||
t.Errorf("仅 a.com 应可用: got %s", urls[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不在列表中(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com"}
|
||||
|
||||
ua.MarkSuccess("https://not-in-list.com")
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
// lastSuccess 不在自定义列表中,不应被添加
|
||||
if len(urls) != 2 {
|
||||
t.Fatalf("可用 URL 数量不匹配: got %d, want 2", len(urls))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SessionStore
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNewSessionStore(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
if store == nil {
|
||||
t.Fatal("NewSessionStore 返回 nil")
|
||||
}
|
||||
if store.sessions == nil {
|
||||
t.Error("sessions map 不应为 nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_SetAndGet(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "test-state",
|
||||
CodeVerifier: "test-verifier",
|
||||
ProxyURL: "http://proxy.example.com",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
store.Set("session-1", session)
|
||||
|
||||
got, ok := store.Get("session-1")
|
||||
if !ok {
|
||||
t.Fatal("Get 应返回 true")
|
||||
}
|
||||
if got.State != "test-state" {
|
||||
t.Errorf("State 不匹配: got %s", got.State)
|
||||
}
|
||||
if got.CodeVerifier != "test-verifier" {
|
||||
t.Errorf("CodeVerifier 不匹配: got %s", got.CodeVerifier)
|
||||
}
|
||||
if got.ProxyURL != "http://proxy.example.com" {
|
||||
t.Errorf("ProxyURL 不匹配: got %s", got.ProxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Get_不存在(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
_, ok := store.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("不存在的 session 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Get_过期(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "expired-state",
|
||||
CreatedAt: time.Now().Add(-SessionTTL - time.Minute), // 已过期
|
||||
}
|
||||
|
||||
store.Set("expired-session", session)
|
||||
|
||||
_, ok := store.Get("expired-session")
|
||||
if ok {
|
||||
t.Error("过期的 session 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Delete(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "to-delete",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
store.Set("del-session", session)
|
||||
store.Delete("del-session")
|
||||
|
||||
_, ok := store.Get("del-session")
|
||||
if ok {
|
||||
t.Error("删除后 Get 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Delete_不存在(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
// 删除不存在的 session 不应 panic
|
||||
store.Delete("nonexistent")
|
||||
}
|
||||
|
||||
func TestSessionStore_Stop(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
store.Stop()
|
||||
|
||||
// 多次 Stop 不应 panic
|
||||
store.Stop()
|
||||
}
|
||||
|
||||
func TestSessionStore_多个Session(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
session := &OAuthSession{
|
||||
State: "state-" + string(rune('0'+i)),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
store.Set("session-"+string(rune('0'+i)), session)
|
||||
}
|
||||
|
||||
// 验证都能取到
|
||||
for i := 0; i < 10; i++ {
|
||||
_, ok := store.Get("session-" + string(rune('0'+i)))
|
||||
if !ok {
|
||||
t.Errorf("session-%d 应存在", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateRandomBytes
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateRandomBytes_长度正确(t *testing.T) {
|
||||
sizes := []int{0, 1, 16, 32, 64, 128}
|
||||
for _, size := range sizes {
|
||||
b, err := GenerateRandomBytes(size)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRandomBytes(%d) 失败: %v", size, err)
|
||||
}
|
||||
if len(b) != size {
|
||||
t.Errorf("长度不匹配: got %d, want %d", len(b), size)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomBytes_不同调用产生不同结果(t *testing.T) {
|
||||
b1, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
t.Fatalf("第一次调用失败: %v", err)
|
||||
}
|
||||
b2, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
t.Fatalf("第二次调用失败: %v", err)
|
||||
}
|
||||
// 两次生成的随机字节应该不同(概率上几乎不可能相同)
|
||||
if string(b1) == string(b2) {
|
||||
t.Error("两次生成的随机字节相同,概率极低,可能有问题")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateState
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateState_返回值格式(t *testing.T) {
|
||||
state, err := GenerateState()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateState 失败: %v", err)
|
||||
}
|
||||
if state == "" {
|
||||
t.Error("GenerateState 返回空字符串")
|
||||
}
|
||||
// base64url 编码不应包含 +, /, =
|
||||
if strings.ContainsAny(state, "+/=") {
|
||||
t.Errorf("GenerateState 返回值包含非 base64url 字符: %s", state)
|
||||
}
|
||||
// 32 字节的 base64url 编码长度应为 43(去掉了尾部 = 填充)
|
||||
if len(state) != 43 {
|
||||
t.Errorf("GenerateState 返回值长度不匹配: got %d, want 43", len(state))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateState_唯一性(t *testing.T) {
|
||||
s1, _ := GenerateState()
|
||||
s2, _ := GenerateState()
|
||||
if s1 == s2 {
|
||||
t.Error("两次 GenerateState 结果相同")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateSessionID
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateSessionID_返回值格式(t *testing.T) {
|
||||
id, err := GenerateSessionID()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateSessionID 失败: %v", err)
|
||||
}
|
||||
if id == "" {
|
||||
t.Error("GenerateSessionID 返回空字符串")
|
||||
}
|
||||
// 16 字节的 hex 编码长度应为 32
|
||||
if len(id) != 32 {
|
||||
t.Errorf("GenerateSessionID 返回值长度不匹配: got %d, want 32", len(id))
|
||||
}
|
||||
// 验证是合法的 hex 字符串
|
||||
if _, err := hex.DecodeString(id); err != nil {
|
||||
t.Errorf("GenerateSessionID 返回值不是合法的 hex 字符串: %s, err: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSessionID_唯一性(t *testing.T) {
|
||||
id1, _ := GenerateSessionID()
|
||||
id2, _ := GenerateSessionID()
|
||||
if id1 == id2 {
|
||||
t.Error("两次 GenerateSessionID 结果相同")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeVerifier
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeVerifier_返回值格式(t *testing.T) {
|
||||
verifier, err := GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCodeVerifier 失败: %v", err)
|
||||
}
|
||||
if verifier == "" {
|
||||
t.Error("GenerateCodeVerifier 返回空字符串")
|
||||
}
|
||||
// base64url 编码不应包含 +, /, =
|
||||
if strings.ContainsAny(verifier, "+/=") {
|
||||
t.Errorf("GenerateCodeVerifier 返回值包含非 base64url 字符: %s", verifier)
|
||||
}
|
||||
// 32 字节的 base64url 编码长度应为 43
|
||||
if len(verifier) != 43 {
|
||||
t.Errorf("GenerateCodeVerifier 返回值长度不匹配: got %d, want 43", len(verifier))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeVerifier_唯一性(t *testing.T) {
|
||||
v1, _ := GenerateCodeVerifier()
|
||||
v2, _ := GenerateCodeVerifier()
|
||||
if v1 == v2 {
|
||||
t.Error("两次 GenerateCodeVerifier 结果相同")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeChallenge
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeChallenge_SHA256_Base64URL(t *testing.T) {
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
|
||||
challenge := GenerateCodeChallenge(verifier)
|
||||
|
||||
// 手动计算预期值
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
expected := strings.TrimRight(base64.URLEncoding.EncodeToString(hash[:]), "=")
|
||||
|
||||
if challenge != expected {
|
||||
t.Errorf("CodeChallenge 不匹配: got %s, want %s", challenge, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_不含填充字符(t *testing.T) {
|
||||
challenge := GenerateCodeChallenge("test-verifier")
|
||||
if strings.Contains(challenge, "=") {
|
||||
t.Errorf("CodeChallenge 不应包含 = 填充字符: %s", challenge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_不含非URL安全字符(t *testing.T) {
|
||||
challenge := GenerateCodeChallenge("another-verifier")
|
||||
if strings.ContainsAny(challenge, "+/") {
|
||||
t.Errorf("CodeChallenge 不应包含 + 或 / 字符: %s", challenge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_相同输入相同输出(t *testing.T) {
|
||||
c1 := GenerateCodeChallenge("same-verifier")
|
||||
c2 := GenerateCodeChallenge("same-verifier")
|
||||
if c1 != c2 {
|
||||
t.Errorf("相同输入应产生相同输出: got %s and %s", c1, c2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_不同输入不同输出(t *testing.T) {
|
||||
c1 := GenerateCodeChallenge("verifier-1")
|
||||
c2 := GenerateCodeChallenge("verifier-2")
|
||||
if c1 == c2 {
|
||||
t.Error("不同输入应产生不同输出")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BuildAuthorizationURL
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBuildAuthorizationURL_参数验证(t *testing.T) {
|
||||
state := "test-state-123"
|
||||
codeChallenge := "test-challenge-abc"
|
||||
|
||||
authURL := BuildAuthorizationURL(state, codeChallenge)
|
||||
|
||||
// 验证以 AuthorizeURL 开头
|
||||
if !strings.HasPrefix(authURL, AuthorizeURL+"?") {
|
||||
t.Errorf("URL 应以 %s? 开头: got %s", AuthorizeURL, authURL)
|
||||
}
|
||||
|
||||
// 解析 URL 并验证参数
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
params := parsed.Query()
|
||||
|
||||
expectedParams := map[string]string{
|
||||
"client_id": ClientID,
|
||||
"redirect_uri": RedirectURI,
|
||||
"response_type": "code",
|
||||
"scope": Scopes,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
"include_granted_scopes": "true",
|
||||
}
|
||||
|
||||
for key, want := range expectedParams {
|
||||
got := params.Get(key)
|
||||
if got != want {
|
||||
t.Errorf("参数 %s 不匹配: got %q, want %q", key, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_参数数量(t *testing.T) {
|
||||
authURL := BuildAuthorizationURL("s", "c")
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
params := parsed.Query()
|
||||
// 应包含 10 个参数
|
||||
expectedCount := 10
|
||||
if len(params) != expectedCount {
|
||||
t.Errorf("参数数量不匹配: got %d, want %d", len(params), expectedCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_特殊字符编码(t *testing.T) {
|
||||
state := "state+with/special=chars"
|
||||
codeChallenge := "challenge+value"
|
||||
|
||||
authURL := BuildAuthorizationURL(state, codeChallenge)
|
||||
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析后应正确还原特殊字符
|
||||
if got := parsed.Query().Get("state"); got != state {
|
||||
t.Errorf("state 参数编码/解码不匹配: got %q, want %q", got, state)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 常量值验证
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestConstants_值正确(t *testing.T) {
|
||||
if AuthorizeURL != "https://accounts.google.com/o/oauth2/v2/auth" {
|
||||
t.Errorf("AuthorizeURL 不匹配: got %s", AuthorizeURL)
|
||||
}
|
||||
if TokenURL != "https://oauth2.googleapis.com/token" {
|
||||
t.Errorf("TokenURL 不匹配: got %s", TokenURL)
|
||||
}
|
||||
if UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" {
|
||||
t.Errorf("UserInfoURL 不匹配: got %s", UserInfoURL)
|
||||
}
|
||||
if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" {
|
||||
t.Errorf("ClientID 不匹配: got %s", ClientID)
|
||||
}
|
||||
if ClientSecret != "" {
|
||||
t.Error("ClientSecret 应为空字符串")
|
||||
}
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||
}
|
||||
if UserAgent != "antigravity/1.15.8 windows/amd64" {
|
||||
t.Errorf("UserAgent 不匹配: got %s", UserAgent)
|
||||
}
|
||||
if SessionTTL != 30*time.Minute {
|
||||
t.Errorf("SessionTTL 不匹配: got %v", SessionTTL)
|
||||
}
|
||||
if URLAvailabilityTTL != 5*time.Minute {
|
||||
t.Errorf("URLAvailabilityTTL 不匹配: got %v", URLAvailabilityTTL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopes_包含必要范围(t *testing.T) {
|
||||
expectedScopes := []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/cclog",
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
||||
}
|
||||
|
||||
for _, scope := range expectedScopes {
|
||||
if !strings.Contains(Scopes, scope) {
|
||||
t.Errorf("Scopes 缺少 %s", scope)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,13 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
|
||||
@@ -341,12 +344,30 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string {
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// generateRandomID 生成随机 ID
|
||||
// fallbackCounter 降级伪随机 ID 的全局计数器,混入 seed 避免高并发下 UnixNano 相同导致碰撞。
|
||||
var fallbackCounter uint64
|
||||
|
||||
// generateRandomID 生成密码学安全的随机 ID
|
||||
func generateRandomID() string {
|
||||
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, 12)
|
||||
for i := range result {
|
||||
result[i] = chars[i%len(chars)]
|
||||
id := make([]byte, 12)
|
||||
randBytes := make([]byte, 12)
|
||||
if _, err := rand.Read(randBytes); err != nil {
|
||||
// 避免在请求路径里 panic:极端情况下熵源不可用时降级为伪随机。
|
||||
// 这里主要用于生成响应/工具调用的临时 ID,安全要求不高但需尽量避免碰撞。
|
||||
cnt := atomic.AddUint64(&fallbackCounter, 1)
|
||||
seed := uint64(time.Now().UnixNano()) ^ cnt
|
||||
seed ^= uint64(len(err.Error())) << 32
|
||||
for i := range id {
|
||||
seed ^= seed << 13
|
||||
seed ^= seed >> 7
|
||||
seed ^= seed << 17
|
||||
id[i] = chars[int(seed)%len(chars)]
|
||||
}
|
||||
return string(id)
|
||||
}
|
||||
return string(result)
|
||||
for i, b := range randBytes {
|
||||
id[i] = chars[int(b)%len(chars)]
|
||||
}
|
||||
return string(id)
|
||||
}
|
||||
|
||||
109
backend/internal/pkg/antigravity/response_transformer_test.go
Normal file
109
backend/internal/pkg/antigravity/response_transformer_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
//go:build unit
|
||||
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- Task 7: 验证 generateRandomID 和降级碰撞防护 ---
|
||||
|
||||
func TestGenerateRandomID_Uniqueness(t *testing.T) {
|
||||
seen := make(map[string]struct{}, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
id := generateRandomID()
|
||||
require.Len(t, id, 12, "ID 长度应为 12")
|
||||
_, dup := seen[id]
|
||||
require.False(t, dup, "第 %d 次调用生成了重复 ID: %s", i, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallbackCounter_Increments(t *testing.T) {
|
||||
// 验证 fallbackCounter 的原子递增行为确保降级分支不会生成相同 seed
|
||||
before := atomic.LoadUint64(&fallbackCounter)
|
||||
cnt1 := atomic.AddUint64(&fallbackCounter, 1)
|
||||
cnt2 := atomic.AddUint64(&fallbackCounter, 1)
|
||||
require.Equal(t, before+1, cnt1, "第一次递增应为 before+1")
|
||||
require.Equal(t, before+2, cnt2, "第二次递增应为 before+2")
|
||||
require.NotEqual(t, cnt1, cnt2, "连续两次递增的计数器值应不同")
|
||||
}
|
||||
|
||||
func TestFallbackCounter_ConcurrentIncrements(t *testing.T) {
|
||||
// 验证并发递增的原子性 — 每次递增都应产生唯一值
|
||||
const goroutines = 50
|
||||
results := make([]uint64, goroutines)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
results[idx] = atomic.AddUint64(&fallbackCounter, 1)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// 所有结果应唯一
|
||||
seen := make(map[uint64]bool, goroutines)
|
||||
for _, v := range results {
|
||||
assert.False(t, seen[v], "并发递增产生了重复值: %d", v)
|
||||
seen[v] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomID_Charset(t *testing.T) {
|
||||
const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
validSet := make(map[byte]struct{}, len(validChars))
|
||||
for i := 0; i < len(validChars); i++ {
|
||||
validSet[validChars[i]] = struct{}{}
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
id := generateRandomID()
|
||||
for j := 0; j < len(id); j++ {
|
||||
_, ok := validSet[id[j]]
|
||||
require.True(t, ok, "ID 包含非法字符: %c (ID=%s)", id[j], id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomID_Length(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
id := generateRandomID()
|
||||
assert.Len(t, id, 12, "每次生成的 ID 长度应为 12")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomID_ConcurrentUniqueness(t *testing.T) {
|
||||
// 验证并发调用不会产生重复 ID
|
||||
const goroutines = 100
|
||||
results := make([]string, goroutines)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
results[idx] = generateRandomID()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
seen := make(map[string]bool, goroutines)
|
||||
for _, id := range results {
|
||||
assert.False(t, seen[id], "并发调用产生了重复 ID: %s", id)
|
||||
seen[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGenerateRandomID(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = generateRandomID()
|
||||
}
|
||||
}
|
||||
@@ -8,9 +8,21 @@ const (
|
||||
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
|
||||
ForcePlatform Key = "ctx_force_platform"
|
||||
|
||||
// RequestID 为服务端生成/透传的请求 ID。
|
||||
RequestID Key = "ctx_request_id"
|
||||
|
||||
// ClientRequestID 客户端请求的唯一标识,用于追踪请求全生命周期(用于 Ops 监控与排障)。
|
||||
ClientRequestID Key = "ctx_client_request_id"
|
||||
|
||||
// Model 请求模型标识(用于统一请求链路日志字段)。
|
||||
Model Key = "ctx_model"
|
||||
|
||||
// Platform 当前请求最终命中的平台(用于统一请求链路日志字段)。
|
||||
Platform Key = "ctx_platform"
|
||||
|
||||
// AccountID 当前请求最终命中的账号 ID(用于统一请求链路日志字段)。
|
||||
AccountID Key = "ctx_account_id"
|
||||
|
||||
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
|
||||
RetryCount Key = "ctx_retry_count"
|
||||
|
||||
@@ -32,4 +44,12 @@ const (
|
||||
// SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。
|
||||
// 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。
|
||||
SingleAccountRetry Key = "ctx_single_account_retry"
|
||||
|
||||
// PrefetchedStickyAccountID 标识上游(通常 handler)预取到的 sticky session 账号 ID。
|
||||
// Service 层可复用该值,避免同请求链路重复读取 Redis。
|
||||
PrefetchedStickyAccountID Key = "ctx_prefetched_sticky_account_id"
|
||||
|
||||
// PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。
|
||||
// Service 层仅在分组匹配时复用 PrefetchedStickyAccountID,避免分组切换重试误用旧 sticky。
|
||||
PrefetchedStickyGroupID Key = "ctx_prefetched_sticky_group_id"
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ func DefaultModels() []Model {
|
||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -38,8 +38,13 @@ const (
|
||||
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
|
||||
// They enable the "login without creating your own OAuth client" experience, but Google may
|
||||
// restrict which scopes are allowed for this client.
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
GeminiCLIOAuthClientSecret = "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"
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
// GeminiCLIOAuthClientSecret is intentionally not embedded in this repository.
|
||||
// If you rely on the built-in Gemini CLI OAuth client, you MUST provide its client_secret via config/env.
|
||||
GeminiCLIOAuthClientSecret = ""
|
||||
|
||||
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
|
||||
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"
|
||||
|
||||
SessionTTL = 30 * time.Minute
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ var DefaultModels = []Model{
|
||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
|
||||
}
|
||||
|
||||
// DefaultTestModel is the default model to preselect in test flows.
|
||||
|
||||
@@ -6,10 +6,14 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
type OAuthConfig struct {
|
||||
@@ -164,15 +168,24 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
|
||||
}
|
||||
|
||||
// Fall back to built-in Gemini CLI OAuth client when not configured.
|
||||
// SECURITY: This repo does not embed the built-in client secret; it must be provided via env.
|
||||
if effective.ClientID == "" && effective.ClientSecret == "" {
|
||||
secret := strings.TrimSpace(GeminiCLIOAuthClientSecret)
|
||||
if secret == "" {
|
||||
if v, ok := os.LookupEnv(GeminiCLIOAuthClientSecretEnv); ok {
|
||||
secret = strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
if secret == "" {
|
||||
return OAuthConfig{}, infraerrors.Newf(http.StatusBadRequest, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING", "built-in Gemini CLI OAuth client_secret is not configured; set %s or provide a custom OAuth client", GeminiCLIOAuthClientSecretEnv)
|
||||
}
|
||||
effective.ClientID = GeminiCLIOAuthClientID
|
||||
effective.ClientSecret = GeminiCLIOAuthClientSecret
|
||||
effective.ClientSecret = secret
|
||||
} else if effective.ClientID == "" || effective.ClientSecret == "" {
|
||||
return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
|
||||
return OAuthConfig{}, infraerrors.New(http.StatusBadRequest, "GEMINI_OAUTH_CLIENT_NOT_CONFIGURED", "OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
|
||||
}
|
||||
|
||||
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID &&
|
||||
effective.ClientSecret == GeminiCLIOAuthClientSecret
|
||||
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID
|
||||
|
||||
if effective.Scopes == "" {
|
||||
// Use different default scopes based on OAuth type
|
||||
|
||||
@@ -1,11 +1,439 @@
|
||||
package geminicli
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SessionStore 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSessionStore_SetAndGet(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "test-state",
|
||||
OAuthType: "code_assist",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
store.Set("sid-1", session)
|
||||
|
||||
got, ok := store.Get("sid-1")
|
||||
if !ok {
|
||||
t.Fatal("期望 Get 返回 ok=true,实际返回 false")
|
||||
}
|
||||
if got.State != "test-state" {
|
||||
t.Errorf("期望 State=%q,实际=%q", "test-state", got.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_GetNotFound(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
_, ok := store.Get("不存在的ID")
|
||||
if ok {
|
||||
t.Error("期望不存在的 sessionID 返回 ok=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_GetExpired(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
// 创建一个已过期的 session(CreatedAt 设置为 SessionTTL+1 分钟之前)
|
||||
session := &OAuthSession{
|
||||
State: "expired-state",
|
||||
OAuthType: "code_assist",
|
||||
CreatedAt: time.Now().Add(-(SessionTTL + 1*time.Minute)),
|
||||
}
|
||||
store.Set("expired-sid", session)
|
||||
|
||||
_, ok := store.Get("expired-sid")
|
||||
if ok {
|
||||
t.Error("期望过期的 session 返回 ok=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Delete(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "to-delete",
|
||||
OAuthType: "code_assist",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
store.Set("del-sid", session)
|
||||
|
||||
// 先确认存在
|
||||
if _, ok := store.Get("del-sid"); !ok {
|
||||
t.Fatal("删除前 session 应该存在")
|
||||
}
|
||||
|
||||
store.Delete("del-sid")
|
||||
|
||||
if _, ok := store.Get("del-sid"); ok {
|
||||
t.Error("删除后 session 不应该存在")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Stop_Idempotent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
|
||||
// 多次调用 Stop 不应 panic
|
||||
store.Stop()
|
||||
store.Stop()
|
||||
store.Stop()
|
||||
}
|
||||
|
||||
func TestSessionStore_ConcurrentAccess(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
const goroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines * 3)
|
||||
|
||||
// 并发写入
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
sid := "concurrent-" + string(rune('A'+idx%26))
|
||||
store.Set(sid, &OAuthSession{
|
||||
State: sid,
|
||||
OAuthType: "code_assist",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 并发读取
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
sid := "concurrent-" + string(rune('A'+idx%26))
|
||||
store.Get(sid) // 可能找到也可能没找到,关键是不 panic
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 并发删除
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
sid := "concurrent-" + string(rune('A'+idx%26))
|
||||
store.Delete(sid)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateRandomBytes 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateRandomBytes(t *testing.T) {
|
||||
tests := []int{0, 1, 16, 32, 64}
|
||||
for _, n := range tests {
|
||||
b, err := GenerateRandomBytes(n)
|
||||
if err != nil {
|
||||
t.Errorf("GenerateRandomBytes(%d) 出错: %v", n, err)
|
||||
continue
|
||||
}
|
||||
if len(b) != n {
|
||||
t.Errorf("GenerateRandomBytes(%d) 返回长度=%d,期望=%d", n, len(b), n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomBytes_Uniqueness(t *testing.T) {
|
||||
// 两次调用应该返回不同的结果(极小概率相同,32字节足够)
|
||||
a, _ := GenerateRandomBytes(32)
|
||||
b, _ := GenerateRandomBytes(32)
|
||||
if string(a) == string(b) {
|
||||
t.Error("两次 GenerateRandomBytes(32) 返回了相同结果,随机性可能有问题")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateState 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateState(t *testing.T) {
|
||||
state, err := GenerateState()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateState() 出错: %v", err)
|
||||
}
|
||||
if state == "" {
|
||||
t.Error("GenerateState() 返回空字符串")
|
||||
}
|
||||
// base64url 编码不应包含 padding '='
|
||||
if strings.Contains(state, "=") {
|
||||
t.Errorf("GenerateState() 结果包含 '=' padding: %s", state)
|
||||
}
|
||||
// base64url 不应包含 '+' 或 '/'
|
||||
if strings.ContainsAny(state, "+/") {
|
||||
t.Errorf("GenerateState() 结果包含非 base64url 字符: %s", state)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateSessionID 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateSessionID(t *testing.T) {
|
||||
sid, err := GenerateSessionID()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateSessionID() 出错: %v", err)
|
||||
}
|
||||
// 16 字节 -> 32 个 hex 字符
|
||||
if len(sid) != 32 {
|
||||
t.Errorf("GenerateSessionID() 长度=%d,期望=32", len(sid))
|
||||
}
|
||||
// 必须是合法的 hex 字符串
|
||||
if _, err := hex.DecodeString(sid); err != nil {
|
||||
t.Errorf("GenerateSessionID() 不是合法的 hex 字符串: %s, err=%v", sid, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSessionID_Uniqueness(t *testing.T) {
|
||||
a, _ := GenerateSessionID()
|
||||
b, _ := GenerateSessionID()
|
||||
if a == b {
|
||||
t.Error("两次 GenerateSessionID() 返回了相同结果")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeVerifier 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeVerifier(t *testing.T) {
|
||||
verifier, err := GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCodeVerifier() 出错: %v", err)
|
||||
}
|
||||
if verifier == "" {
|
||||
t.Error("GenerateCodeVerifier() 返回空字符串")
|
||||
}
|
||||
// RFC 7636 要求 code_verifier 至少 43 个字符
|
||||
if len(verifier) < 43 {
|
||||
t.Errorf("GenerateCodeVerifier() 长度=%d,RFC 7636 要求至少 43 字符", len(verifier))
|
||||
}
|
||||
// base64url 编码不应包含 padding 和非 URL 安全字符
|
||||
if strings.Contains(verifier, "=") {
|
||||
t.Errorf("GenerateCodeVerifier() 包含 '=' padding: %s", verifier)
|
||||
}
|
||||
if strings.ContainsAny(verifier, "+/") {
|
||||
t.Errorf("GenerateCodeVerifier() 包含非 base64url 字符: %s", verifier)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeChallenge 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeChallenge(t *testing.T) {
|
||||
// 使用已知输入验证输出
|
||||
// RFC 7636 附录 B 示例: verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
// 预期 challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
expected := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||
|
||||
challenge := GenerateCodeChallenge(verifier)
|
||||
if challenge != expected {
|
||||
t.Errorf("GenerateCodeChallenge(%q) = %q,期望 %q", verifier, challenge, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_NoPadding(t *testing.T) {
|
||||
challenge := GenerateCodeChallenge("test-verifier-string")
|
||||
if strings.Contains(challenge, "=") {
|
||||
t.Errorf("GenerateCodeChallenge() 结果包含 '=' padding: %s", challenge)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// base64URLEncode 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBase64URLEncode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
}{
|
||||
{"空字节", []byte{}},
|
||||
{"单字节", []byte{0xff}},
|
||||
{"多字节", []byte{0x01, 0x02, 0x03, 0x04, 0x05}},
|
||||
{"全零", []byte{0x00, 0x00, 0x00}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := base64URLEncode(tt.input)
|
||||
// 不应包含 '=' padding
|
||||
if strings.Contains(result, "=") {
|
||||
t.Errorf("base64URLEncode(%v) 包含 '=' padding: %s", tt.input, result)
|
||||
}
|
||||
// 不应包含标准 base64 的 '+' 或 '/'
|
||||
if strings.ContainsAny(result, "+/") {
|
||||
t.Errorf("base64URLEncode(%v) 包含非 URL 安全字符: %s", tt.input, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// hasRestrictedScope 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHasRestrictedScope(t *testing.T) {
|
||||
tests := []struct {
|
||||
scope string
|
||||
expected bool
|
||||
}{
|
||||
// 受限 scope
|
||||
{"https://www.googleapis.com/auth/generative-language", true},
|
||||
{"https://www.googleapis.com/auth/generative-language.retriever", true},
|
||||
{"https://www.googleapis.com/auth/generative-language.tuning", true},
|
||||
{"https://www.googleapis.com/auth/drive", true},
|
||||
{"https://www.googleapis.com/auth/drive.readonly", true},
|
||||
{"https://www.googleapis.com/auth/drive.file", true},
|
||||
// 非受限 scope
|
||||
{"https://www.googleapis.com/auth/cloud-platform", false},
|
||||
{"https://www.googleapis.com/auth/userinfo.email", false},
|
||||
{"https://www.googleapis.com/auth/userinfo.profile", false},
|
||||
// 边界情况
|
||||
{"", false},
|
||||
{"random-scope", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.scope, func(t *testing.T) {
|
||||
got := hasRestrictedScope(tt.scope)
|
||||
if got != tt.expected {
|
||||
t.Errorf("hasRestrictedScope(%q) = %v,期望 %v", tt.scope, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BuildAuthorizationURL 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBuildAuthorizationURL(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
|
||||
|
||||
authURL, err := BuildAuthorizationURL(
|
||||
OAuthConfig{},
|
||||
"test-state",
|
||||
"test-challenge",
|
||||
"https://example.com/callback",
|
||||
"",
|
||||
"code_assist",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildAuthorizationURL() 出错: %v", err)
|
||||
}
|
||||
|
||||
// 检查返回的 URL 包含期望的参数
|
||||
checks := []string{
|
||||
"response_type=code",
|
||||
"client_id=" + GeminiCLIOAuthClientID,
|
||||
"redirect_uri=",
|
||||
"state=test-state",
|
||||
"code_challenge=test-challenge",
|
||||
"code_challenge_method=S256",
|
||||
"access_type=offline",
|
||||
"prompt=consent",
|
||||
"include_granted_scopes=true",
|
||||
}
|
||||
for _, check := range checks {
|
||||
if !strings.Contains(authURL, check) {
|
||||
t.Errorf("BuildAuthorizationURL() URL 缺少参数 %q\nURL: %s", check, authURL)
|
||||
}
|
||||
}
|
||||
|
||||
// 不应包含 project_id(因为传的是空字符串)
|
||||
if strings.Contains(authURL, "project_id=") {
|
||||
t.Errorf("BuildAuthorizationURL() 空 projectID 时不应包含 project_id 参数")
|
||||
}
|
||||
|
||||
// URL 应该以正确的授权端点开头
|
||||
if !strings.HasPrefix(authURL, AuthorizeURL+"?") {
|
||||
t.Errorf("BuildAuthorizationURL() URL 应以 %s? 开头,实际: %s", AuthorizeURL, authURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_EmptyRedirectURI(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
|
||||
|
||||
_, err := BuildAuthorizationURL(
|
||||
OAuthConfig{},
|
||||
"test-state",
|
||||
"test-challenge",
|
||||
"", // 空 redirectURI
|
||||
"",
|
||||
"code_assist",
|
||||
)
|
||||
if err == nil {
|
||||
t.Error("BuildAuthorizationURL() 空 redirectURI 应该报错")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "redirect_uri") {
|
||||
t.Errorf("错误消息应包含 'redirect_uri',实际: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_WithProjectID(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
|
||||
|
||||
authURL, err := BuildAuthorizationURL(
|
||||
OAuthConfig{},
|
||||
"test-state",
|
||||
"test-challenge",
|
||||
"https://example.com/callback",
|
||||
"my-project-123",
|
||||
"code_assist",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildAuthorizationURL() 出错: %v", err)
|
||||
}
|
||||
if !strings.Contains(authURL, "project_id=my-project-123") {
|
||||
t.Errorf("BuildAuthorizationURL() 带 projectID 时应包含 project_id 参数\nURL: %s", authURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_OAuthConfigError(t *testing.T) {
|
||||
// 不设置环境变量,也不提供 client 凭据,EffectiveOAuthConfig 应该报错
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
||||
|
||||
_, err := BuildAuthorizationURL(
|
||||
OAuthConfig{},
|
||||
"test-state",
|
||||
"test-challenge",
|
||||
"https://example.com/callback",
|
||||
"",
|
||||
"code_assist",
|
||||
)
|
||||
if err == nil {
|
||||
t.Error("当 EffectiveOAuthConfig 失败时,BuildAuthorizationURL 应该返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// EffectiveOAuthConfig 测试 - 原有测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
// 内置的 Gemini CLI client secret 不嵌入在此仓库中。
|
||||
// 测试通过环境变量设置一个假的 secret 来模拟运维配置。
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input OAuthConfig
|
||||
@@ -15,7 +443,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Google One with built-in client (empty config)",
|
||||
name: "Google One 使用内置客户端(空配置)",
|
||||
input: OAuthConfig{},
|
||||
oauthType: "google_one",
|
||||
wantClientID: GeminiCLIOAuthClientID,
|
||||
@@ -23,18 +451,18 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One always uses built-in client (even if custom credentials passed)",
|
||||
name: "Google One 使用自定义客户端(传入自定义凭据时使用自定义)",
|
||||
input: OAuthConfig{
|
||||
ClientID: "custom-client-id",
|
||||
ClientSecret: "custom-client-secret",
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: "custom-client-id",
|
||||
wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client
|
||||
wantScopes: DefaultCodeAssistScopes,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One with built-in client and custom scopes (should filter restricted scopes)",
|
||||
name: "Google One 内置客户端 + 自定义 scopes(应过滤受限 scopes)",
|
||||
input: OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
|
||||
},
|
||||
@@ -44,7 +472,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One with built-in client and only restricted scopes (should fallback to default)",
|
||||
name: "Google One 内置客户端 + 仅受限 scopes(应回退到默认)",
|
||||
input: OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
|
||||
},
|
||||
@@ -54,7 +482,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Code Assist with built-in client",
|
||||
name: "Code Assist 使用内置客户端",
|
||||
input: OAuthConfig{},
|
||||
oauthType: "code_assist",
|
||||
wantClientID: GeminiCLIOAuthClientID,
|
||||
@@ -84,7 +512,9 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) {
|
||||
// Test that Google One with built-in client filters out restricted scopes
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// 测试 Google One + 内置客户端过滤受限 scopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.profile",
|
||||
}, "google_one")
|
||||
@@ -93,21 +523,240 @@ func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
|
||||
// Should only contain cloud-platform, userinfo.email, and userinfo.profile
|
||||
// Should NOT contain generative-language or drive scopes
|
||||
// 应仅包含 cloud-platform、userinfo.email 和 userinfo.profile
|
||||
// 不应包含 generative-language 或 drive scopes
|
||||
if strings.Contains(cfg.Scopes, "generative-language") {
|
||||
t.Errorf("Scopes should not contain generative-language when using built-in client, got: %v", cfg.Scopes)
|
||||
t.Errorf("使用内置客户端时 Scopes 不应包含 generative-language,实际: %v", cfg.Scopes)
|
||||
}
|
||||
if strings.Contains(cfg.Scopes, "drive") {
|
||||
t.Errorf("Scopes should not contain drive when using built-in client, got: %v", cfg.Scopes)
|
||||
t.Errorf("使用内置客户端时 Scopes 不应包含 drive,实际: %v", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "cloud-platform") {
|
||||
t.Errorf("Scopes should contain cloud-platform, got: %v", cfg.Scopes)
|
||||
t.Errorf("Scopes 应包含 cloud-platform,实际: %v", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "userinfo.email") {
|
||||
t.Errorf("Scopes should contain userinfo.email, got: %v", cfg.Scopes)
|
||||
t.Errorf("Scopes 应包含 userinfo.email,实际: %v", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "userinfo.profile") {
|
||||
t.Errorf("Scopes should contain userinfo.profile, got: %v", cfg.Scopes)
|
||||
t.Errorf("Scopes 应包含 userinfo.profile,实际: %v", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// EffectiveOAuthConfig 测试 - 新增分支覆盖
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestEffectiveOAuthConfig_OnlyClientID_NoSecret(t *testing.T) {
|
||||
// 只提供 clientID 不提供 secret 应报错
|
||||
_, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "some-client-id",
|
||||
}, "code_assist")
|
||||
if err == nil {
|
||||
t.Error("只提供 ClientID 不提供 ClientSecret 应该报错")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") {
|
||||
t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_OnlyClientSecret_NoID(t *testing.T) {
|
||||
// 只提供 secret 不提供 clientID 应报错
|
||||
_, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientSecret: "some-client-secret",
|
||||
}, "code_assist")
|
||||
if err == nil {
|
||||
t.Error("只提供 ClientSecret 不提供 ClientID 应该报错")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") {
|
||||
t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_BuiltinClient(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// ai_studio 类型,使用内置客户端,scopes 为空 -> 应使用 DefaultCodeAssistScopes(因为内置客户端不能请求 generative-language scope)
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "ai_studio")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if cfg.Scopes != DefaultCodeAssistScopes {
|
||||
t.Errorf("ai_studio + 内置客户端应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_CustomClient(t *testing.T) {
|
||||
// ai_studio 类型,使用自定义客户端,scopes 为空 -> 应使用 DefaultAIStudioScopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "custom-id",
|
||||
ClientSecret: "custom-secret",
|
||||
}, "ai_studio")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if cfg.Scopes != DefaultAIStudioScopes {
|
||||
t.Errorf("ai_studio + 自定义客户端应使用 DefaultAIStudioScopes,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_AIStudio_ScopeNormalization(t *testing.T) {
|
||||
// ai_studio 类型,旧的 generative-language scope 应被归一化为 generative-language.retriever
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "custom-id",
|
||||
ClientSecret: "custom-secret",
|
||||
Scopes: "https://www.googleapis.com/auth/generative-language https://www.googleapis.com/auth/cloud-platform",
|
||||
}, "ai_studio")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if strings.Contains(cfg.Scopes, "auth/generative-language ") || strings.HasSuffix(cfg.Scopes, "auth/generative-language") {
|
||||
// 确保不包含未归一化的旧 scope(仅 generative-language 而非 generative-language.retriever)
|
||||
parts := strings.Fields(cfg.Scopes)
|
||||
for _, p := range parts {
|
||||
if p == "https://www.googleapis.com/auth/generative-language" {
|
||||
t.Errorf("ai_studio 应将 generative-language 归一化为 generative-language.retriever,实际 scopes: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "generative-language.retriever") {
|
||||
t.Errorf("ai_studio 归一化后应包含 generative-language.retriever,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_CommaSeparatedScopes(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// 逗号分隔的 scopes 应被归一化为空格分隔
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "custom-id",
|
||||
ClientSecret: "custom-secret",
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/userinfo.email",
|
||||
}, "code_assist")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
// 应该用空格分隔,而非逗号
|
||||
if strings.Contains(cfg.Scopes, ",") {
|
||||
t.Errorf("逗号分隔的 scopes 应被归一化为空格分隔,实际: %q", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "cloud-platform") {
|
||||
t.Errorf("归一化后应包含 cloud-platform,实际: %q", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "userinfo.email") {
|
||||
t.Errorf("归一化后应包含 userinfo.email,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_MixedCommaAndSpaceScopes(t *testing.T) {
|
||||
// 混合逗号和空格分隔的 scopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "custom-id",
|
||||
ClientSecret: "custom-secret",
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform, https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile",
|
||||
}, "code_assist")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
parts := strings.Fields(cfg.Scopes)
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("归一化后应有 3 个 scope,实际: %d,scopes: %q", len(parts), cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_WhitespaceTriming(t *testing.T) {
|
||||
// 输入中的前后空白应被清理
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: " custom-id ",
|
||||
ClientSecret: " custom-secret ",
|
||||
Scopes: " https://www.googleapis.com/auth/cloud-platform ",
|
||||
}, "code_assist")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if cfg.ClientID != "custom-id" {
|
||||
t.Errorf("ClientID 应去除前后空白,实际: %q", cfg.ClientID)
|
||||
}
|
||||
if cfg.ClientSecret != "custom-secret" {
|
||||
t.Errorf("ClientSecret 应去除前后空白,实际: %q", cfg.ClientSecret)
|
||||
}
|
||||
if cfg.Scopes != "https://www.googleapis.com/auth/cloud-platform" {
|
||||
t.Errorf("Scopes 应去除前后空白,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) {
|
||||
// 不设置环境变量且不提供凭据,应该报错
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
||||
|
||||
_, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
|
||||
if err == nil {
|
||||
t.Error("没有内置 secret 且未提供凭据时应该报错")
|
||||
}
|
||||
if !strings.Contains(err.Error(), GeminiCLIOAuthClientSecretEnv) {
|
||||
t.Errorf("错误消息应提及环境变量 %s,实际: %v", GeminiCLIOAuthClientSecretEnv, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_AIStudio_BuiltinClient_CustomScopes(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// ai_studio + 内置客户端 + 自定义 scopes -> 应过滤受限 scopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever",
|
||||
}, "ai_studio")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
// 内置客户端应过滤 generative-language.retriever
|
||||
if strings.Contains(cfg.Scopes, "generative-language") {
|
||||
t.Errorf("ai_studio + 内置客户端应过滤受限 scopes,实际: %q", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "cloud-platform") {
|
||||
t.Errorf("应保留 cloud-platform scope,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_UnknownOAuthType_DefaultScopes(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// 未知的 oauthType 应回退到默认的 code_assist scopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "unknown_type")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if cfg.Scopes != DefaultCodeAssistScopes {
|
||||
t.Errorf("未知 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_EmptyOAuthType_DefaultScopes(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// 空的 oauthType 应走 default 分支,使用 DefaultCodeAssistScopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if cfg.Scopes != DefaultCodeAssistScopes {
|
||||
t.Errorf("空 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_CustomClient_NoScopeFiltering(t *testing.T) {
|
||||
// 自定义客户端 + google_one + 包含受限 scopes -> 不应被过滤(因为不是内置客户端)
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "custom-id",
|
||||
ClientSecret: "custom-secret",
|
||||
Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
|
||||
}, "google_one")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
// 自定义客户端不应过滤任何 scope
|
||||
if !strings.Contains(cfg.Scopes, "generative-language.retriever") {
|
||||
t.Errorf("自定义客户端不应过滤 generative-language.retriever,实际: %q", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "drive.readonly") {
|
||||
t.Errorf("自定义客户端不应过滤 drive.readonly,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string {
|
||||
return normalizeIP(c.ClientIP())
|
||||
}
|
||||
|
||||
// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。
|
||||
// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。
|
||||
// 适用于 ACL / 风控等安全敏感场景。
|
||||
func GetTrustedClientIP(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
return normalizeIP(c.ClientIP())
|
||||
}
|
||||
|
||||
// normalizeIP 规范化 IP 地址,去除端口号和空格。
|
||||
func normalizeIP(ip string) string {
|
||||
ip = strings.TrimSpace(ip)
|
||||
@@ -54,29 +64,34 @@ func normalizeIP(ip string) string {
|
||||
return ip
|
||||
}
|
||||
|
||||
// isPrivateIP 检查 IP 是否为私有地址。
|
||||
func isPrivateIP(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
|
||||
var privateNets []*net.IPNet
|
||||
|
||||
// 私有 IP 范围
|
||||
privateBlocks := []string{
|
||||
func init() {
|
||||
for _, cidr := range []string{
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"127.0.0.0/8",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
}
|
||||
|
||||
for _, block := range privateBlocks {
|
||||
_, cidr, err := net.ParseCIDR(block)
|
||||
} {
|
||||
_, block, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
continue
|
||||
panic("invalid CIDR: " + cidr)
|
||||
}
|
||||
if cidr.Contains(ip) {
|
||||
privateNets = append(privateNets, block)
|
||||
}
|
||||
}
|
||||
|
||||
// isPrivateIP 检查 IP 是否为私有地址。
|
||||
func isPrivateIP(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
for _, block := range privateNets {
|
||||
if block.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
75
backend/internal/pkg/ip/ip_test.go
Normal file
75
backend/internal/pkg/ip/ip_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
//go:build unit
|
||||
|
||||
package ip
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsPrivateIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
// 私有 IPv4
|
||||
{"10.x 私有地址", "10.0.0.1", true},
|
||||
{"10.x 私有地址段末", "10.255.255.255", true},
|
||||
{"172.16.x 私有地址", "172.16.0.1", true},
|
||||
{"172.31.x 私有地址", "172.31.255.255", true},
|
||||
{"192.168.x 私有地址", "192.168.1.1", true},
|
||||
{"127.0.0.1 本地回环", "127.0.0.1", true},
|
||||
{"127.x 回环段", "127.255.255.255", true},
|
||||
|
||||
// 公网 IPv4
|
||||
{"8.8.8.8 公网 DNS", "8.8.8.8", false},
|
||||
{"1.1.1.1 公网", "1.1.1.1", false},
|
||||
{"172.15.255.255 非私有", "172.15.255.255", false},
|
||||
{"172.32.0.0 非私有", "172.32.0.0", false},
|
||||
{"11.0.0.1 公网", "11.0.0.1", false},
|
||||
|
||||
// IPv6
|
||||
{"::1 IPv6 回环", "::1", true},
|
||||
{"fc00:: IPv6 私有", "fc00::1", true},
|
||||
{"fd00:: IPv6 私有", "fd00::1", true},
|
||||
{"2001:db8::1 IPv6 公网", "2001:db8::1", false},
|
||||
|
||||
// 无效输入
|
||||
{"空字符串", "", false},
|
||||
{"非法字符串", "not-an-ip", false},
|
||||
{"不完整 IP", "192.168", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := isPrivateIP(tc.ip)
|
||||
require.Equal(t, tc.expected, got, "isPrivateIP(%q)", tc.ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
require.NoError(t, r.SetTrustedProxies(nil))
|
||||
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
c.String(200, GetTrustedClientIP(c))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/t", nil)
|
||||
req.RemoteAddr = "9.9.9.9:12345"
|
||||
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||
req.Header.Set("X-Real-IP", "1.2.3.4")
|
||||
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, 200, w.Code)
|
||||
require.Equal(t, "9.9.9.9", w.Body.String())
|
||||
}
|
||||
31
backend/internal/pkg/logger/config_adapter.go
Normal file
31
backend/internal/pkg/logger/config_adapter.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package logger
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/config"
|
||||
|
||||
func OptionsFromConfig(cfg config.LogConfig) InitOptions {
|
||||
return InitOptions{
|
||||
Level: cfg.Level,
|
||||
Format: cfg.Format,
|
||||
ServiceName: cfg.ServiceName,
|
||||
Environment: cfg.Environment,
|
||||
Caller: cfg.Caller,
|
||||
StacktraceLevel: cfg.StacktraceLevel,
|
||||
Output: OutputOptions{
|
||||
ToStdout: cfg.Output.ToStdout,
|
||||
ToFile: cfg.Output.ToFile,
|
||||
FilePath: cfg.Output.FilePath,
|
||||
},
|
||||
Rotation: RotationOptions{
|
||||
MaxSizeMB: cfg.Rotation.MaxSizeMB,
|
||||
MaxBackups: cfg.Rotation.MaxBackups,
|
||||
MaxAgeDays: cfg.Rotation.MaxAgeDays,
|
||||
Compress: cfg.Rotation.Compress,
|
||||
LocalTime: cfg.Rotation.LocalTime,
|
||||
},
|
||||
Sampling: SamplingOptions{
|
||||
Enabled: cfg.Sampling.Enabled,
|
||||
Initial: cfg.Sampling.Initial,
|
||||
Thereafter: cfg.Sampling.Thereafter,
|
||||
},
|
||||
}
|
||||
}
|
||||
519
backend/internal/pkg/logger/logger.go
Normal file
519
backend/internal/pkg/logger/logger.go
Normal file
@@ -0,0 +1,519 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
type Level = zapcore.Level
|
||||
|
||||
const (
|
||||
LevelDebug = zapcore.DebugLevel
|
||||
LevelInfo = zapcore.InfoLevel
|
||||
LevelWarn = zapcore.WarnLevel
|
||||
LevelError = zapcore.ErrorLevel
|
||||
LevelFatal = zapcore.FatalLevel
|
||||
)
|
||||
|
||||
type Sink interface {
|
||||
WriteLogEvent(event *LogEvent)
|
||||
}
|
||||
|
||||
type LogEvent struct {
|
||||
Time time.Time
|
||||
Level string
|
||||
Component string
|
||||
Message string
|
||||
LoggerName string
|
||||
Fields map[string]any
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.RWMutex
|
||||
global *zap.Logger
|
||||
sugar *zap.SugaredLogger
|
||||
atomicLevel zap.AtomicLevel
|
||||
initOptions InitOptions
|
||||
currentSink Sink
|
||||
stdLogUndo func()
|
||||
bootstrapOnce sync.Once
|
||||
)
|
||||
|
||||
func InitBootstrap() {
|
||||
bootstrapOnce.Do(func() {
|
||||
if err := Init(bootstrapOptions()); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "logger bootstrap init failed: %v\n", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Init(options InitOptions) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return initLocked(options)
|
||||
}
|
||||
|
||||
func initLocked(options InitOptions) error {
|
||||
normalized := options.normalized()
|
||||
zl, al, err := buildLogger(normalized)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prev := global
|
||||
global = zl
|
||||
sugar = zl.Sugar()
|
||||
atomicLevel = al
|
||||
initOptions = normalized
|
||||
|
||||
bridgeSlogLocked()
|
||||
bridgeStdLogLocked()
|
||||
|
||||
if prev != nil {
|
||||
_ = prev.Sync()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Reconfigure(mutator func(*InitOptions) error) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
next := initOptions
|
||||
if mutator != nil {
|
||||
if err := mutator(&next); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return initLocked(next)
|
||||
}
|
||||
|
||||
func SetLevel(level string) error {
|
||||
lv, ok := parseLevel(level)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid log level: %s", level)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
atomicLevel.SetLevel(lv)
|
||||
initOptions.Level = strings.ToLower(strings.TrimSpace(level))
|
||||
return nil
|
||||
}
|
||||
|
||||
func CurrentLevel() string {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if global == nil {
|
||||
return "info"
|
||||
}
|
||||
return atomicLevel.Level().String()
|
||||
}
|
||||
|
||||
func SetSink(sink Sink) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
currentSink = sink
|
||||
}
|
||||
|
||||
// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。
|
||||
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
|
||||
func WriteSinkEvent(level, component, message string, fields map[string]any) {
|
||||
mu.RLock()
|
||||
sink := currentSink
|
||||
mu.RUnlock()
|
||||
if sink == nil {
|
||||
return
|
||||
}
|
||||
|
||||
level = strings.ToLower(strings.TrimSpace(level))
|
||||
if level == "" {
|
||||
level = "info"
|
||||
}
|
||||
component = strings.TrimSpace(component)
|
||||
message = strings.TrimSpace(message)
|
||||
if message == "" {
|
||||
return
|
||||
}
|
||||
|
||||
eventFields := make(map[string]any, len(fields)+1)
|
||||
for k, v := range fields {
|
||||
eventFields[k] = v
|
||||
}
|
||||
if component != "" {
|
||||
if _, ok := eventFields["component"]; !ok {
|
||||
eventFields["component"] = component
|
||||
}
|
||||
}
|
||||
|
||||
sink.WriteLogEvent(&LogEvent{
|
||||
Time: time.Now(),
|
||||
Level: level,
|
||||
Component: component,
|
||||
Message: message,
|
||||
LoggerName: component,
|
||||
Fields: eventFields,
|
||||
})
|
||||
}
|
||||
|
||||
func L() *zap.Logger {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if global != nil {
|
||||
return global
|
||||
}
|
||||
return zap.NewNop()
|
||||
}
|
||||
|
||||
func S() *zap.SugaredLogger {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if sugar != nil {
|
||||
return sugar
|
||||
}
|
||||
return zap.NewNop().Sugar()
|
||||
}
|
||||
|
||||
func With(fields ...zap.Field) *zap.Logger {
|
||||
return L().With(fields...)
|
||||
}
|
||||
|
||||
func Sync() {
|
||||
mu.RLock()
|
||||
l := global
|
||||
mu.RUnlock()
|
||||
if l != nil {
|
||||
_ = l.Sync()
|
||||
}
|
||||
}
|
||||
|
||||
func bridgeStdLogLocked() {
|
||||
if stdLogUndo != nil {
|
||||
stdLogUndo()
|
||||
stdLogUndo = nil
|
||||
}
|
||||
|
||||
prevFlags := log.Flags()
|
||||
prevPrefix := log.Prefix()
|
||||
prevWriter := log.Writer()
|
||||
|
||||
log.SetFlags(0)
|
||||
log.SetPrefix("")
|
||||
log.SetOutput(newStdLogBridge(global.Named("stdlog")))
|
||||
|
||||
stdLogUndo = func() {
|
||||
log.SetOutput(prevWriter)
|
||||
log.SetFlags(prevFlags)
|
||||
log.SetPrefix(prevPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
func bridgeSlogLocked() {
|
||||
slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog"))))
|
||||
}
|
||||
|
||||
func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) {
|
||||
level, _ := parseLevel(options.Level)
|
||||
atomic := zap.NewAtomicLevelAt(level)
|
||||
|
||||
encoderCfg := zapcore.EncoderConfig{
|
||||
TimeKey: "time",
|
||||
LevelKey: "level",
|
||||
NameKey: "logger",
|
||||
CallerKey: "caller",
|
||||
MessageKey: "msg",
|
||||
StacktraceKey: "stacktrace",
|
||||
LineEnding: zapcore.DefaultLineEnding,
|
||||
EncodeLevel: zapcore.CapitalLevelEncoder,
|
||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||
EncodeDuration: zapcore.MillisDurationEncoder,
|
||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||
}
|
||||
|
||||
var enc zapcore.Encoder
|
||||
if options.Format == "console" {
|
||||
enc = zapcore.NewConsoleEncoder(encoderCfg)
|
||||
} else {
|
||||
enc = zapcore.NewJSONEncoder(encoderCfg)
|
||||
}
|
||||
|
||||
sinkCore := newSinkCore()
|
||||
cores := make([]zapcore.Core, 0, 3)
|
||||
|
||||
if options.Output.ToStdout {
|
||||
infoPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
|
||||
return lvl >= atomic.Level() && lvl < zapcore.WarnLevel
|
||||
})
|
||||
errPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
|
||||
return lvl >= atomic.Level() && lvl >= zapcore.WarnLevel
|
||||
})
|
||||
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), infoPriority))
|
||||
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stderr), errPriority))
|
||||
}
|
||||
|
||||
if options.Output.ToFile {
|
||||
fileCore, filePath, fileErr := buildFileCore(enc, atomic, options)
|
||||
if fileErr != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "time=%s level=WARN msg=\"日志文件输出初始化失败,降级为仅标准输出\" path=%s err=%v\n",
|
||||
time.Now().Format(time.RFC3339Nano),
|
||||
filePath,
|
||||
fileErr,
|
||||
)
|
||||
} else {
|
||||
cores = append(cores, fileCore)
|
||||
}
|
||||
}
|
||||
|
||||
if len(cores) == 0 {
|
||||
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), atomic))
|
||||
}
|
||||
|
||||
core := zapcore.NewTee(cores...)
|
||||
if options.Sampling.Enabled {
|
||||
core = zapcore.NewSamplerWithOptions(core, samplingTick(), options.Sampling.Initial, options.Sampling.Thereafter)
|
||||
}
|
||||
core = sinkCore.Wrap(core)
|
||||
|
||||
stacktraceLevel, _ := parseStacktraceLevel(options.StacktraceLevel)
|
||||
zapOpts := make([]zap.Option, 0, 5)
|
||||
if options.Caller {
|
||||
zapOpts = append(zapOpts, zap.AddCaller())
|
||||
}
|
||||
if stacktraceLevel <= zapcore.FatalLevel {
|
||||
zapOpts = append(zapOpts, zap.AddStacktrace(stacktraceLevel))
|
||||
}
|
||||
|
||||
logger := zap.New(core, zapOpts...).With(
|
||||
zap.String("service", options.ServiceName),
|
||||
zap.String("env", options.Environment),
|
||||
)
|
||||
return logger, atomic, nil
|
||||
}
|
||||
|
||||
func buildFileCore(enc zapcore.Encoder, atomic zap.AtomicLevel, options InitOptions) (zapcore.Core, string, error) {
|
||||
filePath := options.Output.FilePath
|
||||
if strings.TrimSpace(filePath) == "" {
|
||||
filePath = resolveLogFilePath("")
|
||||
}
|
||||
|
||||
dir := filepath.Dir(filePath)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, filePath, err
|
||||
}
|
||||
lj := &lumberjack.Logger{
|
||||
Filename: filePath,
|
||||
MaxSize: options.Rotation.MaxSizeMB,
|
||||
MaxBackups: options.Rotation.MaxBackups,
|
||||
MaxAge: options.Rotation.MaxAgeDays,
|
||||
Compress: options.Rotation.Compress,
|
||||
LocalTime: options.Rotation.LocalTime,
|
||||
}
|
||||
return zapcore.NewCore(enc, zapcore.AddSync(lj), atomic), filePath, nil
|
||||
}
|
||||
|
||||
type sinkCore struct {
|
||||
core zapcore.Core
|
||||
fields []zapcore.Field
|
||||
}
|
||||
|
||||
func newSinkCore() *sinkCore {
|
||||
return &sinkCore{}
|
||||
}
|
||||
|
||||
func (s *sinkCore) Wrap(core zapcore.Core) zapcore.Core {
|
||||
cp := *s
|
||||
cp.core = core
|
||||
return &cp
|
||||
}
|
||||
|
||||
func (s *sinkCore) Enabled(level zapcore.Level) bool {
|
||||
return s.core.Enabled(level)
|
||||
}
|
||||
|
||||
func (s *sinkCore) With(fields []zapcore.Field) zapcore.Core {
|
||||
nextFields := append([]zapcore.Field{}, s.fields...)
|
||||
nextFields = append(nextFields, fields...)
|
||||
return &sinkCore{
|
||||
core: s.core.With(fields),
|
||||
fields: nextFields,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
|
||||
// Delegate to inner core (tee) so each sub-core's level enabler is respected.
|
||||
// Then add ourselves for sink forwarding only.
|
||||
ce = s.core.Check(entry, ce)
|
||||
if ce != nil {
|
||||
ce = ce.AddCore(entry, s)
|
||||
}
|
||||
return ce
|
||||
}
|
||||
|
||||
func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
|
||||
// Only handle sink forwarding — the inner cores write via their own
|
||||
// Write methods (added to CheckedEntry by s.core.Check above).
|
||||
mu.RLock()
|
||||
sink := currentSink
|
||||
mu.RUnlock()
|
||||
if sink == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
enc := zapcore.NewMapObjectEncoder()
|
||||
for _, f := range s.fields {
|
||||
f.AddTo(enc)
|
||||
}
|
||||
for _, f := range fields {
|
||||
f.AddTo(enc)
|
||||
}
|
||||
|
||||
event := &LogEvent{
|
||||
Time: entry.Time,
|
||||
Level: strings.ToLower(entry.Level.String()),
|
||||
Component: entry.LoggerName,
|
||||
Message: entry.Message,
|
||||
LoggerName: entry.LoggerName,
|
||||
Fields: enc.Fields,
|
||||
}
|
||||
sink.WriteLogEvent(event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sinkCore) Sync() error {
|
||||
return s.core.Sync()
|
||||
}
|
||||
|
||||
type stdLogBridge struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func newStdLogBridge(l *zap.Logger) io.Writer {
|
||||
if l == nil {
|
||||
l = zap.NewNop()
|
||||
}
|
||||
return &stdLogBridge{logger: l}
|
||||
}
|
||||
|
||||
func (b *stdLogBridge) Write(p []byte) (int, error) {
|
||||
msg := normalizeStdLogMessage(string(p))
|
||||
if msg == "" {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
level := inferStdLogLevel(msg)
|
||||
entry := b.logger.WithOptions(zap.AddCallerSkip(4))
|
||||
|
||||
switch level {
|
||||
case LevelDebug:
|
||||
entry.Debug(msg, zap.Bool("legacy_stdlog", true))
|
||||
case LevelWarn:
|
||||
entry.Warn(msg, zap.Bool("legacy_stdlog", true))
|
||||
case LevelError, LevelFatal:
|
||||
entry.Error(msg, zap.Bool("legacy_stdlog", true))
|
||||
default:
|
||||
entry.Info(msg, zap.Bool("legacy_stdlog", true))
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func normalizeStdLogMessage(raw string) string {
|
||||
msg := strings.TrimSpace(strings.ReplaceAll(raw, "\n", " "))
|
||||
if msg == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(strings.Fields(msg), " ")
|
||||
}
|
||||
|
||||
func inferStdLogLevel(msg string) Level {
|
||||
lower := strings.ToLower(strings.TrimSpace(msg))
|
||||
if lower == "" {
|
||||
return LevelInfo
|
||||
}
|
||||
|
||||
if strings.HasPrefix(lower, "[debug]") || strings.HasPrefix(lower, "debug:") {
|
||||
return LevelDebug
|
||||
}
|
||||
if strings.HasPrefix(lower, "[warn]") || strings.HasPrefix(lower, "[warning]") || strings.HasPrefix(lower, "warn:") || strings.HasPrefix(lower, "warning:") {
|
||||
return LevelWarn
|
||||
}
|
||||
if strings.HasPrefix(lower, "[error]") || strings.HasPrefix(lower, "error:") || strings.HasPrefix(lower, "fatal:") || strings.HasPrefix(lower, "panic:") {
|
||||
return LevelError
|
||||
}
|
||||
|
||||
if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") {
|
||||
return LevelError
|
||||
}
|
||||
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
|
||||
return LevelWarn
|
||||
}
|
||||
return LevelInfo
|
||||
}
|
||||
|
||||
// LegacyPrintf 用于平滑迁移历史的 printf 风格日志到结构化 logger。
|
||||
func LegacyPrintf(component, format string, args ...any) {
|
||||
msg := normalizeStdLogMessage(fmt.Sprintf(format, args...))
|
||||
if msg == "" {
|
||||
return
|
||||
}
|
||||
|
||||
mu.RLock()
|
||||
initialized := global != nil
|
||||
mu.RUnlock()
|
||||
if !initialized {
|
||||
// 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。
|
||||
log.Print(msg)
|
||||
return
|
||||
}
|
||||
|
||||
l := L()
|
||||
if component != "" {
|
||||
l = l.With(zap.String("component", component))
|
||||
}
|
||||
l = l.WithOptions(zap.AddCallerSkip(1))
|
||||
|
||||
switch inferStdLogLevel(msg) {
|
||||
case LevelDebug:
|
||||
l.Debug(msg, zap.Bool("legacy_printf", true))
|
||||
case LevelWarn:
|
||||
l.Warn(msg, zap.Bool("legacy_printf", true))
|
||||
case LevelError, LevelFatal:
|
||||
l.Error(msg, zap.Bool("legacy_printf", true))
|
||||
default:
|
||||
l.Info(msg, zap.Bool("legacy_printf", true))
|
||||
}
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
const loggerContextKey contextKey = "ctx_logger"
|
||||
|
||||
func IntoContext(ctx context.Context, l *zap.Logger) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if l == nil {
|
||||
l = L()
|
||||
}
|
||||
return context.WithValue(ctx, loggerContextKey, l)
|
||||
}
|
||||
|
||||
func FromContext(ctx context.Context) *zap.Logger {
|
||||
if ctx == nil {
|
||||
return L()
|
||||
}
|
||||
if l, ok := ctx.Value(loggerContextKey).(*zap.Logger); ok && l != nil {
|
||||
return l
|
||||
}
|
||||
return L()
|
||||
}
|
||||
192
backend/internal/pkg/logger/logger_test.go
Normal file
192
backend/internal/pkg/logger/logger_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInit_DualOutput(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "logs", "sub2api.log")
|
||||
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
stdoutR, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout pipe: %v", err)
|
||||
}
|
||||
stderrR, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
os.Stderr = stderrW
|
||||
t.Cleanup(func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
_ = stdoutR.Close()
|
||||
_ = stderrR.Close()
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrW.Close()
|
||||
})
|
||||
|
||||
err = Init(InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: true,
|
||||
FilePath: logPath,
|
||||
},
|
||||
Rotation: RotationOptions{
|
||||
MaxSizeMB: 10,
|
||||
MaxBackups: 2,
|
||||
MaxAgeDays: 1,
|
||||
},
|
||||
Sampling: SamplingOptions{Enabled: false},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Init() error: %v", err)
|
||||
}
|
||||
|
||||
L().Info("dual-output-info")
|
||||
L().Warn("dual-output-warn")
|
||||
Sync()
|
||||
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrW.Close()
|
||||
stdoutBytes, _ := io.ReadAll(stdoutR)
|
||||
stderrBytes, _ := io.ReadAll(stderrR)
|
||||
stdoutText := string(stdoutBytes)
|
||||
stderrText := string(stderrBytes)
|
||||
|
||||
if !strings.Contains(stdoutText, "dual-output-info") {
|
||||
t.Fatalf("stdout missing info log: %s", stdoutText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "dual-output-warn") {
|
||||
t.Fatalf("stderr missing warn log: %s", stderrText)
|
||||
}
|
||||
|
||||
fileBytes, err := os.ReadFile(logPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read log file: %v", err)
|
||||
}
|
||||
fileText := string(fileBytes)
|
||||
if !strings.Contains(fileText, "dual-output-info") || !strings.Contains(fileText, "dual-output-warn") {
|
||||
t.Fatalf("file missing logs: %s", fileText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInit_FileOutputFailureDowngrade(t *testing.T) {
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
_, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout pipe: %v", err)
|
||||
}
|
||||
stderrR, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
os.Stderr = stderrW
|
||||
t.Cleanup(func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrR.Close()
|
||||
_ = stderrW.Close()
|
||||
})
|
||||
|
||||
err = Init(InitOptions{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: true,
|
||||
FilePath: filepath.Join(os.DevNull, "logs", "sub2api.log"),
|
||||
},
|
||||
Rotation: RotationOptions{
|
||||
MaxSizeMB: 10,
|
||||
MaxBackups: 1,
|
||||
MaxAgeDays: 1,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Init() should downgrade instead of failing, got: %v", err)
|
||||
}
|
||||
|
||||
_ = stderrW.Close()
|
||||
stderrBytes, _ := io.ReadAll(stderrR)
|
||||
if !strings.Contains(string(stderrBytes), "日志文件输出初始化失败") {
|
||||
t.Fatalf("stderr should contain fallback warning, got: %s", string(stderrBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInit_CallerShouldPointToCallsite(t *testing.T) {
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
stdoutR, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout pipe: %v", err)
|
||||
}
|
||||
_, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
os.Stderr = stderrW
|
||||
t.Cleanup(func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
_ = stdoutR.Close()
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrW.Close()
|
||||
})
|
||||
|
||||
if err := Init(InitOptions{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Caller: true,
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Sampling: SamplingOptions{Enabled: false},
|
||||
}); err != nil {
|
||||
t.Fatalf("Init() error: %v", err)
|
||||
}
|
||||
|
||||
L().Info("caller-check")
|
||||
Sync()
|
||||
_ = stdoutW.Close()
|
||||
logBytes, _ := io.ReadAll(stdoutR)
|
||||
|
||||
var line string
|
||||
for _, item := range strings.Split(string(logBytes), "\n") {
|
||||
if strings.Contains(item, "caller-check") {
|
||||
line = item
|
||||
break
|
||||
}
|
||||
}
|
||||
if line == "" {
|
||||
t.Fatalf("log output missing caller-check: %s", string(logBytes))
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &payload); err != nil {
|
||||
t.Fatalf("parse log json failed: %v, line=%s", err, line)
|
||||
}
|
||||
caller, _ := payload["caller"].(string)
|
||||
if !strings.Contains(caller, "logger_test.go:") {
|
||||
t.Fatalf("caller should point to this test file, got: %s", caller)
|
||||
}
|
||||
}
|
||||
161
backend/internal/pkg/logger/options.go
Normal file
161
backend/internal/pkg/logger/options.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultContainerLogPath 为容器内默认日志文件路径。
|
||||
DefaultContainerLogPath = "/app/data/logs/sub2api.log"
|
||||
defaultLogFilename = "sub2api.log"
|
||||
)
|
||||
|
||||
type InitOptions struct {
|
||||
Level string
|
||||
Format string
|
||||
ServiceName string
|
||||
Environment string
|
||||
Caller bool
|
||||
StacktraceLevel string
|
||||
Output OutputOptions
|
||||
Rotation RotationOptions
|
||||
Sampling SamplingOptions
|
||||
}
|
||||
|
||||
type OutputOptions struct {
|
||||
ToStdout bool
|
||||
ToFile bool
|
||||
FilePath string
|
||||
}
|
||||
|
||||
type RotationOptions struct {
|
||||
MaxSizeMB int
|
||||
MaxBackups int
|
||||
MaxAgeDays int
|
||||
Compress bool
|
||||
LocalTime bool
|
||||
}
|
||||
|
||||
type SamplingOptions struct {
|
||||
Enabled bool
|
||||
Initial int
|
||||
Thereafter int
|
||||
}
|
||||
|
||||
func (o InitOptions) normalized() InitOptions {
|
||||
out := o
|
||||
out.Level = strings.ToLower(strings.TrimSpace(out.Level))
|
||||
if out.Level == "" {
|
||||
out.Level = "info"
|
||||
}
|
||||
out.Format = strings.ToLower(strings.TrimSpace(out.Format))
|
||||
if out.Format == "" {
|
||||
out.Format = "console"
|
||||
}
|
||||
out.ServiceName = strings.TrimSpace(out.ServiceName)
|
||||
if out.ServiceName == "" {
|
||||
out.ServiceName = "sub2api"
|
||||
}
|
||||
out.Environment = strings.TrimSpace(out.Environment)
|
||||
if out.Environment == "" {
|
||||
out.Environment = "production"
|
||||
}
|
||||
out.StacktraceLevel = strings.ToLower(strings.TrimSpace(out.StacktraceLevel))
|
||||
if out.StacktraceLevel == "" {
|
||||
out.StacktraceLevel = "error"
|
||||
}
|
||||
if !out.Output.ToStdout && !out.Output.ToFile {
|
||||
out.Output.ToStdout = true
|
||||
}
|
||||
out.Output.FilePath = resolveLogFilePath(out.Output.FilePath)
|
||||
if out.Rotation.MaxSizeMB <= 0 {
|
||||
out.Rotation.MaxSizeMB = 100
|
||||
}
|
||||
if out.Rotation.MaxBackups < 0 {
|
||||
out.Rotation.MaxBackups = 10
|
||||
}
|
||||
if out.Rotation.MaxAgeDays < 0 {
|
||||
out.Rotation.MaxAgeDays = 7
|
||||
}
|
||||
if out.Sampling.Enabled {
|
||||
if out.Sampling.Initial <= 0 {
|
||||
out.Sampling.Initial = 100
|
||||
}
|
||||
if out.Sampling.Thereafter <= 0 {
|
||||
out.Sampling.Thereafter = 100
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func resolveLogFilePath(explicit string) string {
|
||||
explicit = strings.TrimSpace(explicit)
|
||||
if explicit != "" {
|
||||
return explicit
|
||||
}
|
||||
dataDir := strings.TrimSpace(os.Getenv("DATA_DIR"))
|
||||
if dataDir != "" {
|
||||
return filepath.Join(dataDir, "logs", defaultLogFilename)
|
||||
}
|
||||
return DefaultContainerLogPath
|
||||
}
|
||||
|
||||
func bootstrapOptions() InitOptions {
|
||||
return InitOptions{
|
||||
Level: "info",
|
||||
Format: "console",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "bootstrap",
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Rotation: RotationOptions{
|
||||
MaxSizeMB: 100,
|
||||
MaxBackups: 10,
|
||||
MaxAgeDays: 7,
|
||||
Compress: true,
|
||||
LocalTime: true,
|
||||
},
|
||||
Sampling: SamplingOptions{
|
||||
Enabled: false,
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func parseLevel(level string) (Level, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(level)) {
|
||||
case "debug":
|
||||
return LevelDebug, true
|
||||
case "info":
|
||||
return LevelInfo, true
|
||||
case "warn":
|
||||
return LevelWarn, true
|
||||
case "error":
|
||||
return LevelError, true
|
||||
default:
|
||||
return LevelInfo, false
|
||||
}
|
||||
}
|
||||
|
||||
func parseStacktraceLevel(level string) (Level, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(level)) {
|
||||
case "none":
|
||||
return LevelFatal + 1, true
|
||||
case "error":
|
||||
return LevelError, true
|
||||
case "fatal":
|
||||
return LevelFatal, true
|
||||
default:
|
||||
return LevelError, false
|
||||
}
|
||||
}
|
||||
|
||||
func samplingTick() time.Duration {
|
||||
return time.Second
|
||||
}
|
||||
102
backend/internal/pkg/logger/options_test.go
Normal file
102
backend/internal/pkg/logger/options_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func TestResolveLogFilePath_Default(t *testing.T) {
|
||||
t.Setenv("DATA_DIR", "")
|
||||
got := resolveLogFilePath("")
|
||||
if got != DefaultContainerLogPath {
|
||||
t.Fatalf("resolveLogFilePath() = %q, want %q", got, DefaultContainerLogPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLogFilePath_WithDataDir(t *testing.T) {
|
||||
t.Setenv("DATA_DIR", "/tmp/sub2api-data")
|
||||
got := resolveLogFilePath("")
|
||||
want := filepath.Join("/tmp/sub2api-data", "logs", "sub2api.log")
|
||||
if got != want {
|
||||
t.Fatalf("resolveLogFilePath() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLogFilePath_ExplicitPath(t *testing.T) {
|
||||
t.Setenv("DATA_DIR", "/tmp/ignore")
|
||||
got := resolveLogFilePath("/var/log/custom.log")
|
||||
if got != "/var/log/custom.log" {
|
||||
t.Fatalf("resolveLogFilePath() = %q, want explicit path", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedOptions_InvalidFallback(t *testing.T) {
|
||||
t.Setenv("DATA_DIR", "")
|
||||
opts := InitOptions{
|
||||
Level: "TRACE",
|
||||
Format: "TEXT",
|
||||
ServiceName: "",
|
||||
Environment: "",
|
||||
StacktraceLevel: "panic",
|
||||
Output: OutputOptions{
|
||||
ToStdout: false,
|
||||
ToFile: false,
|
||||
},
|
||||
Rotation: RotationOptions{
|
||||
MaxSizeMB: 0,
|
||||
MaxBackups: -1,
|
||||
MaxAgeDays: -1,
|
||||
},
|
||||
Sampling: SamplingOptions{
|
||||
Enabled: true,
|
||||
Initial: 0,
|
||||
Thereafter: 0,
|
||||
},
|
||||
}
|
||||
out := opts.normalized()
|
||||
if out.Level != "trace" {
|
||||
// normalized 仅做 trim/lower,不做校验;校验在 config 层。
|
||||
t.Fatalf("normalized level should preserve value for upstream validation, got %q", out.Level)
|
||||
}
|
||||
if !out.Output.ToStdout {
|
||||
t.Fatalf("normalized output should fallback to stdout")
|
||||
}
|
||||
if out.Output.FilePath != DefaultContainerLogPath {
|
||||
t.Fatalf("normalized file path = %q", out.Output.FilePath)
|
||||
}
|
||||
if out.Rotation.MaxSizeMB != 100 {
|
||||
t.Fatalf("normalized max_size_mb = %d", out.Rotation.MaxSizeMB)
|
||||
}
|
||||
if out.Rotation.MaxBackups != 10 {
|
||||
t.Fatalf("normalized max_backups = %d", out.Rotation.MaxBackups)
|
||||
}
|
||||
if out.Rotation.MaxAgeDays != 7 {
|
||||
t.Fatalf("normalized max_age_days = %d", out.Rotation.MaxAgeDays)
|
||||
}
|
||||
if out.Sampling.Initial != 100 || out.Sampling.Thereafter != 100 {
|
||||
t.Fatalf("normalized sampling defaults invalid: %+v", out.Sampling)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFileCore_InvalidPathFallback(t *testing.T) {
|
||||
t.Setenv("DATA_DIR", "")
|
||||
opts := bootstrapOptions()
|
||||
opts.Output.ToFile = true
|
||||
opts.Output.FilePath = filepath.Join(os.DevNull, "logs", "sub2api.log")
|
||||
encoderCfg := zapcore.EncoderConfig{
|
||||
TimeKey: "time",
|
||||
LevelKey: "level",
|
||||
MessageKey: "msg",
|
||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||
EncodeLevel: zapcore.CapitalLevelEncoder,
|
||||
}
|
||||
encoder := zapcore.NewJSONEncoder(encoderCfg)
|
||||
_, _, err := buildFileCore(encoder, zap.NewAtomicLevel(), opts)
|
||||
if err == nil {
|
||||
t.Fatalf("buildFileCore() expected error for invalid path")
|
||||
}
|
||||
}
|
||||
132
backend/internal/pkg/logger/slog_handler.go
Normal file
132
backend/internal/pkg/logger/slog_handler.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
type slogZapHandler struct {
|
||||
logger *zap.Logger
|
||||
attrs []slog.Attr
|
||||
groups []string
|
||||
}
|
||||
|
||||
func newSlogZapHandler(logger *zap.Logger) slog.Handler {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
return &slogZapHandler{
|
||||
logger: logger,
|
||||
attrs: make([]slog.Attr, 0, 8),
|
||||
groups: make([]string, 0, 4),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *slogZapHandler) Enabled(_ context.Context, level slog.Level) bool {
|
||||
switch {
|
||||
case level >= slog.LevelError:
|
||||
return h.logger.Core().Enabled(LevelError)
|
||||
case level >= slog.LevelWarn:
|
||||
return h.logger.Core().Enabled(LevelWarn)
|
||||
case level <= slog.LevelDebug:
|
||||
return h.logger.Core().Enabled(LevelDebug)
|
||||
default:
|
||||
return h.logger.Core().Enabled(LevelInfo)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
|
||||
fields := make([]zap.Field, 0, len(h.attrs)+record.NumAttrs()+3)
|
||||
fields = append(fields, slogAttrsToZapFields(h.groups, h.attrs)...)
|
||||
record.Attrs(func(attr slog.Attr) bool {
|
||||
fields = append(fields, slogAttrToZapField(h.groups, attr))
|
||||
return true
|
||||
})
|
||||
|
||||
entry := h.logger.With(fields...)
|
||||
switch {
|
||||
case record.Level >= slog.LevelError:
|
||||
entry.Error(record.Message)
|
||||
case record.Level >= slog.LevelWarn:
|
||||
entry.Warn(record.Message)
|
||||
case record.Level <= slog.LevelDebug:
|
||||
entry.Debug(record.Message)
|
||||
default:
|
||||
entry.Info(record.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *slogZapHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
next := *h
|
||||
next.attrs = append(append([]slog.Attr{}, h.attrs...), attrs...)
|
||||
return &next
|
||||
}
|
||||
|
||||
func (h *slogZapHandler) WithGroup(name string) slog.Handler {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return h
|
||||
}
|
||||
next := *h
|
||||
next.groups = append(append([]string{}, h.groups...), name)
|
||||
return &next
|
||||
}
|
||||
|
||||
func slogAttrsToZapFields(groups []string, attrs []slog.Attr) []zap.Field {
|
||||
fields := make([]zap.Field, 0, len(attrs))
|
||||
for _, attr := range attrs {
|
||||
fields = append(fields, slogAttrToZapField(groups, attr))
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func slogAttrToZapField(groups []string, attr slog.Attr) zap.Field {
|
||||
if len(groups) > 0 {
|
||||
attr.Key = strings.Join(append(append([]string{}, groups...), attr.Key), ".")
|
||||
}
|
||||
value := attr.Value.Resolve()
|
||||
switch value.Kind() {
|
||||
case slog.KindBool:
|
||||
return zap.Bool(attr.Key, value.Bool())
|
||||
case slog.KindInt64:
|
||||
return zap.Int64(attr.Key, value.Int64())
|
||||
case slog.KindUint64:
|
||||
return zap.Uint64(attr.Key, value.Uint64())
|
||||
case slog.KindFloat64:
|
||||
return zap.Float64(attr.Key, value.Float64())
|
||||
case slog.KindDuration:
|
||||
return zap.Duration(attr.Key, value.Duration())
|
||||
case slog.KindTime:
|
||||
return zap.Time(attr.Key, value.Time())
|
||||
case slog.KindString:
|
||||
return zap.String(attr.Key, value.String())
|
||||
case slog.KindGroup:
|
||||
groupFields := make([]zap.Field, 0, len(value.Group()))
|
||||
for _, nested := range value.Group() {
|
||||
groupFields = append(groupFields, slogAttrToZapField(nil, nested))
|
||||
}
|
||||
return zap.Object(attr.Key, zapObjectFields(groupFields))
|
||||
case slog.KindAny:
|
||||
if t, ok := value.Any().(time.Time); ok {
|
||||
return zap.Time(attr.Key, t)
|
||||
}
|
||||
return zap.Any(attr.Key, value.Any())
|
||||
default:
|
||||
return zap.String(attr.Key, value.String())
|
||||
}
|
||||
}
|
||||
|
||||
type zapObjectFields []zap.Field
|
||||
|
||||
func (z zapObjectFields) MarshalLogObject(enc zapcore.ObjectEncoder) error {
|
||||
for _, field := range z {
|
||||
field.AddTo(enc)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
88
backend/internal/pkg/logger/slog_handler_test.go
Normal file
88
backend/internal/pkg/logger/slog_handler_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
type captureState struct {
|
||||
writes []capturedWrite
|
||||
}
|
||||
|
||||
type capturedWrite struct {
|
||||
fields []zapcore.Field
|
||||
}
|
||||
|
||||
type captureCore struct {
|
||||
state *captureState
|
||||
withFields []zapcore.Field
|
||||
}
|
||||
|
||||
func newCaptureCore() *captureCore {
|
||||
return &captureCore{state: &captureState{}}
|
||||
}
|
||||
|
||||
func (c *captureCore) Enabled(zapcore.Level) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *captureCore) With(fields []zapcore.Field) zapcore.Core {
|
||||
nextFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields))
|
||||
nextFields = append(nextFields, c.withFields...)
|
||||
nextFields = append(nextFields, fields...)
|
||||
return &captureCore{
|
||||
state: c.state,
|
||||
withFields: nextFields,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *captureCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
|
||||
return ce.AddCore(entry, c)
|
||||
}
|
||||
|
||||
func (c *captureCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
|
||||
allFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields))
|
||||
allFields = append(allFields, c.withFields...)
|
||||
allFields = append(allFields, fields...)
|
||||
c.state.writes = append(c.state.writes, capturedWrite{
|
||||
fields: allFields,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *captureCore) Sync() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSlogZapHandler_Handle_DoesNotAppendTimeField(t *testing.T) {
|
||||
core := newCaptureCore()
|
||||
handler := newSlogZapHandler(zap.New(core))
|
||||
|
||||
record := slog.NewRecord(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC), slog.LevelInfo, "hello", 0)
|
||||
record.AddAttrs(slog.String("component", "http.access"))
|
||||
|
||||
if err := handler.Handle(context.Background(), record); err != nil {
|
||||
t.Fatalf("handle slog record: %v", err)
|
||||
}
|
||||
if len(core.state.writes) != 1 {
|
||||
t.Fatalf("write calls = %d, want 1", len(core.state.writes))
|
||||
}
|
||||
|
||||
var hasComponent bool
|
||||
for _, field := range core.state.writes[0].fields {
|
||||
if field.Key == "time" {
|
||||
t.Fatalf("unexpected duplicate time field in slog adapter output")
|
||||
}
|
||||
if field.Key == "component" {
|
||||
hasComponent = true
|
||||
}
|
||||
}
|
||||
if !hasComponent {
|
||||
t.Fatalf("component field should be preserved")
|
||||
}
|
||||
}
|
||||
165
backend/internal/pkg/logger/stdlog_bridge_test.go
Normal file
165
backend/internal/pkg/logger/stdlog_bridge_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInferStdLogLevel(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want Level
|
||||
}{
|
||||
{msg: "Warning: queue full", want: LevelWarn},
|
||||
{msg: "Forward request failed: timeout", want: LevelError},
|
||||
{msg: "[ERROR] upstream unavailable", want: LevelError},
|
||||
{msg: "service started", want: LevelInfo},
|
||||
{msg: "debug: cache miss", want: LevelDebug},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
got := inferStdLogLevel(tc.msg)
|
||||
if got != tc.want {
|
||||
t.Fatalf("inferStdLogLevel(%q)=%v want=%v", tc.msg, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStdLogMessage(t *testing.T) {
|
||||
raw := " [TokenRefresh] cycle complete \n total=1 failed=0 \n"
|
||||
got := normalizeStdLogMessage(raw)
|
||||
want := "[TokenRefresh] cycle complete total=1 failed=0"
|
||||
if got != want {
|
||||
t.Fatalf("normalizeStdLogMessage()=%q want=%q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStdLogBridgeRoutesLevels(t *testing.T) {
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
stdoutR, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout pipe: %v", err)
|
||||
}
|
||||
stderrR, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
os.Stderr = stderrW
|
||||
t.Cleanup(func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
_ = stdoutR.Close()
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrR.Close()
|
||||
_ = stderrW.Close()
|
||||
})
|
||||
|
||||
if err := Init(InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Sampling: SamplingOptions{Enabled: false},
|
||||
}); err != nil {
|
||||
t.Fatalf("Init() error: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("service started")
|
||||
log.Printf("Warning: queue full")
|
||||
log.Printf("Forward request failed: timeout")
|
||||
Sync()
|
||||
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrW.Close()
|
||||
stdoutBytes, _ := io.ReadAll(stdoutR)
|
||||
stderrBytes, _ := io.ReadAll(stderrR)
|
||||
stdoutText := string(stdoutBytes)
|
||||
stderrText := string(stderrBytes)
|
||||
|
||||
if !strings.Contains(stdoutText, "service started") {
|
||||
t.Fatalf("stdout missing info log: %s", stdoutText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "Warning: queue full") {
|
||||
t.Fatalf("stderr missing warn log: %s", stderrText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "Forward request failed: timeout") {
|
||||
t.Fatalf("stderr missing error log: %s", stderrText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "\"legacy_stdlog\":true") {
|
||||
t.Fatalf("stderr missing legacy_stdlog marker: %s", stderrText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyPrintfRoutesLevels(t *testing.T) {
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
stdoutR, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout pipe: %v", err)
|
||||
}
|
||||
stderrR, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
os.Stderr = stderrW
|
||||
t.Cleanup(func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
_ = stdoutR.Close()
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrR.Close()
|
||||
_ = stderrW.Close()
|
||||
})
|
||||
|
||||
if err := Init(InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Sampling: SamplingOptions{Enabled: false},
|
||||
}); err != nil {
|
||||
t.Fatalf("Init() error: %v", err)
|
||||
}
|
||||
|
||||
LegacyPrintf("service.test", "request started")
|
||||
LegacyPrintf("service.test", "Warning: queue full")
|
||||
LegacyPrintf("service.test", "forward failed: timeout")
|
||||
Sync()
|
||||
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrW.Close()
|
||||
stdoutBytes, _ := io.ReadAll(stdoutR)
|
||||
stderrBytes, _ := io.ReadAll(stderrR)
|
||||
stdoutText := string(stdoutBytes)
|
||||
stderrText := string(stderrBytes)
|
||||
|
||||
if !strings.Contains(stdoutText, "request started") {
|
||||
t.Fatalf("stdout missing info log: %s", stdoutText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "Warning: queue full") {
|
||||
t.Fatalf("stderr missing warn log: %s", stderrText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "forward failed: timeout") {
|
||||
t.Fatalf("stderr missing error log: %s", stderrText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "\"legacy_printf\":true") {
|
||||
t.Fatalf("stderr missing legacy_printf marker: %s", stderrText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "\"component\":\"service.test\"") {
|
||||
t.Fatalf("stderr missing component field: %s", stderrText)
|
||||
}
|
||||
}
|
||||
@@ -50,6 +50,7 @@ type OAuthSession struct {
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopOnce sync.Once
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
@@ -65,7 +66,9 @@ func NewSessionStore() *SessionStore {
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
}
|
||||
|
||||
// Set stores a session
|
||||
|
||||
43
backend/internal/pkg/oauth/oauth_test.go
Normal file
43
backend/internal/pkg/oauth/oauth_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSessionStore_Stop_Idempotent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
|
||||
store.Stop()
|
||||
store.Stop()
|
||||
|
||||
select {
|
||||
case <-store.stopCh:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Stop_Concurrent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 50 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
store.Stop()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case <-store.stopCh:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
}
|
||||
@@ -15,8 +15,8 @@ type Model struct {
|
||||
|
||||
// DefaultModels OpenAI models list
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gpt-5.3", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3"},
|
||||
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
|
||||
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
|
||||
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
||||
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
|
||||
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
|
||||
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
const (
|
||||
// OAuth Client ID for OpenAI (Codex CLI official)
|
||||
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
|
||||
SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
|
||||
|
||||
// OAuth endpoints
|
||||
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
||||
@@ -47,6 +49,7 @@ type OAuthSession struct {
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopOnce sync.Once
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
@@ -92,7 +95,9 @@ func (s *SessionStore) Delete(sessionID string) {
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
}
|
||||
|
||||
// cleanup removes expired sessions periodically
|
||||
|
||||
43
backend/internal/pkg/openai/oauth_test.go
Normal file
43
backend/internal/pkg/openai/oauth_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSessionStore_Stop_Idempotent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
|
||||
store.Stop()
|
||||
store.Stop()
|
||||
|
||||
select {
|
||||
case <-store.stopCh:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Stop_Concurrent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 50 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
store.Stop()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case <-store.stopCh:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
|
||||
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
|
||||
var CodexCLIUserAgentPrefixes = []string{
|
||||
@@ -7,10 +9,67 @@ var CodexCLIUserAgentPrefixes = []string{
|
||||
"codex_cli_rs/",
|
||||
}
|
||||
|
||||
// CodexOfficialClientUserAgentPrefixes matches Codex 官方客户端家族 User-Agent 前缀。
|
||||
// 该列表仅用于 OpenAI OAuth `codex_cli_only` 访问限制判定。
|
||||
var CodexOfficialClientUserAgentPrefixes = []string{
|
||||
"codex_cli_rs/",
|
||||
"codex_vscode/",
|
||||
"codex_app/",
|
||||
"codex_chatgpt_desktop/",
|
||||
"codex_atlas/",
|
||||
"codex_exec/",
|
||||
"codex_sdk_ts/",
|
||||
"codex ",
|
||||
}
|
||||
|
||||
// CodexOfficialClientOriginatorPrefixes matches Codex 官方客户端家族 originator 前缀。
|
||||
// 说明:OpenAI 官方 Codex 客户端并不只使用固定的 codex_app 标识。
|
||||
// 例如 codex_cli_rs、codex_vscode、codex_chatgpt_desktop、codex_atlas、codex_exec、codex_sdk_ts 等。
|
||||
var CodexOfficialClientOriginatorPrefixes = []string{
|
||||
"codex_",
|
||||
"codex ",
|
||||
}
|
||||
|
||||
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
|
||||
func IsCodexCLIRequest(userAgent string) bool {
|
||||
for _, prefix := range CodexCLIUserAgentPrefixes {
|
||||
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
|
||||
ua := normalizeCodexClientHeader(userAgent)
|
||||
if ua == "" {
|
||||
return false
|
||||
}
|
||||
return matchCodexClientHeaderPrefixes(ua, CodexCLIUserAgentPrefixes)
|
||||
}
|
||||
|
||||
// IsCodexOfficialClientRequest checks if the User-Agent indicates a Codex 官方客户端请求。
|
||||
// 与 IsCodexCLIRequest 解耦,避免影响历史兼容逻辑。
|
||||
func IsCodexOfficialClientRequest(userAgent string) bool {
|
||||
ua := normalizeCodexClientHeader(userAgent)
|
||||
if ua == "" {
|
||||
return false
|
||||
}
|
||||
return matchCodexClientHeaderPrefixes(ua, CodexOfficialClientUserAgentPrefixes)
|
||||
}
|
||||
|
||||
// IsCodexOfficialClientOriginator checks if originator indicates a Codex 官方客户端请求。
|
||||
func IsCodexOfficialClientOriginator(originator string) bool {
|
||||
v := normalizeCodexClientHeader(originator)
|
||||
if v == "" {
|
||||
return false
|
||||
}
|
||||
return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes)
|
||||
}
|
||||
|
||||
func normalizeCodexClientHeader(value string) string {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
func matchCodexClientHeaderPrefixes(value string, prefixes []string) bool {
|
||||
for _, prefix := range prefixes {
|
||||
normalizedPrefix := normalizeCodexClientHeader(prefix)
|
||||
if normalizedPrefix == "" {
|
||||
continue
|
||||
}
|
||||
// 优先前缀匹配;若 UA/Originator 被网关拼接为复合字符串时,退化为包含匹配。
|
||||
if strings.HasPrefix(value, normalizedPrefix) || strings.Contains(value, normalizedPrefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
87
backend/internal/pkg/openai/request_test.go
Normal file
87
backend/internal/pkg/openai/request_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsCodexCLIRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ua string
|
||||
want bool
|
||||
}{
|
||||
{name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.1.0", want: true},
|
||||
{name: "codex_vscode 前缀", ua: "codex_vscode/1.2.3", want: true},
|
||||
{name: "大小写混合", ua: "Codex_CLI_Rs/0.1.0", want: true},
|
||||
{name: "复合 UA 包含 codex", ua: "Mozilla/5.0 codex_cli_rs/0.1.0", want: true},
|
||||
{name: "空白包裹", ua: " codex_vscode/1.2.3 ", want: true},
|
||||
{name: "非 codex", ua: "curl/8.0.1", want: false},
|
||||
{name: "空字符串", ua: "", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsCodexCLIRequest(tt.ua)
|
||||
if got != tt.want {
|
||||
t.Fatalf("IsCodexCLIRequest(%q) = %v, want %v", tt.ua, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCodexOfficialClientRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ua string
|
||||
want bool
|
||||
}{
|
||||
{name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.98.0", want: true},
|
||||
{name: "codex_vscode 前缀", ua: "codex_vscode/1.0.0", want: true},
|
||||
{name: "codex_app 前缀", ua: "codex_app/0.1.0", want: true},
|
||||
{name: "codex_chatgpt_desktop 前缀", ua: "codex_chatgpt_desktop/1.0.0", want: true},
|
||||
{name: "codex_atlas 前缀", ua: "codex_atlas/1.0.0", want: true},
|
||||
{name: "codex_exec 前缀", ua: "codex_exec/0.1.0", want: true},
|
||||
{name: "codex_sdk_ts 前缀", ua: "codex_sdk_ts/0.1.0", want: true},
|
||||
{name: "Codex 桌面 UA", ua: "Codex Desktop/1.2.3", want: true},
|
||||
{name: "复合 UA 包含 codex_app", ua: "Mozilla/5.0 codex_app/0.1.0", want: true},
|
||||
{name: "大小写混合", ua: "Codex_VSCode/1.2.3", want: true},
|
||||
{name: "非 codex", ua: "curl/8.0.1", want: false},
|
||||
{name: "空字符串", ua: "", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsCodexOfficialClientRequest(tt.ua)
|
||||
if got != tt.want {
|
||||
t.Fatalf("IsCodexOfficialClientRequest(%q) = %v, want %v", tt.ua, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCodexOfficialClientOriginator(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
originator string
|
||||
want bool
|
||||
}{
|
||||
{name: "codex_cli_rs", originator: "codex_cli_rs", want: true},
|
||||
{name: "codex_vscode", originator: "codex_vscode", want: true},
|
||||
{name: "codex_app", originator: "codex_app", want: true},
|
||||
{name: "codex_chatgpt_desktop", originator: "codex_chatgpt_desktop", want: true},
|
||||
{name: "codex_atlas", originator: "codex_atlas", want: true},
|
||||
{name: "codex_exec", originator: "codex_exec", want: true},
|
||||
{name: "codex_sdk_ts", originator: "codex_sdk_ts", want: true},
|
||||
{name: "Codex 前缀", originator: "Codex Desktop", want: true},
|
||||
{name: "空白包裹", originator: " codex_vscode ", want: true},
|
||||
{name: "非 codex", originator: "my_client", want: false},
|
||||
{name: "空字符串", originator: "", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsCodexOfficialClientOriginator(tt.originator)
|
||||
if got != tt.want {
|
||||
t.Fatalf("IsCodexOfficialClientOriginator(%q) = %v, want %v", tt.originator, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -78,7 +79,7 @@ func ErrorFrom(c *gin.Context, err error) bool {
|
||||
|
||||
// Log internal errors with full details for debugging
|
||||
if statusCode >= 500 && c.Request != nil {
|
||||
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
|
||||
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(err.Error()))
|
||||
}
|
||||
|
||||
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
|
||||
|
||||
@@ -14,6 +14,44 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------- 辅助函数 ----------
|
||||
|
||||
// parseResponseBody 从 httptest.ResponseRecorder 中解析 JSON 响应体
|
||||
func parseResponseBody(t *testing.T, w *httptest.ResponseRecorder) Response {
|
||||
t.Helper()
|
||||
var got Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
return got
|
||||
}
|
||||
|
||||
// parsePaginatedBody 从响应体中解析分页数据(Data 字段是 PaginatedData)
|
||||
func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, PaginatedData) {
|
||||
t.Helper()
|
||||
// 先用 raw json 解析,因为 Data 是 any 类型
|
||||
var raw struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))
|
||||
|
||||
var pd PaginatedData
|
||||
require.NoError(t, json.Unmarshal(raw.Data, &pd))
|
||||
|
||||
return Response{Code: raw.Code, Message: raw.Message, Reason: raw.Reason}, pd
|
||||
}
|
||||
|
||||
// newContextWithQuery 创建一个带有 URL query 参数的 gin.Context 用于测试 ParsePagination
|
||||
func newContextWithQuery(query string) (*httptest.ResponseRecorder, *gin.Context) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil)
|
||||
return w, c
|
||||
}
|
||||
|
||||
// ---------- 现有测试 ----------
|
||||
|
||||
func TestErrorWithDetails(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -169,3 +207,582 @@ func TestErrorFrom(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- 新增测试 ----------
|
||||
|
||||
func TestSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data any
|
||||
wantCode int
|
||||
wantBody Response
|
||||
}{
|
||||
{
|
||||
name: "返回字符串数据",
|
||||
data: "hello",
|
||||
wantCode: http.StatusOK,
|
||||
wantBody: Response{Code: 0, Message: "success", Data: "hello"},
|
||||
},
|
||||
{
|
||||
name: "返回nil数据",
|
||||
data: nil,
|
||||
wantCode: http.StatusOK,
|
||||
wantBody: Response{Code: 0, Message: "success"},
|
||||
},
|
||||
{
|
||||
name: "返回map数据",
|
||||
data: map[string]string{"key": "value"},
|
||||
wantCode: http.StatusOK,
|
||||
wantBody: Response{Code: 0, Message: "success"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Success(c, tt.data)
|
||||
|
||||
require.Equal(t, tt.wantCode, w.Code)
|
||||
|
||||
// 只验证 code 和 message,data 字段类型在 JSON 反序列化时会变成 map/slice
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, 0, got.Code)
|
||||
require.Equal(t, "success", got.Message)
|
||||
|
||||
if tt.data == nil {
|
||||
require.Nil(t, got.Data)
|
||||
} else {
|
||||
require.NotNil(t, got.Data)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreated(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data any
|
||||
wantCode int
|
||||
}{
|
||||
{
|
||||
name: "创建成功_返回数据",
|
||||
data: map[string]int{"id": 42},
|
||||
wantCode: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "创建成功_nil数据",
|
||||
data: nil,
|
||||
wantCode: http.StatusCreated,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Created(c, tt.data)
|
||||
|
||||
require.Equal(t, tt.wantCode, w.Code)
|
||||
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, 0, got.Code)
|
||||
require.Equal(t, "success", got.Message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "400错误",
|
||||
statusCode: http.StatusBadRequest,
|
||||
message: "bad request",
|
||||
},
|
||||
{
|
||||
name: "500错误",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
message: "internal error",
|
||||
},
|
||||
{
|
||||
name: "自定义状态码",
|
||||
statusCode: 418,
|
||||
message: "I'm a teapot",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Error(c, tt.statusCode, tt.message)
|
||||
|
||||
require.Equal(t, tt.statusCode, w.Code)
|
||||
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, tt.statusCode, got.Code)
|
||||
require.Equal(t, tt.message, got.Message)
|
||||
require.Empty(t, got.Reason)
|
||||
require.Nil(t, got.Metadata)
|
||||
require.Nil(t, got.Data)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
BadRequest(c, "参数无效")
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusBadRequest, got.Code)
|
||||
require.Equal(t, "参数无效", got.Message)
|
||||
}
|
||||
|
||||
func TestUnauthorized(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Unauthorized(c, "未登录")
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusUnauthorized, got.Code)
|
||||
require.Equal(t, "未登录", got.Message)
|
||||
}
|
||||
|
||||
func TestForbidden(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Forbidden(c, "无权限")
|
||||
|
||||
require.Equal(t, http.StatusForbidden, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusForbidden, got.Code)
|
||||
require.Equal(t, "无权限", got.Message)
|
||||
}
|
||||
|
||||
func TestNotFound(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
NotFound(c, "资源不存在")
|
||||
|
||||
require.Equal(t, http.StatusNotFound, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusNotFound, got.Code)
|
||||
require.Equal(t, "资源不存在", got.Message)
|
||||
}
|
||||
|
||||
func TestInternalError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
InternalError(c, "服务器内部错误")
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusInternalServerError, got.Code)
|
||||
require.Equal(t, "服务器内部错误", got.Message)
|
||||
}
|
||||
|
||||
func TestPaginated(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
items any
|
||||
total int64
|
||||
page int
|
||||
pageSize int
|
||||
wantPages int
|
||||
wantTotal int64
|
||||
wantPage int
|
||||
wantPageSize int
|
||||
}{
|
||||
{
|
||||
name: "标准分页_多页",
|
||||
items: []string{"a", "b"},
|
||||
total: 25,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
wantPages: 3,
|
||||
wantTotal: 25,
|
||||
wantPage: 1,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
{
|
||||
name: "总数刚好整除",
|
||||
items: []string{"a"},
|
||||
total: 20,
|
||||
page: 2,
|
||||
pageSize: 10,
|
||||
wantPages: 2,
|
||||
wantTotal: 20,
|
||||
wantPage: 2,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
{
|
||||
name: "总数为0_pages至少为1",
|
||||
items: []string{},
|
||||
total: 0,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
wantPages: 1,
|
||||
wantTotal: 0,
|
||||
wantPage: 1,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
{
|
||||
name: "单页数据",
|
||||
items: []int{1, 2, 3},
|
||||
total: 3,
|
||||
page: 1,
|
||||
pageSize: 20,
|
||||
wantPages: 1,
|
||||
wantTotal: 3,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "总数为1",
|
||||
items: []string{"only"},
|
||||
total: 1,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
wantPages: 1,
|
||||
wantTotal: 1,
|
||||
wantPage: 1,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Paginated(c, tt.items, tt.total, tt.page, tt.pageSize)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp, pd := parsePaginatedBody(t, w)
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Equal(t, "success", resp.Message)
|
||||
require.Equal(t, tt.wantTotal, pd.Total)
|
||||
require.Equal(t, tt.wantPage, pd.Page)
|
||||
require.Equal(t, tt.wantPageSize, pd.PageSize)
|
||||
require.Equal(t, tt.wantPages, pd.Pages)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaginatedWithResult(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
items any
|
||||
pagination *PaginationResult
|
||||
wantTotal int64
|
||||
wantPage int
|
||||
wantPageSize int
|
||||
wantPages int
|
||||
}{
|
||||
{
|
||||
name: "正常分页结果",
|
||||
items: []string{"a", "b"},
|
||||
pagination: &PaginationResult{
|
||||
Total: 50,
|
||||
Page: 3,
|
||||
PageSize: 10,
|
||||
Pages: 5,
|
||||
},
|
||||
wantTotal: 50,
|
||||
wantPage: 3,
|
||||
wantPageSize: 10,
|
||||
wantPages: 5,
|
||||
},
|
||||
{
|
||||
name: "pagination为nil_使用默认值",
|
||||
items: []string{},
|
||||
pagination: nil,
|
||||
wantTotal: 0,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
wantPages: 1,
|
||||
},
|
||||
{
|
||||
name: "单页结果",
|
||||
items: []int{1},
|
||||
pagination: &PaginationResult{
|
||||
Total: 1,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
Pages: 1,
|
||||
},
|
||||
wantTotal: 1,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
wantPages: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
PaginatedWithResult(c, tt.items, tt.pagination)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp, pd := parsePaginatedBody(t, w)
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Equal(t, "success", resp.Message)
|
||||
require.Equal(t, tt.wantTotal, pd.Total)
|
||||
require.Equal(t, tt.wantPage, pd.Page)
|
||||
require.Equal(t, tt.wantPageSize, pd.PageSize)
|
||||
require.Equal(t, tt.wantPages, pd.Pages)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePagination(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantPage int
|
||||
wantPageSize int
|
||||
}{
|
||||
{
|
||||
name: "无参数_使用默认值",
|
||||
query: "",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "仅指定page",
|
||||
query: "page=3",
|
||||
wantPage: 3,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "仅指定page_size",
|
||||
query: "page_size=50",
|
||||
wantPage: 1,
|
||||
wantPageSize: 50,
|
||||
},
|
||||
{
|
||||
name: "同时指定page和page_size",
|
||||
query: "page=2&page_size=30",
|
||||
wantPage: 2,
|
||||
wantPageSize: 30,
|
||||
},
|
||||
{
|
||||
name: "使用limit代替page_size",
|
||||
query: "limit=15",
|
||||
wantPage: 1,
|
||||
wantPageSize: 15,
|
||||
},
|
||||
{
|
||||
name: "page_size优先于limit",
|
||||
query: "page_size=25&limit=50",
|
||||
wantPage: 1,
|
||||
wantPageSize: 25,
|
||||
},
|
||||
{
|
||||
name: "page为0_使用默认值",
|
||||
query: "page=0",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size超过1000_使用默认值",
|
||||
query: "page_size=1001",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size恰好1000_有效",
|
||||
query: "page_size=1000",
|
||||
wantPage: 1,
|
||||
wantPageSize: 1000,
|
||||
},
|
||||
{
|
||||
name: "page为非数字_使用默认值",
|
||||
query: "page=abc",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size为非数字_使用默认值",
|
||||
query: "page_size=xyz",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "limit为非数字_使用默认值",
|
||||
query: "limit=abc",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size为0_使用默认值",
|
||||
query: "page_size=0",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "limit为0_使用默认值",
|
||||
query: "limit=0",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "大页码",
|
||||
query: "page=999&page_size=100",
|
||||
wantPage: 999,
|
||||
wantPageSize: 100,
|
||||
},
|
||||
{
|
||||
name: "page_size为1_最小有效值",
|
||||
query: "page_size=1",
|
||||
wantPage: 1,
|
||||
wantPageSize: 1,
|
||||
},
|
||||
{
|
||||
name: "混合数字和字母的page",
|
||||
query: "page=12a",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "limit超过1000_使用默认值",
|
||||
query: "limit=2000",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, c := newContextWithQuery(tt.query)
|
||||
|
||||
page, pageSize := ParsePagination(c)
|
||||
|
||||
require.Equal(t, tt.wantPage, page, "page 不符合预期")
|
||||
require.Equal(t, tt.wantPageSize, pageSize, "pageSize 不符合预期")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantVal int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常数字",
|
||||
input: "123",
|
||||
wantVal: 123,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "零",
|
||||
input: "0",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "单个数字",
|
||||
input: "5",
|
||||
wantVal: 5,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "大数字",
|
||||
input: "99999",
|
||||
wantVal: 99999,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含字母_返回0",
|
||||
input: "abc",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "数字开头接字母_返回0",
|
||||
input: "12a",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含负号_返回0",
|
||||
input: "-1",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含小数点_返回0",
|
||||
input: "1.5",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含空格_返回0",
|
||||
input: "1 2",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
input: "",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, err := parseInt(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, tt.wantVal, val)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
|
||||
return nil, fmt.Errorf("apply TLS preset: %w", err)
|
||||
}
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build unit
|
||||
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
//
|
||||
// Unit tests for TLS fingerprint dialer.
|
||||
@@ -9,26 +11,161 @@
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FingerprintResponse represents the response from tls.peet.ws/api/all.
|
||||
type FingerprintResponse struct {
|
||||
IP string `json:"ip"`
|
||||
TLS TLSInfo `json:"tls"`
|
||||
HTTP2 any `json:"http2"`
|
||||
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
|
||||
func TestDialerBasicConnection(t *testing.T) {
|
||||
skipNetworkTest(t)
|
||||
|
||||
// Create a dialer with default profile
|
||||
profile := &Profile{
|
||||
Name: "Test Profile",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
dialer := NewDialer(profile, nil)
|
||||
|
||||
// Create HTTP client with custom TLS dialer
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Make a request to a known HTTPS endpoint
|
||||
resp, err := client.Get("https://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect: %v", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// TLSInfo contains TLS fingerprint details.
|
||||
type TLSInfo struct {
|
||||
JA3 string `json:"ja3"`
|
||||
JA3Hash string `json:"ja3_hash"`
|
||||
JA4 string `json:"ja4"`
|
||||
PeetPrint string `json:"peetprint"`
|
||||
PeetPrintHash string `json:"peetprint_hash"`
|
||||
ClientRandom string `json:"client_random"`
|
||||
SessionID string `json:"session_id"`
|
||||
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
|
||||
// This test uses tls.peet.ws to verify the fingerprint.
|
||||
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
|
||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
|
||||
func TestJA3Fingerprint(t *testing.T) {
|
||||
skipNetworkTest(t)
|
||||
|
||||
profile := &Profile{
|
||||
Name: "Claude CLI Test",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
dialer := NewDialer(profile, nil)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Use tls.peet.ws fingerprint detection API
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get fingerprint: %v", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
}
|
||||
|
||||
var fpResp FingerprintResponse
|
||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||
}
|
||||
|
||||
// Log all fingerprint information
|
||||
t.Logf("JA3: %s", fpResp.TLS.JA3)
|
||||
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
|
||||
t.Logf("JA4: %s", fpResp.TLS.JA4)
|
||||
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
|
||||
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
|
||||
|
||||
// Verify JA3 hash matches expected value
|
||||
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
|
||||
if fpResp.TLS.JA3Hash == expectedJA3Hash {
|
||||
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
|
||||
} else {
|
||||
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
|
||||
}
|
||||
|
||||
// Verify JA4 fingerprint
|
||||
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
|
||||
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
|
||||
// The suffix _a33745022dd6_1f22a2ca17c4 should match
|
||||
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
|
||||
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
|
||||
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
|
||||
}
|
||||
|
||||
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
|
||||
// d = domain (SNI present), i = IP (no SNI)
|
||||
// Since we connect to tls.peet.ws (domain), we expect 'd'
|
||||
expectedJA4Prefix := "t13d5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
|
||||
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
|
||||
} else {
|
||||
// Also accept 'i' variant for IP connections
|
||||
altPrefix := "t13i5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
|
||||
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
|
||||
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
|
||||
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
|
||||
} else {
|
||||
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
|
||||
}
|
||||
|
||||
// Verify extension list (should be 11 extensions including SNI)
|
||||
// Expected: 0-11-10-35-16-22-23-13-43-45-51
|
||||
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
|
||||
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
|
||||
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
|
||||
} else {
|
||||
t.Logf("Warning: JA3 extension list may differ")
|
||||
}
|
||||
}
|
||||
|
||||
func skipNetworkTest(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("跳过网络测试(short 模式)")
|
||||
}
|
||||
if os.Getenv("TLSFINGERPRINT_NETWORK_TESTS") != "1" {
|
||||
t.Skip("跳过网络测试(需要设置 TLSFINGERPRINT_NETWORK_TESTS=1)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDialerWithProfile tests that different profiles produce different fingerprints.
|
||||
@@ -158,3 +295,137 @@ func mustParseURL(rawURL string) *url.URL {
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
// TestProfileExpectation defines expected fingerprint values for a profile.
|
||||
type TestProfileExpectation struct {
|
||||
Profile *Profile
|
||||
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
|
||||
ExpectedJA4 string // Expected full JA4 (empty = don't check)
|
||||
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
|
||||
}
|
||||
|
||||
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
|
||||
// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
|
||||
func TestAllProfiles(t *testing.T) {
|
||||
skipNetworkTest(t)
|
||||
|
||||
// Define all profiles to test with their expected fingerprints
|
||||
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
|
||||
profiles := []TestProfileExpectation{
|
||||
{
|
||||
// Linux x64 Node.js v22.17.1
|
||||
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
|
||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
|
||||
Profile: &Profile{
|
||||
Name: "linux_x64_node_v22171",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||
PointFormats: []uint8{0, 1, 2},
|
||||
},
|
||||
JA4CipherHash: "a33745022dd6", // stable part
|
||||
},
|
||||
{
|
||||
// MacOS arm64 Node.js v22.18.0
|
||||
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
|
||||
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
|
||||
Profile: &Profile{
|
||||
Name: "macos_arm64_node_v22180",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||
PointFormats: []uint8{0, 1, 2},
|
||||
},
|
||||
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range profiles {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.Profile.Name, func(t *testing.T) {
|
||||
fp := fetchFingerprint(t, tc.Profile)
|
||||
if fp == nil {
|
||||
return // fetchFingerprint already called t.Fatal
|
||||
}
|
||||
|
||||
t.Logf("Profile: %s", tc.Profile.Name)
|
||||
t.Logf(" JA3: %s", fp.JA3)
|
||||
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
|
||||
t.Logf(" JA4: %s", fp.JA4)
|
||||
t.Logf(" PeetPrint: %s", fp.PeetPrint)
|
||||
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
|
||||
|
||||
// Verify expectations
|
||||
if tc.ExpectedJA3 != "" {
|
||||
if fp.JA3Hash == tc.ExpectedJA3 {
|
||||
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.ExpectedJA4 != "" {
|
||||
if fp.JA4 == tc.ExpectedJA4 {
|
||||
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
|
||||
}
|
||||
}
|
||||
|
||||
// Check JA4 cipher hash (stable middle part)
|
||||
// JA4 format: prefix_cipherHash_extHash
|
||||
if tc.JA4CipherHash != "" {
|
||||
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
|
||||
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
|
||||
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
|
||||
t.Helper()
|
||||
|
||||
dialer := NewDialer(profile, nil)
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
return nil
|
||||
}
|
||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get fingerprint: %v", err)
|
||||
return nil
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
var fpResp FingerprintResponse
|
||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &fpResp.TLS
|
||||
}
|
||||
|
||||
20
backend/internal/pkg/tlsfingerprint/test_types_test.go
Normal file
20
backend/internal/pkg/tlsfingerprint/test_types_test.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package tlsfingerprint
|
||||
|
||||
// FingerprintResponse represents the response from tls.peet.ws/api/all.
|
||||
// 共享测试类型,供 unit 和 integration 测试文件使用。
|
||||
type FingerprintResponse struct {
|
||||
IP string `json:"ip"`
|
||||
TLS TLSInfo `json:"tls"`
|
||||
HTTP2 any `json:"http2"`
|
||||
}
|
||||
|
||||
// TLSInfo contains TLS fingerprint details.
|
||||
type TLSInfo struct {
|
||||
JA3 string `json:"ja3"`
|
||||
JA3Hash string `json:"ja3_hash"`
|
||||
JA4 string `json:"ja4"`
|
||||
PeetPrint string `json:"peetprint"`
|
||||
PeetPrintHash string `json:"peetprint_hash"`
|
||||
ClientRandom string `json:"client_random"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -25,6 +24,7 @@ import (
|
||||
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
||||
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
@@ -127,7 +127,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
||||
account.CreatedAt = created.CreatedAt
|
||||
account.UpdatedAt = created.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -388,7 +388,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
}
|
||||
account.UpdatedAt = updated.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
@@ -429,7 +429,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -533,7 +533,7 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
|
||||
},
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -568,7 +568,7 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
|
||||
}
|
||||
payload := map[string]any{"last_used": lastUsedPayload}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -583,7 +583,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
@@ -603,11 +603,11 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
|
||||
}
|
||||
account, err := r.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
|
||||
logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
|
||||
return
|
||||
}
|
||||
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
|
||||
logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -631,7 +631,7 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
|
||||
}
|
||||
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -648,7 +648,7 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
|
||||
}
|
||||
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -721,7 +721,7 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
|
||||
}
|
||||
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -829,7 +829,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -876,7 +876,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -890,7 +890,7 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -909,7 +909,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
@@ -928,7 +928,7 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -944,7 +944,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -968,7 +968,7 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -992,7 +992,7 @@ func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64)
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1014,7 +1014,7 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
|
||||
// 触发调度器缓存更新(仅当窗口时间有变化时)
|
||||
if start != nil || end != nil {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -1029,7 +1029,7 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||
}
|
||||
if !schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
@@ -1057,7 +1057,7 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
|
||||
}
|
||||
if rows > 0 {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
@@ -1093,7 +1093,7 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1187,7 +1187,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
if rows > 0 {
|
||||
payload := map[string]any{"account_ids": ids}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||
}
|
||||
shouldSync := false
|
||||
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
|
||||
@@ -1560,3 +1560,64 @@ func joinClauses(clauses []string, sep string) string {
|
||||
func itoa(v int) string {
|
||||
return strconv.Itoa(v)
|
||||
}
|
||||
|
||||
// FindByExtraField 根据 extra 字段中的键值对查找账号。
|
||||
// 该方法限定 platform='sora',避免误查询其他平台的账号。
|
||||
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
|
||||
//
|
||||
// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
|
||||
//
|
||||
// FindByExtraField finds accounts by key-value pairs in the extra field.
|
||||
// Limited to platform='sora' to avoid querying accounts from other platforms.
|
||||
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
|
||||
//
|
||||
// Use case: Finding Sora accounts linked via linked_openai_account_id.
|
||||
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
||||
accounts, err := r.client.Account.Query().
|
||||
Where(
|
||||
dbaccount.PlatformEQ("sora"), // 限定平台为 sora
|
||||
dbaccount.DeletedAtIsNil(),
|
||||
func(s *entsql.Selector) {
|
||||
path := sqljson.Path(key)
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
preds := []*entsql.Predicate{sqljson.ValueEQ(dbaccount.FieldExtra, v, path)}
|
||||
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
preds = append(preds, sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path))
|
||||
}
|
||||
if len(preds) == 1 {
|
||||
s.Where(preds[0])
|
||||
} else {
|
||||
s.Where(entsql.Or(preds...))
|
||||
}
|
||||
case int:
|
||||
s.Where(entsql.Or(
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, v, path),
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, strconv.Itoa(v), path),
|
||||
))
|
||||
case int64:
|
||||
s.Where(entsql.Or(
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, v, path),
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, strconv.FormatInt(v, 10), path),
|
||||
))
|
||||
case json.Number:
|
||||
if parsed, err := v.Int64(); err == nil {
|
||||
s.Where(entsql.Or(
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path),
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path),
|
||||
))
|
||||
} else {
|
||||
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path))
|
||||
}
|
||||
default:
|
||||
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, value, path))
|
||||
}
|
||||
},
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
||||
}
|
||||
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetNillableGroupID(key.GroupID).
|
||||
SetNillableLastUsedAt(key.LastUsedAt).
|
||||
SetQuota(key.Quota).
|
||||
SetQuotaUsed(key.QuotaUsed).
|
||||
SetNillableExpiresAt(key.ExpiresAt)
|
||||
@@ -48,6 +49,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
key.ID = created.ID
|
||||
key.LastUsedAt = created.LastUsedAt
|
||||
key.CreatedAt = created.CreatedAt
|
||||
key.UpdatedAt = created.UpdatedAt
|
||||
}
|
||||
@@ -140,6 +142,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldImagePrice1k,
|
||||
group.FieldImagePrice2k,
|
||||
group.FieldImagePrice4k,
|
||||
group.FieldSoraImagePrice360,
|
||||
group.FieldSoraImagePrice540,
|
||||
group.FieldSoraVideoPricePerRequest,
|
||||
group.FieldSoraVideoPricePerRequestHd,
|
||||
group.FieldClaudeCodeOnly,
|
||||
group.FieldFallbackGroupID,
|
||||
group.FieldFallbackGroupIDOnInvalidRequest,
|
||||
@@ -375,36 +381,34 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// IncrementQuotaUsed atomically increments the quota_used field and returns the new value
|
||||
// IncrementQuotaUsed 使用 Ent 原子递增 quota_used 字段并返回新值
|
||||
func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||
// Use raw SQL for atomic increment to avoid race conditions
|
||||
// First get current value
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Select(apikey.FieldQuotaUsed).
|
||||
Only(ctx)
|
||||
updated, err := r.client.APIKey.UpdateOneID(id).
|
||||
Where(apikey.DeletedAtIsNil()).
|
||||
AddQuotaUsed(amount).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return updated.QuotaUsed, nil
|
||||
}
|
||||
|
||||
newValue := m.QuotaUsed + amount
|
||||
|
||||
// Update with new value
|
||||
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetQuotaUsed(newValue).
|
||||
SetLastUsedAt(usedAt).
|
||||
SetUpdatedAt(usedAt).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
return newValue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
@@ -419,6 +423,7 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
Status: m.Status,
|
||||
IPWhitelist: m.IPWhitelist,
|
||||
IPBlacklist: m.IPBlacklist,
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
@@ -477,6 +482,10 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
ImagePrice1K: g.ImagePrice1k,
|
||||
ImagePrice2K: g.ImagePrice2k,
|
||||
ImagePrice4K: g.ImagePrice4k,
|
||||
SoraImagePrice360: g.SoraImagePrice360,
|
||||
SoraImagePrice540: g.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
|
||||
@@ -4,11 +4,14 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
@@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group
|
||||
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
|
||||
return k
|
||||
}
|
||||
|
||||
// --- IncrementQuotaUsed ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_Basic() {
|
||||
user := s.mustCreateUser("incr-basic@test.com")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-incr-basic", "Incr", nil)
|
||||
|
||||
newQuota, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.5)
|
||||
s.Require().NoError(err, "IncrementQuotaUsed")
|
||||
s.Require().Equal(1.5, newQuota, "第一次递增后应为 1.5")
|
||||
|
||||
newQuota, err = s.repo.IncrementQuotaUsed(s.ctx, key.ID, 2.5)
|
||||
s.Require().NoError(err, "IncrementQuotaUsed second")
|
||||
s.Require().Equal(4.0, newQuota, "第二次递增后应为 4.0")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_NotFound() {
|
||||
_, err := s.repo.IncrementQuotaUsed(s.ctx, 999999, 1.0)
|
||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "不存在的 key 应返回 ErrAPIKeyNotFound")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
|
||||
user := s.mustCreateUser("incr-deleted@test.com")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-incr-del", "Deleted", nil)
|
||||
|
||||
s.Require().NoError(s.repo.Delete(s.ctx, key.ID), "Delete")
|
||||
|
||||
_, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.0)
|
||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
||||
}
|
||||
|
||||
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
||||
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
||||
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := NewAPIKeyRepository(client).(*apiKeyRepository)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建测试用户和 API Key
|
||||
u, err := client.User.Create().
|
||||
SetEmail("concurrent-incr-" + time.Now().Format(time.RFC3339Nano) + "@test.com").
|
||||
SetPasswordHash("hash").
|
||||
SetStatus(service.StatusActive).
|
||||
SetRole(service.RoleUser).
|
||||
Save(ctx)
|
||||
require.NoError(t, err, "create user")
|
||||
|
||||
k := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: "sk-concurrent-" + time.Now().Format(time.RFC3339Nano),
|
||||
Name: "Concurrent",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, k), "create api key")
|
||||
t.Cleanup(func() {
|
||||
_ = client.APIKey.DeleteOneID(k.ID).Exec(ctx)
|
||||
_ = client.User.DeleteOneID(u.ID).Exec(ctx)
|
||||
})
|
||||
|
||||
// 10 个 goroutine 各递增 1.0,总计应为 10.0
|
||||
const goroutines = 10
|
||||
const increment = 1.0
|
||||
var wg sync.WaitGroup
|
||||
errs := make([]error, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
_, errs[idx] = repo.IncrementQuotaUsed(ctx, k.ID, increment)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i, e := range errs {
|
||||
require.NoError(t, e, "goroutine %d failed", i)
|
||||
}
|
||||
|
||||
// 验证最终结果
|
||||
got, err := repo.GetByID(ctx, k.ID)
|
||||
require.NoError(t, err, "GetByID")
|
||||
require.Equal(t, float64(goroutines)*increment, got.QuotaUsed,
|
||||
"并发递增后总和应为 %v,实际为 %v", float64(goroutines)*increment, got.QuotaUsed)
|
||||
}
|
||||
|
||||
156
backend/internal/repository/api_key_repo_last_used_unit_test.go
Normal file
156
backend/internal/repository/api_key_repo_last_used_unit_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func newAPIKeyRepoSQLite(t *testing.T) (*apiKeyRepository, *dbent.Client) {
|
||||
t.Helper()
|
||||
|
||||
db, err := sql.Open("sqlite", "file:api_key_repo_last_used?mode=memory&cache=shared")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
return &apiKeyRepository{client: client}, client
|
||||
}
|
||||
|
||||
func mustCreateAPIKeyRepoUser(t *testing.T, ctx context.Context, client *dbent.Client, email string) *service.User {
|
||||
t.Helper()
|
||||
u, err := client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash("test-password-hash").
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
return userEntityToService(u)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepository_CreateWithLastUsedAt(t *testing.T) {
|
||||
repo, client := newAPIKeyRepoSQLite(t)
|
||||
ctx := context.Background()
|
||||
user := mustCreateAPIKeyRepoUser(t, ctx, client, "create-last-used@test.com")
|
||||
|
||||
lastUsed := time.Now().UTC().Add(-time.Hour).Truncate(time.Second)
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-create-last-used",
|
||||
Name: "CreateWithLastUsed",
|
||||
Status: service.StatusActive,
|
||||
LastUsedAt: &lastUsed,
|
||||
}
|
||||
|
||||
require.NoError(t, repo.Create(ctx, key))
|
||||
require.NotNil(t, key.LastUsedAt)
|
||||
require.WithinDuration(t, lastUsed, *key.LastUsedAt, time.Second)
|
||||
|
||||
got, err := repo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.LastUsedAt)
|
||||
require.WithinDuration(t, lastUsed, *got.LastUsedAt, time.Second)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepository_UpdateLastUsed(t *testing.T) {
|
||||
repo, client := newAPIKeyRepoSQLite(t)
|
||||
ctx := context.Background()
|
||||
user := mustCreateAPIKeyRepoUser(t, ctx, client, "update-last-used@test.com")
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update-last-used",
|
||||
Name: "UpdateLastUsed",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, key))
|
||||
|
||||
before, err := repo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, before.LastUsedAt)
|
||||
|
||||
target := time.Now().UTC().Add(2 * time.Minute).Truncate(time.Second)
|
||||
require.NoError(t, repo.UpdateLastUsed(ctx, key.ID, target))
|
||||
|
||||
after, err := repo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, after.LastUsedAt)
|
||||
require.WithinDuration(t, target, *after.LastUsedAt, time.Second)
|
||||
require.WithinDuration(t, target, after.UpdatedAt, time.Second)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepository_UpdateLastUsedDeletedKey(t *testing.T) {
|
||||
repo, client := newAPIKeyRepoSQLite(t)
|
||||
ctx := context.Background()
|
||||
user := mustCreateAPIKeyRepoUser(t, ctx, client, "deleted-last-used@test.com")
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update-last-used-deleted",
|
||||
Name: "UpdateLastUsedDeleted",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, key))
|
||||
require.NoError(t, repo.Delete(ctx, key.ID))
|
||||
|
||||
err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC())
|
||||
require.ErrorIs(t, err, service.ErrAPIKeyNotFound)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepository_UpdateLastUsedDBError(t *testing.T) {
|
||||
repo, client := newAPIKeyRepoSQLite(t)
|
||||
ctx := context.Background()
|
||||
user := mustCreateAPIKeyRepoUser(t, ctx, client, "db-error-last-used@test.com")
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update-last-used-db-error",
|
||||
Name: "UpdateLastUsedDBError",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, key))
|
||||
|
||||
require.NoError(t, client.Close())
|
||||
err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC())
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepository_CreateDuplicateKey(t *testing.T) {
|
||||
repo, client := newAPIKeyRepoSQLite(t)
|
||||
ctx := context.Background()
|
||||
user := mustCreateAPIKeyRepoUser(t, ctx, client, "duplicate-key@test.com")
|
||||
|
||||
first := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-duplicate",
|
||||
Name: "first",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
second := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-duplicate",
|
||||
Name: "second",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
require.NoError(t, repo.Create(ctx, first))
|
||||
err := repo.Create(ctx, second)
|
||||
require.ErrorIs(t, err, service.ErrAPIKeyExists)
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand/v2"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -16,8 +17,19 @@ const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
billingCacheJitter = 30 * time.Second
|
||||
)
|
||||
|
||||
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
|
||||
func jitteredTTL() time.Duration {
|
||||
// 只做“减法抖动”,确保实际 TTL 不会超过 billingCacheTTL(避免上界预期被打破)。
|
||||
if billingCacheJitter <= 0 {
|
||||
return billingCacheTTL
|
||||
}
|
||||
jitter := time.Duration(rand.IntN(int(billingCacheJitter)))
|
||||
return billingCacheTTL - jitter
|
||||
}
|
||||
|
||||
// billingBalanceKey generates the Redis key for user balance cache.
|
||||
func billingBalanceKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
@@ -82,14 +94,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
|
||||
|
||||
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
|
||||
return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -163,16 +176,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, billingCacheTTL)
|
||||
pipe.Expire(ctx, key, jitteredTTL())
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
key := billingSubKey(userID, groupID)
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -278,6 +278,90 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复:
|
||||
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||||
func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, cache service.BillingCache)
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "key_not_exists_returns_nil",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
// key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误
|
||||
err := cache.DeductUserBalance(ctx, 99999, 1.0)
|
||||
require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "existing_key_deducts_successfully",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0))
|
||||
err := cache.DeductUserBalance(ctx, 200, 10.0)
|
||||
require.NoError(s.T(), err, "DeductUserBalance should succeed")
|
||||
|
||||
bal, err := cache.GetUserBalance(ctx, 200)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 40.0, bal, "余额应为 40.0")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cancelled_context_propagates_error",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0))
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
cancel() // 立即取消
|
||||
|
||||
err := cache.DeductUserBalance(cancelCtx, 201, 10.0)
|
||||
require.Error(s.T(), err, "cancelled context should propagate error")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
tt.fn(ctx, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复:
|
||||
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||||
func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() {
|
||||
s.Run("key_not_exists_returns_nil", func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0)
|
||||
require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil")
|
||||
})
|
||||
|
||||
s.Run("cancelled_context_propagates_error", func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data))
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
|
||||
err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0)
|
||||
require.Error(s.T(), err, "cancelled context should propagate error")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBillingCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(BillingCacheSuite))
|
||||
}
|
||||
|
||||
82
backend/internal/repository/billing_cache_jitter_test.go
Normal file
82
backend/internal/repository/billing_cache_jitter_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- Task 6.1 验证: math/rand/v2 迁移后 jitteredTTL 行为正确 ---
|
||||
|
||||
func TestJitteredTTL_WithinExpectedRange(t *testing.T) {
|
||||
// jitteredTTL 使用减法抖动: billingCacheTTL - [0, billingCacheJitter)
|
||||
// 所以结果应在 [billingCacheTTL - billingCacheJitter, billingCacheTTL] 范围内
|
||||
lowerBound := billingCacheTTL - billingCacheJitter // 5min - 30s = 4min30s
|
||||
upperBound := billingCacheTTL // 5min
|
||||
|
||||
for i := 0; i < 200; i++ {
|
||||
ttl := jitteredTTL()
|
||||
assert.GreaterOrEqual(t, int64(ttl), int64(lowerBound),
|
||||
"TTL 不应低于 %v,实际得到 %v", lowerBound, ttl)
|
||||
assert.LessOrEqual(t, int64(ttl), int64(upperBound),
|
||||
"TTL 不应超过 %v(上界不变保证),实际得到 %v", upperBound, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL_NeverExceedsBase(t *testing.T) {
|
||||
// 关键安全性测试:jitteredTTL 使用减法抖动,确保永远不超过 billingCacheTTL
|
||||
for i := 0; i < 500; i++ {
|
||||
ttl := jitteredTTL()
|
||||
assert.LessOrEqual(t, int64(ttl), int64(billingCacheTTL),
|
||||
"jitteredTTL 不应超过基础 TTL(上界预期不被打破)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL_HasVariance(t *testing.T) {
|
||||
// 验证抖动确实产生了不同的值
|
||||
results := make(map[time.Duration]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
ttl := jitteredTTL()
|
||||
results[ttl] = true
|
||||
}
|
||||
|
||||
require.Greater(t, len(results), 1,
|
||||
"jitteredTTL 应产生不同的值(抖动生效),但 100 次调用结果全部相同")
|
||||
}
|
||||
|
||||
func TestJitteredTTL_AverageNearCenter(t *testing.T) {
|
||||
// 验证平均值大约在抖动范围中间
|
||||
var sum time.Duration
|
||||
runs := 1000
|
||||
for i := 0; i < runs; i++ {
|
||||
sum += jitteredTTL()
|
||||
}
|
||||
|
||||
avg := sum / time.Duration(runs)
|
||||
expectedCenter := billingCacheTTL - billingCacheJitter/2 // 4min45s
|
||||
|
||||
// 允许 ±5s 的误差
|
||||
tolerance := 5 * time.Second
|
||||
assert.InDelta(t, float64(expectedCenter), float64(avg), float64(tolerance),
|
||||
"平均 TTL 应接近抖动范围中心 %v", expectedCenter)
|
||||
}
|
||||
|
||||
func TestBillingKeyGeneration(t *testing.T) {
|
||||
t.Run("balance_key", func(t *testing.T) {
|
||||
key := billingBalanceKey(12345)
|
||||
assert.Equal(t, "billing:balance:12345", key)
|
||||
})
|
||||
|
||||
t.Run("sub_key", func(t *testing.T) {
|
||||
key := billingSubKey(100, 200)
|
||||
assert.Equal(t, "billing:sub:100:200", key)
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkJitteredTTL(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = jitteredTTL()
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ package repository
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL(t *testing.T) {
|
||||
const (
|
||||
minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s
|
||||
maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s
|
||||
)
|
||||
|
||||
for i := 0; i < 200; i++ {
|
||||
ttl := jitteredTTL()
|
||||
require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl)
|
||||
require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL_HasVariation(t *testing.T) {
|
||||
// 多次调用应该产生不同的值(验证抖动存在)
|
||||
seen := make(map[time.Duration]struct{}, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
seen[jitteredTTL()] = struct{}{}
|
||||
}
|
||||
// 50 次调用中应该至少有 2 个不同的值
|
||||
require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值")
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user